mirror of
https://github.com/apernet/OpenGFW.git
synced 2024-11-13 13:59:24 +08:00
feat: added protected dial support, removed multi-IO support for simplicity
This commit is contained in:
parent
ae34b4856a
commit
9c0893c512
16
cmd/root.go
16
cmd/root.go
@ -204,7 +204,7 @@ func (c *cliConfig) fillIO(config *engine.Config) error {
|
||||
if err != nil {
|
||||
return configError{Field: "io", Err: err}
|
||||
}
|
||||
config.IOs = []io.PacketIO{nfio}
|
||||
config.IO = nfio
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -247,12 +247,7 @@ func runMain(cmd *cobra.Command, args []string) {
|
||||
if err != nil {
|
||||
logger.Fatal("failed to parse config", zap.Error(err))
|
||||
}
|
||||
defer func() {
|
||||
// Make sure to close all IOs on exit
|
||||
for _, i := range engineConfig.IOs {
|
||||
_ = i.Close()
|
||||
}
|
||||
}()
|
||||
defer engineConfig.IO.Close() // Make sure to close IO on exit
|
||||
|
||||
// Ruleset
|
||||
rawRs, err := ruleset.ExprRulesFromYAML(args[0])
|
||||
@ -260,9 +255,10 @@ func runMain(cmd *cobra.Command, args []string) {
|
||||
logger.Fatal("failed to load rules", zap.Error(err))
|
||||
}
|
||||
rsConfig := &ruleset.BuiltinConfig{
|
||||
Logger: &rulesetLogger{},
|
||||
GeoSiteFilename: config.Ruleset.GeoSite,
|
||||
GeoIpFilename: config.Ruleset.GeoIp,
|
||||
Logger: &rulesetLogger{},
|
||||
GeoSiteFilename: config.Ruleset.GeoSite,
|
||||
GeoIpFilename: config.Ruleset.GeoIp,
|
||||
ProtectedDialContext: engineConfig.IO.ProtectedDialContext,
|
||||
}
|
||||
rs, err := ruleset.CompileExprRules(rawRs, analyzers, modifiers, rsConfig)
|
||||
if err != nil {
|
||||
|
@ -15,7 +15,7 @@ var _ Engine = (*engine)(nil)
|
||||
|
||||
type engine struct {
|
||||
logger Logger
|
||||
ioList []io.PacketIO
|
||||
io io.PacketIO
|
||||
workers []*worker
|
||||
}
|
||||
|
||||
@ -42,7 +42,7 @@ func NewEngine(config Config) (Engine, error) {
|
||||
}
|
||||
return &engine{
|
||||
logger: config.Logger,
|
||||
ioList: config.IOs,
|
||||
io: config.IO,
|
||||
workers: workers,
|
||||
}, nil
|
||||
}
|
||||
@ -58,27 +58,24 @@ func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
|
||||
|
||||
func (e *engine) Run(ctx context.Context) error {
|
||||
ioCtx, ioCancel := context.WithCancel(ctx)
|
||||
defer ioCancel() // Stop workers & IOs
|
||||
defer ioCancel() // Stop workers & IO
|
||||
|
||||
// Start workers
|
||||
for _, w := range e.workers {
|
||||
go w.Run(ioCtx)
|
||||
}
|
||||
|
||||
// Register callbacks
|
||||
errChan := make(chan error, len(e.ioList))
|
||||
for _, i := range e.ioList {
|
||||
ioEntry := i // Make sure dispatch() uses the correct ioEntry
|
||||
err := ioEntry.Register(ioCtx, func(p io.Packet, err error) bool {
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return false
|
||||
}
|
||||
return e.dispatch(ioEntry, p)
|
||||
})
|
||||
// Register IO callback
|
||||
errChan := make(chan error, 1)
|
||||
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
|
||||
if err != nil {
|
||||
return err
|
||||
errChan <- err
|
||||
return false
|
||||
}
|
||||
return e.dispatch(p)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Block until IO errors or context is cancelled
|
||||
@ -91,8 +88,7 @@ func (e *engine) Run(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// dispatch dispatches a packet to a worker.
|
||||
// This must be safe for concurrent use, as it may be called from multiple IOs.
|
||||
func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
|
||||
func (e *engine) dispatch(p io.Packet) bool {
|
||||
data := p.Data()
|
||||
ipVersion := data[0] >> 4
|
||||
var layerType gopacket.LayerType
|
||||
@ -102,7 +98,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
|
||||
layerType = layers.LayerTypeIPv6
|
||||
} else {
|
||||
// Unsupported network layer
|
||||
_ = ioEntry.SetVerdict(p, io.VerdictAcceptStream, nil)
|
||||
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
|
||||
return true
|
||||
}
|
||||
// Load balance by stream ID
|
||||
@ -112,7 +108,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
|
||||
StreamID: p.StreamID(),
|
||||
Packet: packet,
|
||||
SetVerdict: func(v io.Verdict, b []byte) error {
|
||||
return ioEntry.SetVerdict(p, v, b)
|
||||
return e.io.SetVerdict(p, v, b)
|
||||
},
|
||||
})
|
||||
return true
|
||||
|
@ -18,7 +18,7 @@ type Engine interface {
|
||||
// Config is the configuration for the engine.
|
||||
type Config struct {
|
||||
Logger Logger
|
||||
IOs []io.PacketIO
|
||||
IO io.PacketIO
|
||||
Ruleset ruleset.Ruleset
|
||||
|
||||
Workers int // Number of workers. Zero or negative means auto (number of CPU cores).
|
||||
|
@ -2,6 +2,7 @@ package io
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
type Verdict int
|
||||
@ -29,7 +30,6 @@ type Packet interface {
|
||||
|
||||
// PacketCallback is called for each packet received.
|
||||
// Return false to "unregister" and stop receiving packets.
|
||||
// It must be safe for concurrent use.
|
||||
type PacketCallback func(Packet, error) bool
|
||||
|
||||
type PacketIO interface {
|
||||
@ -39,6 +39,10 @@ type PacketIO interface {
|
||||
Register(context.Context, PacketCallback) error
|
||||
// SetVerdict sets the verdict for a packet.
|
||||
SetVerdict(Packet, Verdict, []byte) error
|
||||
// ProtectedDialContext is like net.DialContext, but the connection is "protected"
|
||||
// in the sense that the packets sent/received through the connection must bypass
|
||||
// the packet IO and not be processed by the callback.
|
||||
ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
// Close closes the packet IO.
|
||||
Close() error
|
||||
}
|
||||
|
@ -5,9 +5,11 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/florianl/go-nfqueue"
|
||||
@ -50,12 +52,13 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) {
|
||||
}
|
||||
for i := range table.Chains {
|
||||
c := &table.Chains[i]
|
||||
c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK") // Bypass protected connections
|
||||
c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept")
|
||||
if rst {
|
||||
c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset")
|
||||
}
|
||||
c.Rules = append(c.Rules, "ct mark $DROP_CTMARK counter drop")
|
||||
c.Rules = append(c.Rules, "ip protocol tcp counter queue num $QUEUE_NUM bypass")
|
||||
c.Rules = append(c.Rules, "counter queue num $QUEUE_NUM bypass")
|
||||
}
|
||||
return table, nil
|
||||
}
|
||||
@ -72,6 +75,8 @@ func generateIptRules(local, rst bool) ([]iptRule, error) {
|
||||
}
|
||||
rules := make([]iptRule, 0, 4*len(chains))
|
||||
for _, chain := range chains {
|
||||
// Bypass protected connections
|
||||
rules = append(rules, iptRule{"filter", chain, []string{"-m", "mark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "CONNMARK", "--set-mark", strconv.Itoa(nfqueueConnMarkAccept)}})
|
||||
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}})
|
||||
if rst {
|
||||
rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}})
|
||||
@ -96,6 +101,8 @@ type nfqueuePacketIO struct {
|
||||
// iptables not nil = use iptables instead of nftables
|
||||
ipt4 *iptables.IPTables
|
||||
ipt6 *iptables.IPTables
|
||||
|
||||
protectedDialer *net.Dialer
|
||||
}
|
||||
|
||||
type NFQueuePacketIOConfig struct {
|
||||
@ -153,6 +160,18 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
|
||||
rst: config.RST,
|
||||
ipt4: ipt4,
|
||||
ipt6: ipt6,
|
||||
protectedDialer: &net.Dialer{
|
||||
Control: func(network, address string, c syscall.RawConn) error {
|
||||
var err error
|
||||
cErr := c.Control(func(fd uintptr) {
|
||||
err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, nfqueueConnMarkAccept)
|
||||
})
|
||||
if cErr != nil {
|
||||
return cErr
|
||||
}
|
||||
return err
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -239,6 +258,10 @@ func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) erro
|
||||
}
|
||||
}
|
||||
|
||||
func (n *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return n.protectedDialer.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
func (n *nfqueuePacketIO) Close() error {
|
||||
if n.rSet {
|
||||
if n.ipt4 != nil {
|
||||
|
@ -14,14 +14,12 @@ type GeoMatcher struct {
|
||||
ipMatcherLock sync.Mutex
|
||||
}
|
||||
|
||||
func NewGeoMatcher(geoSiteFilename, geoIpFilename string) (*GeoMatcher, error) {
|
||||
geoLoader := NewDefaultGeoLoader(geoSiteFilename, geoIpFilename)
|
||||
|
||||
func NewGeoMatcher(geoSiteFilename, geoIpFilename string) *GeoMatcher {
|
||||
return &GeoMatcher{
|
||||
geoLoader: geoLoader,
|
||||
geoLoader: NewDefaultGeoLoader(geoSiteFilename, geoIpFilename),
|
||||
geoSiteMatcher: make(map[string]hostMatcher),
|
||||
geoIpMatcher: make(map[string]hostMatcher),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool {
|
||||
|
@ -59,10 +59,9 @@ type compiledExprRule struct {
|
||||
var _ Ruleset = (*exprRuleset)(nil)
|
||||
|
||||
type exprRuleset struct {
|
||||
Rules []compiledExprRule
|
||||
Ans []analyzer.Analyzer
|
||||
Logger Logger
|
||||
GeoMatcher *geo.GeoMatcher
|
||||
Rules []compiledExprRule
|
||||
Ans []analyzer.Analyzer
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
|
||||
@ -104,11 +103,7 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
|
||||
fullAnMap := analyzersToMap(ans)
|
||||
fullModMap := modifiersToMap(mods)
|
||||
depAnMap := make(map[string]analyzer.Analyzer)
|
||||
geoMatcher, err := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
funcMap := buildFunctionMap(geoMatcher)
|
||||
funcMap := buildFunctionMap(config)
|
||||
// Compile all rules and build a map of analyzers that are used by the rules.
|
||||
for _, rule := range rules {
|
||||
if rule.Action == "" && !rule.Log {
|
||||
@ -186,10 +181,9 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
|
||||
depAns = append(depAns, a)
|
||||
}
|
||||
return &exprRuleset{
|
||||
Rules: compiledRules,
|
||||
Ans: depAns,
|
||||
Logger: config.Logger,
|
||||
GeoMatcher: geoMatcher,
|
||||
Rules: compiledRules,
|
||||
Ans: depAns,
|
||||
Logger: config.Logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -307,7 +301,8 @@ type Function struct {
|
||||
Types []reflect.Type
|
||||
}
|
||||
|
||||
func buildFunctionMap(geoMatcher *geo.GeoMatcher) map[string]*Function {
|
||||
func buildFunctionMap(config *BuiltinConfig) map[string]*Function {
|
||||
geoMatcher := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename)
|
||||
return map[string]*Function{
|
||||
"geoip": {
|
||||
InitFunc: geoMatcher.LoadGeoIP,
|
||||
@ -342,39 +337,41 @@ func buildFunctionMap(geoMatcher *geo.GeoMatcher) map[string]*Function {
|
||||
Func: func(params ...any) (any, error) {
|
||||
return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil
|
||||
},
|
||||
Types: []reflect.Type{reflect.TypeOf((func(string, string) bool)(nil)), reflect.TypeOf(builtins.MatchCIDR)},
|
||||
Types: []reflect.Type{reflect.TypeOf(builtins.MatchCIDR)},
|
||||
},
|
||||
"lookup": {
|
||||
InitFunc: nil,
|
||||
PatchFunc: func(args *[]ast.Node) error {
|
||||
if len(*args) < 2 {
|
||||
// Second argument (DNS server) is optional
|
||||
return nil
|
||||
}
|
||||
serverStr, ok := (*args)[1].(*ast.StringNode)
|
||||
if !ok {
|
||||
return fmt.Errorf("lookup: invalid argument type")
|
||||
var serverStr *ast.StringNode
|
||||
if len(*args) > 1 {
|
||||
// Has the optional server argument
|
||||
var ok bool
|
||||
serverStr, ok = (*args)[1].(*ast.StringNode)
|
||||
if !ok {
|
||||
return fmt.Errorf("lookup: invalid argument type")
|
||||
}
|
||||
}
|
||||
r := &net.Resolver{
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return net.Dial(network, serverStr.Value)
|
||||
if serverStr != nil {
|
||||
address = serverStr.Value
|
||||
}
|
||||
return config.ProtectedDialContext(ctx, network, address)
|
||||
},
|
||||
}
|
||||
(*args)[1] = &ast.ConstantNode{Value: r}
|
||||
if len(*args) > 1 {
|
||||
(*args)[1] = &ast.ConstantNode{Value: r}
|
||||
} else {
|
||||
*args = append(*args, &ast.ConstantNode{Value: r})
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Func: func(params ...any) (any, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
|
||||
defer cancel()
|
||||
if len(params) < 2 {
|
||||
return net.DefaultResolver.LookupHost(ctx, params[0].(string))
|
||||
} else {
|
||||
return params[1].(*net.Resolver).LookupHost(ctx, params[0].(string))
|
||||
}
|
||||
return params[1].(*net.Resolver).LookupHost(ctx, params[0].(string))
|
||||
},
|
||||
Types: []reflect.Type{
|
||||
reflect.TypeOf((func(string, string) []string)(nil)),
|
||||
reflect.TypeOf((func(string) []string)(nil)),
|
||||
reflect.TypeOf((func(string, *net.Resolver) []string)(nil)),
|
||||
},
|
||||
},
|
||||
|
@ -1,6 +1,7 @@
|
||||
package ruleset
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
@ -100,7 +101,8 @@ type Logger interface {
|
||||
}
|
||||
|
||||
type BuiltinConfig struct {
|
||||
Logger Logger
|
||||
GeoSiteFilename string
|
||||
GeoIpFilename string
|
||||
Logger Logger
|
||||
GeoSiteFilename string
|
||||
GeoIpFilename string
|
||||
ProtectedDialContext func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user