OpenGFW/ruleset/expr.go

378 lines
10 KiB
Go
Raw Normal View History

2024-01-20 08:45:01 +08:00
package ruleset
import (
2024-04-04 11:02:57 +08:00
"context"
2024-01-20 08:45:01 +08:00
"fmt"
"net"
2024-01-20 08:45:01 +08:00
"os"
"reflect"
"strings"
2024-04-04 11:02:57 +08:00
"time"
"github.com/expr-lang/expr/builtin"
2024-01-20 08:45:01 +08:00
"github.com/expr-lang/expr"
"github.com/expr-lang/expr/ast"
"github.com/expr-lang/expr/conf"
"github.com/expr-lang/expr/vm"
"gopkg.in/yaml.v3"
"github.com/apernet/OpenGFW/analyzer"
"github.com/apernet/OpenGFW/modifier"
"github.com/apernet/OpenGFW/ruleset/builtins"
2024-01-20 08:45:01 +08:00
)
// ExprRule is the external representation of an expression rule.
type ExprRule struct {
Name string `yaml:"name"`
Action string `yaml:"action"`
2024-02-24 06:13:35 +08:00
Log bool `yaml:"log"`
2024-01-20 08:45:01 +08:00
Modifier ModifierEntry `yaml:"modifier"`
Expr string `yaml:"expr"`
}
type ModifierEntry struct {
Name string `yaml:"name"`
Args map[string][]interface{} `yaml:"args"`
2024-01-20 08:45:01 +08:00
}
func ExprRulesFromYAML(file string) ([]ExprRule, error) {
bs, err := os.ReadFile(file)
if err != nil {
return nil, err
}
var rules []ExprRule
err = yaml.Unmarshal(bs, &rules)
return rules, err
}
// compiledExprRule is the internal, compiled representation of an expression rule.
type compiledExprRule struct {
Name string
2024-02-24 06:13:35 +08:00
Action *Action // fallthrough if nil
Log bool
2024-01-20 08:45:01 +08:00
ModInstance modifier.Instance
Program *vm.Program
}
var _ Ruleset = (*exprRuleset)(nil)
type exprRuleset struct {
Rules []compiledExprRule
Ans []analyzer.Analyzer
Logger Logger
2024-01-20 08:45:01 +08:00
}
func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer {
return r.Ans
}
2024-02-24 06:13:35 +08:00
func (r *exprRuleset) Match(info StreamInfo) MatchResult {
2024-01-20 08:45:01 +08:00
env := streamInfoToExprEnv(info)
for _, rule := range r.Rules {
v, err := vm.Run(rule.Program, env)
if err != nil {
2024-02-24 06:13:35 +08:00
// Log the error and continue to the next rule.
r.Logger.MatchError(info, rule.Name, err)
continue
2024-01-20 08:45:01 +08:00
}
if vBool, ok := v.(bool); ok && vBool {
2024-02-24 06:13:35 +08:00
if rule.Log {
r.Logger.Log(info, rule.Name)
}
if rule.Action != nil {
return MatchResult{
Action: *rule.Action,
ModInstance: rule.ModInstance,
}
}
2024-01-20 08:45:01 +08:00
}
}
2024-02-24 06:13:35 +08:00
// No match
2024-01-20 08:45:01 +08:00
return MatchResult{
Action: ActionMaybe,
2024-02-24 06:13:35 +08:00
}
2024-01-20 08:45:01 +08:00
}
// CompileExprRules compiles a list of expression rules into a ruleset.
// It returns an error if any of the rules are invalid, or if any of the analyzers
// used by the rules are unknown (not provided in the analyzer list).
func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier.Modifier, config *BuiltinConfig) (Ruleset, error) {
2024-01-20 08:45:01 +08:00
var compiledRules []compiledExprRule
fullAnMap := analyzersToMap(ans)
fullModMap := modifiersToMap(mods)
depAnMap := make(map[string]analyzer.Analyzer)
funcMap := buildFunctionMap(config)
2024-01-20 08:45:01 +08:00
// Compile all rules and build a map of analyzers that are used by the rules.
for _, rule := range rules {
2024-02-24 06:13:35 +08:00
if rule.Action == "" && !rule.Log {
return nil, fmt.Errorf("rule %q must have at least one of action or log", rule.Name)
}
var action *Action
if rule.Action != "" {
a, ok := actionStringToAction(rule.Action)
if !ok {
return nil, fmt.Errorf("rule %q has invalid action %q", rule.Name, rule.Action)
}
action = &a
2024-01-20 08:45:01 +08:00
}
visitor := &idVisitor{Variables: make(map[string]bool), Identifiers: make(map[string]bool)}
2024-04-04 11:02:57 +08:00
patcher := &idPatcher{FuncMap: funcMap}
2024-01-20 08:45:01 +08:00
program, err := expr.Compile(rule.Expr,
func(c *conf.Config) {
c.Strict = false
c.Expect = reflect.Bool
c.Visitors = append(c.Visitors, visitor, patcher)
2024-04-04 11:02:57 +08:00
for name, f := range funcMap {
c.Functions[name] = &builtin.Function{
Name: name,
Func: f.Func,
Types: f.Types,
}
}
2024-01-20 08:45:01 +08:00
},
)
if err != nil {
return nil, fmt.Errorf("rule %q has invalid expression: %w", rule.Name, err)
}
if patcher.Err != nil {
return nil, fmt.Errorf("rule %q failed to patch expression: %w", rule.Name, patcher.Err)
}
for name := range visitor.Identifiers {
// Skip built-in analyzers & user-defined variables
if isBuiltInAnalyzer(name) || visitor.Variables[name] {
continue
2024-01-20 08:45:01 +08:00
}
2024-04-04 11:02:57 +08:00
if f, ok := funcMap[name]; ok {
// Built-in function, initialize if necessary
if f.InitFunc != nil {
if err := f.InitFunc(); err != nil {
return nil, fmt.Errorf("rule %q failed to initialize function %q: %w", rule.Name, name, err)
}
}
2024-04-04 11:02:57 +08:00
} else if a, ok := fullAnMap[name]; ok {
// Analyzer, add to dependency map
depAnMap[name] = a
}
}
2024-01-20 08:45:01 +08:00
cr := compiledExprRule{
Name: rule.Name,
Action: action,
2024-02-24 06:13:35 +08:00
Log: rule.Log,
Program: program,
2024-01-20 08:45:01 +08:00
}
2024-02-24 06:13:35 +08:00
if action != nil && *action == ActionModify {
2024-01-20 08:45:01 +08:00
mod, ok := fullModMap[rule.Modifier.Name]
if !ok {
return nil, fmt.Errorf("rule %q uses unknown modifier %q", rule.Name, rule.Modifier.Name)
}
modInst, err := mod.New(rule.Modifier.Args)
if err != nil {
return nil, fmt.Errorf("rule %q failed to create modifier instance: %w", rule.Name, err)
}
cr.ModInstance = modInst
}
compiledRules = append(compiledRules, cr)
}
// Convert the analyzer map to a list.
var depAns []analyzer.Analyzer
for _, a := range depAnMap {
depAns = append(depAns, a)
}
return &exprRuleset{
Rules: compiledRules,
Ans: depAns,
Logger: config.Logger,
2024-01-20 08:45:01 +08:00
}, nil
}
func streamInfoToExprEnv(info StreamInfo) map[string]interface{} {
m := map[string]interface{}{
"id": info.ID,
"proto": info.Protocol.String(),
"ip": map[string]string{
"src": info.SrcIP.String(),
"dst": info.DstIP.String(),
},
"port": map[string]uint16{
"src": info.SrcPort,
"dst": info.DstPort,
},
}
for anName, anProps := range info.Props {
if len(anProps) != 0 {
// Ignore analyzers with empty properties
m[anName] = anProps
}
}
return m
}
func isBuiltInAnalyzer(name string) bool {
switch name {
case "id", "proto", "ip", "port":
return true
default:
return false
}
}
func actionStringToAction(action string) (Action, bool) {
switch strings.ToLower(action) {
case "allow":
return ActionAllow, true
case "block":
return ActionBlock, true
case "drop":
return ActionDrop, true
case "modify":
return ActionModify, true
default:
return ActionMaybe, false
}
}
// analyzersToMap converts a list of analyzers to a map of name -> analyzer.
// This is for easier lookup when compiling rules.
func analyzersToMap(ans []analyzer.Analyzer) map[string]analyzer.Analyzer {
anMap := make(map[string]analyzer.Analyzer)
for _, a := range ans {
anMap[a.Name()] = a
}
return anMap
}
// modifiersToMap converts a list of modifiers to a map of name -> modifier.
// This is for easier lookup when compiling rules.
func modifiersToMap(mods []modifier.Modifier) map[string]modifier.Modifier {
modMap := make(map[string]modifier.Modifier)
for _, m := range mods {
modMap[m.Name()] = m
}
return modMap
}
// idVisitor is a visitor that collects all identifiers in an expression.
// This is for determining which analyzers are used by the expression.
type idVisitor struct {
Variables map[string]bool
Identifiers map[string]bool
2024-01-20 08:45:01 +08:00
}
func (v *idVisitor) Visit(node *ast.Node) {
if varNode, ok := (*node).(*ast.VariableDeclaratorNode); ok {
v.Variables[varNode.Name] = true
} else if idNode, ok := (*node).(*ast.IdentifierNode); ok {
v.Identifiers[idNode.Value] = true
2024-01-20 08:45:01 +08:00
}
}
// idPatcher patches the AST during expr compilation, replacing certain values with
// their internal representations for better runtime performance.
type idPatcher struct {
2024-04-04 11:02:57 +08:00
FuncMap map[string]*Function
Err error
}
func (p *idPatcher) Visit(node *ast.Node) {
switch (*node).(type) {
case *ast.CallNode:
callNode := (*node).(*ast.CallNode)
2024-04-04 11:02:57 +08:00
if callNode.Callee == nil {
2024-02-24 06:13:35 +08:00
// Ignore invalid call nodes
return
}
2024-04-04 11:02:57 +08:00
if f, ok := p.FuncMap[callNode.Callee.String()]; ok {
if f.PatchFunc != nil {
if err := f.PatchFunc(&callNode.Arguments); err != nil {
p.Err = err
return
}
}
}
}
}
2024-04-04 11:02:57 +08:00
type Function struct {
InitFunc func() error
PatchFunc func(args *[]ast.Node) error
Func func(params ...any) (any, error)
Types []reflect.Type
}
func buildFunctionMap(config *BuiltinConfig) map[string]*Function {
2024-04-04 11:02:57 +08:00
return map[string]*Function{
"geoip": {
InitFunc: config.GeoMatcher.LoadGeoIP,
2024-04-04 11:02:57 +08:00
PatchFunc: nil,
Func: func(params ...any) (any, error) {
return config.GeoMatcher.MatchGeoIp(params[0].(string), params[1].(string)), nil
2024-04-04 11:02:57 +08:00
},
Types: []reflect.Type{reflect.TypeOf(config.GeoMatcher.MatchGeoIp)},
2024-04-04 11:02:57 +08:00
},
"geosite": {
InitFunc: config.GeoMatcher.LoadGeoSite,
2024-04-04 11:02:57 +08:00
PatchFunc: nil,
Func: func(params ...any) (any, error) {
return config.GeoMatcher.MatchGeoSite(params[0].(string), params[1].(string)), nil
2024-04-04 11:02:57 +08:00
},
Types: []reflect.Type{reflect.TypeOf(config.GeoMatcher.MatchGeoSite)},
2024-04-04 11:02:57 +08:00
},
"cidr": {
InitFunc: nil,
PatchFunc: func(args *[]ast.Node) error {
cidrStringNode, ok := (*args)[1].(*ast.StringNode)
if !ok {
return fmt.Errorf("cidr: invalid argument type")
}
cidr, err := builtins.CompileCIDR(cidrStringNode.Value)
if err != nil {
return err
}
(*args)[1] = &ast.ConstantNode{Value: cidr}
return nil
},
Func: func(params ...any) (any, error) {
return builtins.MatchCIDR(params[0].(string), params[1].(*net.IPNet)), nil
},
Types: []reflect.Type{reflect.TypeOf(builtins.MatchCIDR)},
2024-04-04 11:02:57 +08:00
},
"lookup": {
InitFunc: nil,
PatchFunc: func(args *[]ast.Node) error {
var serverStr *ast.StringNode
if len(*args) > 1 {
// Has the optional server argument
var ok bool
serverStr, ok = (*args)[1].(*ast.StringNode)
if !ok {
return fmt.Errorf("lookup: invalid argument type")
}
2024-04-04 11:02:57 +08:00
}
r := &net.Resolver{
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
if serverStr != nil {
address = serverStr.Value
}
return config.ProtectedDialContext(ctx, network, address)
2024-04-04 11:02:57 +08:00
},
}
if len(*args) > 1 {
(*args)[1] = &ast.ConstantNode{Value: r}
} else {
*args = append(*args, &ast.ConstantNode{Value: r})
}
2024-04-04 11:02:57 +08:00
return nil
},
Func: func(params ...any) (any, error) {
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
2024-04-04 11:02:57 +08:00
defer cancel()
return params[1].(*net.Resolver).LookupHost(ctx, params[0].(string))
2024-04-04 11:02:57 +08:00
},
Types: []reflect.Type{
reflect.TypeOf((func(string, *net.Resolver) []string)(nil)),
},
},
}
}