diff --git a/ruleset/expr.go b/ruleset/expr.go index 5bcfb4d..738ed7c 100644 --- a/ruleset/expr.go +++ b/ruleset/expr.go @@ -46,7 +46,6 @@ type compiledExprRule struct { Action Action ModInstance modifier.Instance Program *vm.Program - Analyzers map[string]struct{} } var _ Ruleset = (*exprRuleset)(nil) @@ -100,55 +99,45 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier if !ok { return nil, fmt.Errorf("rule %q has invalid action %q", rule.Name, rule.Action) } - visitor := &depVisitor{Analyzers: make(map[string]struct{})} - geoip := expr.Function( - "geoip", - func(params ...any) (any, error) { - return geoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil - }, - new(func(string, string) bool), - ) - geosite := expr.Function( - "geosite", - func(params ...any) (any, error) { - return geoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil - }, - new(func(string, string) bool), - ) + visitor := &idVisitor{Identifiers: make(map[string]bool)} program, err := expr.Compile(rule.Expr, func(c *conf.Config) { c.Strict = false c.Expect = reflect.Bool c.Visitors = append(c.Visitors, visitor) + registerBuiltinFunctions(c.Functions, geoMatcher) }, - geoip, - geosite, ) if err != nil { return nil, fmt.Errorf("rule %q has invalid expression: %w", rule.Name, err) } - for name := range visitor.Analyzers { - a, ok := fullAnMap[name] - if !ok && !isBuiltInAnalyzer(name) { - return nil, fmt.Errorf("rule %q uses unknown analyzer %q", rule.Name, name) + for name := range visitor.Identifiers { + if isBuiltInAnalyzer(name) { + continue } - depAnMap[name] = a - } - if visitor.UseGeoSite { - if err := geoMatcher.LoadGeoSite(); err != nil { - return nil, fmt.Errorf("rule %q failed to load geosite: %w", rule.Name, err) - } - } - if visitor.UseGeoIp { - if err := geoMatcher.LoadGeoIP(); err != nil { - return nil, fmt.Errorf("rule %q failed to load geoip: %w", rule.Name, err) + // Check if it's one of the built-in functions, and if so, + // skip it as an analyzer & do initialization if necessary. + switch name { + case "geoip": + if err := geoMatcher.LoadGeoIP(); err != nil { + return nil, fmt.Errorf("rule %q failed to load geoip: %w", rule.Name, err) + } + case "geosite": + if err := geoMatcher.LoadGeoSite(); err != nil { + return nil, fmt.Errorf("rule %q failed to load geosite: %w", rule.Name, err) + } + default: + a, ok := fullAnMap[name] + if !ok { + return nil, fmt.Errorf("rule %q uses unknown analyzer %q", rule.Name, name) + } + depAnMap[name] = a } } cr := compiledExprRule{ - Name: rule.Name, - Action: action, - Program: program, - Analyzers: visitor.Analyzers, + Name: rule.Name, + Action: action, + Program: program, } if action == ActionModify { mod, ok := fullModMap[rule.Modifier.Name] @@ -175,6 +164,23 @@ func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier }, nil } +func registerBuiltinFunctions(funcMap map[string]*ast.Function, geoMatcher *geo.GeoMatcher) { + funcMap["geoip"] = &ast.Function{ + Name: "geoip", + Func: func(params ...any) (any, error) { + return geoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil + }, + Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoIp)}, + } + funcMap["geosite"] = &ast.Function{ + Name: "geosite", + Func: func(params ...any) (any, error) { + return geoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil + }, + Types: []reflect.Type{reflect.TypeOf(geoMatcher.MatchGeoSite)}, + } +} + func streamInfoToExprEnv(info StreamInfo) map[string]interface{} { m := map[string]interface{}{ "id": info.ID, @@ -241,22 +247,12 @@ func modifiersToMap(mods []modifier.Modifier) map[string]modifier.Modifier { return modMap } -type depVisitor struct { - Analyzers map[string]struct{} - - UseGeoSite bool - UseGeoIp bool +type idVisitor struct { + Identifiers map[string]bool } -func (v *depVisitor) Visit(node *ast.Node) { +func (v *idVisitor) Visit(node *ast.Node) { if idNode, ok := (*node).(*ast.IdentifierNode); ok { - switch idNode.Value { - case "geosite": - v.UseGeoSite = true - case "geoip": - v.UseGeoIp = true - default: - v.Analyzers[idNode.Value] = struct{}{} - } + v.Identifiers[idNode.Value] = true } }