fix: netlink race condition (#48)

This commit is contained in:
Toby 2024-02-05 19:32:52 -08:00 committed by GitHub
parent 6871244809
commit 843f17896c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -41,10 +41,11 @@ var _ PacketIO = (*nfqueuePacketIO)(nil)
var errNotNFQueuePacket = errors.New("not an NFQueue packet") var errNotNFQueuePacket = errors.New("not an NFQueue packet")
type nfqueuePacketIO struct { type nfqueuePacketIO struct {
n *nfqueue.Nfqueue n *nfqueue.Nfqueue
local bool local bool
ipt4 *iptables.IPTables ipt4 *iptables.IPTables
ipt6 *iptables.IPTables ipt6 *iptables.IPTables
iptSet bool // whether iptables rules are set
} }
type NFQueuePacketIOConfig struct { type NFQueuePacketIOConfig struct {
@ -74,22 +75,16 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
io := &nfqueuePacketIO{ return &nfqueuePacketIO{
n: n, n: n,
local: config.Local, local: config.Local,
ipt4: ipt4, ipt4: ipt4,
ipt6: ipt6, ipt6: ipt6,
} }, nil
err = io.setupIpt(config.Local, false)
if err != nil {
_ = n.Close()
return nil, err
}
return io, nil
} }
func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error { func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error {
return n.n.RegisterWithErrorFunc(ctx, err := n.n.RegisterWithErrorFunc(ctx,
func(a nfqueue.Attribute) int { func(a nfqueue.Attribute) int {
if a.PacketID == nil || a.Ct == nil || a.Payload == nil || len(*a.Payload) < 20 { if a.PacketID == nil || a.Ct == nil || a.Payload == nil || len(*a.Payload) < 20 {
// Invalid packet, ignore // Invalid packet, ignore
@ -106,6 +101,17 @@ func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error
func(e error) int { func(e error) int {
return okBoolToInt(cb(nil, e)) return okBoolToInt(cb(nil, e))
}) })
if err != nil {
return err
}
if !n.iptSet {
err = n.setupIpt(n.local, false)
if err != nil {
return err
}
n.iptSet = true
}
return nil
} }
func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error { func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error {
@ -150,9 +156,13 @@ func (n *nfqueuePacketIO) setupIpt(local, remove bool) error {
} }
func (n *nfqueuePacketIO) Close() error { func (n *nfqueuePacketIO) Close() error {
err := n.setupIpt(n.local, true) if n.iptSet {
_ = n.n.Close() err := n.setupIpt(n.local, true)
return err if err != nil {
return err
}
}
return n.n.Close()
} }
var _ Packet = (*nfqueuePacket)(nil) var _ Packet = (*nfqueuePacket)(nil)