feat: added protected dial support, removed multi-IO support for simplicity

This commit is contained in:
Toby 2024-04-06 14:42:45 -07:00
parent ae34b4856a
commit 9c0893c512
8 changed files with 88 additions and 72 deletions

View File

@ -204,7 +204,7 @@ func (c *cliConfig) fillIO(config *engine.Config) error {
if err != nil { if err != nil {
return configError{Field: "io", Err: err} return configError{Field: "io", Err: err}
} }
config.IOs = []io.PacketIO{nfio} config.IO = nfio
return nil return nil
} }
@ -247,12 +247,7 @@ func runMain(cmd *cobra.Command, args []string) {
if err != nil { if err != nil {
logger.Fatal("failed to parse config", zap.Error(err)) logger.Fatal("failed to parse config", zap.Error(err))
} }
defer func() { defer engineConfig.IO.Close() // Make sure to close IO on exit
// Make sure to close all IOs on exit
for _, i := range engineConfig.IOs {
_ = i.Close()
}
}()
// Ruleset // Ruleset
rawRs, err := ruleset.ExprRulesFromYAML(args[0]) rawRs, err := ruleset.ExprRulesFromYAML(args[0])
@ -260,9 +255,10 @@ func runMain(cmd *cobra.Command, args []string) {
logger.Fatal("failed to load rules", zap.Error(err)) logger.Fatal("failed to load rules", zap.Error(err))
} }
rsConfig := &ruleset.BuiltinConfig{ rsConfig := &ruleset.BuiltinConfig{
Logger: &rulesetLogger{}, Logger: &rulesetLogger{},
GeoSiteFilename: config.Ruleset.GeoSite, GeoSiteFilename: config.Ruleset.GeoSite,
GeoIpFilename: config.Ruleset.GeoIp, GeoIpFilename: config.Ruleset.GeoIp,
ProtectedDialContext: engineConfig.IO.ProtectedDialContext,
} }
rs, err := ruleset.CompileExprRules(rawRs, analyzers, modifiers, rsConfig) rs, err := ruleset.CompileExprRules(rawRs, analyzers, modifiers, rsConfig)
if err != nil { if err != nil {

View File

@ -15,7 +15,7 @@ var _ Engine = (*engine)(nil)
type engine struct { type engine struct {
logger Logger logger Logger
ioList []io.PacketIO io io.PacketIO
workers []*worker workers []*worker
} }
@ -42,7 +42,7 @@ func NewEngine(config Config) (Engine, error) {
} }
return &engine{ return &engine{
logger: config.Logger, logger: config.Logger,
ioList: config.IOs, io: config.IO,
workers: workers, workers: workers,
}, nil }, nil
} }
@ -58,27 +58,24 @@ func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
func (e *engine) Run(ctx context.Context) error { func (e *engine) Run(ctx context.Context) error {
ioCtx, ioCancel := context.WithCancel(ctx) ioCtx, ioCancel := context.WithCancel(ctx)
defer ioCancel() // Stop workers & IOs defer ioCancel() // Stop workers & IO
// Start workers // Start workers
for _, w := range e.workers { for _, w := range e.workers {
go w.Run(ioCtx) go w.Run(ioCtx)
} }
// Register callbacks // Register IO callback
errChan := make(chan error, len(e.ioList)) errChan := make(chan error, 1)
for _, i := range e.ioList { err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
ioEntry := i // Make sure dispatch() uses the correct ioEntry
err := ioEntry.Register(ioCtx, func(p io.Packet, err error) bool {
if err != nil {
errChan <- err
return false
}
return e.dispatch(ioEntry, p)
})
if err != nil { if err != nil {
return err errChan <- err
return false
} }
return e.dispatch(p)
})
if err != nil {
return err
} }
// Block until IO errors or context is cancelled // Block until IO errors or context is cancelled
@ -91,8 +88,7 @@ func (e *engine) Run(ctx context.Context) error {
} }
// dispatch dispatches a packet to a worker. // dispatch dispatches a packet to a worker.
// This must be safe for concurrent use, as it may be called from multiple IOs. func (e *engine) dispatch(p io.Packet) bool {
func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
data := p.Data() data := p.Data()
ipVersion := data[0] >> 4 ipVersion := data[0] >> 4
var layerType gopacket.LayerType var layerType gopacket.LayerType
@ -102,7 +98,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
layerType = layers.LayerTypeIPv6 layerType = layers.LayerTypeIPv6
} else { } else {
// Unsupported network layer // Unsupported network layer
_ = ioEntry.SetVerdict(p, io.VerdictAcceptStream, nil) _ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
return true return true
} }
// Load balance by stream ID // Load balance by stream ID
@ -112,7 +108,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
StreamID: p.StreamID(), StreamID: p.StreamID(),
Packet: packet, Packet: packet,
SetVerdict: func(v io.Verdict, b []byte) error { SetVerdict: func(v io.Verdict, b []byte) error {
return ioEntry.SetVerdict(p, v, b) return e.io.SetVerdict(p, v, b)
}, },
}) })
return true return true

View File

