OpenGFW/engine/worker.go
2024-01-19 16:45:01 -08:00

183 lines
4.7 KiB
Go

package engine
import (
"context"
"github.com/apernet/OpenGFW/io"
"github.com/apernet/OpenGFW/ruleset"
"github.com/bwmarrin/snowflake"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/gopacket/reassembly"
)
const (
defaultChanSize = 64
defaultTCPMaxBufferedPagesTotal = 4096
defaultTCPMaxBufferedPagesPerConnection = 64
defaultUDPMaxStreams = 4096
)
type workerPacket struct {
StreamID uint32
Packet gopacket.Packet
SetVerdict func(io.Verdict, []byte) error
}
type worker struct {
id int
packetChan chan *workerPacket
logger Logger
tcpStreamFactory *tcpStreamFactory
tcpStreamPool *reassembly.StreamPool
tcpAssembler *reassembly.Assembler
udpStreamFactory *udpStreamFactory
udpStreamManager *udpStreamManager
modSerializeBuffer gopacket.SerializeBuffer
}
type workerConfig struct {
ID int
ChanSize int
Logger Logger
Ruleset ruleset.Ruleset
TCPMaxBufferedPagesTotal int
TCPMaxBufferedPagesPerConn int
UDPMaxStreams int
}
func (c *workerConfig) fillDefaults() {
if c.ChanSize <= 0 {
c.ChanSize = defaultChanSize
}
if c.TCPMaxBufferedPagesTotal <= 0 {
c.TCPMaxBufferedPagesTotal = defaultTCPMaxBufferedPagesTotal
}
if c.TCPMaxBufferedPagesPerConn <= 0 {
c.TCPMaxBufferedPagesPerConn = defaultTCPMaxBufferedPagesPerConnection
}
if c.UDPMaxStreams <= 0 {
c.UDPMaxStreams = defaultUDPMaxStreams
}
}
func newWorker(config workerConfig) (*worker, error) {
config.fillDefaults()
sfNode, err := snowflake.NewNode(int64(config.ID))
if err != nil {
return nil, err
}
tcpSF := &tcpStreamFactory{
WorkerID: config.ID,
Logger: config.Logger,
Node: sfNode,
Ruleset: config.Ruleset,
}
tcpStreamPool := reassembly.NewStreamPool(tcpSF)
tcpAssembler := reassembly.NewAssembler(tcpStreamPool)
tcpAssembler.MaxBufferedPagesTotal = config.TCPMaxBufferedPagesTotal
tcpAssembler.MaxBufferedPagesPerConnection = config.TCPMaxBufferedPagesPerConn
udpSF := &udpStreamFactory{
WorkerID: config.ID,
Logger: config.Logger,
Node: sfNode,
Ruleset: config.Ruleset,
}
udpSM, err := newUDPStreamManager(udpSF, config.UDPMaxStreams)
if err != nil {
return nil, err
}
return &worker{
id: config.ID,
packetChan: make(chan *workerPacket, config.ChanSize),
logger: config.Logger,
tcpStreamFactory: tcpSF,
tcpStreamPool: tcpStreamPool,
tcpAssembler: tcpAssembler,
udpStreamFactory: udpSF,
udpStreamManager: udpSM,
modSerializeBuffer: gopacket.NewSerializeBuffer(),
}, nil
}
func (w *worker) Feed(p *workerPacket) {
w.packetChan <- p
}
func (w *worker) Run(ctx context.Context) {
w.logger.WorkerStart(w.id)
defer w.logger.WorkerStop(w.id)
for {
select {
case <-ctx.Done():
return
case wPkt := <-w.packetChan:
if wPkt == nil {
// Closed
return
}
v, b := w.handle(wPkt.StreamID, wPkt.Packet)
_ = wPkt.SetVerdict(v, b)
}
}
}
func (w *worker) UpdateRuleset(r ruleset.Ruleset) error {
return w.tcpStreamFactory.UpdateRuleset(r)
}
func (w *worker) handle(streamID uint32, p gopacket.Packet) (io.Verdict, []byte) {
netLayer, trLayer := p.NetworkLayer(), p.TransportLayer()
if netLayer == nil || trLayer == nil {
// Invalid packet
return io.VerdictAccept, nil
}
ipFlow := netLayer.NetworkFlow()
switch tr := trLayer.(type) {
case *layers.TCP:
return w.handleTCP(ipFlow, p.Metadata(), tr), nil
case *layers.UDP:
v, modPayload := w.handleUDP(streamID, ipFlow, tr)
if v == io.VerdictAcceptModify && modPayload != nil {
tr.Payload = modPayload
_ = tr.SetNetworkLayerForChecksum(netLayer)
_ = w.modSerializeBuffer.Clear()
err := gopacket.SerializePacket(w.modSerializeBuffer,
gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}, p)
if err != nil {
// Just accept without modification for now
return io.VerdictAccept, nil
}
return v, w.modSerializeBuffer.Bytes()
}
return v, nil
default:
// Unsupported protocol
return io.VerdictAccept, nil
}
}
func (w *worker) handleTCP(ipFlow gopacket.Flow, pMeta *gopacket.PacketMetadata, tcp *layers.TCP) io.Verdict {
ctx := &tcpContext{
PacketMetadata: pMeta,
Verdict: tcpVerdictAccept,
}
w.tcpAssembler.AssembleWithContext(ipFlow, tcp, ctx)
return io.Verdict(ctx.Verdict)
}
func (w *worker) handleUDP(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP) (io.Verdict, []byte) {
ctx := &udpContext{
Verdict: udpVerdictAccept,
}
w.udpStreamManager.MatchWithContext(streamID, ipFlow, udp, ctx)
return io.Verdict(ctx.Verdict), ctx.Packet
}