Merge pull request #109 from apernet/wip-io-rst

feat: io tcp reset support (forward only)
This commit is contained in:
Toby 2024-03-21 18:43:43 -07:00 committed by GitHub
commit bf2988116a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 109 additions and 69 deletions

View File

@ -78,6 +78,7 @@ io:
rcvBuf: 4194304
sndBuf: 4194304
local: true # FORWARD チェーンで OpenGFW を実行したい場合は false に設定する
rst: false # ブロックされたTCP接続に対してRSTを送信する場合はtrueに設定してください。local=falseのみです
workers:
count: 4

View File

@ -82,6 +82,7 @@ io:
rcvBuf: 4194304
sndBuf: 4194304
local: true # set to false if you want to run OpenGFW on FORWARD chain
rst: false # set to true if you want to send RST for blocked TCP connections, local=false only
workers:
count: 4

View File

@ -78,6 +78,7 @@ io:
rcvBuf: 4194304
sndBuf: 4194304
local: true # 如果需要在 FORWARD 链上运行 OpenGFW请设置为 false
rst: false # 是否对要阻断的 TCP 连接发送 RST。仅在 local=false 时有效
workers:
count: 4

View File

@ -171,6 +171,7 @@ type cliConfigIO struct {
ReadBuffer int `mapstructure:"rcvBuf"`
WriteBuffer int `mapstructure:"sndBuf"`
Local bool `mapstructure:"local"`
RST bool `mapstructure:"rst"`
}
type cliConfigWorkers struct {
@ -197,6 +198,7 @@ func (c *cliConfig) fillIO(config *engine.Config) error {
ReadBuffer: c.IO.ReadBuffer,
WriteBuffer: c.IO.WriteBuffer,
Local: c.IO.Local,
RST: c.IO.RST,
})
if err != nil {
return configError{Field: "io", Err: err}

View File

@ -27,59 +27,60 @@ const (
nftTable = "opengfw"
)
var nftRulesForward = fmt.Sprintf(`
define ACCEPT_CTMARK=%d
define DROP_CTMARK=%d
define QUEUE_NUM=%d
table %s %s {
chain FORWARD {
type filter hook forward priority filter; policy accept;
ct mark $ACCEPT_CTMARK counter accept
ct mark $DROP_CTMARK counter drop
counter queue num $QUEUE_NUM bypass
}
}
`, nfqueueConnMarkAccept, nfqueueConnMarkDrop, nfqueueNum, nftFamily, nftTable)
var nftRulesLocal = fmt.Sprintf(`
define ACCEPT_CTMARK=%d
define DROP_CTMARK=%d
define QUEUE_NUM=%d
table %s %s {
chain INPUT {
type filter hook input priority filter; policy accept;
ct mark $ACCEPT_CTMARK counter accept
ct mark $DROP_CTMARK counter drop
counter queue num $QUEUE_NUM bypass
}
chain OUTPUT {
type filter hook output priority filter; policy accept;
ct mark $ACCEPT_CTMARK counter accept
ct mark $DROP_CTMARK counter drop
counter queue num $QUEUE_NUM bypass
}
}
`, nfqueueConnMarkAccept, nfqueueConnMarkDrop, nfqueueNum, nftFamily, nftTable)
var iptRulesForward = []iptRule{
{"filter", "FORWARD", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}},
{"filter", "FORWARD", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}},
{"filter", "FORWARD", []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}},
func generateNftRules(local, rst bool) (*nftTableSpec, error) {
if local && rst {
return nil, errors.New("tcp rst is not supported in local mode")
}
table := &nftTableSpec{
Family: nftFamily,
Table: nftTable,
}
table.Defines = append(table.Defines, fmt.Sprintf("define ACCEPT_CTMARK=%d", nfqueueConnMarkAccept))
table.Defines = append(table.Defines, fmt.Sprintf("define DROP_CTMARK=%d", nfqueueConnMarkDrop))
table.Defines = append(table.Defines, fmt.Sprintf("define QUEUE_NUM=%d", nfqueueNum))
if local {
table.Chains = []nftChainSpec{
{Chain: "INPUT", Header: "type filter hook input priority filter; policy accept;"},
{Chain: "OUTPUT", Header: "type filter hook output priority filter; policy accept;"},
}
} else {
table.Chains = []nftChainSpec{
{Chain: "FORWARD", Header: "type filter hook forward priority filter; policy accept;"},
}
}
for i := range table.Chains {
c := &table.Chains[i]
c.Rules = append(c.Rules, "ct mark $ACCEPT_CTMARK counter accept")
if rst {
c.Rules = append(c.Rules, "ip protocol tcp ct mark $DROP_CTMARK counter reject with tcp reset")
}
c.Rules = append(c.Rules, "ct mark $DROP_CTMARK counter drop")
c.Rules = append(c.Rules, "counter queue num $QUEUE_NUM bypass")
}
return table, nil
}
var iptRulesLocal = []iptRule{
{"filter", "INPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}},
{"filter", "INPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}},
{"filter", "INPUT", []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}},
func generateIptRules(local, rst bool) ([]iptRule, error) {
if local && rst {
return nil, errors.New("tcp rst is not supported in local mode")
}
var chains []string
if local {
chains = []string{"INPUT", "OUTPUT"}
} else {
chains = []string{"FORWARD"}
}
rules := make([]iptRule, 0, 4*len(chains))
for _, chain := range chains {
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}})
if rst {
rules = append(rules, iptRule{"filter", chain, []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}})
}
rules = append(rules, iptRule{"filter", chain, []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}})
rules = append(rules, iptRule{"filter", chain, []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}})
}
{"filter", "OUTPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}},
{"filter", "OUTPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}},
{"filter", "OUTPUT", []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}},
return rules, nil
}
var _ PacketIO = (*nfqueuePacketIO)(nil)
@ -89,6 +90,7 @@ var errNotNFQueuePacket = errors.New("not an NFQueue packet")
type nfqueuePacketIO struct {
n *nfqueue.Nfqueue
local bool
rst bool
rSet bool // whether the nftables/iptables rules have been set
// iptables not nil = use iptables instead of nftables
@ -101,6 +103,7 @@ type NFQueuePacketIOConfig struct {
ReadBuffer int
WriteBuffer int
Local bool
RST bool
}
func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
@ -147,6 +150,7 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
return &nfqueuePacketIO{
n: n,
local: config.Local,
rst: config.RST,
ipt4: ipt4,
ipt6: ipt6,
}, nil
@ -182,9 +186,9 @@ func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error
}
if !n.rSet {
if n.ipt4 != nil {
err = n.setupIpt(n.local, false)
err = n.setupIpt(n.local, n.rst, false)
} else {
err = n.setupNft(n.local, false)
err = n.setupNft(n.local, n.rst, false)
}
if err != nil {
return err
@ -238,29 +242,27 @@ func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) erro
func (n *nfqueuePacketIO) Close() error {
if n.rSet {
if n.ipt4 != nil {
_ = n.setupIpt(n.local, true)
_ = n.setupIpt(n.local, n.rst, true)
} else {
_ = n.setupNft(n.local, true)
_ = n.setupNft(n.local, n.rst, true)
}
n.rSet = false
}
return n.n.Close()
}
func (n *nfqueuePacketIO) setupNft(local, remove bool) error {
var rules string
if local {
rules = nftRulesLocal
} else {
rules = nftRulesForward
func (n *nfqueuePacketIO) setupNft(local, rst, remove bool) error {
rules, err := generateNftRules(local, rst)
if err != nil {
return err
}
var err error
rulesText := rules.String()
if remove {
err = nftDelete(nftFamily, nftTable)
} else {
// Delete first to make sure no leftover rules
_ = nftDelete(nftFamily, nftTable)
err = nftAdd(rules)
err = nftAdd(rulesText)
}
if err != nil {
return err
@ -268,14 +270,11 @@ func (n *nfqueuePacketIO) setupNft(local, remove bool) error {
return nil
}
func (n *nfqueuePacketIO) setupIpt(local, remove bool) error {
var rules []iptRule
if local {
rules = iptRulesLocal
} else {
rules = iptRulesForward
func (n *nfqueuePacketIO) setupIpt(local, rst, remove bool) error {
rules, err := generateIptRules(local, rst)
if err != nil {
return err
}
var err error
if remove {
err = iptsBatchDeleteIfExists([]*iptables.IPTables{n.ipt4, n.ipt6}, rules)
} else {
@ -330,6 +329,42 @@ func nftDelete(family, table string) error {
return cmd.Run()
}
type nftTableSpec struct {
Defines []string
Family, Table string
Chains []nftChainSpec
}
func (t *nftTableSpec) String() string {
chains := make([]string, 0, len(t.Chains))
for _, c := range t.Chains {
chains = append(chains, c.String())
}
return fmt.Sprintf(`
%s
table %s %s {
%s
}
`, strings.Join(t.Defines, "\n"), t.Family, t.Table, strings.Join(chains, ""))
}
type nftChainSpec struct {
Chain string
Header string
Rules []string
}
func (c *nftChainSpec) String() string {
return fmt.Sprintf(`
chain %s {
%s
%s
}
`, c.Chain, c.Header, strings.Join(c.Rules, "\n\x20\x20\x20\x20"))
}
type iptRule struct {
Table, Chain string
RuleSpec []string