mirror of
https://github.com/apernet/OpenGFW.git
synced 2024-11-11 04:49:22 +08:00
8d94400855
* feat: add WireGuard analyzer * chore(wg): reduce map creating for non wg packets * chore: import format * docs: add wg usage --------- Co-authored-by: Toby <tobyxdd@gmail.com>
218 lines
5.5 KiB
Go
218 lines
5.5 KiB
Go
package udp
|
|
|
|
import (
|
|
"container/ring"
|
|
"encoding/binary"
|
|
"slices"
|
|
"sync"
|
|
|
|
"github.com/apernet/OpenGFW/analyzer"
|
|
)
|
|
|
|
var (
|
|
_ analyzer.UDPAnalyzer = (*WireGuardAnalyzer)(nil)
|
|
_ analyzer.UDPStream = (*wireGuardUDPStream)(nil)
|
|
)
|
|
|
|
const (
|
|
wireguardUDPInvalidCountThreshold = 4
|
|
wireguardRememberedIndexCount = 6
|
|
wireguardPropKeyMessageType = "message_type"
|
|
)
|
|
|
|
const (
|
|
wireguardTypeHandshakeInitiation = 1
|
|
wireguardTypeHandshakeResponse = 2
|
|
wireguardTypeData = 4
|
|
wireguardTypeCookieReply = 3
|
|
)
|
|
|
|
const (
|
|
wireguardSizeHandshakeInitiation = 148
|
|
wireguardSizeHandshakeResponse = 92
|
|
wireguardMinSizePacketData = 32 // 16 bytes header + 16 bytes AEAD overhead
|
|
wireguardSizePacketCookieReply = 64
|
|
)
|
|
|
|
type WireGuardAnalyzer struct{}
|
|
|
|
func (a *WireGuardAnalyzer) Name() string {
|
|
return "wireguard"
|
|
}
|
|
|
|
func (a *WireGuardAnalyzer) Limit() int {
|
|
return 0
|
|
}
|
|
|
|
func (a *WireGuardAnalyzer) NewUDP(info analyzer.UDPInfo, logger analyzer.Logger) analyzer.UDPStream {
|
|
return newWireGuardUDPStream(logger)
|
|
}
|
|
|
|
type wireGuardUDPStream struct {
|
|
logger analyzer.Logger
|
|
invalidCount int
|
|
rememberedIndexes *ring.Ring
|
|
rememberedIndexesLock sync.RWMutex
|
|
}
|
|
|
|
func newWireGuardUDPStream(logger analyzer.Logger) *wireGuardUDPStream {
|
|
return &wireGuardUDPStream{
|
|
logger: logger,
|
|
rememberedIndexes: ring.New(wireguardRememberedIndexCount),
|
|
}
|
|
}
|
|
|
|
func (s *wireGuardUDPStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, done bool) {
|
|
m := s.parseWireGuardPacket(rev, data)
|
|
if m == nil {
|
|
s.invalidCount++
|
|
return nil, s.invalidCount >= wireguardUDPInvalidCountThreshold
|
|
}
|
|
s.invalidCount = 0 // Reset invalid count on valid WireGuard packet
|
|
messageType := m[wireguardPropKeyMessageType].(byte)
|
|
propUpdateType := analyzer.PropUpdateMerge
|
|
if messageType == wireguardTypeHandshakeInitiation {
|
|
propUpdateType = analyzer.PropUpdateReplace
|
|
}
|
|
return &analyzer.PropUpdate{
|
|
Type: propUpdateType,
|
|
M: m,
|
|
}, false
|
|
}
|
|
|
|
func (s *wireGuardUDPStream) Close(limited bool) *analyzer.PropUpdate {
|
|
return nil
|
|
}
|
|
|
|
func (s *wireGuardUDPStream) parseWireGuardPacket(rev bool, data []byte) analyzer.PropMap {
|
|
if len(data) < 4 {
|
|
return nil
|
|
}
|
|
if slices.Max(data[1:4]) != 0 {
|
|
return nil
|
|
}
|
|
|
|
messageType := data[0]
|
|
var propKey string
|
|
var propValue analyzer.PropMap
|
|
switch messageType {
|
|
case wireguardTypeHandshakeInitiation:
|
|
propKey = "handshake_initiation"
|
|
propValue = s.parseWireGuardHandshakeInitiation(rev, data)
|
|
case wireguardTypeHandshakeResponse:
|
|
propKey = "handshake_response"
|
|
propValue = s.parseWireGuardHandshakeResponse(rev, data)
|
|
case wireguardTypeData:
|
|
propKey = "packet_data"
|
|
propValue = s.parseWireGuardPacketData(rev, data)
|
|
case wireguardTypeCookieReply:
|
|
propKey = "packet_cookie_reply"
|
|
propValue = s.parseWireGuardPacketCookieReply(rev, data)
|
|
}
|
|
if propValue == nil {
|
|
return nil
|
|
}
|
|
|
|
m := make(analyzer.PropMap)
|
|
m[wireguardPropKeyMessageType] = messageType
|
|
m[propKey] = propValue
|
|
return m
|
|
}
|
|
|
|
func (s *wireGuardUDPStream) parseWireGuardHandshakeInitiation(rev bool, data []byte) analyzer.PropMap {
|
|
if len(data) != wireguardSizeHandshakeInitiation {
|
|
return nil
|
|
}
|
|
m := make(analyzer.PropMap)
|
|
|
|
senderIndex := binary.LittleEndian.Uint32(data[4:8])
|
|
m["sender_index"] = senderIndex
|
|
s.putSenderIndex(rev, senderIndex)
|
|
|
|
return m
|
|
}
|
|
|
|
func (s *wireGuardUDPStream) parseWireGuardHandshakeResponse(rev bool, data []byte) analyzer.PropMap {
|
|
if len(data) != wireguardSizeHandshakeResponse {
|
|
return nil
|
|
}
|
|
m := make(analyzer.PropMap)
|
|
|
|
senderIndex := binary.LittleEndian.Uint32(data[4:8])
|
|
m["sender_index"] = senderIndex
|
|
s.putSenderIndex(rev, senderIndex)
|
|
|
|
receiverIndex := binary.LittleEndian.Uint32(data[8:12])
|
|
m["receiver_index"] = receiverIndex
|
|
m["receiver_index_matched"] = s.matchReceiverIndex(rev, receiverIndex)
|
|
|
|
return m
|
|
}
|
|
|
|
func (s *wireGuardUDPStream) parseWireGuardPacketData(rev bool, data []byte) analyzer.PropMap {
|
|
if len(data) < wireguardMinSizePacketData {
|
|
return nil
|
|
}
|
|
if len(data)%16 != 0 {
|
|
// WireGuard zero padding the packet to make the length a multiple of 16
|
|
return nil
|
|
}
|
|
m := make(analyzer.PropMap)
|
|
|
|
receiverIndex := binary.LittleEndian.Uint32(data[4:8])
|
|
m["receiver_index"] = receiverIndex
|
|
m["receiver_index_matched"] = s.matchReceiverIndex(rev, receiverIndex)
|
|
|
|
m["counter"] = binary.LittleEndian.Uint64(data[8:16])
|
|
|
|
return m
|
|
}
|
|
|
|
func (s *wireGuardUDPStream) parseWireGuardPacketCookieReply(rev bool, data []byte) analyzer.PropMap {
|
|
if len(data) != wireguardSizePacketCookieReply {
|
|
return nil
|
|
}
|
|
m := make(analyzer.PropMap)
|
|
|
|
receiverIndex := binary.LittleEndian.Uint32(data[4:8])
|
|
m["receiver_index"] = receiverIndex
|
|
m["receiver_index_matched"] = s.matchReceiverIndex(rev, receiverIndex)
|
|
|
|
return m
|
|
}
|
|
|
|
type wireGuardIndex struct {
|
|
SenderIndex uint32
|
|
Reverse bool
|
|
}
|
|
|
|
func (s *wireGuardUDPStream) putSenderIndex(rev bool, senderIndex uint32) {
|
|
s.rememberedIndexesLock.Lock()
|
|
defer s.rememberedIndexesLock.Unlock()
|
|
|
|
s.rememberedIndexes.Value = &wireGuardIndex{
|
|
SenderIndex: senderIndex,
|
|
Reverse: rev,
|
|
}
|
|
s.rememberedIndexes = s.rememberedIndexes.Prev()
|
|
}
|
|
|
|
func (s *wireGuardUDPStream) matchReceiverIndex(rev bool, receiverIndex uint32) bool {
|
|
s.rememberedIndexesLock.RLock()
|
|
defer s.rememberedIndexesLock.RUnlock()
|
|
|
|
var found bool
|
|
ris := s.rememberedIndexes
|
|
for it := ris.Next(); it != ris; it = it.Next() {
|
|
if it.Value == nil {
|
|
break
|
|
}
|
|
wgidx := it.Value.(*wireGuardIndex)
|
|
if wgidx.Reverse == !rev && wgidx.SenderIndex == receiverIndex {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
return found
|
|
}
|