Merge pull request #123 from apernet/wip-lookup

feat: dns lookup function
This commit is contained in:
Toby 2024-04-07 17:49:33 -07:00 committed by GitHub
commit 393c29bd2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 181 additions and 110 deletions

View File

@ -204,7 +204,7 @@ func (c *cliConfig) fillIO(config *engine.Config) error {
if err != nil { if err != nil {
return configError{Field: "io", Err: err} return configError{Field: "io", Err: err}
} }
config.IOs = []io.PacketIO{nfio} config.IO = nfio
return nil return nil
} }
@ -247,12 +247,7 @@ func runMain(cmd *cobra.Command, args []string) {
if err != nil { if err != nil {
logger.Fatal("failed to parse config", zap.Error(err)) logger.Fatal("failed to parse config", zap.Error(err))
} }
defer func() { defer engineConfig.IO.Close() // Make sure to close IO on exit
// Make sure to close all IOs on exit
for _, i := range engineConfig.IOs {
_ = i.Close()
}
}()
// Ruleset // Ruleset
rawRs, err := ruleset.ExprRulesFromYAML(args[0]) rawRs, err := ruleset.ExprRulesFromYAML(args[0])
@ -263,6 +258,7 @@ func runMain(cmd *cobra.Command, args []string) {
Logger: &rulesetLogger{}, Logger: &rulesetLogger{},
GeoSiteFilename: config.Ruleset.GeoSite, GeoSiteFilename: config.Ruleset.GeoSite,
GeoIpFilename: config.Ruleset.GeoIp, GeoIpFilename: config.Ruleset.GeoIp,
ProtectedDialContext: engineConfig.IO.ProtectedDialContext,
} }
rs, err := ruleset.CompileExprRules(rawRs, analyzers, modifiers, rsConfig) rs, err := ruleset.CompileExprRules(rawRs, analyzers, modifiers, rsConfig)
if err != nil { if err != nil {

View File

@ -15,7 +15,7 @@ var _ Engine = (*engine)(nil)
type engine struct { type engine struct {
logger Logger logger Logger
ioList []io.PacketIO io io.PacketIO
workers []*worker workers []*worker
} }
@ -42,7 +42,7 @@ func NewEngine(config Config) (Engine, error) {
} }
return &engine{ return &engine{
logger: config.Logger, logger: config.Logger,
ioList: config.IOs, io: config.IO,
workers: workers, workers: workers,
}, nil }, nil
} }
@ -58,28 +58,25 @@ func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
func (e *engine) Run(ctx context.Context) error { func (e *engine) Run(ctx context.Context) error {
ioCtx, ioCancel := context.WithCancel(ctx) ioCtx, ioCancel := context.WithCancel(ctx)
defer ioCancel() // Stop workers & IOs defer ioCancel() // Stop workers & IO
// Start workers // Start workers
for _, w := range e.workers { for _, w := range e.workers {
go w.Run(ioCtx) go w.Run(ioCtx)
} }
// Register callbacks // Register IO callback
errChan := make(chan error, len(e.ioList)) errChan := make(chan error, 1)
for _, i := range e.ioList { err := e.io.Register(ioCtx, func(p io.Packet, err error) bool {
ioEntry := i // Make sure dispatch() uses the correct ioEntry
err := ioEntry.Register(ioCtx, func(p io.Packet, err error) bool {
if err != nil { if err != nil {
errChan <- err errChan <- err
return false return false
} }
return e.dispatch(ioEntry, p) return e.dispatch(p)
}) })
if err != nil { if err != nil {
return err return err
} }
}
// Block until IO errors or context is cancelled // Block until IO errors or context is cancelled
select { select {
@ -91,8 +88,7 @@ func (e *engine) Run(ctx context.Context) error {
} }
// dispatch dispatches a packet to a worker. // dispatch dispatches a packet to a worker.
// This must be safe for concurrent use, as it may be called from multiple IOs. func (e *engine) dispatch(p io.Packet) bool {
func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
data := p.Data() data := p.Data()
ipVersion := data[0] >> 4 ipVersion := data[0] >> 4
var layerType gopacket.LayerType var layerType gopacket.LayerType
@ -102,7 +98,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
layerType = layers.LayerTypeIPv6 layerType = layers.LayerTypeIPv6
} else { } else {
// Unsupported network layer // Unsupported network layer
_ = ioEntry.SetVerdict(p, io.VerdictAcceptStream, nil) _ = e.io.SetVerdict(p, io.VerdictAcceptStream, nil)
return true return true
} }
// Load balance by stream ID // Load balance by stream ID
@ -112,7 +108,7 @@ func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
StreamID: p.StreamID(), StreamID: p.StreamID(),
Packet: packet, Packet: packet,
SetVerdict: func(v io.Verdict, b []byte) error { SetVerdict: func(v io.Verdict, b []byte) error {
return ioEntry.SetVerdict(p, v, b) return e.io.SetVerdict(p, v, b)
}, },
}) })
return true return true

