feat: added protected dial support, removed multi-IO support for simplicity

This commit is contained in:
Toby 2024-04-06 14:42:45 -07:00
parent ae34b4856a
commit 9c0893c512
8 changed files with 88 additions and 72 deletions

View File

@ -204,7 +204,7 @@ func (c *cliConfig) fillIO(config *engine.Config) error {
if err != nil {
return configError{Field: "io", Err: err}
}
config.IOs = []io.PacketIO{nfio}
config.IO = nfio
return nil
}
@ -247,12 +247,7 @@ func runMain(cmd *cobra.Command, args []string) {
if err != nil {
logger.Fatal("failed to parse config", zap.Error(err))
}
defer func() {
// Make sure to close all IOs on exit
for _, i := range engineConfig.IOs {
_ = i.Close()
}
}()
defer engineConfig.IO.Close() // Make sure to close IO on exit
// Ruleset
rawRs, err := ruleset.ExprRulesFromYAML(args[0])
@ -260,9 +255,10 @@ func runMain(cmd *cobra.Command, args []string) {
logger.Fatal("failed to load rules", zap.Error(err))
}
rsConfig := &ruleset.BuiltinConfig{
Logger: &rulesetLogger{},
GeoSiteFilename: config.Ruleset.GeoSite,
GeoIpFilename: config.Ruleset.GeoIp,
Logger: &rulesetLogger{},
GeoSiteFilename: config.Ruleset.GeoSite,
GeoIpFilename: config.Ruleset.GeoIp,
ProtectedDialContext: engineConfig.IO.ProtectedDialContext,
}
rs, err := ruleset.CompileExprRules(rawRs, analyzers, modifiers, rsConfig)
if err != nil {

View File

@ -15,7 +15,7 @@ var _ Engine = (*engine)(nil)
type engine struct {
logger Logger
ioList []io.PacketIO
io io.PacketIO
workers []*worker
}
@ -42,7 +42,7 @@ func NewEngine(config Config) (Engine, error) {
}
return &engine{
logger: config.Logger,
ioList: config.IOs,
io: config.IO,
workers: workers,
}, nil
}
@ -58,27 +58,24 @@ func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
func (e *engine) Run(ctx context.Context) error {
ioCtx, ioCancel := context.WithCancel(ctx)
defer ioCancel() // Stop workers & IOs
defer ioCancel() // Stop workers & IO
// Start workers
for _, w := range e.workers {
go w.Run(ioCtx)
}
// Register callbacks
errChan := make(chan error, len(e.ioList))
for _, i := range e.ioList {
ioEntry := i // Make sure dispatch() uses the correct ioEntry
err := ioEntry.Register(ioCtx, func(p io.Packet, err error) bool {
if err != nil {
errChan <- err
return false
}
return e.dispatch(ioEntry, p)
})
// Register IO callback
errChan := make(chan error, 1)
err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
if err != nil {
return err
errChan <- err
return false
}
return e.dispatch(p)
})
if err != nil {
return err
}
// Block until IO errors or context is cancelled
@ -91,8 +88,7 @@ func (e *engine) Run(ctx context.Context) error {
}
// dispatch dispatches a packet to a worker.
// This must be safe for concurrent use, as it may be called from multiple IOs.
func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
func (e *engine) dispatch(p io.Packet) bool {
data := p.Data()
ipVersion := data[0] >> 4
var layerType gopacket.LayerType
@ -102,7 +98,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
layerType = layers.LayerTypeIPv6
} else {
// Unsupported network layer
_ = ioEntry.SetVerdict(p, io.VerdictAcceptStream, nil)
_ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
return true
}
// Load balance by stream ID
@ -112,7 +108,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
StreamID: p.StreamID(),
Packet: packet,
SetVerdict: func(v io.Verdict, b []byte) error {
return ioEntry.SetVerdict(p, v, b)
return e.io.SetVerdict(p, v, b)
},
})
return true

View File

