diff --git a/cmd/root.go b/cmd/root.go index 1ccf025..93a4791 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -174,6 +174,8 @@ type cliConfig struct { type cliConfigIO struct { QueueSize uint32 `mapstructure:"queueSize"` + QueueNum uint16 `mapstructure:"queueNum"` + Table string `mapstructure:"table"` ReadBuffer int `mapstructure:"rcvBuf"` WriteBuffer int `mapstructure:"sndBuf"` Local bool `mapstructure:"local"` @@ -221,6 +223,8 @@ func (c *cliConfig) fillIO(config *engine.Config) error { WriteBuffer: c.IO.WriteBuffer, Local: c.IO.Local, RST: c.IO.RST, + QueueNum: c.IO.QueueNum, + Table: c.IO.Table, }) } diff --git a/io/nfqueue.go b/io/nfqueue.go index f1a64df..eeca6d7 100644 --- a/io/nfqueue.go +++ b/io/nfqueue.go @@ -19,18 +19,18 @@ import ( ) const ( - nfqueueNum = 100 + nfqueueDefaultQueueNum = 100 nfqueueMaxPacketLen = 0xFFFF nfqueueDefaultQueueSize = 128 nfqueueConnMarkAccept = 1001 nfqueueConnMarkDrop = 1002 - nftFamily = "inet" - nftTable = "opengfw" + nftFamily = "inet" + nftDefaultTable = "opengfw" ) -func generateNftRules(local, rst bool) (*nftTableSpec, error) { +func generateNftRules(local, rst bool, nfqueueNum int, nftTable string) (*nftTableSpec, error) { if local && rst { return nil, errors.New("tcp rst is not supported in local mode") } @@ -64,7 +64,7 @@ func generateNftRules(local, rst bool) (*nftTableSpec, error) { return table, nil } -func generateIptRules(local, rst bool) ([]iptRule, error) { +func generateIptRules(local, rst bool, nfqueueNum int) ([]iptRule, error) { if local && rst { return nil, errors.New("tcp rst is not supported in local mode") } @@ -94,10 +94,12 @@ var _ PacketIO = (*nfqueuePacketIO)(nil) var errNotNFQueuePacket = errors.New("not an NFQueue packet") type nfqueuePacketIO struct { - n *nfqueue.Nfqueue - local bool - rst bool - rSet bool // whether the nftables/iptables rules have been set + n *nfqueue.Nfqueue + local bool + rst bool + rSet bool // whether the nftables/iptables rules have been set + queueNum int + table string // nftable name // iptables not nil = use iptables instead of nftables ipt4 *iptables.IPTables @@ -108,6 +110,8 @@ type nfqueuePacketIO struct { type NFQueuePacketIOConfig struct { QueueSize uint32 + QueueNum uint16 + Table string ReadBuffer int WriteBuffer int Local bool @@ -118,6 +122,12 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) { if config.QueueSize == 0 { config.QueueSize = nfqueueDefaultQueueSize } + if config.QueueNum == 0 { + config.QueueNum = nfqueueDefaultQueueNum + } + if config.Table == "" { + config.Table = nftDefaultTable + } var ipt4, ipt6 *iptables.IPTables var err error if nftCheck() != nil { @@ -132,7 +142,7 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) { } } n, err := nfqueue.Open(&nfqueue.Config{ - NfQueue: nfqueueNum, + NfQueue: config.QueueNum, MaxPacketLen: nfqueueMaxPacketLen, MaxQueueLen: config.QueueSize, Copymode: nfqueue.NfQnlCopyPacket, @@ -156,11 +166,13 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) { } } return &nfqueuePacketIO{ - n: n, - local: config.Local, - rst: config.RST, - ipt4: ipt4, - ipt6: ipt6, + n: n, + local: config.Local, + rst: config.RST, + queueNum: int(config.QueueNum), + table: config.Table, + ipt4: ipt4, + ipt6: ipt6, protectedDialer: &net.Dialer{ Control: func(network, address string, c syscall.RawConn) error { var err error @@ -214,7 +226,7 @@ func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error if n.ipt4 != nil { err = n.setupIpt(n.local, n.rst, false) } else { - err = n.setupNft(n.local, n.rst, false) + err = n.setupNft(n.local, n.rst, false, n.queueNum) } if err != nil { return err @@ -274,7 +286,7 @@ func (n *nfqueuePacketIO) Close() error { if n.ipt4 != nil { _ = n.setupIpt(n.local, n.rst, true) } else { - _ = n.setupNft(n.local, n.rst, true) + _ = n.setupNft(n.local, n.rst, true, n.queueNum) } n.rSet = false } @@ -286,17 +298,17 @@ func (n *nfqueuePacketIO) SetCancelFunc(cancelFunc context.CancelFunc) error { return nil } -func (n *nfqueuePacketIO) setupNft(local, rst, remove bool) error { - rules, err := generateNftRules(local, rst) +func (n *nfqueuePacketIO) setupNft(local, rst, remove bool, nfqueueNum int) error { + rules, err := generateNftRules(local, rst, nfqueueNum, n.table) if err != nil { return err } rulesText := rules.String() if remove { - err = nftDelete(nftFamily, nftTable) + err = nftDelete(nftFamily, n.table) } else { // Delete first to make sure no leftover rules - _ = nftDelete(nftFamily, nftTable) + _ = nftDelete(nftFamily, n.table) err = nftAdd(rulesText) } if err != nil { @@ -306,7 +318,7 @@ func (n *nfqueuePacketIO) setupNft(local, rst, remove bool) error { } func (n *nfqueuePacketIO) setupIpt(local, rst, remove bool) error { - rules, err := generateIptRules(local, rst) + rules, err := generateIptRules(local, rst, n.queueNum) if err != nil { return err }