diff --git a/io/nfqueue.go b/io/nfqueue.go index 224159f..9f0ddc2 100644 --- a/io/nfqueue.go +++ b/io/nfqueue.go @@ -41,10 +41,11 @@ var _ PacketIO = (*nfqueuePacketIO)(nil) var errNotNFQueuePacket = errors.New("not an NFQueue packet") type nfqueuePacketIO struct { - n *nfqueue.Nfqueue - local bool - ipt4 *iptables.IPTables - ipt6 *iptables.IPTables + n *nfqueue.Nfqueue + local bool + ipt4 *iptables.IPTables + ipt6 *iptables.IPTables + iptSet bool // whether iptables rules are set } type NFQueuePacketIOConfig struct { @@ -74,22 +75,16 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) { if err != nil { return nil, err } - io := &nfqueuePacketIO{ + return &nfqueuePacketIO{ n: n, local: config.Local, ipt4: ipt4, ipt6: ipt6, - } - err = io.setupIpt(config.Local, false) - if err != nil { - _ = n.Close() - return nil, err - } - return io, nil + }, nil } func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error { - return n.n.RegisterWithErrorFunc(ctx, + err := n.n.RegisterWithErrorFunc(ctx, func(a nfqueue.Attribute) int { if a.PacketID == nil || a.Ct == nil || a.Payload == nil || len(*a.Payload) < 20 { // Invalid packet, ignore @@ -106,6 +101,17 @@ func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error func(e error) int { return okBoolToInt(cb(nil, e)) }) + if err != nil { + return err + } + if !n.iptSet { + err = n.setupIpt(n.local, false) + if err != nil { + return err + } + n.iptSet = true + } + return nil } func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error { @@ -150,9 +156,13 @@ func (n *nfqueuePacketIO) setupIpt(local, remove bool) error { } func (n *nfqueuePacketIO) Close() error { - err := n.setupIpt(n.local, true) - _ = n.n.Close() - return err + if n.iptSet { + err := n.setupIpt(n.local, true) + if err != nil { + return err + } + } + return n.n.Close() } var _ Packet = (*nfqueuePacket)(nil)