mirror of
https://github.com/apernet/OpenGFW.git
synced 2024-11-11 04:49:22 +08:00
fix: netlink race condition (#48)
This commit is contained in:
parent
6871244809
commit
843f17896c
@ -41,10 +41,11 @@ var _ PacketIO = (*nfqueuePacketIO)(nil)
|
|||||||
var errNotNFQueuePacket = errors.New("not an NFQueue packet")
|
var errNotNFQueuePacket = errors.New("not an NFQueue packet")
|
||||||
|
|
||||||
type nfqueuePacketIO struct {
|
type nfqueuePacketIO struct {
|
||||||
n *nfqueue.Nfqueue
|
n *nfqueue.Nfqueue
|
||||||
local bool
|
local bool
|
||||||
ipt4 *iptables.IPTables
|
ipt4 *iptables.IPTables
|
||||||
ipt6 *iptables.IPTables
|
ipt6 *iptables.IPTables
|
||||||
|
iptSet bool // whether iptables rules are set
|
||||||
}
|
}
|
||||||
|
|
||||||
type NFQueuePacketIOConfig struct {
|
type NFQueuePacketIOConfig struct {
|
||||||
@ -74,22 +75,16 @@ func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
io := &nfqueuePacketIO{
|
return &nfqueuePacketIO{
|
||||||
n: n,
|
n: n,
|
||||||
local: config.Local,
|
local: config.Local,
|
||||||
ipt4: ipt4,
|
ipt4: ipt4,
|
||||||
ipt6: ipt6,
|
ipt6: ipt6,
|
||||||
}
|
}, nil
|
||||||
err = io.setupIpt(config.Local, false)
|
|
||||||
if err != nil {
|
|
||||||
_ = n.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return io, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error {
|
func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error {
|
||||||
return n.n.RegisterWithErrorFunc(ctx,
|
err := n.n.RegisterWithErrorFunc(ctx,
|
||||||
func(a nfqueue.Attribute) int {
|
func(a nfqueue.Attribute) int {
|
||||||
if a.PacketID == nil || a.Ct == nil || a.Payload == nil || len(*a.Payload) < 20 {
|
if a.PacketID == nil || a.Ct == nil || a.Payload == nil || len(*a.Payload) < 20 {
|
||||||
// Invalid packet, ignore
|
// Invalid packet, ignore
|
||||||
@ -106,6 +101,17 @@ func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error
|
|||||||
func(e error) int {
|
func(e error) int {
|
||||||
return okBoolToInt(cb(nil, e))
|
return okBoolToInt(cb(nil, e))
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !n.iptSet {
|
||||||
|
err = n.setupIpt(n.local, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
n.iptSet = true
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error {
|
func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error {
|
||||||
@ -150,9 +156,13 @@ func (n *nfqueuePacketIO) setupIpt(local, remove bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (n *nfqueuePacketIO) Close() error {
|
func (n *nfqueuePacketIO) Close() error {
|
||||||
err := n.setupIpt(n.local, true)
|
if n.iptSet {
|
||||||
_ = n.n.Close()
|
err := n.setupIpt(n.local, true)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n.n.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Packet = (*nfqueuePacket)(nil)
|
var _ Packet = (*nfqueuePacket)(nil)
|
||||||
|
Loading…
Reference in New Issue
Block a user