diff --git a/README.ja.md b/README.ja.md index 973d635..a8e9d16 100644 --- a/README.ja.md +++ b/README.ja.md @@ -130,6 +130,10 @@ workers: - name: block CN geoip action: block expr: geoip(string(ip.dst), "cn") + +- name: block cidr + action: block + expr: cidr(string(ip.dst), "192.168.0.0/16") ``` #### サポートされるアクション diff --git a/README.md b/README.md index b41ce72..890bcc7 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,10 @@ to [Expr Language Definition](https://expr-lang.org/docs/language-definition). - name: block CN geoip action: block expr: geoip(string(ip.dst), "cn") + +- name: block cidr + action: block + expr: cidr(string(ip.dst), "192.168.0.0/16") ``` #### Supported actions diff --git a/README.zh.md b/README.zh.md index f3efd39..25dc2e7 100644 --- a/README.zh.md +++ b/README.zh.md @@ -131,6 +131,10 @@ workers: - name: block CN geoip action: block expr: geoip(string(ip.dst), "cn") + +- name: block cidr + action: block + expr: cidr(string(ip.dst), "192.168.0.0/16") ``` #### 支持的 action diff --git a/ruleset/builtins/cidr.go b/ruleset/builtins/cidr.go new file mode 100644 index 0000000..669d469 --- /dev/null +++ b/ruleset/builtins/cidr.go @@ -0,0 +1,18 @@ +package builtins + +import ( + "net" +) + +func MatchCIDR(ip string, cidr *net.IPNet) bool { + ipAddr := net.ParseIP(ip) + if ipAddr == nil { + return false + } + return cidr.Contains(ipAddr) +} + +func CompileCIDR(cidr string) (*net.IPNet, error) { + _, ipNet, err := net.ParseCIDR(cidr) + return ipNet, err +} diff --git a/ruleset/expr.go b/ruleset/expr.go index 738ed7c..4512e1f 100644 --- a/ruleset/expr.go +++ b/ruleset/expr.go @@ -2,6 +2,7 @@ package ruleset import ( "fmt" + "net" "os" "reflect" "strings" @@ -14,6 +15,7 @@ import ( "github.com/apernet/OpenGFW/analyzer" "github.com/apernet/OpenGFW/modifier" + "github.com/apernet/OpenGFW/ruleset/builtins" "github.com/apernet/OpenGFW/ruleset/builtins/geo" ) @@ -100,17 +102,21 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier return nil, fmt.Errorf("rule %q has invalid action %q", rule.Name, rule.Action) } visitor := &idVisitor{Identifiers: make(map[string]bool)} + patcher := &idPatcher{} program, err := expr.Compile(rule.Expr, func(c *conf.Config) { c.Strict = false c.Expect = reflect.Bool - c.Visitors = append(c.Visitors, visitor) + c.Visitors = append(c.Visitors, visitor, patcher) registerBuiltinFunctions(c.Functions, geoMatcher) }, ) if err != nil { return nil, fmt.Errorf("rule %q has invalid expression: %w", rule.Name, err) } + if patcher.Err != nil { + return nil, fmt.Errorf("rule %q failed to patch expression: %w", rule.Name, patcher.Err) + } for name := range visitor.Identifiers { if isBuiltInAnalyzer(name) { continue @@ -126,6 +132,8 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier if err := geoMatcher.LoadGeoSite(); err != nil { return nil, fmt.Errorf("rule %q failed to load geosite: %w", rule.Name, err) } + case "cidr": + // No initialization needed for CIDR. default: a, ok := fullAnMap[name] if !ok { @@ -179,6 +187,13 @@ func registerBuiltinFunctions(funcMap map[string]*ast.Function, geoMatcher *geo. }, Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)}, } + funcMap["cidr"] = &ast.Function{ + Name: "cidr", + Func: func(params ...any) (any, error) { + return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil + }, + Types: []reflect.Type{reflect.TypeOf((func(string, string) bool)(nil)), reflect.TypeOf(builtins.MatchCIDR)}, + } } func streamInfoToExprEnv(info StreamInfo) map[string]interface{} { @@ -247,6 +262,8 @@ func modifiersToMap(mods []modifier.Modifier) map[string]modifier.Modifier { return modMap } +// idVisitor is a visitor that collects all identifiers in an expression. +// This is for determining which analyzers are used by the expression. type idVisitor struct { Identifiers map[string]bool } @@ -256,3 +273,29 @@ func (v *idVisitor) Visit(node *ast.Node) { v.Identifiers[idNode.Value] = true } } + +// idPatcher patches the AST during expr compilation, replacing certain values with +// their internal representations for better runtime performance. +type idPatcher struct { + Err error +} + +func (p *idPatcher) Visit(node *ast.Node) { + switch (*node).(type) { + case *ast.CallNode: + callNode := (*node).(*ast.CallNode) + switch callNode.Func.Name { + case "cidr": + cidrStringNode, ok := callNode.Arguments[1].(*ast.StringNode) + if !ok { + return + } + cidr, err := builtins.CompileCIDR(cidrStringNode.Value) + if err != nil { + p.Err = err + return + } + callNode.Arguments[1] = &ast.ConstantNode{Value: cidr} + } + } +}