@ -18,7 +18,7 @@ type Engine interface {
// Config is the configuration for the engine.
type Config struct {
Logger Logger
IOs []io.PacketIO
IO io.PacketIO
Ruleset ruleset.Ruleset
Workers int // Number of workers. Zero or negative means auto (number of CPU cores).

View File

@ -2,6 +2,7 @@ package io
import (
"context"
"net"
)
type Verdict int
@ -29,7 +30,6 @@ type Packet interface {
// PacketCallback is called for each packet received.
// Return false to "unregister" and stop receiving packets.
// It must be safe for concurrent use.
type PacketCallback func(Packet, error) bool
type PacketIO interface {
@ -39,6 +39,10 @@ type PacketIO interface {
Register(context.Context, PacketCallback) error
// SetVerdict sets the verdict for a packet.
SetVerdict(Packet, Verdict, []byte) error
// ProtectedDialContext is like net.DialContext, but the connection is "protected"
// in the sense that the packets sent/received through the connection must bypass
// the packet IO and not be processed by the callback.
ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error)
// Close closes the packet IO.
Close() error
}

View File

@ -5,9 +5,11 @@ import (
"encoding/binary"
"errors"
"fmt"
"net"
"os/exec"
"strconv"
"strings"
"syscall"
"github.com/coreos/go-iptables/iptables"
"github.com/florianl/go-nfqueue"
@ -50,12 +52,13 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) {
}
for i := range table.Chains {
c := &table.Chains[i]
c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK") // Bypass protected connections
c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept")
if rst {
c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset")
}
c.Rules = append(c.Rules, "ct mark $DROP_CTMARK counter drop")
c.Rules = append(c.Rules, "ip protocol tcp counter queue num $QUEUE_NUM bypass")
c.Rules = append(c.Rules, "counter queue num $QUEUE_NUM bypass")
}
return table, nil
}
@ -72,6 +75,8 @@ func generateIptRules(local, rst bool) ([]iptRule, error) {
}
rules := make([]iptRule, 0, 4*len(chains))
for _, chain := range chains {
// Bypass protected connections
rules = append(rules, iptRule{"filter", chain, []string{"-m", "mark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "CONNMARK", "--set-mark", strconv.Itoa(nfqueueConnMarkAccept)}})
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}})
if rst {
rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}})
@ -96,6 +101,8 @@ type nfqueuePacketIO struct {
// iptables not nil = use iptables instead of nftables
ipt4 *iptables.IPTables
ipt6 *iptables.IPTables
protectedDialer *net.Dialer
}
type NFQueuePacketIOConfig struct {
@ -153,6 +160,18 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
rst: config.RST,
ipt4: ipt4,
ipt6: ipt6,
protectedDialer: &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
var err error
cErr := c.Control(func(fd uintptr) {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, nfqueueConnMarkAccept)
})
if cErr != nil {
return cErr
}
return err
},
},
}, nil
}
@ -239,6 +258,10 @@ func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) erro
}
}
func (n *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) {
return n.protectedDialer.DialContext(ctx, network, address)
}
func (n *nfqueuePacketIO) Close() error {
if n.rSet {
if n.ipt4 != nil {

View File

@ -14,14 +14,12 @@ type GeoMatcher struct {
ipMatcherLock sync.Mutex
}
func NewGeoMatcher(geoSiteFilename, geoIpFilename string) (*GeoMatcher, error) {
geoLoader := NewDefaultGeoLoader(geoSiteFilename, geoIpFilename)
func NewGeoMatcher(geoSiteFilename, geoIpFilename string) *GeoMatcher {
return &GeoMatcher{
geoLoader: geoLoader,
geoLoader: NewDefaultGeoLoader(geoSiteFilename, geoIpFilename),
geoSiteMatcher: make(map[string]hostMatcher),
geoIpMatcher: make(map[string]hostMatcher),
}, nil
}
}
func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool {

View File

@ -59,10 +59,9 @@ type compiledExprRule struct {
var _ Ruleset = (*exprRuleset)(nil)
type exprRuleset struct {
Rules []compiledExprRule
Ans []analyzer.Analyzer
Logger Logger
GeoMatcher *geo.GeoMatcher
Rules []compiledExprRule
Ans []analyzer.Analyzer
Logger Logger
}
func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
@ -104,11 +103,7 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
fullAnMap := analyzersToMap(ans)
fullModMap := modifiersToMap(mods)
depAnMap := make(map[string]analyzer.Analyzer)
geoMatcher, err := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename)
if err != nil {
return nil, err
}
funcMap := buildFunctionMap(geoMatcher)
funcMap := buildFunctionMap(config)
// Compile all rules and build a map of analyzers that are used by the rules.
for _, rule := range rules {
if rule.Action == "" && !rule.Log {
@ -186,10 +181,9 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
depAns = append(depAns, a)
}
return &exprRuleset{
Rules: compiledRules,
Ans: depAns,
Logger: config.Logger,
GeoMatcher: geoMatcher,
Rules: compiledRules,
Ans: depAns,
Logger: config.Logger,
}, nil
}
@ -307,7 +301,8 @@ type Function struct {
Types []reflect.Type
}
func buildFunctionMap(geoMatcher *geo.GeoMatcher) map[string]*Function {
func buildFunctionMap(config *BuiltinConfig) map[string]*Function {
geoMatcher := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename)
return map[string]*Function{
"geoip": {
InitFunc: geoMatcher.LoadGeoIP,
@ -342,39 +337,41 @@ func buildFunctionMap(geoMatcher *geo.GeoMatcher) map[string]*Function {
Func: func(params ...any) (any, error) {
return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil
},
Types: []reflect.Type{reflect.TypeOf((func(string, string) bool)(nil)), reflect.TypeOf(builtins.MatchCIDR)},
Types: []reflect.Type{reflect.TypeOf(builtins.MatchCIDR)},
},
"lookup": {
InitFunc: nil,
PatchFunc: func(args *[]ast.Node) error {
if len(*args) < 2 {
// Second argument (DNS server) is optional
return nil
}
serverStr, ok := (*args)[1].(*ast.StringNode)
if !ok {
return fmt.Errorf("lookup: invalid argument type")
var serverStr *ast.StringNode
if len(*args) > 1 {
// Has the optional server argument
var ok bool
serverStr, ok = (*args)[1].(*ast.StringNode)
if !ok {
return fmt.Errorf("lookup: invalid argument type")
}
}
r := &net.Resolver{
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
return net.Dial(network, serverStr.Value)
if serverStr != nil {
address = serverStr.Value
}
return config.ProtectedDialContext(ctx, network, address)
},
}
(*args)[1] = &ast.ConstantNode{Value: r}
if len(*args) > 1 {
(*args)[1] = &ast.ConstantNode{Value: r}
} else {
*args = append(*args, &ast.ConstantNode{Value: r})
}
return nil
},
Func: func(params ...any) (any, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
defer cancel()
if len(params) < 2 {
return net.DefaultResolver.LookupHost(ctx, params[0].(string))
} else {
return params[1].(*net.Resolver).LookupHost(ctx, params[0].(string))
}
return params[1].(*net.Resolver).LookupHost(ctx, params[0].(string))
},
Types: []reflect.Type{
reflect.TypeOf((func(string, string) []string)(nil)),
reflect.TypeOf((func(string) []string)(nil)),
reflect.TypeOf((func(string, *net.Resolver) []string)(nil)),
},
},

View File

@ -1,6 +1,7 @@
package ruleset
import (
"context"
"net"
"strconv"
@ -100,7 +101,8 @@ type Logger interface {
}
type BuiltinConfig struct {
Logger Logger
GeoSiteFilename string
GeoIpFilename string
Logger Logger
GeoSiteFilename string
GeoIpFilename string
ProtectedDialContext func(ctx context.Context, network, address string) (net.Conn, error)
}