OpenGFW/analyzer/udp/wireguard.go
Haruue 8d94400855
Add WireGuard analyzer (#41)
* feat: add WireGuard analyzer

* chore(wg): reduce map creating for non wg packets

* chore: import format

* docs: add wg usage

---------

Co-authored-by: Toby <tobyxdd@gmail.com>
2024-01-30 18:05:51 -08:00

218 lines
5.5 KiB
Go

package udp
import (
"container/ring"
"encoding/binary"
"slices"
"sync"
"github.com/apernet/OpenGFW/analyzer"
)
var (
_ analyzer.UDPAnalyzer = (*WireGuardAnalyzer)(nil)
_ analyzer.UDPStream = (*wireGuardUDPStream)(nil)
)
const (
wireguardUDPInvalidCountThreshold = 4
wireguardRememberedIndexCount = 6
wireguardPropKeyMessageType = "message_type"
)
const (
wireguardTypeHandshakeInitiation = 1
wireguardTypeHandshakeResponse = 2
wireguardTypeData = 4
wireguardTypeCookieReply = 3
)
const (
wireguardSizeHandshakeInitiation = 148
wireguardSizeHandshakeResponse = 92
wireguardMinSizePacketData = 32 // 16 bytes header + 16 bytes AEAD overhead
wireguardSizePacketCookieReply = 64
)
type WireGuardAnalyzer struct{}
func (a *WireGuardAnalyzer) Name() string {
return "wireguard"
}
func (a *WireGuardAnalyzer) Limit() int {
return 0
}
func (a *WireGuardAnalyzer) NewUDP(info analyzer.UDPInfo, logger analyzer.Logger) analyzer.UDPStream {
return newWireGuardUDPStream(logger)
}
type wireGuardUDPStream struct {
logger analyzer.Logger
invalidCount int
rememberedIndexes *ring.Ring
rememberedIndexesLock sync.RWMutex
}
func newWireGuardUDPStream(logger analyzer.Logger) *wireGuardUDPStream {
return &wireGuardUDPStream{
logger: logger,
rememberedIndexes: ring.New(wireguardRememberedIndexCount),
}
}
func (s *wireGuardUDPStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, done bool) {
m := s.parseWireGuardPacket(rev, data)
if m == nil {
s.invalidCount++
return nil, s.invalidCount >= wireguardUDPInvalidCountThreshold
}
s.invalidCount = 0 // Reset invalid count on valid WireGuard packet
messageType := m[wireguardPropKeyMessageType].(byte)
propUpdateType := analyzer.PropUpdateMerge
if messageType == wireguardTypeHandshakeInitiation {
propUpdateType = analyzer.PropUpdateReplace
}
return &analyzer.PropUpdate{
Type: propUpdateType,
M: m,
}, false
}
func (s *wireGuardUDPStream) Close(limited bool) *analyzer.PropUpdate {
return nil
}
func (s *wireGuardUDPStream) parseWireGuardPacket(rev bool, data []byte) analyzer.PropMap {
if len(data) < 4 {
return nil
}
if slices.Max(data[1:4]) != 0 {
return nil
}
messageType := data[0]
var propKey string
var propValue analyzer.PropMap
switch messageType {
case wireguardTypeHandshakeInitiation:
propKey = "handshake_initiation"
propValue = s.parseWireGuardHandshakeInitiation(rev, data)
case wireguardTypeHandshakeResponse:
propKey = "handshake_response"
propValue = s.parseWireGuardHandshakeResponse(rev, data)
case wireguardTypeData:
propKey = "packet_data"
propValue = s.parseWireGuardPacketData(rev, data)
case wireguardTypeCookieReply:
propKey = "packet_cookie_reply"
propValue = s.parseWireGuardPacketCookieReply(rev, data)
}
if propValue == nil {
return nil
}
m := make(analyzer.PropMap)
m[wireguardPropKeyMessageType] = messageType
m[propKey] = propValue
return m
}
func (s *wireGuardUDPStream) parseWireGuardHandshakeInitiation(rev bool, data []byte) analyzer.PropMap {
if len(data) != wireguardSizeHandshakeInitiation {
return nil
}
m := make(analyzer.PropMap)
senderIndex := binary.LittleEndian.Uint32(data[4:8])
m["sender_index"] = senderIndex
s.putSenderIndex(rev, senderIndex)
return m
}
func (s *wireGuardUDPStream) parseWireGuardHandshakeResponse(rev bool, data []byte) analyzer.PropMap {
if len(data) != wireguardSizeHandshakeResponse {
return nil
}
m := make(analyzer.PropMap)
senderIndex := binary.LittleEndian.Uint32(data[4:8])
m["sender_index"] = senderIndex
s.putSenderIndex(rev, senderIndex)
receiverIndex := binary.LittleEndian.Uint32(data[8:12])
m["receiver_index"] = receiverIndex
m["receiver_index_matched"] = s.matchReceiverIndex(rev, receiverIndex)
return m
}
func (s *wireGuardUDPStream) parseWireGuardPacketData(rev bool, data []byte) analyzer.PropMap {
if len(data) < wireguardMinSizePacketData {
return nil
}
if len(data)%16 != 0 {
// WireGuard zero padding the packet to make the length a multiple of 16
return nil
}
m := make(analyzer.PropMap)
receiverIndex := binary.LittleEndian.Uint32(data[4:8])
m["receiver_index"] = receiverIndex
m["receiver_index_matched"] = s.matchReceiverIndex(rev, receiverIndex)
m["counter"] = binary.LittleEndian.Uint64(data[8:16])
return m
}
func (s *wireGuardUDPStream) parseWireGuardPacketCookieReply(rev bool, data []byte) analyzer.PropMap {
if len(data) != wireguardSizePacketCookieReply {
return nil
}
m := make(analyzer.PropMap)
receiverIndex := binary.LittleEndian.Uint32(data[4:8])
m["receiver_index"] = receiverIndex
m["receiver_index_matched"] = s.matchReceiverIndex(rev, receiverIndex)
return m
}
type wireGuardIndex struct {
SenderIndex uint32
Reverse bool
}
func (s *wireGuardUDPStream) putSenderIndex(rev bool, senderIndex uint32) {
s.rememberedIndexesLock.Lock()
defer s.rememberedIndexesLock.Unlock()
s.rememberedIndexes.Value = &wireGuardIndex{
SenderIndex: senderIndex,
Reverse: rev,
}
s.rememberedIndexes = s.rememberedIndexes.Prev()
}
func (s *wireGuardUDPStream) matchReceiverIndex(rev bool, receiverIndex uint32) bool {
s.rememberedIndexesLock.RLock()
defer s.rememberedIndexesLock.RUnlock()
var found bool
ris := s.rememberedIndexes
for it := ris.Next(); it != ris; it = it.Next() {
if it.Value == nil {
break
}
wgidx := it.Value.(*wireGuardIndex)
if wgidx.Reverse == !rev && wgidx.SenderIndex == receiverIndex {
found = true
break
}
}
return found
}