View File

@ -18,7 +18,7 @@ type Engine interface {
// Config is the configuration for the engine. // Config is the configuration for the engine.
type Config struct { type Config struct {
Logger Logger Logger Logger
IOs []io.PacketIO IO io.PacketIO
Ruleset ruleset.Ruleset Ruleset ruleset.Ruleset
Workers int // Number of workers. Zero or negative means auto (number of CPU cores). Workers int // Number of workers. Zero or negative means auto (number of CPU cores).

2
go.mod
View File

@ -5,7 +5,7 @@ go 1.21
require ( require (
github.com/bwmarrin/snowflake v0.3.0 github.com/bwmarrin/snowflake v0.3.0
github.com/coreos/go-iptables v0.7.0 github.com/coreos/go-iptables v0.7.0
github.com/expr-lang/expr v1.15.7 github.com/expr-lang/expr v1.16.3
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf
github.com/google/gopacket v1.1.20-0.20220810144506-32ee38206866 github.com/google/gopacket v1.1.20-0.20220810144506-32ee38206866
github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/hashicorp/golang-lru/v2 v2.0.7

4
go.sum
View File

@ -7,8 +7,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/expr-lang/expr v1.15.7 h1:BK0JcWUkoW6nrbLBo6xCKhz4BvH5DSOOu1Gx5lucyZo= github.com/expr-lang/expr v1.16.3 h1:NLldf786GffptcXNxxJx5dQ+FzeWDKChBDqOOwyK8to=
github.com/expr-lang/expr v1.15.7/go.mod h1:uCkhfG+x7fcZ5A5sXHKuQ07jGZRl6J0FCAaf2k4PtVQ= github.com/expr-lang/expr v1.16.3/go.mod h1:uCkhfG+x7fcZ5A5sXHKuQ07jGZRl6J0FCAaf2k4PtVQ=
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf h1:NqGS3vTHzVENbIfd87cXZwdpO6MB2R1PjHMJLi4Z3ow= github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf h1:NqGS3vTHzVENbIfd87cXZwdpO6MB2R1PjHMJLi4Z3ow=
github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf/go.mod h1:eSnAor2YCfMCVYrVNEhkLGN/r1L+J4uDjc0EUy0tfq4= github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf/go.mod h1:eSnAor2YCfMCVYrVNEhkLGN/r1L+J4uDjc0EUy0tfq4=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=

View File

@ -2,6 +2,7 @@ package io
import ( import (
"context" "context"
"net"
) )
type Verdict int type Verdict int
@ -29,7 +30,6 @@ type Packet interface {
// PacketCallback is called for each packet received. // PacketCallback is called for each packet received.
// Return false to "unregister" and stop receiving packets. // Return false to "unregister" and stop receiving packets.
// It must be safe for concurrent use.
type PacketCallback func(Packet, error) bool type PacketCallback func(Packet, error) bool
type PacketIO interface { type PacketIO interface {
@ -39,6 +39,10 @@ type PacketIO interface {
Register(context.Context, PacketCallback) error Register(context.Context, PacketCallback) error
// SetVerdict sets the verdict for a packet. // SetVerdict sets the verdict for a packet.
SetVerdict(Packet, Verdict, []byte) error SetVerdict(Packet, Verdict, []byte) error
// ProtectedDialContext is like net.DialContext, but the connection is "protected"
// in the sense that the packets sent/received through the connection must bypass
// the packet IO and not be processed by the callback.
ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error)
// Close closes the packet IO. // Close closes the packet IO.
Close() error Close() error
} }

View File

@ -5,9 +5,11 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"net"
"os/exec" "os/exec"
"strconv" "strconv"
"strings" "strings"
"syscall"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/florianl/go-nfqueue" "github.com/florianl/go-nfqueue"
@ -50,6 +52,7 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) {
} }
for i := range table.Chains { for i := range table.Chains {
c := &table.Chains[i] c := &table.Chains[i]
c.Rules = append(c.Rules, "meta mark $ACCEPT_CTMARK ct mark set $ACCEPT_CTMARK") // Bypass protected connections
c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept") c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept")
if rst { if rst {
c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset") c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset")
@ -72,6 +75,8 @@ func generateIptRules(local, rst bool) ([]iptRule, error) {
} }
rules := make([]iptRule, 0, 4*len(chains)) rules := make([]iptRule, 0, 4*len(chains))
for _, chain := range chains { for _, chain := range chains {
// Bypass protected connections
rules = append(rules, iptRule{"filter", chain, []string{"-m", "mark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "CONNMARK", "--set-mark", strconv.Itoa(nfqueueConnMarkAccept)}})
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}}) rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}})
if rst { if rst {
rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}}) rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}})
@ -96,6 +101,8 @@ type nfqueuePacketIO struct {
// iptables not nil = use iptables instead of nftables // iptables not nil = use iptables instead of nftables
ipt4 *iptables.IPTables ipt4 *iptables.IPTables
ipt6 *iptables.IPTables ipt6 *iptables.IPTables
protectedDialer *net.Dialer
} }
type NFQueuePacketIOConfig struct { type NFQueuePacketIOConfig struct {
@ -153,6 +160,18 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
rst: config.RST, rst: config.RST,
ipt4: ipt4, ipt4: ipt4,
ipt6: ipt6, ipt6: ipt6,
protectedDialer: &net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
var err error
cErr := c.Control(func(fd uintptr) {
err = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, nfqueueConnMarkAccept)
})
if cErr != nil {
return cErr
}
return err
},
},
}, nil }, nil
} }
@ -239,6 +258,10 @@ func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) erro
} }
} }
func (n *nfqueuePacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) {
return n.protectedDialer.DialContext(ctx, network, address)
}
func (n *nfqueuePacketIO) Close() error { func (n *nfqueuePacketIO) Close() error {
if n.rSet { if n.rSet {
if n.ipt4 != nil { if n.ipt4 != nil {

View File

@ -14,14 +14,12 @@ type GeoMatcher struct {
ipMatcherLock sync.Mutex ipMatcherLock sync.Mutex
} }
func NewGeoMatcher(geoSiteFilename, geoIpFilename string) (*GeoMatcher, error) { func NewGeoMatcher(geoSiteFilename, geoIpFilename string) *GeoMatcher {
geoLoader := NewDefaultGeoLoader(geoSiteFilename, geoIpFilename)
return &GeoMatcher{ return &GeoMatcher{
geoLoader: geoLoader, geoLoader: NewDefaultGeoLoader(geoSiteFilename, geoIpFilename),
geoSiteMatcher: make(map[string]hostMatcher), geoSiteMatcher: make(map[string]hostMatcher),
geoIpMatcher: make(map[string]hostMatcher), geoIpMatcher: make(map[string]hostMatcher),
}, nil }
} }
func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool { func (g *GeoMatcher) MatchGeoIp(ip, condition string) bool {

View File

@ -1,11 +1,15 @@
package ruleset package ruleset
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"os" "os"
"reflect" "reflect"
"strings" "strings"
"time"
"github.com/expr-lang/expr/builtin"
"github.com/expr-lang/expr" "github.com/expr-lang/expr"
"github.com/expr-lang/expr/ast" "github.com/expr-lang/expr/ast"
@ -58,7 +62,6 @@ type exprRuleset struct {
Rules []compiledExprRule Rules []compiledExprRule
Ans []analyzer.Analyzer Ans []analyzer.Analyzer
Logger Logger Logger Logger
GeoMatcher *geo.GeoMatcher
} }
func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer { func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
@ -100,10 +103,7 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
fullAnMap := analyzersToMap(ans) fullAnMap := analyzersToMap(ans)
fullModMap := modifiersToMap(mods) fullModMap := modifiersToMap(mods)
depAnMap := make(map[string]analyzer.Analyzer) depAnMap := make(map[string]analyzer.Analyzer)
geoMatcher, err := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename) funcMap := buildFunctionMap(config)
if err != nil {
return nil, err
}
// Compile all rules and build a map of analyzers that are used by the rules. // Compile all rules and build a map of analyzers that are used by the rules.
for _, rule := range rules { for _, rule := range rules {
if rule.Action == "" && !rule.Log { if rule.Action == "" && !rule.Log {
@ -118,13 +118,19 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
action = &a action = &a
} }
visitor := &idVisitor{Variables: make(map[string]bool), Identifiers: make(map[string]bool)} visitor := &idVisitor{Variables: make(map[string]bool), Identifiers: make(map[string]bool)}
patcher := &idPatcher{} patcher := &idPatcher{FuncMap: funcMap}
program, err := expr.Compile(rule.Expr, program, err := expr.Compile(rule.Expr,
func(c *conf.Config) { func(c *conf.Config) {
c.Strict = false c.Strict = false
c.Expect = reflect.Bool c.Expect = reflect.Bool
c.Visitors = append(c.Visitors, visitor, patcher) c.Visitors = append(c.Visitors, visitor, patcher)
registerBuiltinFunctions(c.Functions, geoMatcher) for name, f := range funcMap {
c.Functions[name] = &builtin.Function{
Name: name,
Func: f.Func,
Types: f.Types,
}
}
}, },
) )
if err != nil { if err != nil {
@ -138,24 +144,15 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
if isBuiltInAnalyzer(name) || visitor.Variables[name] { if isBuiltInAnalyzer(name) || visitor.Variables[name] {
continue continue
} }
// Check if it's one of the built-in functions, and if so, if f, ok := funcMap[name]; ok {
// skip it as an analyzer & do initialization if necessary. // Built-in function, initialize if necessary
switch name { if f.InitFunc != nil {
case "geoip": if err := f.InitFunc(); err != nil {
if err := geoMatcher.LoadGeoIP(); err != nil { return nil, fmt.Errorf("rule %q failed to initialize function %q: %w", rule.Name, name, err)
return nil, fmt.Errorf("rule %q failed to load geoip: %w", rule.Name, err)
} }
case "geosite":
if err := geoMatcher.LoadGeoSite(); err != nil {
return nil, fmt.Errorf("rule %q failed to load geosite: %w", rule.Name, err)
}
case "cidr":
// No initialization needed for CIDR.
default:
a, ok := fullAnMap[name]
if !ok {
return nil, fmt.Errorf("rule %q uses unknown analyzer %q", rule.Name, name)
} }
} else if a, ok := fullAnMap[name]; ok {
// Analyzer, add to dependency map
depAnMap[name] = a depAnMap[name] = a
} }
} }
@ -187,34 +184,9 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier
Rules: compiledRules, Rules: compiledRules,
Ans: depAns, Ans: depAns,
Logger: config.Logger, Logger: config.Logger,
GeoMatcher: geoMatcher,
}, nil }, nil
} }
func registerBuiltinFunctions(funcMap map[string]*ast.Function, geoMatcher *geo.GeoMatcher) {
funcMap["geoip"] = &ast.Function{
Name: "geoip",
Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil
},
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)},
}
funcMap["geosite"] = &ast.Function{
Name: "geosite",
Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil
},
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)},
}
funcMap["cidr"] = &ast.Function{
Name: "cidr",
Func: func(params ...any) (any, error) {
return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil
},
Types: []reflect.Type{reflect.TypeOf((func(string, string) bool)(nil)), reflect.TypeOf(builtins.MatchCIDR)},
}
}
func streamInfoToExprEnv(info StreamInfo) map[string]interface{} { func streamInfoToExprEnv(info StreamInfo) map[string]interface{} {
m := map[string]interface{}{ m := map[string]interface{}{
"id": info.ID, "id": info.ID,
@ -299,6 +271,7 @@ func (v *idVisitor) Visit(node *ast.Node) {
// idPatcher patches the AST during expr compilation, replacing certain values with // idPatcher patches the AST during expr compilation, replacing certain values with
// their internal representations for better runtime performance. // their internal representations for better runtime performance.
type idPatcher struct { type idPatcher struct {
FuncMap map[string]*Function
Err error Err error
} }
@ -306,22 +279,101 @@ func (p *idPatcher) Visit(node *ast.Node) {
switch (*node).(type) { switch (*node).(type) {
case *ast.CallNode: case *ast.CallNode:
callNode := (*node).(*ast.CallNode) callNode := (*node).(*ast.CallNode)
if callNode.Func == nil { if callNode.Callee == nil {
// Ignore invalid call nodes // Ignore invalid call nodes
return return
} }
switch callNode.Func.Name { if f, ok := p.FuncMap[callNode.Callee.String()]; ok {
case "cidr": if f.PatchFunc != nil {
cidrStringNode, ok := callNode.Arguments[1].(*ast.StringNode) if err := f.PatchFunc(&callNode.Arguments); err != nil {
if !ok {
return
}
cidr, err := builtins.CompileCIDR(cidrStringNode.Value)
if err != nil {
p.Err = err p.Err = err
return return
} }
callNode.Arguments[1] = &ast.ConstantNode{Value: cidr} }
} }
} }
} }
type Function struct {
InitFunc func() error
PatchFunc func(args *[]ast.Node) error
Func func(params ...any) (any, error)
Types []reflect.Type
}
func buildFunctionMap(config *BuiltinConfig) map[string]*Function {
geoMatcher := geo.NewGeoMatcher(config.GeoSiteFilename, config.GeoIpFilename)
return map[string]*Function{
"geoip": {
InitFunc: geoMatcher.LoadGeoIP,
PatchFunc: nil,
Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil
},
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)},
},
"geosite": {
InitFunc: geoMatcher.LoadGeoSite,
PatchFunc: nil,
Func: func(params ...any) (any, error) {
return geoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil
},
Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)},
},
"cidr": {
InitFunc: nil,
PatchFunc: func(args *[]ast.Node) error {
cidrStringNode, ok := (*args)[1].(*ast.StringNode)
if !ok {
return fmt.Errorf("cidr: invalid argument type")
}
cidr, err := builtins.CompileCIDR(cidrStringNode.Value)
if err != nil {
return err
}
(*args)[1] = &ast.ConstantNode{Value: cidr}
return nil
},
Func: func(params ...any) (any, error) {
return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil
},
Types: []reflect.Type{reflect.TypeOf(builtins.MatchCIDR)},
},
"lookup": {
InitFunc: nil,
PatchFunc: func(args *[]ast.Node) error {
var serverStr *ast.StringNode
if len(*args) > 1 {
// Has the optional server argument
var ok bool
serverStr, ok = (*args)[1].(*ast.StringNode)
if !ok {
return fmt.Errorf("lookup: invalid argument type")
}
}
r := &net.Resolver{
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
if serverStr != nil {
address = serverStr.Value
}
return config.ProtectedDialContext(ctx, network, address)
},
}
if len(*args) > 1 {
(*args)[1] = &ast.ConstantNode{Value: r}
} else {
*args = append(*args, &ast.ConstantNode{Value: r})
}
return nil
},
Func: func(params ...any) (any, error) {
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
defer cancel()
return params[1].(*net.Resolver).LookupHost(ctx, params[0].(string))
},
Types: []reflect.Type{
reflect.TypeOf((func(string, *net.Resolver) []string)(nil)),
},
},
}
}

View File

@ -1,6 +1,7 @@
package ruleset package ruleset
import ( import (
"context"
"net" "net"
"strconv" "strconv"
@ -103,4 +104,5 @@ type BuiltinConfig struct {
Logger Logger Logger Logger
GeoSiteFilename string GeoSiteFilename string
GeoIpFilename string GeoIpFilename string
ProtectedDialContext func(ctx context.Context, network, address string) (net.Conn, error)
} }