diff --git a/README.ja.md b/README.ja.md index b3acc30..5757cc5 100644 --- a/README.ja.md +++ b/README.ja.md @@ -78,6 +78,7 @@ io: rcvBuf: 4194304 sndBuf: 4194304 local: true # FORWARD チェーンで OpenGFW を実行したい場合は false に設定する + rst: false # ブロックされたTCP接続に対してRSTを送信する場合はtrueに設定してください。local=falseのみです workers: count: 4 diff --git a/README.md b/README.md index 50758a2..7da03e1 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,7 @@ io: rcvBuf: 4194304 sndBuf: 4194304 local: true # set to false if you want to run OpenGFW on FORWARD chain + rst: false # set to true if you want to send RST for blocked TCP connections, local=false only workers: count: 4 diff --git a/README.zh.md b/README.zh.md index 5580c16..66d00bc 100644 --- a/README.zh.md +++ b/README.zh.md @@ -78,6 +78,7 @@ io: rcvBuf: 4194304 sndBuf: 4194304 local: true # 如果需要在 FORWARD 链上运行 OpenGFW,请设置为 false + rst: false # 是否对要阻断的 TCP 连接发送 RST。仅在 local=false 时有效 workers: count: 4 diff --git a/cmd/root.go b/cmd/root.go index 0c8fa77..7e54462 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -171,6 +171,7 @@ type cliConfigIO struct { ReadBuffer int `mapstructure:"rcvBuf"` WriteBuffer int `mapstructure:"sndBuf"` Local bool `mapstructure:"local"` + RST bool `mapstructure:"rst"` } type cliConfigWorkers struct { @@ -197,6 +198,7 @@ func (c *cliConfig) fillIO(config *engine.Config) error { ReadBuffer: c.IO.ReadBuffer, WriteBuffer: c.IO.WriteBuffer, Local: c.IO.Local, + RST: c.IO.RST, }) if err != nil { return configError{Field: "io", Err: err} diff --git a/io/nfqueue.go b/io/nfqueue.go index 2c1cff2..499ff2c 100644 --- a/io/nfqueue.go +++ b/io/nfqueue.go @@ -27,59 +27,60 @@ const ( nftTable = "opengfw" ) -var nftRulesForward = fmt.Sprintf(` -define ACCEPT_CTMARK=%d -define DROP_CTMARK=%d -define QUEUE_NUM=%d - -table %s %s { - chain FORWARD { - type filter hook forward priority filter; policy accept; - - ct mark $ACCEPT_CTMARK counter accept - ct mark $DROP_CTMARK counter drop - counter queue num $QUEUE_NUM bypass - } -} -`, nfqueueConnMarkAccept, nfqueueConnMarkDrop, nfqueueNum, nftFamily, nftTable) - -var nftRulesLocal = fmt.Sprintf(` -define ACCEPT_CTMARK=%d -define DROP_CTMARK=%d -define QUEUE_NUM=%d - -table %s %s { - chain INPUT { - type filter hook input priority filter; policy accept; - - ct mark $ACCEPT_CTMARK counter accept - ct mark $DROP_CTMARK counter drop - counter queue num $QUEUE_NUM bypass - } - chain OUTPUT { - type filter hook output priority filter; policy accept; - - ct mark $ACCEPT_CTMARK counter accept - ct mark $DROP_CTMARK counter drop - counter queue num $QUEUE_NUM bypass - } -} -`, nfqueueConnMarkAccept, nfqueueConnMarkDrop, nfqueueNum, nftFamily, nftTable) - -var iptRulesForward = []iptRule{ - {"filter", "FORWARD", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}}, - {"filter", "FORWARD", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}}, - {"filter", "FORWARD", []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}}, +func generateNftRules(local, rst bool) (*nftTableSpec, error) { + if local && rst { + return nil, errors.New("tcp rst is not supported in local mode") + } + table := &nftTableSpec{ + Family: nftFamily, + Table: nftTable, + } + table.Defines = append(table.Defines, fmt.Sprintf("define ACCEPT_CTMARK=%d", nfqueueConnMarkAccept)) + table.Defines = append(table.Defines, fmt.Sprintf("define DROP_CTMARK=%d", nfqueueConnMarkDrop)) + table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d", nfqueueNum)) + if local { + table.Chains = []nftChainSpec{ + {Chain: "INPUT", Header: "type filter hook input priority filter; policy accept;"}, + {Chain: "OUTPUT", Header: "type filter hook output priority filter; policy accept;"}, + } + } else { + table.Chains = []nftChainSpec{ + {Chain: "FORWARD", Header: "type filter hook forward priority filter; policy accept;"}, + } + } + for i := range table.Chains { + c := &table.Chains[i] + c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept") + if rst { + c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset") + } + c.Rules = append(c.Rules, "ct mark $DROP_CTMARK counter drop") + c.Rules = append(c.Rules, "counter queue num $QUEUE_NUM bypass") + } + return table, nil } -var iptRulesLocal = []iptRule{ - {"filter", "INPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}}, - {"filter", "INPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}}, - {"filter", "INPUT", []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}}, +func generateIptRules(local, rst bool) ([]iptRule, error) { + if local && rst { + return nil, errors.New("tcp rst is not supported in local mode") + } + var chains []string + if local { + chains = []string{"INPUT", "OUTPUT"} + } else { + chains = []string{"FORWARD"} + } + rules := make([]iptRule, 0, 4*len(chains)) + for _, chain := range chains { + rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}}) + if rst { + rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}}) + } + rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}}) + rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}}) + } - {"filter", "OUTPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}}, - {"filter", "OUTPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}}, - {"filter", "OUTPUT", []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}}, + return rules, nil } var _ PacketIO = (*nfqueuePacketIO)(nil) @@ -89,6 +90,7 @@ var errNotNFQueuePacket = errors.New("not an NFQueue packet") type nfqueuePacketIO struct { n *nfqueue.Nfqueue local bool + rst bool rSet bool // whether the nftables/iptables rules have been set // iptables not nil = use iptables instead of nftables @@ -101,6 +103,7 @@ type NFQueuePacketIOConfig struct { ReadBuffer int WriteBuffer int Local bool + RST bool } func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) { @@ -147,6 +150,7 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) { return &nfqueuePacketIO{ n: n, local: config.Local, + rst: config.RST, ipt4: ipt4, ipt6: ipt6, }, nil @@ -182,9 +186,9 @@ func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error } if !n.rSet { if n.ipt4 != nil { - err = n.setupIpt(n.local, false) + err = n.setupIpt(n.local, n.rst, false) } else { - err = n.setupNft(n.local, false) + err = n.setupNft(n.local, n.rst, false) } if err != nil { return err @@ -238,29 +242,27 @@ func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) erro func (n *nfqueuePacketIO) Close() error { if n.rSet { if n.ipt4 != nil { - _ = n.setupIpt(n.local, true) + _ = n.setupIpt(n.local, n.rst, true) } else { - _ = n.setupNft(n.local, true) + _ = n.setupNft(n.local, n.rst, true) } n.rSet = false } return n.n.Close() } -func (n *nfqueuePacketIO) setupNft(local, remove bool) error { - var rules string - if local { - rules = nftRulesLocal - } else { - rules = nftRulesForward +func (n *nfqueuePacketIO) setupNft(local, rst, remove bool) error { + rules, err := generateNftRules(local, rst) + if err != nil { + return err } - var err error + rulesText := rules.String() if remove { err = nftDelete(nftFamily, nftTable) } else { // Delete first to make sure no leftover rules _ = nftDelete(nftFamily, nftTable) - err = nftAdd(rules) + err = nftAdd(rulesText) } if err != nil { return err @@ -268,14 +270,11 @@ func (n *nfqueuePacketIO) setupNft(local, remove bool) error { return nil } -func (n *nfqueuePacketIO) setupIpt(local, remove bool) error { - var rules []iptRule - if local { - rules = iptRulesLocal - } else { - rules = iptRulesForward +func (n *nfqueuePacketIO) setupIpt(local, rst, remove bool) error { + rules, err := generateIptRules(local, rst) + if err != nil { + return err } - var err error if remove { err = iptsBatchDeleteIfExists([]*iptables.IPTables{n.ipt4, n.ipt6}, rules) } else { @@ -330,6 +329,42 @@ func nftDelete(family, table string) error { return cmd.Run() } +type nftTableSpec struct { + Defines []string + Family, Table string + Chains []nftChainSpec +} + +func (t *nftTableSpec) String() string { + chains := make([]string, 0, len(t.Chains)) + for _, c := range t.Chains { + chains = append(chains, c.String()) + } + + return fmt.Sprintf(` +%s + +table %s %s { +%s +} +`, strings.Join(t.Defines, "\n"), t.Family, t.Table, strings.Join(chains, "")) +} + +type nftChainSpec struct { + Chain string + Header string + Rules []string +} + +func (c *nftChainSpec) String() string { + return fmt.Sprintf(` + chain %s { + %s + %s + } +`, c.Chain, c.Header, strings.Join(c.Rules, "\n\x20\x20\x20\x20")) +} + type iptRule struct { Table, Chain string RuleSpec []string