diff --git a/cmd/root.go b/cmd/root.go index 288e3d7..1ccf025 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -43,6 +43,7 @@ var logger *zap.Logger // Flags var ( cfgFile string + pcapFile string logLevel string logFormat string ) @@ -118,6 +119,7 @@ func init() { func initFlags() { rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "config file") + rootCmd.PersistentFlags().StringVarP(&pcapFile, "pcap", "p", "", "pcap file (optional)") rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", envOrDefaultString(appLogLevelEnv, "info"), "log level") rootCmd.PersistentFlags().StringVarP(&logFormat, "log-format", "f", envOrDefaultString(appLogFormatEnv, "console"), "log format") } @@ -167,6 +169,7 @@ type cliConfig struct { IO cliConfigIO `mapstructure:"io"` Workers cliConfigWorkers `mapstructure:"workers"` Ruleset cliConfigRuleset `mapstructure:"ruleset"` + Replay cliConfigReplay `mapstructure:"replay"` } type cliConfigIO struct { @@ -177,6 +180,10 @@ type cliConfigIO struct { RST bool `mapstructure:"rst"` } +type cliConfigReplay struct { + Realtime bool `mapstructure:"realtime"` +} + type cliConfigWorkers struct { Count int `mapstructure:"count"` QueueSize int `mapstructure:"queueSize"` @@ -197,17 +204,30 @@ func (c *cliConfig) fillLogger(config *engine.Config) error { } func (c *cliConfig) fillIO(config *engine.Config) error { - nfio, err := io.NewNFQueuePacketIO(io.NFQueuePacketIOConfig{ - QueueSize: c.IO.QueueSize, - ReadBuffer: c.IO.ReadBuffer, - WriteBuffer: c.IO.WriteBuffer, - Local: c.IO.Local, - RST: c.IO.RST, - }) + var ioImpl io.PacketIO + var err error + if pcapFile != "" { + // Setup IO for pcap file replay + logger.Info("replaying from pcap file", zap.String("pcap file", pcapFile)) + ioImpl, err = io.NewPcapPacketIO(io.PcapPacketIOConfig{ + PcapFile: pcapFile, + Realtime: c.Replay.Realtime, + }) + } else { + // Setup IO for nfqueue + ioImpl, err = io.NewNFQueuePacketIO(io.NFQueuePacketIOConfig{ + QueueSize: c.IO.QueueSize, + ReadBuffer: c.IO.ReadBuffer, + WriteBuffer: c.IO.WriteBuffer, + Local: c.IO.Local, + RST: c.IO.RST, + }) + } + if err != nil { return configError{Field: "io", Err: err} } - config.IO = nfio + config.IO = ioImpl return nil } diff --git a/engine/engine.go b/engine/engine.go index 56f5ed3..1270efb 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -58,12 +58,17 @@ func (e *engine) UpdateRuleset(r ruleset.Ruleset) error { } func (e *engine) Run(ctx context.Context) error { + workerCtx, workerCancel := context.WithCancel(ctx) + defer workerCancel() // Stop workers + + // Register IO shutdown ioCtx, ioCancel := context.WithCancel(ctx) - defer ioCancel() // Stop workers & IO + e.io.SetCancelFunc(ioCancel) + defer ioCancel() // Stop IO // Start workers for _, w := range e.workers { - go w.Run(ioCtx) + go w.Run(workerCtx) } // Register IO callback @@ -85,6 +90,8 @@ func (e *engine) Run(ctx context.Context) error { return err case <-ctx.Done(): return nil + case <-ioCtx.Done(): + return nil } } diff --git a/io/interface.go b/io/interface.go index af7e1e7..f996789 100644 --- a/io/interface.go +++ b/io/interface.go @@ -48,6 +48,9 @@ type PacketIO interface { ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) // Close closes the packet IO. Close() error + // SetCancelFunc gives packet IO access to context cancel function, enabling it to + // trigger a shutdown + SetCancelFunc(cancelFunc context.CancelFunc) error } type ErrInvalidPacket struct { diff --git a/io/nfqueue.go b/io/nfqueue.go index e84a0bb..f1a64df 100644 --- a/io/nfqueue.go +++ b/io/nfqueue.go @@ -281,6 +281,11 @@ func (n *nfqueuePacketIO) Close() error { return n.n.Close() } +// nfqueue IO does not issue shutdown +func (n *nfqueuePacketIO) SetCancelFunc(cancelFunc context.CancelFunc) error { + return nil +} + func (n *nfqueuePacketIO) setupNft(local, rst, remove bool) error { rules, err := generateNftRules(local, rst) if err != nil { diff --git a/io/pcap.go b/io/pcap.go new file mode 100644 index 0000000..9801f9c --- /dev/null +++ b/io/pcap.go @@ -0,0 +1,136 @@ +package io + +import ( + "context" + "hash/crc32" + "io" + "net" + "os" + "sort" + "strings" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/pcapgo" +) + +var _ PacketIO = (*pcapPacketIO)(nil) + +type pcapPacketIO struct { + pcapFile io.ReadCloser + pcap *pcapgo.Reader + timeOffset *time.Duration + ioCancel context.CancelFunc + config PcapPacketIOConfig + + dialer *net.Dialer +} + +type PcapPacketIOConfig struct { + PcapFile string + Realtime bool +} + +func NewPcapPacketIO(config PcapPacketIOConfig) (PacketIO, error) { + pcapFile, err := os.Open(config.PcapFile) + if err != nil { + return nil, err + } + + handle, err := pcapgo.NewReader(pcapFile) + if err != nil { + return nil, err + } + + return &pcapPacketIO{ + pcapFile: pcapFile, + pcap: handle, + timeOffset: nil, + ioCancel: nil, + config: config, + dialer: &net.Dialer{}, + }, nil +} + +func (p *pcapPacketIO) Register(ctx context.Context, cb PacketCallback) error { + go func() { + packetSource := gopacket.NewPacketSource(p.pcap, p.pcap.LinkType()) + for packet := range packetSource.Packets() { + p.wait(packet) + + networkLayer := packet.NetworkLayer() + if networkLayer != nil { + src, dst := networkLayer.NetworkFlow().Endpoints() + endpoints := []string{src.String(), dst.String()} + sort.Strings(endpoints) + id := crc32.Checksum([]byte(strings.Join(endpoints, ",")), crc32.IEEETable) + + cb(&pcapPacket{ + streamID: id, + timestamp: packet.Metadata().Timestamp, + data: packet.LinkLayer().LayerPayload(), + }, nil) + } + } + // Give the workers a chance to finish everything + time.Sleep(time.Second) + // Stop the engine when all packets are finished + p.ioCancel() + }() + + return nil +} + +// A normal dialer is sufficient as pcap IO does not mess up with the networking +func (p *pcapPacketIO) ProtectedDialContext(ctx context.Context, network, address string) (net.Conn, error) { + return p.dialer.DialContext(ctx, network, address) +} + +func (p *pcapPacketIO) SetVerdict(pkt Packet, v Verdict, newPacket []byte) error { + return nil +} + +func (p *pcapPacketIO) SetCancelFunc(cancelFunc context.CancelFunc) error { + p.ioCancel = cancelFunc + return nil +} + +func (p *pcapPacketIO) Close() error { + return p.pcapFile.Close() +} + +// Intentionally slow down the replay +// In realtime mode, this is to match the timestamps in the capture +func (p *pcapPacketIO) wait(packet gopacket.Packet) { + if !p.config.Realtime { + return + } + + if p.timeOffset == nil { + offset := time.Since(packet.Metadata().Timestamp) + p.timeOffset = &offset + } else { + t := time.Until(packet.Metadata().Timestamp.Add(*p.timeOffset)) + time.Sleep(t) + } +} + +var _ Packet = (*pcapPacket)(nil) + +type pcapPacket struct { + streamID uint32 + timestamp time.Time + data []byte +} + +func (p *pcapPacket) StreamID() uint32 { + return p.streamID +} + +func (p *pcapPacket) Timestamp() time.Time { + return p.timestamp +} + +func (p *pcapPacket) Data() []byte { + return p.data +}