diff --git a/io/nfqueue.go b/io/nfqueue.go index 8667b8f..499ff2c 100644 --- a/io/nfqueue.go +++ b/io/nfqueue.go @@ -27,83 +27,60 @@ const ( nftTable = "opengfw" ) -var nftRulesForward = fmt.Sprintf(` -define ACCEPT_CTMARK=%d -define DROP_CTMARK=%d -define QUEUE_NUM=%d - -table %s %s { - chain FORWARD { - type filter hook forward priority filter; policy accept; - - ct mark $ACCEPT_CTMARK counter accept - ct mark $DROP_CTMARK counter drop - counter queue num $QUEUE_NUM bypass - } -} -`, nfqueueConnMarkAccept, nfqueueConnMarkDrop, nfqueueNum, nftFamily, nftTable) - -var nftRulesForwardRST = fmt.Sprintf(` -define ACCEPT_CTMARK=%d -define DROP_CTMARK=%d -define QUEUE_NUM=%d - -table %s %s { - chain FORWARD { - type filter hook forward priority filter; policy accept; - - ct mark $ACCEPT_CTMARK counter accept - ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset - ct mark $DROP_CTMARK counter drop - counter queue num $QUEUE_NUM bypass - } -} -`, nfqueueConnMarkAccept, nfqueueConnMarkDrop, nfqueueNum, nftFamily, nftTable) - -var nftRulesLocal = fmt.Sprintf(` -define ACCEPT_CTMARK=%d -define DROP_CTMARK=%d -define QUEUE_NUM=%d - -table %s %s { - chain INPUT { - type filter hook input priority filter; policy accept; - - ct mark $ACCEPT_CTMARK counter accept - ct mark $DROP_CTMARK counter drop - counter queue num $QUEUE_NUM bypass - } - chain OUTPUT { - type filter hook output priority filter; policy accept; - - ct mark $ACCEPT_CTMARK counter accept - ct mark $DROP_CTMARK counter drop - counter queue num $QUEUE_NUM bypass - } -} -`, nfqueueConnMarkAccept, nfqueueConnMarkDrop, nfqueueNum, nftFamily, nftTable) - -var iptRulesForward = []iptRule{ - {"filter", "FORWARD", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}}, - {"filter", "FORWARD", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}}, - {"filter", "FORWARD", []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}}, +func generateNftRules(local, rst bool) (*nftTableSpec, error) { + if local && rst { + return nil, errors.New("tcp rst is not supported in local mode") + } + table := &nftTableSpec{ + Family: nftFamily, + Table: nftTable, + } + table.Defines = append(table.Defines, fmt.Sprintf("define ACCEPT_CTMARK=%d", nfqueueConnMarkAccept)) + table.Defines = append(table.Defines, fmt.Sprintf("define DROP_CTMARK=%d", nfqueueConnMarkDrop)) + table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d", nfqueueNum)) + if local { + table.Chains = []nftChainSpec{ + {Chain: "INPUT", Header: "type filter hook input priority filter; policy accept;"}, + {Chain: "OUTPUT", Header: "type filter hook output priority filter; policy accept;"}, + } + } else { + table.Chains = []nftChainSpec{ + {Chain: "FORWARD", Header: "type filter hook forward priority filter; policy accept;"}, + } + } + for i := range table.Chains { + c := &table.Chains[i] + c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept") + if rst { + c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset") + } + c.Rules = append(c.Rules, "ct mark $DROP_CTMARK counter drop") + c.Rules = append(c.Rules, "counter queue num $QUEUE_NUM bypass") + } + return table, nil } -var iptRulesForwardRST = []iptRule{ - {"filter", "FORWARD", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}}, - {"filter", "FORWARD", []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}}, - {"filter", "FORWARD", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}}, - {"filter", "FORWARD", []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}}, -} +func generateIptRules(local, rst bool) ([]iptRule, error) { + if local && rst { + return nil, errors.New("tcp rst is not supported in local mode") + } + var chains []string + if local { + chains = []string{"INPUT", "OUTPUT"} + } else { + chains = []string{"FORWARD"} + } + rules := make([]iptRule, 0, 4*len(chains)) + for _, chain := range chains { + rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}}) + if rst { + rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}}) + } + rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}}) + rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}}) + } -var iptRulesLocal = []iptRule{ - {"filter", "INPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}}, - {"filter", "INPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}}, - {"filter", "INPUT", []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}}, - - {"filter", "OUTPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}}, - {"filter", "OUTPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}}, - {"filter", "OUTPUT", []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}}, + return rules, nil } var _ PacketIO = (*nfqueuePacketIO)(nil) @@ -275,23 +252,17 @@ func (n *nfqueuePacketIO) Close() error { } func (n *nfqueuePacketIO) setupNft(local, rst, remove bool) error { - var rules string - if local { - rules = nftRulesLocal - } else { - if rst { - rules = nftRulesForwardRST - } else { - rules = nftRulesForward - } + rules, err := generateNftRules(local, rst) + if err != nil { + return err } - var err error + rulesText := rules.String() if remove { err = nftDelete(nftFamily, nftTable) } else { // Delete first to make sure no leftover rules _ = nftDelete(nftFamily, nftTable) - err = nftAdd(rules) + err = nftAdd(rulesText) } if err != nil { return err @@ -300,17 +271,10 @@ func (n *nfqueuePacketIO) setupNft(local, rst, remove bool) error { } func (n *nfqueuePacketIO) setupIpt(local, rst, remove bool) error { - var rules []iptRule - if local { - rules = iptRulesLocal - } else { - if rst { - rules = iptRulesForwardRST - } else { - rules = iptRulesForward - } + rules, err := generateIptRules(local, rst) + if err != nil { + return err } - var err error if remove { err = iptsBatchDeleteIfExists([]*iptables.IPTables{n.ipt4, n.ipt6}, rules) } else { @@ -365,6 +329,42 @@ func nftDelete(family, table string) error { return cmd.Run() } +type nftTableSpec struct { + Defines []string + Family, Table string + Chains []nftChainSpec +} + +func (t *nftTableSpec) String() string { + chains := make([]string, 0, len(t.Chains)) + for _, c := range t.Chains { + chains = append(chains, c.String()) + } + + return fmt.Sprintf(` +%s + +table %s %s { +%s +} +`, strings.Join(t.Defines, "\n"), t.Family, t.Table, strings.Join(chains, "")) +} + +type nftChainSpec struct { + Chain string + Header string + Rules []string +} + +func (c *nftChainSpec) String() string { + return fmt.Sprintf(` + chain %s { + %s + %s + } +`, c.Chain, c.Header, strings.Join(c.Rules, "\n\x20\x20\x20\x20")) +} + type iptRule struct { Table, Chain string RuleSpec []string