@ -18,7 +18,7 @@ type Engine interface {
// Config is the configuration for the engine. // Config is the configuration for the engine.
type Config struct { type Config struct {
Logger Logger Logger Logger
IOs []io.PacketIO IO io.PacketIO
Ruleset ruleset.Ruleset Ruleset ruleset.Ruleset
Workers int // Number of workers. Zero or negative means auto (number of CPU cores). Workers int // Number of workers. Zero or negative means auto (number of CPU cores).

View File

@ -2,6 +2,7 @@ package io
import ( import (
"context" "context"
"net"
) )
type Verdict int type Verdict int
@ -29,7 +30,6 @@ type Packet interface {
// PacketCallback is called for each packet received. // PacketCallback is called for each packet received.
// Return false to "unregister" and stop receiving packets. // Return false to "unregister" and stop receiving packets.
// It must be safe for concurrent use.
type PacketCallback func(Packet, error) bool type PacketCallback func(Packet, error) bool
type PacketIO interface { type PacketIO interface {
@ -39,6 +39,10 @@ type PacketIO interface {
Register(context.Context, PacketCallback) error Register(context.Context, PacketCallback) error
// SetVerdict sets the verdict for a packet. // SetVerdict sets the verdict for a packet.
SetVerdict(Packet, Verdict, []byte) error SetVerdict(Packet, Verdict, []byte) error
// ProtectedDialContext is like net.DialContext, but the connection is "protected"
// in the sense that the packets sent/received through the connection must bypass
// the packet IO and not be processed by the callback.
ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error)
// Close closes the packet IO. // Close closes the packet IO.
Close() error Close() error
} }

View File

@ -5,9 +5,11 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"net"
"os/exec" "os/exec"
"strconv" "strconv"
"strings" "strings"
"syscall"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/florianl/go-nfqueue" "github.com/florianl/go-nfqueue"
@ -50,12 +52,13 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) {
} }
for i := range table.Chains { for i := range table.Chains {
c := &table.Chains[i] c := &table.Chains[i]
c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK") // Bypass protected connections
c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept") c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept")
if rst { if rst {
c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset") c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset")
} }
c.Rules = append(c.Rules, "ct mark $DROP_CTMARK counter drop") c.Rules = append(c.Rules, "ct mark $DROP_CTMARK counter drop")
c.Rules = append(c.Rules, "ip protocol tcp counter queue num $QUEUE_NUM bypass") c.Rules = append(c.Rules, "counter queue num $QUEUE_NUM bypass")
} }
return table, nil return table, nil
} }
@ -72,6 +75,8 @@ func generateIptRules(local, rst bool) ([]iptRule, error) {
} }
rules := make([]iptRule, 0, 4*len(chains)) rules := make([]iptRule, 0, 4*len(chains))
for _, chain := range chains { for _, chain := range chains {
// Bypass protected connections
rules = append(rules, iptRule{"filter", chain, []string{"-m", "mark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "CONNMARK", "--set-mark", strconv.Itoa(nfqueueConnMarkAccept)}})
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}}) rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}})
if rst { if rst {
rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}}) rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}})
@ -96,6 +101,8 @@ type nfqueuePacketIO struct {
// iptables not nil = use iptables instead of nftables // iptables not nil = use iptables instead of nftables
ipt4 *iptables.IPTables ipt4 *iptables.IPTables
ipt6 *iptables.IPTables ipt6 *iptables.IPTables
protectedDialer *net.Dialer
} }
type NFQueuePacketIOConfig struct { type NFQueuePacketIOConfig struct {
@ -153,6 +160,18 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
rst: config.RST, rst: config.RST,
ipt4: ipt4, ipt4: ipt4,
ipt6: ipt6, ipt6: ipt6,
protectedDialer: &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
var err error
cErr := c.Control(func(fd uintptr) {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, nfqueueConnMarkAccept)
})
if cErr != nil {
return cErr
}
return err
},
},
}, nil }, nil
} }
@ -239,6 +258,10 @@ func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) erro
} }
} }
func (n *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) {
return n.protectedDialer.DialContext(ctx, network, address)
}
func (n *nfqueuePacketIO) Close() error { func (n *nfqueuePacketIO) Close() error {
if n.rSet { if n.rSet {
if n.ipt4 != nil { if n.ipt4 != nil {

View File

@ -14,14 +14,12 @@ type GeoMatcher struct {
ipMatcherLock sync.Mutex ipMatcherLock sync.Mutex
} }
func NewGeoMatcher(geoSiteFilename, geoIpFilename string) (*GeoMatcher, error) { func NewGeoMatcher(geoSiteFilename, geoIpFilename string) *GeoMatcher {
geoLoader := NewDefaultGeoLoader(geoSiteFilename, geoIpFilename)
return &GeoMatcher{ return &GeoMatcher{
geoLoader: geoLoader, geoLoader: NewDefaultGeoLoader(geoSiteFilename, geoIpFilename),
geoSiteMatcher: make(map[string]hostMatcher), geoSiteMatcher: make(map[string]hostMatcher),
geoIpMatcher: make(map[string]hostMatcher), geoIpMatcher: make(map[string]hostMatcher),
}, nil }
} }
func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool { func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool {

View File

@ -59,10 +59,9 @@ type compiledExprRule struct {
var _ Ruleset = (*exprRuleset)(nil) var _ Ruleset = (*exprRuleset)(nil)
type exprRuleset struct { type exprRuleset struct {
Rules []compiledExprRule Rules []compiledExprRule
Ans []analyzer.Analyzer Ans []analyzer.Analyzer
Logger Logger Logger Logger
GeoMatcher *geo.GeoMatcher
} }
func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer { func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
@ -104,11 +103,7 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
fullAnMap := analyzersToMap(ans) fullAnMap := analyzersToMap(ans)
fullModMap := modifiersToMap(mods) fullModMap := modifiersToMap(mods)
depAnMap := make(map[string]analyzer.Analyzer) depAnMap := make(map[string]analyzer.Analyzer)
geoMatcher, err := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename) funcMap := buildFunctionMap(config)
if err != nil {
return nil, err
}
funcMap := buildFunctionMap(geoMatcher)
// Compile all rules and build a map of analyzers that are used by the rules. // Compile all rules and build a map of analyzers that are used by the rules.
for _, rule := range rules { for _, rule := range rules {
if rule.Action == "" && !rule.Log { if rule.Action == "" && !rule.Log {
@ -186,10 +181,9 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
depAns = append(depAns, a) depAns = append(depAns, a)
} }
return &exprRuleset{ return &exprRuleset{
Rules: compiledRules, Rules: compiledRules,
Ans: depAns, Ans: depAns,
Logger: config.Logger, Logger: config.Logger,
GeoMatcher: geoMatcher,
}, nil }, nil
} }
@ -307,7 +301,8 @@ type Function struct {
Types []reflect.Type Types []reflect.Type
} }
func buildFunctionMap(geoMatcher *geo.GeoMatcher) map[string]*Function { func buildFunctionMap(config *BuiltinConfig) map[string]*Function {
geoMatcher := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename)
return map[string]*Function{ return map[string]*Function{
"geoip": { "geoip": {
InitFunc: geoMatcher.LoadGeoIP, InitFunc: geoMatcher.LoadGeoIP,
@ -342,39 +337,41 @@ func buildFunctionMap(geoMatcher *geo.GeoMatcher) map[string]*Function {
Func: func(params ...any) (any, error) { Func: func(params ...any) (any, error) {
return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil
}, },
Types: []reflect.Type{reflect.TypeOf((func(string, string) bool)(nil)), reflect.TypeOf(builtins.MatchCIDR)}, Types: []reflect.Type{reflect.TypeOf(builtins.MatchCIDR)},
}, },
"lookup": { "lookup": {
InitFunc: nil, InitFunc: nil,
PatchFunc: func(args *[]ast.Node) error { PatchFunc: func(args *[]ast.Node) error {
if len(*args) < 2 { var serverStr *ast.StringNode
// Second argument (DNS server) is optional if len(*args) > 1 {
return nil // Has the optional server argument
} var ok bool
serverStr, ok := (*args)[1].(*ast.StringNode) serverStr, ok = (*args)[1].(*ast.StringNode)
if !ok { if !ok {
return fmt.Errorf("lookup: invalid argument type") return fmt.Errorf("lookup: invalid argument type")
}
} }
r := &net.Resolver{ r := &net.Resolver{
Dial: func(ctx context.Context, network, address string) (net.Conn, error) { Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, serverStr.Value) if serverStr != nil {
address = serverStr.Value
}
return config.ProtectedDialContext(ctx, network, address)
}, },
} }
(*args)[1] = &ast.ConstantNode{Value: r} if len(*args) > 1 {
(*args)[1] = &ast.ConstantNode{Value: r}
} else {
*args = append(*args, &ast.ConstantNode{Value: r})
}
return nil return nil
}, },
Func: func(params ...any) (any, error) { Func: func(params ...any) (any, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
defer cancel() defer cancel()
if len(params) < 2 { return params[1].(*net.Resolver).LookupHost(ctx, params[0].(string))
return net.DefaultResolver.LookupHost(ctx, params[0].(string))
} else {
return params[1].(*net.Resolver).LookupHost(ctx, params[0].(string))
}
}, },
Types: []reflect.Type{ Types: []reflect.Type{
reflect.TypeOf((func(string, string) []string)(nil)),
reflect.TypeOf((func(string) []string)(nil)),
reflect.TypeOf((func(string, *net.Resolver) []string)(nil)), reflect.TypeOf((func(string, *net.Resolver) []string)(nil)),
}, },
}, },

View File

@ -1,6 +1,7 @@
package ruleset package ruleset
import ( import (
"context"
"net" "net"
"strconv" "strconv"
@ -100,7 +101,8 @@ type Logger interface {
} }
type BuiltinConfig struct { type BuiltinConfig struct {
Logger Logger Logger Logger
GeoSiteFilename string GeoSiteFilename string
GeoIpFilename string GeoIpFilename string
ProtectedDialContext func(ctx context.Context, network, address string) (net.Conn, error)
} }