refactor: Improve parsing docs

Reveal intentions by:
- extracting magic numbers into constants
- changing function names with >1 responsibilities
- documenting non-obvious behaviors.
This commit is contained in:
macie 2024-03-11 20:35:01 +01:00
parent 4257788f33
commit 3bd02ed46e
No known key found for this signature in database
3 changed files with 123 additions and 57 deletions

View File

@ -5,7 +5,26 @@ import (
"github.com/apernet/OpenGFW/analyzer/utils" "github.com/apernet/OpenGFW/analyzer/utils"
) )
func ParseTLSClientHello(chBuf *utils.ByteBuffer) analyzer.PropMap { // TLS record types.
const (
RecordTypeHandshake = 0x16
)
// TLS handshake message types.
const (
TypeClientHello = 0x01
TypeServerHello = 0x02
)
// TLS extension numbers.
const (
extServerName = 0x0000
extALPN = 0x0010
extSupportedVersions = 0x002b
extEncryptedClientHello = 0xfe0d
)
func ParseTLSClientHelloMsgData(chBuf *utils.ByteBuffer) analyzer.PropMap {
var ok bool var ok bool
m := make(analyzer.PropMap) m := make(analyzer.PropMap)
// Version, random & session ID length combined are within 35 bytes, // Version, random & session ID length combined are within 35 bytes,
@ -76,7 +95,7 @@ func ParseTLSClientHello(chBuf *utils.ByteBuffer) analyzer.PropMap {
return m return m
} }
func ParseTLSServerHello(shBuf *utils.ByteBuffer) analyzer.PropMap { func ParseTLSServerHelloMsgData(shBuf *utils.ByteBuffer) analyzer.PropMap {
var ok bool var ok bool
m := make(analyzer.PropMap) m := make(analyzer.PropMap)
// Version, random & session ID length combined are within 35 bytes, // Version, random & session ID length combined are within 35 bytes,
@ -133,7 +152,7 @@ func ParseTLSServerHello(shBuf *utils.ByteBuffer) analyzer.PropMap {
func parseTLSExtensions(extType uint16, extDataBuf *utils.ByteBuffer, m analyzer.PropMap) bool { func parseTLSExtensions(extType uint16, extDataBuf *utils.ByteBuffer, m analyzer.PropMap) bool {
switch extType { switch extType {
case 0x0000: // SNI case extServerName:
ok := extDataBuf.Skip(2) // Ignore list length, we only care about the first entry for now ok := extDataBuf.Skip(2) // Ignore list length, we only care about the first entry for now
if !ok { if !ok {
// Not enough data for list length // Not enough data for list length
@ -154,7 +173,7 @@ func parseTLSExtensions(extType uint16, extDataBuf *utils.ByteBuffer, m analyzer
// Not enough data for SNI // Not enough data for SNI
return false return false
} }
case 0x0010: // ALPN case extALPN:
ok := extDataBuf.Skip(2) // Ignore list length, as we read until the end ok := extDataBuf.Skip(2) // Ignore list length, as we read until the end
if !ok { if !ok {
// Not enough data for list length // Not enough data for list length
@ -175,7 +194,7 @@ func parseTLSExtensions(extType uint16, extDataBuf *utils.ByteBuffer, m analyzer
alpnList = append(alpnList, alpn) alpnList = append(alpnList, alpn)
} }
m["alpn"] = alpnList m["alpn"] = alpnList
case 0x002b: // Supported Versions case extSupportedVersions:
if extDataBuf.Len() == 2 { if extDataBuf.Len() == 2 {
// Server only selects one version // Server only selects one version
m["supported_versions"], _ = extDataBuf.GetUint16(false, true) m["supported_versions"], _ = extDataBuf.GetUint16(false, true)
@ -197,7 +216,7 @@ func parseTLSExtensions(extType uint16, extDataBuf *utils.ByteBuffer, m analyzer
} }
m["supported_versions"] = versions m["supported_versions"] = versions
} }
case 0xfe0d: // ECH case extEncryptedClientHello:
// We can't parse ECH for now, just set a flag // We can't parse ECH for now, just set a flag
m["ech"] = true m["ech"] = true
} }

View File

@ -44,12 +44,12 @@ type tlsStream struct {
func newTLSStream(logger analyzer.Logger) *tlsStream { func newTLSStream(logger analyzer.Logger) *tlsStream {
s := &tlsStream{logger: logger, reqBuf: &utils.ByteBuffer{}, respBuf: &utils.ByteBuffer{}} s := &tlsStream{logger: logger, reqBuf: &utils.ByteBuffer{}, respBuf: &utils.ByteBuffer{}}
s.reqLSM = utils.NewLinearStateMachine( s.reqLSM = utils.NewLinearStateMachine(
s.tlsClientHelloSanityCheck, s.tlsClientHelloPreprocess,
s.parseClientHello, s.parseClientHelloData,
) )
s.respLSM = utils.NewLinearStateMachine( s.respLSM = utils.NewLinearStateMachine(
s.tlsServerHelloSanityCheck, s.tlsServerHelloPreprocess,
s.parseServerHello, s.parseServerHelloData,
) )
return s return s
} }
@ -89,61 +89,105 @@ func (s *tlsStream) Feed(rev, start, end bool, skip int, data []byte) (u *analyz
return update, cancelled || (s.reqDone && s.respDone) return update, cancelled || (s.reqDone && s.respDone)
} }
func (s *tlsStream) tlsClientHelloSanityCheck() utils.LSMAction { // tlsClientHelloPreprocess validates ClientHello message.
data, ok := s.reqBuf.Get(9, true) //
// During validation, message header and first handshake header may be removed
// from `s.reqBuf`.
func (s *tlsStream) tlsClientHelloPreprocess() utils.LSMAction {
// headers size: content type (1 byte) + legacy protocol version (2 bytes) +
// + content length (2 bytes) + message type (1 byte) +
// + handshake length (3 bytes)
const headersSize = 9
// minimal data size: protocol version (2 bytes) + random (32 bytes) +
// + session ID (1 byte) + cipher suites (4 bytes) +
// + compression methods (2 bytes) + no extensions
const minDataSize = 41
header, ok := s.reqBuf.Get(headersSize, true)
if !ok { if !ok {
// not a full header yet
return utils.LSMActionPause return utils.LSMActionPause
} }
if data[0] != 0x16 || data[5] != 0x01 {
// Not a TLS handshake, or not a client hello if header[0] != internal.RecordTypeHandshake || header[5] != internal.TypeClientHello {
return utils.LSMActionCancel return utils.LSMActionCancel
} }
s.clientHelloLen = int(data[6])<<16 | int(data[7])<<8 | int(data[8])
if s.clientHelloLen < 41 { s.clientHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8])
// 2 (Protocol Version) + if s.clientHelloLen < minDataSize {
// 32 (Random) +
// 1 (Session ID Length) +
// 2 (Cipher Suites Length) +_ws.col.protocol == "TLSv1.3"
// 2 (Cipher Suite) +
// 1 (Compression Methods Length) +
// 1 (Compression Method) +
// No extensions
// This should be the bare minimum for a client hello
return utils.LSMActionCancel return utils.LSMActionCancel
} }
// TODO: something is missing. See:
// const messageHeaderSize = 4
// fullMessageLen := int(header[3])<<8 | int(header[4])
// msgNo := fullMessageLen / int(messageHeaderSize+s.serverHelloLen)
// if msgNo != 1 {
// // what here?
// }
// if messageNo != int(messageNo) {
// // what here?
// }
return utils.LSMActionNext return utils.LSMActionNext
} }
func (s *tlsStream) tlsServerHelloSanityCheck() utils.LSMAction { // tlsServerHelloPreprocess validates ServerHello message.
data, ok := s.respBuf.Get(9, true) //
// During validation, message header and first handshake header may be removed
// from `s.reqBuf`.
func (s *tlsStream) tlsServerHelloPreprocess() utils.LSMAction {
// header size: content type (1 byte) + legacy protocol version (2 byte) +
// + content length (2 byte) + message type (1 byte) +
// + handshake length (3 byte)
const headersSize = 9
// minimal data size: server version (2 byte) + random (32 byte) +
// + session ID (>=1 byte) + cipher suite (2 byte) +
// + compression method (1 byte) + no extensions
const minDataSize = 38
header, ok := s.respBuf.Get(headersSize, true)
if !ok { if !ok {
// not a full header yet
return utils.LSMActionPause return utils.LSMActionPause
} }
if data[0] != 0x16 || data[5] != 0x02 {
// Not a TLS handshake, or not a server hello if header[0] != internal.RecordTypeHandshake || header[5] != internal.TypeServerHello {
return utils.LSMActionCancel return utils.LSMActionCancel
} }
s.serverHelloLen = int(data[6])<<16 | int(data[7])<<8 | int(data[8])
if s.serverHelloLen < 38 { s.serverHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8])
// 2 (Protocol Version) + if s.serverHelloLen < minDataSize {
// 32 (Random) +
// 1 (Session ID Length) +
// 2 (Cipher Suite) +
// 1 (Compression Method) +
// No extensions
// This should be the bare minimum for a server hello
return utils.LSMActionCancel return utils.LSMActionCancel
} }
// TODO: something is missing. See example:
// const messageHeaderSize = 4
// fullMessageLen := int(header[3])<<8 | int(header[4])
// msgNo := fullMessageLen / int(messageHeaderSize+s.serverHelloLen)
// if msgNo != 1 {
// // what here?
// }
// if messageNo != int(messageNo) {
// // what here?
// }
return utils.LSMActionNext return utils.LSMActionNext
} }
func (s *tlsStream) parseClientHello() utils.LSMAction { // parseClientHelloData converts valid ClientHello message data (without
// headers) into `analyzer.PropMap`.
//
// Parsing error may leave `s.reqBuf` in an unusable state.
func (s *tlsStream) parseClientHelloData() utils.LSMAction {
chBuf, ok := s.reqBuf.GetSubBuffer(s.clientHelloLen, true) chBuf, ok := s.reqBuf.GetSubBuffer(s.clientHelloLen, true)
if !ok { if !ok {
// Not a full client hello yet // Not a full client hello yet
return utils.LSMActionPause return utils.LSMActionPause
} }
m := internal.ParseTLSClientHello(chBuf) m := internal.ParseTLSClientHelloMsgData(chBuf)
if m == nil { if m == nil {
return utils.LSMActionCancel return utils.LSMActionCancel
} else { } else {
@ -153,13 +197,17 @@ func (s *tlsStream) parseClientHello() utils.LSMAction {
} }
} }
func (s *tlsStream) parseServerHello() utils.LSMAction { // parseServerHelloData converts valid ServerHello message data (without
// headers) into `analyzer.PropMap`.
//
// Parsing error may leave `s.respBuf` in an unusable state.
func (s *tlsStream) parseServerHelloData() utils.LSMAction {
shBuf, ok := s.respBuf.GetSubBuffer(s.serverHelloLen, true) shBuf, ok := s.respBuf.GetSubBuffer(s.serverHelloLen, true)
if !ok { if !ok {
// Not a full server hello yet // Not a full server hello yet
return utils.LSMActionPause return utils.LSMActionPause
} }
m := internal.ParseTLSServerHello(shBuf) m := internal.ParseTLSServerHelloMsgData(shBuf)
if m == nil { if m == nil {
return utils.LSMActionCancel return utils.LSMActionCancel
} else { } else {

View File

@ -36,41 +36,40 @@ type quicStream struct {
} }
func (s *quicStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, done bool) { func (s *quicStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, done bool) {
// minimal data size: protocol version (2 bytes) + random (32 bytes) +
// + session ID (1 byte) + cipher suites (4 bytes) +
// + compression methods (2 bytes) + no extensions
const minDataSize = 41
if rev { if rev {
// We don't support server direction for now // We don't support server direction for now
s.invalidCount++ s.invalidCount++
return nil, s.invalidCount >= quicInvalidCountThreshold return nil, s.invalidCount >= quicInvalidCountThreshold
} }
pl, err := quic.ReadCryptoPayload(data) pl, err := quic.ReadCryptoPayload(data)
if err != nil || len(pl) < 4 { if err != nil || len(pl) < 4 { // FIXME: isn't length checked inside quic.ReadCryptoPayload? Also, what about error handling?
s.invalidCount++ s.invalidCount++
return nil, s.invalidCount >= quicInvalidCountThreshold return nil, s.invalidCount >= quicInvalidCountThreshold
} }
// Should be a TLS client hello
if pl[0] != 0x01 { if pl[0] != internal.TypeClientHello {
// Not a client hello
s.invalidCount++ s.invalidCount++
return nil, s.invalidCount >= quicInvalidCountThreshold return nil, s.invalidCount >= quicInvalidCountThreshold
} }
chLen := int(pl[1])<<16 | int(pl[2])<<8 | int(pl[3]) chLen := int(pl[1])<<16 | int(pl[2])<<8 | int(pl[3])
if chLen < 41 { if chLen < minDataSize {
// 2 (Protocol Version) +
// 32 (Random) +
// 1 (Session ID Length) +
// 2 (Cipher Suites Length) +_ws.col.protocol == "TLSv1.3"
// 2 (Cipher Suite) +
// 1 (Compression Methods Length) +
// 1 (Compression Method) +
// No extensions
// This should be the bare minimum for a client hello
s.invalidCount++ s.invalidCount++
return nil, s.invalidCount >= quicInvalidCountThreshold return nil, s.invalidCount >= quicInvalidCountThreshold
} }
m := internal.ParseTLSClientHello(&utils.ByteBuffer{Buf: pl[4:]})
m := internal.ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: pl[4:]})
if m == nil { if m == nil {
s.invalidCount++ s.invalidCount++
return nil, s.invalidCount >= quicInvalidCountThreshold return nil, s.invalidCount >= quicInvalidCountThreshold
} }
return &analyzer.PropUpdate{ return &analyzer.PropUpdate{
Type: analyzer.PropUpdateMerge, Type: analyzer.PropUpdateMerge,
M: analyzer.PropMap{"req": m}, M: analyzer.PropMap{"req": m},