OpenGFW/io/nfqueue.go

358 lines
8.3 KiB
Go
Raw Normal View History

2024-01-20 08:45:01 +08:00
package io
import (
"context"
"encoding/binary"
"errors"
"fmt"
"os/exec"
2024-01-20 08:45:01 +08:00
"strconv"
"strings"
2024-01-20 08:45:01 +08:00
"github.com/coreos/go-iptables/iptables"
"github.com/florianl/go-nfqueue"
"github.com/mdlayher/netlink"
"golang.org/x/sys/unix"
2024-01-20 08:45:01 +08:00
)
const (
nfqueueNum = 100
nfqueueMaxPacketLen = 0xFFFF
nfqueueDefaultQueueSize = 128
nfqueueConnMarkAccept = 1001
nfqueueConnMarkDrop = 1002
nftFamily = "inet"
nftTable = "opengfw"
2024-01-20 08:45:01 +08:00
)
var nftRulesForward = fmt.Sprintf(`
define ACCEPT_CTMARK=%d
define DROP_CTMARK=%d
define QUEUE_NUM=%d
table %s %s {
chain FORWARD {
type filter hook forward priority filter; policy accept;
ct mark $ACCEPT_CTMARK counter accept
ct mark $DROP_CTMARK counter drop
counter queue num $QUEUE_NUM bypass
}
}
`, nfqueueConnMarkAccept, nfqueueConnMarkDrop, nfqueueNum, nftFamily, nftTable)
var nftRulesLocal = fmt.Sprintf(`
define ACCEPT_CTMARK=%d
define DROP_CTMARK=%d
define QUEUE_NUM=%d
table %s %s {
chain INPUT {
type filter hook input priority filter; policy accept;
ct mark $ACCEPT_CTMARK counter accept
ct mark $DROP_CTMARK counter drop
counter queue num $QUEUE_NUM bypass
}
chain OUTPUT {
type filter hook output priority filter; policy accept;
ct mark $ACCEPT_CTMARK counter accept
ct mark $DROP_CTMARK counter drop
counter queue num $QUEUE_NUM bypass
}
}
`, nfqueueConnMarkAccept, nfqueueConnMarkDrop, nfqueueNum, nftFamily, nftTable)
2024-01-20 08:45:01 +08:00
var iptRulesForward = []iptRule{
{"filter", "FORWARD", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}},
{"filter", "FORWARD", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}},
{"filter", "FORWARD", []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}},
}
var iptRulesLocal = []iptRule{
{"filter", "INPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}},
{"filter", "INPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}},
{"filter", "INPUT", []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}},
{"filter", "OUTPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}},
{"filter", "OUTPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}},
{"filter", "OUTPUT", []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}},
}
var _ PacketIO = (*nfqueuePacketIO)(nil)
var errNotNFQueuePacket = errors.New("not an NFQueue packet")
type nfqueuePacketIO struct {
n *nfqueue.Nfqueue
local bool
rSet bool // whether the nftables/iptables rules have been set
// iptables not nil = use iptables instead of nftables
ipt4 *iptables.IPTables
ipt6 *iptables.IPTables
2024-01-20 08:45:01 +08:00
}
type NFQueuePacketIOConfig struct {
QueueSize uint32
Local bool
}
func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) {
if config.QueueSize == 0 {
config.QueueSize = nfqueueDefaultQueueSize
}
var ipt4, ipt6 *iptables.IPTables
var err error
if nftCheck() != nil {
// We prefer nftables, but if it's not available, fall back to iptables
ipt4, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, err
}
ipt6, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
return nil, err
}
2024-01-20 08:45:01 +08:00
}
n, err := nfqueue.Open(&nfqueue.Config{
NfQueue: nfqueueNum,
MaxPacketLen: nfqueueMaxPacketLen,
MaxQueueLen: config.QueueSize,
Copymode: nfqueue.NfQnlCopyPacket,
Flags: nfqueue.NfQaCfgFlagConntrack,
})
if err != nil {
return nil, err
}
2024-02-06 11:32:52 +08:00
return &nfqueuePacketIO{
2024-01-20 08:45:01 +08:00
n: n,
local: config.Local,
ipt4: ipt4,
ipt6: ipt6,
2024-02-06 11:32:52 +08:00
}, nil
2024-01-20 08:45:01 +08:00
}
func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error {
2024-02-06 11:32:52 +08:00
err := n.n.RegisterWithErrorFunc(ctx,
2024-01-20 08:45:01 +08:00
func(a nfqueue.Attribute) int {
if ok, verdict := n.packetAttributeSanityCheck(a); !ok {
if a.PacketID != nil {
_ = n.n.SetVerdict(*a.PacketID, verdict)
}
2024-01-20 08:45:01 +08:00
return 0
}
p := &nfqueuePacket{
id: *a.PacketID,
streamID: ctIDFromCtBytes(*a.Ct),
data: *a.Payload,
}
return okBoolToInt(cb(p, nil))
},
func(e error) int {
if opErr := (*netlink.OpError)(nil); errors.As(e, &opErr) {
if errors.Is(opErr.Err, unix.ENOBUFS) {
// Kernel buffer temporarily full, ignore
return 0
}
}
2024-01-20 08:45:01 +08:00
return okBoolToInt(cb(nil, e))
})
2024-02-06 11:32:52 +08:00
if err != nil {
return err
}
if !n.rSet {
if n.ipt4 != nil {
err = n.setupIpt(n.local, false)
} else {
err = n.setupNft(n.local, false)
}
2024-02-06 11:32:52 +08:00
if err != nil {
return err
}
n.rSet = true
2024-02-06 11:32:52 +08:00
}
return nil
2024-01-20 08:45:01 +08:00
}
func (n *nfqueuePacketIO) packetAttributeSanityCheck(a nfqueue.Attribute) (ok bool, verdict int) {
if a.PacketID == nil {
// Re-inject to NFQUEUE is actually not possible in this condition
return false, -1
}
if a.Payload == nil || len(*a.Payload) < 20 {
// 20 is the minimum possible size of an IP packet
return false, nfqueue.NfDrop
}
if a.Ct == nil {
// Multicast packets may not have a conntrack, but only appear in local mode
if n.local {
return false, nfqueue.NfAccept
}
return false, nfqueue.NfDrop
}
return true, -1
}
2024-01-20 08:45:01 +08:00
func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error {
nP, ok := p.(*nfqueuePacket)
if !ok {
return &ErrInvalidPacket{Err: errNotNFQueuePacket}
}
switch v {
case VerdictAccept:
return n.n.SetVerdict(nP.id, nfqueue.NfAccept)
case VerdictAcceptModify:
return n.n.SetVerdictModPacket(nP.id, nfqueue.NfAccept, newPacket)
case VerdictAcceptStream:
return n.n.SetVerdictWithConnMark(nP.id, nfqueue.NfAccept, nfqueueConnMarkAccept)
case VerdictDrop:
return n.n.SetVerdict(nP.id, nfqueue.NfDrop)
case VerdictDropStream:
return n.n.SetVerdictWithConnMark(nP.id, nfqueue.NfDrop, nfqueueConnMarkDrop)
default:
// Invalid verdict, ignore for now
return nil
}
}
func (n *nfqueuePacketIO) Close() error {
if n.rSet {
if n.ipt4 != nil {
_ = n.setupIpt(n.local, true)
} else {
_ = n.setupNft(n.local, true)
}
n.rSet = false
}
return n.n.Close()
}
func (n *nfqueuePacketIO) setupNft(local, remove bool) error {
var rules string
if local {
rules = nftRulesLocal
} else {
rules = nftRulesForward
}
var err error
if remove {
err = nftDelete(nftFamily, nftTable)
} else {
// Delete first to make sure no leftover rules
_ = nftDelete(nftFamily, nftTable)
err = nftAdd(rules)
}
if err != nil {
return err
}
return nil
}
2024-01-20 08:45:01 +08:00
func (n *nfqueuePacketIO) setupIpt(local, remove bool) error {
var rules []iptRule
if local {
rules = iptRulesLocal
} else {
rules = iptRulesForward
}
var err error
if remove {
err = iptsBatchDeleteIfExists([]*iptables.IPTables{n.ipt4, n.ipt6}, rules)
} else {
err = iptsBatchAppendUnique([]*iptables.IPTables{n.ipt4, n.ipt6}, rules)
}
if err != nil {
return err
}
return nil
}
var _ Packet = (*nfqueuePacket)(nil)
type nfqueuePacket struct {
id uint32
streamID uint32
data []byte
}
func (p *nfqueuePacket) StreamID() uint32 {
return p.streamID
}
func (p *nfqueuePacket) Data() []byte {
return p.data
}
func okBoolToInt(ok bool) int {
if ok {
return 0
} else {
return 1
}
}
func nftCheck() error {
_, err := exec.LookPath("nft")
if err != nil {
return err
}
return nil
}
func nftAdd(input string) error {
cmd := exec.Command("nft", "-f", "-")
cmd.Stdin = strings.NewReader(input)
return cmd.Run()
}
func nftDelete(family, table string) error {
cmd := exec.Command("nft", "delete", "table", family, table)
return cmd.Run()
}
2024-01-20 08:45:01 +08:00
type iptRule struct {
Table, Chain string
RuleSpec []string
}
func iptsBatchAppendUnique(ipts []*iptables.IPTables, rules []iptRule) error {
for _, r := range rules {
for _, ipt := range ipts {
err := ipt.AppendUnique(r.Table, r.Chain, r.RuleSpec...)
if err != nil {
return err
}
}
}
return nil
}
func iptsBatchDeleteIfExists(ipts []*iptables.IPTables, rules []iptRule) error {
for _, r := range rules {
for _, ipt := range ipts {
err := ipt.DeleteIfExists(r.Table, r.Chain, r.RuleSpec...)
if err != nil {
return err
}
}
}
return nil
}
func ctIDFromCtBytes(ct []byte) uint32 {
ctAttrs, err := netlink.UnmarshalAttributes(ct)
if err != nil {
return 0
}
for _, attr := range ctAttrs {
if attr.Type == 12 { // CTA_ID
return binary.BigEndian.Uint32(attr.Data)
}
}
return 0
}