mirror of
https://github.com/apernet/OpenGFW.git
synced 2024-11-14 14:29:22 +08:00
120 lines
2.8 KiB
Go
120 lines
2.8 KiB
Go
|
package engine
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"runtime"
|
||
|
|
||
|
"github.com/apernet/OpenGFW/io"
|
||
|
"github.com/apernet/OpenGFW/ruleset"
|
||
|
|
||
|
"github.com/google/gopacket"
|
||
|
"github.com/google/gopacket/layers"
|
||
|
)
|
||
|
|
||
|
var _ Engine = (*engine)(nil)
|
||
|
|
||
|
type engine struct {
|
||
|
logger Logger
|
||
|
ioList []io.PacketIO
|
||
|
workers []*worker
|
||
|
}
|
||
|
|
||
|
func NewEngine(config Config) (Engine, error) {
|
||
|
workerCount := config.Workers
|
||
|
if workerCount <= 0 {
|
||
|
workerCount = runtime.NumCPU()
|
||
|
}
|
||
|
var err error
|
||
|
workers := make([]*worker, workerCount)
|
||
|
for i := range workers {
|
||
|
workers[i], err = newWorker(workerConfig{
|
||
|
ID: i,
|
||
|
ChanSize: config.WorkerQueueSize,
|
||
|
Logger: config.Logger,
|
||
|
Ruleset: config.Ruleset,
|
||
|
TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal,
|
||
|
TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn,
|
||
|
UDPMaxStreams: config.WorkerUDPMaxStreams,
|
||
|
})
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
return &engine{
|
||
|
logger: config.Logger,
|
||
|
ioList: config.IOs,
|
||
|
workers: workers,
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
func (e *engine) UpdateRuleset(r ruleset.Ruleset) error {
|
||
|
for _, w := range e.workers {
|
||
|
if err := w.UpdateRuleset(r); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (e *engine) Run(ctx context.Context) error {
|
||
|
ioCtx, ioCancel := context.WithCancel(ctx)
|
||
|
defer ioCancel() // Stop workers & IOs
|
||
|
|
||
|
// Start workers
|
||
|
for _, w := range e.workers {
|
||
|
go w.Run(ioCtx)
|
||
|
}
|
||
|
|
||
|
// Register callbacks
|
||
|
errChan := make(chan error, len(e.ioList))
|
||
|
for _, i := range e.ioList {
|
||
|
ioEntry := i // Make sure dispatch() uses the correct ioEntry
|
||
|
err := ioEntry.Register(ioCtx, func(p io.Packet, err error) bool {
|
||
|
if err != nil {
|
||
|
errChan <- err
|
||
|
return false
|
||
|
}
|
||
|
return e.dispatch(ioEntry, p)
|
||
|
})
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Block until IO errors or context is cancelled
|
||
|
select {
|
||
|
case err := <-errChan:
|
||
|
return err
|
||
|
case <-ctx.Done():
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// dispatch dispatches a packet to a worker.
|
||
|
// This must be safe for concurrent use, as it may be called from multiple IOs.
|
||
|
func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool {
|
||
|
data := p.Data()
|
||
|
ipVersion := data[0] >> 4
|
||
|
var layerType gopacket.LayerType
|
||
|
if ipVersion == 4 {
|
||
|
layerType = layers.LayerTypeIPv4
|
||
|
} else if ipVersion == 6 {
|
||
|
layerType = layers.LayerTypeIPv6
|
||
|
} else {
|
||
|
// Unsupported network layer
|
||
|
_ = ioEntry.SetVerdict(p, io.VerdictAcceptStream, nil)
|
||
|
return true
|
||
|
}
|
||
|
// Load balance by stream ID
|
||
|
index := p.StreamID() % uint32(len(e.workers))
|
||
|
packet := gopacket.NewPacket(data, layerType, gopacket.DecodeOptions{Lazy: true, NoCopy: true})
|
||
|
e.workers[index].Feed(&workerPacket{
|
||
|
StreamID: p.StreamID(),
|
||
|
Packet: packet,
|
||
|
SetVerdict: func(v io.Verdict, b []byte) error {
|
||
|
return ioEntry.SetVerdict(p, v, b)
|
||
|
},
|
||
|
})
|
||
|
return true
|
||
|
}
|