diff --git a/.github/workflows/check.yaml b/.github/workflows/check.yaml new file mode 100644 index 0000000..ac9e66d --- /dev/null +++ b/.github/workflows/check.yaml @@ -0,0 +1,47 @@ +name: Quality check +on: + push: + branches: + - "*" + pull_request: + +permissions: + contents: read + +jobs: + static-analysis: + name: Static analysis + runs-on: ubuntu-latest + timeout-minutes: 10 + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: 'stable' + + - run: go vet ./... + + - name: staticcheck + uses: dominikh/staticcheck-action@v1.3.0 + with: + install-go: false + + tests: + name: Tests + runs-on: ubuntu-latest + timeout-minutes: 10 + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: 'stable' + + - run: go test ./... diff --git a/README.ja.md b/README.ja.md index 2247118..3d6af93 100644 --- a/README.ja.md +++ b/README.ja.md @@ -1,5 +1,6 @@ # ![OpenGFW](docs/logo.png) +[![Quality check status](https://github.com/apernet/OpenGFW/actions/workflows/check.yml/badge.svg)](https://github.com/apernet/OpenGFW/actions/workflows/check.yml) [![License][1]][2] [1]: https://img.shields.io/badge/License-MPL_2.0-brightgreen.svg diff --git a/README.md b/README.md index 36e5817..1466e42 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # ![OpenGFW](docs/logo.png) +[![Quality check status](https://github.com/apernet/OpenGFW/actions/workflows/check.yml/badge.svg)](https://github.com/apernet/OpenGFW/actions/workflows/check.yml) [![License][1]][2] [1]: https://img.shields.io/badge/License-MPL_2.0-brightgreen.svg diff --git a/README.zh.md b/README.zh.md index 4639e71..74e4a8e 100644 --- a/README.zh.md +++ b/README.zh.md @@ -1,5 +1,6 @@ # ![OpenGFW](docs/logo.png) +[![Quality check status](https://github.com/apernet/OpenGFW/actions/workflows/check.yml/badge.svg)](https://github.com/apernet/OpenGFW/actions/workflows/check.yml) [![License][1]][2] [1]: https://img.shields.io/badge/License-MPL_2.0-brightgreen.svg diff --git a/analyzer/internal/tls.go b/analyzer/internal/tls.go index 810780a..c25605f 100644 --- a/analyzer/internal/tls.go +++ b/analyzer/internal/tls.go @@ -5,7 +5,26 @@ import ( "github.com/apernet/OpenGFW/analyzer/utils" ) -func ParseTLSClientHello(chBuf *utils.ByteBuffer) analyzer.PropMap { +// TLS record types. +const ( + RecordTypeHandshake = 0x16 +) + +// TLS handshake message types. +const ( + TypeClientHello = 0x01 + TypeServerHello = 0x02 +) + +// TLS extension numbers. +const ( + extServerName = 0x0000 + extALPN = 0x0010 + extSupportedVersions = 0x002b + extEncryptedClientHello = 0xfe0d +) + +func ParseTLSClientHelloMsgData(chBuf *utils.ByteBuffer) analyzer.PropMap { var ok bool m := make(analyzer.PropMap) // Version, random & session ID length combined are within 35 bytes, @@ -76,7 +95,7 @@ func ParseTLSClientHello(chBuf *utils.ByteBuffer) analyzer.PropMap { return m } -func ParseTLSServerHello(shBuf *utils.ByteBuffer) analyzer.PropMap { +func ParseTLSServerHelloMsgData(shBuf *utils.ByteBuffer) analyzer.PropMap { var ok bool m := make(analyzer.PropMap) // Version, random & session ID length combined are within 35 bytes, @@ -133,7 +152,7 @@ func ParseTLSServerHello(shBuf *utils.ByteBuffer) analyzer.PropMap { func parseTLSExtensions(extType uint16, extDataBuf *utils.ByteBuffer, m analyzer.PropMap) bool { switch extType { - case 0x0000: // SNI + case extServerName: ok := extDataBuf.Skip(2) // Ignore list length, we only care about the first entry for now if !ok { // Not enough data for list length @@ -154,7 +173,7 @@ func parseTLSExtensions(extType uint16, extDataBuf *utils.ByteBuffer, m analyzer // Not enough data for SNI return false } - case 0x0010: // ALPN + case extALPN: ok := extDataBuf.Skip(2) // Ignore list length, as we read until the end if !ok { // Not enough data for list length @@ -175,7 +194,7 @@ func parseTLSExtensions(extType uint16, extDataBuf *utils.ByteBuffer, m analyzer alpnList = append(alpnList, alpn) } m["alpn"] = alpnList - case 0x002b: // Supported Versions + case extSupportedVersions: if extDataBuf.Len() == 2 { // Server only selects one version m["supported_versions"], _ = extDataBuf.GetUint16(false, true) @@ -197,7 +216,7 @@ func parseTLSExtensions(extType uint16, extDataBuf *utils.ByteBuffer, m analyzer } m["supported_versions"] = versions } - case 0xfe0d: // ECH + case extEncryptedClientHello: // We can't parse ECH for now, just set a flag m["ech"] = true } diff --git a/analyzer/tcp/http_test.go b/analyzer/tcp/http_test.go new file mode 100644 index 0000000..dee4f57 --- /dev/null +++ b/analyzer/tcp/http_test.go @@ -0,0 +1,64 @@ +package tcp + +import ( + "reflect" + "strings" + "testing" + + "github.com/apernet/OpenGFW/analyzer" +) + +func TestHTTPParsing_Request(t *testing.T) { + testCases := map[string]analyzer.PropMap{ + "GET / HTTP/1.1\r\n": { + "method": "GET", "path": "/", "version": "HTTP/1.1", + }, + "POST /hello?a=1&b=2 HTTP/1.0\r\n": { + "method": "POST", "path": "/hello?a=1&b=2", "version": "HTTP/1.0", + }, + "PUT /world HTTP/1.1\r\nContent-Length: 4\r\n\r\nbody": { + "method": "PUT", "path": "/world", "version": "HTTP/1.1", "headers": analyzer.PropMap{"content-length": "4"}, + }, + "DELETE /goodbye HTTP/2.0\r\n": { + "method": "DELETE", "path": "/goodbye", "version": "HTTP/2.0", + }, + } + + for tc, want := range testCases { + t.Run(strings.Split(tc, " ")[0], func(t *testing.T) { + tc, want := tc, want + t.Parallel() + + u, _ := newHTTPStream(nil).Feed(false, false, false, 0, []byte(tc)) + got := u.M.Get("req") + if !reflect.DeepEqual(got, want) { + t.Errorf("\"%s\" parsed = %v, want %v", tc, got, want) + } + }) + } +} + +func TestHTTPParsing_Response(t *testing.T) { + testCases := map[string]analyzer.PropMap{ + "HTTP/1.0 200 OK\r\nContent-Length: 4\r\n\r\nbody": { + "version": "HTTP/1.0", "status": 200, + "headers": analyzer.PropMap{"content-length": "4"}, + }, + "HTTP/2.0 204 No Content\r\n\r\n": { + "version": "HTTP/2.0", "status": 204, + }, + } + + for tc, want := range testCases { + t.Run(strings.Split(tc, " ")[0], func(t *testing.T) { + tc, want := tc, want + t.Parallel() + + u, _ := newHTTPStream(nil).Feed(true, false, false, 0, []byte(tc)) + got := u.M.Get("resp") + if !reflect.DeepEqual(got, want) { + t.Errorf("\"%s\" parsed = %v, want %v", tc, got, want) + } + }) + } +} diff --git a/analyzer/tcp/socks.go b/analyzer/tcp/socks.go index 18ffc2b..a242069 100644 --- a/analyzer/tcp/socks.go +++ b/analyzer/tcp/socks.go @@ -208,10 +208,10 @@ func (s *socksStream) parseSocks5ReqMethod() utils.LSMAction { switch method { case Socks5AuthNotRequired: s.authReqMethod = Socks5AuthNotRequired - break + return utils.LSMActionNext case Socks5AuthPassword: s.authReqMethod = Socks5AuthPassword - break + return utils.LSMActionNext default: // TODO: more auth method to support } diff --git a/analyzer/tcp/tls.go b/analyzer/tcp/tls.go index 74c21f2..c5f1ea9 100644 --- a/analyzer/tcp/tls.go +++ b/analyzer/tcp/tls.go @@ -44,12 +44,12 @@ type tlsStream struct { func newTLSStream(logger analyzer.Logger) *tlsStream { s := &tlsStream{logger: logger, reqBuf: &utils.ByteBuffer{}, respBuf: &utils.ByteBuffer{}} s.reqLSM = utils.NewLinearStateMachine( - s.tlsClientHelloSanityCheck, - s.parseClientHello, + s.tlsClientHelloPreprocess, + s.parseClientHelloData, ) s.respLSM = utils.NewLinearStateMachine( - s.tlsServerHelloSanityCheck, - s.parseServerHello, + s.tlsServerHelloPreprocess, + s.parseServerHelloData, ) return s } @@ -89,61 +89,105 @@ func (s *tlsStream) Feed(rev, start, end bool, skip int, data []byte) (u *analyz return update, cancelled || (s.reqDone && s.respDone) } -func (s *tlsStream) tlsClientHelloSanityCheck() utils.LSMAction { - data, ok := s.reqBuf.Get(9, true) +// tlsClientHelloPreprocess validates ClientHello message. +// +// During validation, message header and first handshake header may be removed +// from `s.reqBuf`. +func (s *tlsStream) tlsClientHelloPreprocess() utils.LSMAction { + // headers size: content type (1 byte) + legacy protocol version (2 bytes) + + // + content length (2 bytes) + message type (1 byte) + + // + handshake length (3 bytes) + const headersSize = 9 + + // minimal data size: protocol version (2 bytes) + random (32 bytes) + + // + session ID (1 byte) + cipher suites (4 bytes) + + // + compression methods (2 bytes) + no extensions + const minDataSize = 41 + + header, ok := s.reqBuf.Get(headersSize, true) if !ok { + // not a full header yet return utils.LSMActionPause } - if data[0] != 0x16 || data[5] != 0x01 { - // Not a TLS handshake, or not a client hello + + if header[0] != internal.RecordTypeHandshake || header[5] != internal.TypeClientHello { return utils.LSMActionCancel } - s.clientHelloLen = int(data[6])<<16 | int(data[7])<<8 | int(data[8]) - if s.clientHelloLen < 41 { - // 2 (Protocol Version) + - // 32 (Random) + - // 1 (Session ID Length) + - // 2 (Cipher Suites Length) +_ws.col.protocol == "TLSv1.3" - // 2 (Cipher Suite) + - // 1 (Compression Methods Length) + - // 1 (Compression Method) + - // No extensions - // This should be the bare minimum for a client hello + + s.clientHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8]) + if s.clientHelloLen < minDataSize { return utils.LSMActionCancel } + + // TODO: something is missing. See: + // const messageHeaderSize = 4 + // fullMessageLen := int(header[3])<<8 | int(header[4]) + // msgNo := fullMessageLen / int(messageHeaderSize+s.serverHelloLen) + // if msgNo != 1 { + // // what here? + // } + // if messageNo != int(messageNo) { + // // what here? + // } + return utils.LSMActionNext } -func (s *tlsStream) tlsServerHelloSanityCheck() utils.LSMAction { - data, ok := s.respBuf.Get(9, true) +// tlsServerHelloPreprocess validates ServerHello message. +// +// During validation, message header and first handshake header may be removed +// from `s.reqBuf`. +func (s *tlsStream) tlsServerHelloPreprocess() utils.LSMAction { + // header size: content type (1 byte) + legacy protocol version (2 byte) + + // + content length (2 byte) + message type (1 byte) + + // + handshake length (3 byte) + const headersSize = 9 + + // minimal data size: server version (2 byte) + random (32 byte) + + // + session ID (>=1 byte) + cipher suite (2 byte) + + // + compression method (1 byte) + no extensions + const minDataSize = 38 + + header, ok := s.respBuf.Get(headersSize, true) if !ok { + // not a full header yet return utils.LSMActionPause } - if data[0] != 0x16 || data[5] != 0x02 { - // Not a TLS handshake, or not a server hello + + if header[0] != internal.RecordTypeHandshake || header[5] != internal.TypeServerHello { return utils.LSMActionCancel } - s.serverHelloLen = int(data[6])<<16 | int(data[7])<<8 | int(data[8]) - if s.serverHelloLen < 38 { - // 2 (Protocol Version) + - // 32 (Random) + - // 1 (Session ID Length) + - // 2 (Cipher Suite) + - // 1 (Compression Method) + - // No extensions - // This should be the bare minimum for a server hello + + s.serverHelloLen = int(header[6])<<16 | int(header[7])<<8 | int(header[8]) + if s.serverHelloLen < minDataSize { return utils.LSMActionCancel } + + // TODO: something is missing. See example: + // const messageHeaderSize = 4 + // fullMessageLen := int(header[3])<<8 | int(header[4]) + // msgNo := fullMessageLen / int(messageHeaderSize+s.serverHelloLen) + // if msgNo != 1 { + // // what here? + // } + // if messageNo != int(messageNo) { + // // what here? + // } + return utils.LSMActionNext } -func (s *tlsStream) parseClientHello() utils.LSMAction { +// parseClientHelloData converts valid ClientHello message data (without +// headers) into `analyzer.PropMap`. +// +// Parsing error may leave `s.reqBuf` in an unusable state. +func (s *tlsStream) parseClientHelloData() utils.LSMAction { chBuf, ok := s.reqBuf.GetSubBuffer(s.clientHelloLen, true) if !ok { // Not a full client hello yet return utils.LSMActionPause } - m := internal.ParseTLSClientHello(chBuf) + m := internal.ParseTLSClientHelloMsgData(chBuf) if m == nil { return utils.LSMActionCancel } else { @@ -153,13 +197,17 @@ func (s *tlsStream) parseClientHello() utils.LSMAction { } } -func (s *tlsStream) parseServerHello() utils.LSMAction { +// parseServerHelloData converts valid ServerHello message data (without +// headers) into `analyzer.PropMap`. +// +// Parsing error may leave `s.respBuf` in an unusable state. +func (s *tlsStream) parseServerHelloData() utils.LSMAction { shBuf, ok := s.respBuf.GetSubBuffer(s.serverHelloLen, true) if !ok { // Not a full server hello yet return utils.LSMActionPause } - m := internal.ParseTLSServerHello(shBuf) + m := internal.ParseTLSServerHelloMsgData(shBuf) if m == nil { return utils.LSMActionCancel } else { diff --git a/analyzer/tcp/tls_test.go b/analyzer/tcp/tls_test.go new file mode 100644 index 0000000..1ebb86b --- /dev/null +++ b/analyzer/tcp/tls_test.go @@ -0,0 +1,69 @@ +package tcp + +import ( + "reflect" + "testing" + + "github.com/apernet/OpenGFW/analyzer" +) + +func TestTlsStreamParsing_ClientHello(t *testing.T) { + // example packet taken from + clientHello := []byte{ + 0x16, 0x03, 0x01, 0x00, 0xa5, 0x01, 0x00, 0x00, 0xa1, 0x03, 0x03, 0x00, + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, + 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x00, 0x00, 0x20, 0xcc, 0xa8, + 0xcc, 0xa9, 0xc0, 0x2f, 0xc0, 0x30, 0xc0, 0x2b, 0xc0, 0x2c, 0xc0, 0x13, + 0xc0, 0x09, 0xc0, 0x14, 0xc0, 0x0a, 0x00, 0x9c, 0x00, 0x9d, 0x00, 0x2f, + 0x00, 0x35, 0xc0, 0x12, 0x00, 0x0a, 0x01, 0x00, 0x00, 0x58, 0x00, 0x00, + 0x00, 0x18, 0x00, 0x16, 0x00, 0x00, 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, + 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65, 0x69, 0x6d, 0x2e, 0x6e, + 0x65, 0x74, 0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x0a, 0x00, 0x08, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00, + 0x19, 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, 0x00, 0x0d, 0x00, 0x12, 0x00, + 0x10, 0x04, 0x01, 0x04, 0x03, 0x05, 0x01, 0x05, 0x03, 0x06, 0x01, 0x06, + 0x03, 0x02, 0x01, 0x02, 0x03, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00, 0x12, + 0x00, 0x00, + } + want := analyzer.PropMap{ + "ciphers": []uint16{52392, 52393, 49199, 49200, 49195, 49196, 49171, 49161, 49172, 49162, 156, 157, 47, 53, 49170, 10}, + "compression": []uint8{0}, + "random": []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, + "session": []uint8{}, + "sni": "example.ulfheim.net", + "version": uint16(771), + } + + s := newTLSStream(nil) + u, _ := s.Feed(false, false, false, 0, clientHello) + got := u.M.Get("req") + if !reflect.DeepEqual(got, want) { + t.Errorf("%d B parsed = %v, want %v", len(clientHello), got, want) + } +} + +func TestTlsStreamParsing_ServerHello(t *testing.T) { + // example packet taken from + serverHello := []byte{ + 0x16, 0x03, 0x03, 0x00, 0x31, 0x02, 0x00, 0x00, 0x2d, 0x03, 0x03, 0x70, + 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, + 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, + 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x00, 0xc0, 0x13, 0x00, 0x00, + 0x05, 0xff, 0x01, 0x00, 0x01, 0x00, + } + want := analyzer.PropMap{ + "cipher": uint16(49171), + "compression": uint8(0), + "random": []uint8{112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143}, + "session": []uint8{}, + "version": uint16(771), + } + + s := newTLSStream(nil) + u, _ := s.Feed(true, false, false, 0, serverHello) + got := u.M.Get("resp") + if !reflect.DeepEqual(got, want) { + t.Errorf("%d B parsed = %v, want %v", len(serverHello), got, want) + } +} diff --git a/analyzer/udp/quic.go b/analyzer/udp/quic.go index 3954192..a1a9ef0 100644 --- a/analyzer/udp/quic.go +++ b/analyzer/udp/quic.go @@ -36,41 +36,40 @@ type quicStream struct { } func (s *quicStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, done bool) { + // minimal data size: protocol version (2 bytes) + random (32 bytes) + + // + session ID (1 byte) + cipher suites (4 bytes) + + // + compression methods (2 bytes) + no extensions + const minDataSize = 41 + if rev { // We don't support server direction for now s.invalidCount++ return nil, s.invalidCount >= quicInvalidCountThreshold } + pl, err := quic.ReadCryptoPayload(data) - if err != nil || len(pl) < 4 { + if err != nil || len(pl) < 4 { // FIXME: isn't length checked inside quic.ReadCryptoPayload? Also, what about error handling? s.invalidCount++ return nil, s.invalidCount >= quicInvalidCountThreshold } - // Should be a TLS client hello - if pl[0] != 0x01 { - // Not a client hello + + if pl[0] != internal.TypeClientHello { s.invalidCount++ return nil, s.invalidCount >= quicInvalidCountThreshold } + chLen := int(pl[1])<<16 | int(pl[2])<<8 | int(pl[3]) - if chLen < 41 { - // 2 (Protocol Version) + - // 32 (Random) + - // 1 (Session ID Length) + - // 2 (Cipher Suites Length) +_ws.col.protocol == "TLSv1.3" - // 2 (Cipher Suite) + - // 1 (Compression Methods Length) + - // 1 (Compression Method) + - // No extensions - // This should be the bare minimum for a client hello + if chLen < minDataSize { s.invalidCount++ return nil, s.invalidCount >= quicInvalidCountThreshold } - m := internal.ParseTLSClientHello(&utils.ByteBuffer{Buf: pl[4:]}) + + m := internal.ParseTLSClientHelloMsgData(&utils.ByteBuffer{Buf: pl[4:]}) if m == nil { s.invalidCount++ return nil, s.invalidCount >= quicInvalidCountThreshold } + return &analyzer.PropUpdate{ Type: analyzer.PropUpdateMerge, M: analyzer.PropMap{"req": m}, diff --git a/analyzer/udp/quic_test.go b/analyzer/udp/quic_test.go new file mode 100644 index 0000000..a00c67d --- /dev/null +++ b/analyzer/udp/quic_test.go @@ -0,0 +1,58 @@ +package udp + +import ( + "reflect" + "testing" + + "github.com/apernet/OpenGFW/analyzer" +) + +func TestQuicStreamParsing_ClientHello(t *testing.T) { + // example packet taken from + clientHello := make([]byte, 1200) + clientInitial := []byte{ + 0xcd, 0x00, 0x00, 0x00, 0x01, 0x08, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, + 0x06, 0x07, 0x05, 0x63, 0x5f, 0x63, 0x69, 0x64, 0x00, 0x41, 0x03, 0x98, + 0x1c, 0x36, 0xa7, 0xed, 0x78, 0x71, 0x6b, 0xe9, 0x71, 0x1b, 0xa4, 0x98, + 0xb7, 0xed, 0x86, 0x84, 0x43, 0xbb, 0x2e, 0x0c, 0x51, 0x4d, 0x4d, 0x84, + 0x8e, 0xad, 0xcc, 0x7a, 0x00, 0xd2, 0x5c, 0xe9, 0xf9, 0xaf, 0xa4, 0x83, + 0x97, 0x80, 0x88, 0xde, 0x83, 0x6b, 0xe6, 0x8c, 0x0b, 0x32, 0xa2, 0x45, + 0x95, 0xd7, 0x81, 0x3e, 0xa5, 0x41, 0x4a, 0x91, 0x99, 0x32, 0x9a, 0x6d, + 0x9f, 0x7f, 0x76, 0x0d, 0xd8, 0xbb, 0x24, 0x9b, 0xf3, 0xf5, 0x3d, 0x9a, + 0x77, 0xfb, 0xb7, 0xb3, 0x95, 0xb8, 0xd6, 0x6d, 0x78, 0x79, 0xa5, 0x1f, + 0xe5, 0x9e, 0xf9, 0x60, 0x1f, 0x79, 0x99, 0x8e, 0xb3, 0x56, 0x8e, 0x1f, + 0xdc, 0x78, 0x9f, 0x64, 0x0a, 0xca, 0xb3, 0x85, 0x8a, 0x82, 0xef, 0x29, + 0x30, 0xfa, 0x5c, 0xe1, 0x4b, 0x5b, 0x9e, 0xa0, 0xbd, 0xb2, 0x9f, 0x45, + 0x72, 0xda, 0x85, 0xaa, 0x3d, 0xef, 0x39, 0xb7, 0xef, 0xaf, 0xff, 0xa0, + 0x74, 0xb9, 0x26, 0x70, 0x70, 0xd5, 0x0b, 0x5d, 0x07, 0x84, 0x2e, 0x49, + 0xbb, 0xa3, 0xbc, 0x78, 0x7f, 0xf2, 0x95, 0xd6, 0xae, 0x3b, 0x51, 0x43, + 0x05, 0xf1, 0x02, 0xaf, 0xe5, 0xa0, 0x47, 0xb3, 0xfb, 0x4c, 0x99, 0xeb, + 0x92, 0xa2, 0x74, 0xd2, 0x44, 0xd6, 0x04, 0x92, 0xc0, 0xe2, 0xe6, 0xe2, + 0x12, 0xce, 0xf0, 0xf9, 0xe3, 0xf6, 0x2e, 0xfd, 0x09, 0x55, 0xe7, 0x1c, + 0x76, 0x8a, 0xa6, 0xbb, 0x3c, 0xd8, 0x0b, 0xbb, 0x37, 0x55, 0xc8, 0xb7, + 0xeb, 0xee, 0x32, 0x71, 0x2f, 0x40, 0xf2, 0x24, 0x51, 0x19, 0x48, 0x70, + 0x21, 0xb4, 0xb8, 0x4e, 0x15, 0x65, 0xe3, 0xca, 0x31, 0x96, 0x7a, 0xc8, + 0x60, 0x4d, 0x40, 0x32, 0x17, 0x0d, 0xec, 0x28, 0x0a, 0xee, 0xfa, 0x09, + 0x5d, 0x08, 0xb3, 0xb7, 0x24, 0x1e, 0xf6, 0x64, 0x6a, 0x6c, 0x86, 0xe5, + 0xc6, 0x2c, 0xe0, 0x8b, 0xe0, 0x99, + } + copy(clientHello, clientInitial) + + want := analyzer.PropMap{ + "alpn": []string{"ping/1.0"}, + "ciphers": []uint16{4865, 4866, 4867}, + "compression": []uint8{0}, + "random": []uint8{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, + "session": []uint8{}, + "sni": "example.ulfheim.net", + "supported_versions": []uint16{772}, + "version": uint16(771), + } + + s := quicStream{} + u, _ := s.Feed(false, clientHello) + got := u.M.Get("req") + if !reflect.DeepEqual(got, want) { + t.Errorf("%d B parsed = %v, want %v", len(clientHello), got, want) + } +} diff --git a/cmd/root.go b/cmd/root.go index aec1650..0c8fa77 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -5,7 +5,6 @@ import ( "fmt" "os" "os/signal" - "strconv" "strings" "syscall" @@ -278,15 +277,15 @@ func runMain(cmd *cobra.Command, args []string) { ctx, cancelFunc := context.WithCancel(context.Background()) go func() { // Graceful shutdown - shutdownChan := make(chan os.Signal) - signal.Notify(shutdownChan, os.Interrupt, os.Kill) + shutdownChan := make(chan os.Signal, 1) + signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM) <-shutdownChan logger.Info("shutting down gracefully...") cancelFunc() }() go func() { // Rule reload - reloadChan := make(chan os.Signal) + reloadChan := make(chan os.Signal, 1) signal.Notify(reloadChan, syscall.SIGHUP) for { <-reloadChan @@ -431,11 +430,3 @@ func envOrDefaultString(key, def string) string { } return def } - -func envOrDefaultBool(key string, def bool) bool { - if v := os.Getenv(key); v != "" { - b, _ := strconv.ParseBool(v) - return b - } - return def -} diff --git a/go.mod b/go.mod index e3e739e..75e54ef 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,6 @@ require ( github.com/quic-go/quic-go v0.41.0 github.com/spf13/cobra v1.8.0 github.com/spf13/viper v1.18.2 - github.com/stretchr/testify v1.8.4 go.uber.org/zap v1.26.0 golang.org/x/crypto v0.19.0 golang.org/x/sys v0.17.0 @@ -22,7 +21,6 @@ require ( ) require ( - github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/hashicorp/hcl v1.0.0 // indirect @@ -32,7 +30,6 @@ require ( github.com/mdlayher/socket v0.1.1 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pelletier/go-toml/v2 v2.1.0 // indirect - github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect diff --git a/ruleset/builtins/geo/geo_loader.go b/ruleset/builtins/geo/geo_loader.go index de5166a..8e16509 100644 --- a/ruleset/builtins/geo/geo_loader.go +++ b/ruleset/builtins/geo/geo_loader.go @@ -49,7 +49,7 @@ func (l *V2GeoLoader) shouldDownload(filename string) bool { if os.IsNotExist(err) { return true } - dt := time.Now().Sub(info.ModTime()) + dt := time.Since(info.ModTime()) if l.UpdateInterval == 0 { return dt > geoDefaultUpdateInterval } else { diff --git a/ruleset/builtins/geo/v2geo/load_test.go b/ruleset/builtins/geo/v2geo/load_test.go deleted file mode 100644 index e9c901a..0000000 --- a/ruleset/builtins/geo/v2geo/load_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package v2geo - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestLoadGeoIP(t *testing.T) { - m, err := LoadGeoIP("geoip.dat") - assert.NoError(t, err) - - // Exact checks since we know the data. - assert.Len(t, m, 252) - assert.Equal(t, m["cn"].CountryCode, "CN") - assert.Len(t, m["cn"].Cidr, 10407) - assert.Equal(t, m["us"].CountryCode, "US") - assert.Len(t, m["us"].Cidr, 193171) - assert.Equal(t, m["private"].CountryCode, "PRIVATE") - assert.Len(t, m["private"].Cidr, 18) - assert.Contains(t, m["private"].Cidr, &CIDR{ - Ip: []byte("\xc0\xa8\x00\x00"), - Prefix: 16, - }) -} - -func TestLoadGeoSite(t *testing.T) { - m, err := LoadGeoSite("geosite.dat") - assert.NoError(t, err) - - // Exact checks since we know the data. - assert.Len(t, m, 1204) - assert.Equal(t, m["netflix"].CountryCode, "NETFLIX") - assert.Len(t, m["netflix"].Domain, 25) - assert.Contains(t, m["netflix"].Domain, &Domain{ - Type: Domain_Full, - Value: "netflix.com.edgesuite.net", - }) - assert.Contains(t, m["netflix"].Domain, &Domain{ - Type: Domain_RootDomain, - Value: "fast.com", - }) - assert.Len(t, m["google"].Domain, 1066) - assert.Contains(t, m["google"].Domain, &Domain{ - Type: Domain_RootDomain, - Value: "ggpht.cn", - Attribute: []*Domain_Attribute{ - { - Key: "cn", - TypedValue: &Domain_Attribute_BoolValue{BoolValue: true}, - }, - }, - }) -}