commit 4f86f91a153b7cae7103db1274b0ec3ea6fa0e19 Author: Toby Date: Fri Jan 19 16:45:01 2024 -0800 first diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a26175e --- /dev/null +++ b/.gitignore @@ -0,0 +1,207 @@ +# Created by https://www.toptal.com/developers/gitignore/api/windows,macos,linux,go,goland+all,visualstudiocode +# Edit at https://www.toptal.com/developers/gitignore?templates=windows,macos,linux,go,goland+all,visualstudiocode + +### Go ### +# If you prefer the allow list template instead of the deny list, see community template: +# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore +# +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work + +### GoLand+all ### +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# AWS User-specific +.idea/**/aws.xml + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# SonarLint plugin +.idea/sonarlint/ + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + +### GoLand+all Patch ### +# Ignore everything but code style settings and run configurations +# that are supposed to be shared within teams. + +.idea/* + +!.idea/codeStyles +!.idea/runConfigurations + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### macOS ### +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +### macOS Patch ### +# iCloud generated files +*.icloud + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +### Windows ### +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk + +# End of https://www.toptal.com/developers/gitignore/api/windows,macos,linux,go,goland+all,visualstudiocode diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..fa0086a --- /dev/null +++ b/LICENSE @@ -0,0 +1,373 @@ +Mozilla Public License Version 2.0 +================================== + +1. Definitions +-------------- + +1.1. "Contributor" + means each individual or legal entity that creates, contributes to + the creation of, or owns Covered Software. + +1.2. "Contributor Version" + means the combination of the Contributions of others (if any) used + by a Contributor and that particular Contributor's Contribution. + +1.3. "Contribution" + means Covered Software of a particular Contributor. + +1.4. "Covered Software" + means Source Code Form to which the initial Contributor has attached + the notice in Exhibit A, the Executable Form of such Source Code + Form, and Modifications of such Source Code Form, in each case + including portions thereof. + +1.5. "Incompatible With Secondary Licenses" + means + + (a) that the initial Contributor has attached the notice described + in Exhibit B to the Covered Software; or + + (b) that the Covered Software was made available under the terms of + version 1.1 or earlier of the License, but not also under the + terms of a Secondary License. + +1.6. "Executable Form" + means any form of the work other than Source Code Form. + +1.7. "Larger Work" + means a work that combines Covered Software with other material, in + a separate file or files, that is not Covered Software. + +1.8. "License" + means this document. + +1.9. "Licensable" + means having the right to grant, to the maximum extent possible, + whether at the time of the initial grant or subsequently, any and + all of the rights conveyed by this License. + +1.10. "Modifications" + means any of the following: + + (a) any file in Source Code Form that results from an addition to, + deletion from, or modification of the contents of Covered + Software; or + + (b) any new file in Source Code Form that contains any Covered + Software. + +1.11. "Patent Claims" of a Contributor + means any patent claim(s), including without limitation, method, + process, and apparatus claims, in any patent Licensable by such + Contributor that would be infringed, but for the grant of the + License, by the making, using, selling, offering for sale, having + made, import, or transfer of either its Contributions or its + Contributor Version. + +1.12. "Secondary License" + means either the GNU General Public License, Version 2.0, the GNU + Lesser General Public License, Version 2.1, the GNU Affero General + Public License, Version 3.0, or any later versions of those + licenses. + +1.13. "Source Code Form" + means the form of the work preferred for making modifications. + +1.14. "You" (or "Your") + means an individual or a legal entity exercising rights under this + License. For legal entities, "You" includes any entity that + controls, is controlled by, or is under common control with You. For + purposes of this definition, "control" means (a) the power, direct + or indirect, to cause the direction or management of such entity, + whether by contract or otherwise, or (b) ownership of more than + fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. + +2. License Grants and Conditions +-------------------------------- + +2.1. Grants + +Each Contributor hereby grants You a world-wide, royalty-free, +non-exclusive license: + +(a) under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or + as part of a Larger Work; and + +(b) under Patent Claims of such Contributor to make, use, sell, offer + for sale, have made, import, and otherwise transfer either its + Contributions or its Contributor Version. + +2.2. Effective Date + +The licenses granted in Section 2.1 with respect to any Contribution +become effective for each Contribution on the date the Contributor first +distributes such Contribution. + +2.3. Limitations on Grant Scope + +The licenses granted in this Section 2 are the only rights granted under +this License. No additional rights or licenses will be implied from the +distribution or licensing of Covered Software under this License. +Notwithstanding Section 2.1(b) above, no patent license is granted by a +Contributor: + +(a) for any code that a Contributor has removed from Covered Software; + or + +(b) for infringements caused by: (i) Your and any other third party's + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + +(c) under Patent Claims infringed by Covered Software in the absence of + its Contributions. + +This License does not grant any rights in the trademarks, service marks, +or logos of any Contributor (except as may be necessary to comply with +the notice requirements in Section 3.4). + +2.4. Subsequent Licenses + +No Contributor makes additional grants as a result of Your choice to +distribute the Covered Software under a subsequent version of this +License (see Section 10.2) or under the terms of a Secondary License (if +permitted under the terms of Section 3.3). + +2.5. Representation + +Each Contributor represents that the Contributor believes its +Contributions are its original creation(s) or it has sufficient rights +to grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use + +This License is not intended to limit any rights You have under +applicable copyright doctrines of fair use, fair dealing, or other +equivalents. + +2.7. Conditions + +Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted +in Section 2.1. + +3. Responsibilities +------------------- + +3.1. Distribution of Source Form + +All distribution of Covered Software in Source Code Form, including any +Modifications that You create or to which You contribute, must be under +the terms of this License. You must inform recipients that the Source +Code Form of the Covered Software is governed by the terms of this +License, and how they can obtain a copy of this License. You may not +attempt to alter or restrict the recipients' rights in the Source Code +Form. + +3.2. Distribution of Executable Form + +If You distribute Covered Software in Executable Form then: + +(a) such Covered Software must also be made available in Source Code + Form, as described in Section 3.1, and You must inform recipients of + the Executable Form how they can obtain a copy of such Source Code + Form by reasonable means in a timely manner, at a charge no more + than the cost of distribution to the recipient; and + +(b) You may distribute such Executable Form under the terms of this + License, or sublicense it under different terms, provided that the + license for the Executable Form does not attempt to limit or alter + the recipients' rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + +You may create and distribute a Larger Work under terms of Your choice, +provided that You also comply with the requirements of this License for +the Covered Software. If the Larger Work is a combination of Covered +Software with a work governed by one or more Secondary Licenses, and the +Covered Software is not Incompatible With Secondary Licenses, this +License permits You to additionally distribute such Covered Software +under the terms of such Secondary License(s), so that the recipient of +the Larger Work may, at their option, further distribute the Covered +Software under the terms of either this License or such Secondary +License(s). + +3.4. Notices + +You may not remove or alter the substance of any license notices +(including copyright notices, patent notices, disclaimers of warranty, +or limitations of liability) contained within the Source Code Form of +the Covered Software, except that You may alter any license notices to +the extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + +You may choose to offer, and to charge a fee for, warranty, support, +indemnity or liability obligations to one or more recipients of Covered +Software. However, You may do so only on Your own behalf, and not on +behalf of any Contributor. You must make it absolutely clear that any +such warranty, support, indemnity, or liability obligation is offered by +You alone, and You hereby agree to indemnify every Contributor for any +liability incurred by such Contributor as a result of warranty, support, +indemnity or liability terms You offer. You may include additional +disclaimers of warranty and limitations of liability specific to any +jurisdiction. + +4. Inability to Comply Due to Statute or Regulation +--------------------------------------------------- + +If it is impossible for You to comply with any of the terms of this +License with respect to some or all of the Covered Software due to +statute, judicial order, or regulation then You must: (a) comply with +the terms of this License to the maximum extent possible; and (b) +describe the limitations and the code they affect. Such description must +be placed in a text file included with all distributions of the Covered +Software under this License. Except to the extent prohibited by statute +or regulation, such description must be sufficiently detailed for a +recipient of ordinary skill to be able to understand it. + +5. Termination +-------------- + +5.1. The rights granted under this License will terminate automatically +if You fail to comply with any of its terms. However, if You become +compliant, then the rights granted under this License from a particular +Contributor are reinstated (a) provisionally, unless and until such +Contributor explicitly and finally terminates Your grants, and (b) on an +ongoing basis, if such Contributor fails to notify You of the +non-compliance by some reasonable means prior to 60 days after You have +come back into compliance. Moreover, Your grants from a particular +Contributor are reinstated on an ongoing basis if such Contributor +notifies You of the non-compliance by some reasonable means, this is the +first time You have received notice of non-compliance with this License +from such Contributor, and You become compliant prior to 30 days after +Your receipt of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent +infringement claim (excluding declaratory judgment actions, +counter-claims, and cross-claims) alleging that a Contributor Version +directly or indirectly infringes any patent, then the rights granted to +You by any and all Contributors for the Covered Software under Section +2.1 of this License shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all +end user license agreements (excluding distributors and resellers) which +have been validly granted by You or Your distributors under this License +prior to termination shall survive termination. + +************************************************************************ +* * +* 6. Disclaimer of Warranty * +* ------------------------- * +* * +* Covered Software is provided under this License on an "as is" * +* basis, without warranty of any kind, either expressed, implied, or * +* statutory, including, without limitation, warranties that the * +* Covered Software is free of defects, merchantable, fit for a * +* particular purpose or non-infringing. The entire risk as to the * +* quality and performance of the Covered Software is with You. * +* Should any Covered Software prove defective in any respect, You * +* (not any Contributor) assume the cost of any necessary servicing, * +* repair, or correction. This disclaimer of warranty constitutes an * +* essential part of this License. No use of any Covered Software is * +* authorized under this License except under this disclaimer. * +* * +************************************************************************ + +************************************************************************ +* * +* 7. Limitation of Liability * +* -------------------------- * +* * +* Under no circumstances and under no legal theory, whether tort * +* (including negligence), contract, or otherwise, shall any * +* Contributor, or anyone who distributes Covered Software as * +* permitted above, be liable to You for any direct, indirect, * +* special, incidental, or consequential damages of any character * +* including, without limitation, damages for lost profits, loss of * +* goodwill, work stoppage, computer failure or malfunction, or any * +* and all other commercial damages or losses, even if such party * +* shall have been informed of the possibility of such damages. This * +* limitation of liability shall not apply to liability for death or * +* personal injury resulting from such party's negligence to the * +* extent applicable law prohibits such limitation. Some * +* jurisdictions do not allow the exclusion or limitation of * +* incidental or consequential damages, so this exclusion and * +* limitation may not apply to You. * +* * +************************************************************************ + +8. Litigation +------------- + +Any litigation relating to this License may be brought only in the +courts of a jurisdiction where the defendant maintains its principal +place of business and such litigation shall be governed by laws of that +jurisdiction, without reference to its conflict-of-law provisions. +Nothing in this Section shall prevent a party's ability to bring +cross-claims or counter-claims. + +9. Miscellaneous +---------------- + +This License represents the complete agreement concerning the subject +matter hereof. If any provision of this License is held to be +unenforceable, such provision shall be reformed only to the extent +necessary to make it enforceable. Any law or regulation which provides +that the language of a contract shall be construed against the drafter +shall not be used to construe this License against a Contributor. + +10. Versions of the License +--------------------------- + +10.1. New Versions + +Mozilla Foundation is the license steward. Except as provided in Section +10.3, no one other than the license steward has the right to modify or +publish new versions of this License. Each version will be given a +distinguishing version number. + +10.2. Effect of New Versions + +You may distribute the Covered Software under the terms of the version +of the License under which You originally received the Covered Software, +or under the terms of any subsequent version published by the license +steward. + +10.3. Modified Versions + +If you create software not governed by this License, and you want to +create a new license for such software, you may create and use a +modified version of this License if you rename the license and remove +any references to the name of the license steward (except to note that +such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary +Licenses + +If You choose to distribute Source Code Form that is Incompatible With +Secondary Licenses under the terms of this version of the License, the +notice described in Exhibit B of this License must be attached. + +Exhibit A - Source Code Form License Notice +------------------------------------------- + + This Source Code Form is subject to the terms of the Mozilla Public + License, v. 2.0. If a copy of the MPL was not distributed with this + file, You can obtain one at http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular +file, then You may include the notice in a location (such as a LICENSE +file in a relevant directory) where a recipient would be likely to look +for such a notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - "Incompatible With Secondary Licenses" Notice +--------------------------------------------------------- + + This Source Code Form is "Incompatible With Secondary Licenses", as + defined by the Mozilla Public License, v. 2.0. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..888eab5 --- /dev/null +++ b/README.md @@ -0,0 +1,110 @@ +# ![OpenGFW](docs/logo.png) + +[![License][1]][2] + +[1]: https://img.shields.io/badge/License-MPL_2.0-brightgreen.svg + +[2]: LICENSE + +**[中文文档](README.zh.md)** + +OpenGFW is a flexible, easy-to-use, open source implementation of [GFW](https://en.wikipedia.org/wiki/Great_Firewall) on +Linux that's in many ways more powerful than the real thing. It's cyber sovereignty you can have on a home router. + +> [!CAUTION] +> This project is still in very early stages of development. Use at your own risk. + +> [!NOTE] +> We are looking for contributors to help us with this project, especially implementing analyzers for more protocols!!! + +## Features + +- Full IP/TCP reassembly, various protocol analyzers + - HTTP, TLS, DNS, SSH, and many more to come + - "Fully encrypted traffic" detection for Shadowsocks, + etc. (https://gfw.report/publications/usenixsecurity23/data/paper/paper.pdf) + - [WIP] Machine learning based traffic classification +- Flow-based multicore load balancing +- Connection offloading +- Powerful rule engine based on [expr](https://github.com/expr-lang/expr) +- Flexible analyzer & modifier framework +- Extensible IO implementation (only NFQueue for now) +- [WIP] Web UI + +## Use cases + +- Ad blocking +- Parental control +- Malware protection +- Abuse prevention for VPN/proxy services +- Traffic analysis (log only mode) + +## Usage + +### Build + +```shell +go build +``` + +### Run + +```shell +export OPENGFW_LOG_LEVEL=debug +./OpenGFW -c config.yaml rules.yaml +``` + +### Example config + +```yaml +io: + queueSize: 1024 + local: true # set to false if you want to run OpenGFW on FORWARD chain + +workers: + count: 4 + queueSize: 16 + tcpMaxBufferedPagesTotal: 4096 + tcpMaxBufferedPagesPerConn: 64 + udpMaxStreams: 4096 +``` + +### Example rules + +Documentation on all supported protocols and what field each one has is not yet ready. For now, you have to check the +code under "analyzer" directory directly. + +For syntax of the expression language, please refer +to [Expr Language Definition](https://expr-lang.org/docs/language-definition). + +```yaml +- name: block v2ex http + action: block + expr: string(http?.req?.headers?.host) endsWith "v2ex.com" + +- name: block v2ex https + action: block + expr: string(tls?.req?.sni) endsWith "v2ex.com" + +- name: block shadowsocks + action: block + expr: fet != nil && fet.yes + +- name: v2ex dns poisoning + action: modify + modifier: + name: dns + args: + a: "0.0.0.0" + aaaa: "::" + expr: dns != nil && dns.qr && any(dns.questions, {.name endsWith "v2ex.com"}) +``` + +#### Supported actions + +- `allow`: Allow the connection, no further processing. +- `block`: Block the connection, no further processing. Send a TCP RST if it's a TCP connection. +- `drop`: For UDP, drop the packet that triggered the rule, continue processing future packets in the same flow. For + TCP, same as `block`. +- `modify`: For UDP, modify the packet that triggered the rule using the given modifier, continue processing future + packets in the same flow. For TCP, same as `allow`. \ No newline at end of file diff --git a/README.zh.md b/README.zh.md new file mode 100644 index 0000000..4914fb8 --- /dev/null +++ b/README.zh.md @@ -0,0 +1,103 @@ +# ![OpenGFW](docs/logo.png) + +[![License][1]][2] + +[1]: https://img.shields.io/badge/License-MPL_2.0-brightgreen.svg + +[2]: LICENSE + +OpenGFW 是一个 Linux 上灵活、易用、开源的 [GFW](https://zh.wikipedia.org/wiki/%E9%98%B2%E7%81%AB%E9%95%BF%E5%9F%8E) +实现,并且在许多方面比真正的 GFW 更强大。可以部署在家用路由器上的网络主权。 + +> [!CAUTION] +> 本项目仍处于早期开发阶段。测试时自行承担风险。 + +> [!NOTE] +> 我们正在寻求贡献者一起完善本项目,尤其是实现更多协议的解析器! + +## 功能 + +- 完整的 IP/TCP 重组,各种协议解析器 + - HTTP, TLS, DNS, SSH, 更多协议正在开发中 + - Shadowsocks 等 "全加密流量" 检测 (https://gfw.report/publications/usenixsecurity23/data/paper/paper.pdf) + - [开发中] 基于机器学习的流量分类 +- 基于流的多核负载均衡 +- 连接 offloading +- 基于 [expr](https://github.com/expr-lang/expr) 的强大规则引擎 +- 灵活的协议解析和修改框架 +- 可扩展的 IO 实现 (目前只有 NFQueue) +- [开发中] Web UI + +## 使用场景 + +- 广告拦截 +- 家长控制 +- 恶意软件防护 +- VPN/代理服务滥用防护 +- 流量分析 (纯日志模式) + +## 使用 + +### 构建 + +```shell +go build +``` + +### 运行 + +```shell +export OPENGFW_LOG_LEVEL=debug +./OpenGFW -c config.yaml rules.yaml +``` + +### 样例配置 + +```yaml +io: + queueSize: 1024 + local: true # 如果需要在 FORWARD 链上运行 OpenGFW,请设置为 false + +workers: + count: 4 + queueSize: 16 + tcpMaxBufferedPagesTotal: 4096 + tcpMaxBufferedPagesPerConn: 64 + udpMaxStreams: 4096 +``` + +### 样例规则 + +关于规则具体支持哪些协议,以及每个协议包含哪些字段的文档还没有写。目前请直接参考 "analyzer" 目录下的代码。 + +规则的语法请参考 [Expr Language Definition](https://expr-lang.org/docs/language-definition)。 + +```yaml +- name: block v2ex http + action: block + expr: string(http?.req?.headers?.host) endsWith "v2ex.com" + +- name: block v2ex https + action: block + expr: string(tls?.req?.sni) endsWith "v2ex.com" + +- name: block shadowsocks + action: block + expr: fet != nil && fet.yes + +- name: v2ex dns poisoning + action: modify + modifier: + name: dns + args: + a: "0.0.0.0" + aaaa: "::" + expr: dns != nil && dns.qr && any(dns.questions, {.name endsWith "v2ex.com"}) +``` + +#### 支持的 action + +- `allow`: 放行连接,不再处理后续的包。 +- `block`: 阻断连接,不再处理后续的包。如果是 TCP 连接,会发送 RST 包。 +- `drop`: 对于 UDP,丢弃触发规则的包,但继续处理同一流中的后续包。对于 TCP,效果同 `block`。 +- `modify`: 对于 UDP,用指定的修改器修改触发规则的包,然后继续处理同一流中的后续包。对于 TCP,效果同 `allow`。 \ No newline at end of file diff --git a/analyzer/interface.go b/analyzer/interface.go new file mode 100644 index 0000000..80ad418 --- /dev/null +++ b/analyzer/interface.go @@ -0,0 +1,131 @@ +package analyzer + +import ( + "net" + "strings" +) + +type Analyzer interface { + // Name returns the name of the analyzer. + Name() string + // Limit returns the byte limit for this analyzer. + // For example, an analyzer can return 1000 to indicate that it only ever needs + // the first 1000 bytes of a stream to do its job. If the stream is still not + // done after 1000 bytes, the engine will stop feeding it data and close it. + // An analyzer can return 0 or a negative number to indicate that it does not + // have a hard limit. + // Note: for UDP streams, the engine always feeds entire packets, even if + // the packet is larger than the remaining quota or the limit itself. + Limit() int +} + +type Logger interface { + Debugf(format string, args ...interface{}) + Infof(format string, args ...interface{}) + Errorf(format string, args ...interface{}) +} + +type TCPAnalyzer interface { + Analyzer + // NewTCP returns a new TCPStream. + NewTCP(TCPInfo, Logger) TCPStream +} + +type TCPInfo struct { + // SrcIP is the source IP address. + SrcIP net.IP + // DstIP is the destination IP address. + DstIP net.IP + // SrcPort is the source port. + SrcPort uint16 + // DstPort is the destination port. + DstPort uint16 +} + +type TCPStream interface { + // Feed feeds a chunk of reassembled data to the stream. + // It returns a prop update containing the information extracted from the stream (can be nil), + // and whether the analyzer is "done" with this stream (i.e. no more data should be fed). + Feed(rev, start, end bool, skip int, data []byte) (u *PropUpdate, done bool) + // Close indicates that the stream is closed. + // Either the connection is closed, or the stream has reached its byte limit. + // Like Feed, it optionally returns a prop update. + Close(limited bool) *PropUpdate +} + +type UDPAnalyzer interface { + Analyzer + // NewUDP returns a new UDPStream. + NewUDP(UDPInfo, Logger) UDPStream +} + +type UDPInfo struct { + // SrcIP is the source IP address. + SrcIP net.IP + // DstIP is the destination IP address. + DstIP net.IP + // SrcPort is the source port. + SrcPort uint16 + // DstPort is the destination port. + DstPort uint16 +} + +type UDPStream interface { + // Feed feeds a new packet to the stream. + // It returns a prop update containing the information extracted from the stream (can be nil), + // and whether the analyzer is "done" with this stream (i.e. no more data should be fed). + Feed(rev bool, data []byte) (u *PropUpdate, done bool) + // Close indicates that the stream is closed. + // Either the connection is closed, or the stream has reached its byte limit. + // Like Feed, it optionally returns a prop update. + Close(limited bool) *PropUpdate +} + +type ( + PropMap map[string]interface{} + CombinedPropMap map[string]PropMap +) + +// Get returns the value of the property with the given key. +// The key can be a nested key, e.g. "foo.bar.baz". +// Returns nil if the key does not exist. +func (m PropMap) Get(key string) interface{} { + keys := strings.Split(key, ".") + if len(keys) == 0 { + return nil + } + var current interface{} = m + for _, k := range keys { + currentMap, ok := current.(PropMap) + if !ok { + return nil + } + current = currentMap[k] + } + return current +} + +// Get returns the value of the property with the given analyzer & key. +// The key can be a nested key, e.g. "foo.bar.baz". +// Returns nil if the key does not exist. +func (cm CombinedPropMap) Get(an string, key string) interface{} { + m, ok := cm[an] + if !ok { + return nil + } + return m.Get(key) +} + +type PropUpdateType int + +const ( + PropUpdateNone PropUpdateType = iota + PropUpdateMerge + PropUpdateReplace + PropUpdateDelete +) + +type PropUpdate struct { + Type PropUpdateType + M PropMap +} diff --git a/analyzer/tcp/fet.go b/analyzer/tcp/fet.go new file mode 100644 index 0000000..2022d7a --- /dev/null +++ b/analyzer/tcp/fet.go @@ -0,0 +1,159 @@ +package tcp + +import "github.com/apernet/OpenGFW/analyzer" + +var _ analyzer.TCPAnalyzer = (*FETAnalyzer)(nil) + +// FETAnalyzer stands for "Fully Encrypted Traffic" analyzer. +// It implements an algorithm to detect fully encrypted proxy protocols +// such as Shadowsocks, mentioned in the following paper: +// https://gfw.report/publications/usenixsecurity23/data/paper/paper.pdf +type FETAnalyzer struct{} + +func (a *FETAnalyzer) Name() string { + return "fet" +} + +func (a *FETAnalyzer) Limit() int { + // We only really look at the first packet + return 8192 +} + +func (a *FETAnalyzer) NewTCP(info analyzer.TCPInfo, logger analyzer.Logger) analyzer.TCPStream { + return newFETStream(logger) +} + +type fetStream struct { + logger analyzer.Logger +} + +func newFETStream(logger analyzer.Logger) *fetStream { + return &fetStream{logger: logger} +} + +func (s *fetStream) Feed(rev, start, end bool, skip int, data []byte) (u *analyzer.PropUpdate, done bool) { + if skip != 0 { + return nil, true + } + if len(data) == 0 { + return nil, false + } + ex1 := averagePopCount(data) + ex2 := isFirstSixPrintable(data) + ex3 := printablePercentage(data) + ex4 := contiguousPrintable(data) + ex5 := isTLSorHTTP(data) + exempt := (ex1 <= 3.4 || ex1 >= 4.6) || ex2 || ex3 > 0.5 || ex4 > 20 || ex5 + return &analyzer.PropUpdate{ + Type: analyzer.PropUpdateReplace, + M: analyzer.PropMap{ + "ex1": ex1, + "ex2": ex2, + "ex3": ex3, + "ex4": ex4, + "ex5": ex5, + "yes": !exempt, + }, + }, true +} + +func (s *fetStream) Close(limited bool) *analyzer.PropUpdate { + return nil +} + +func popCount(b byte) int { + count := 0 + for b != 0 { + count += int(b & 1) + b >>= 1 + } + return count +} + +// averagePopCount returns the average popcount of the given bytes. +// This is the "Ex1" metric in the paper. +func averagePopCount(bytes []byte) float32 { + if len(bytes) == 0 { + return 0 + } + total := 0 + for _, b := range bytes { + total += popCount(b) + } + return float32(total) / float32(len(bytes)) +} + +// isFirstSixPrintable returns true if the first six bytes are printable ASCII. +// This is the "Ex2" metric in the paper. +func isFirstSixPrintable(bytes []byte) bool { + if len(bytes) < 6 { + return false + } + for i := range bytes[:6] { + if !isPrintable(bytes[i]) { + return false + } + } + return true +} + +// printablePercentage returns the percentage of printable ASCII bytes. +// This is the "Ex3" metric in the paper. +func printablePercentage(bytes []byte) float32 { + if len(bytes) == 0 { + return 0 + } + count := 0 + for i := range bytes { + if isPrintable(bytes[i]) { + count++ + } + } + return float32(count) / float32(len(bytes)) +} + +// contiguousPrintable returns the length of the longest contiguous sequence of +// printable ASCII bytes. +// This is the "Ex4" metric in the paper. +func contiguousPrintable(bytes []byte) int { + if len(bytes) == 0 { + return 0 + } + maxCount := 0 + current := 0 + for i := range bytes { + if isPrintable(bytes[i]) { + current++ + } else { + if current > maxCount { + maxCount = current + } + current = 0 + } + } + if current > maxCount { + maxCount = current + } + return maxCount +} + +// isTLSorHTTP returns true if the given bytes look like TLS or HTTP. +// This is the "Ex5" metric in the paper. +func isTLSorHTTP(bytes []byte) bool { + if len(bytes) < 3 { + return false + } + if bytes[0] == 0x16 && bytes[1] == 0x03 && bytes[2] <= 0x03 { + // TLS handshake for TLS 1.0-1.3 + return true + } + // HTTP request + str := string(bytes[:3]) + return str == "GET" || str == "HEA" || str == "POS" || + str == "PUT" || str == "DEL" || str == "CON" || + str == "OPT" || str == "TRA" || str == "PAT" +} + +func isPrintable(b byte) bool { + return b >= 0x20 && b <= 0x7e +} diff --git a/analyzer/tcp/http.go b/analyzer/tcp/http.go new file mode 100644 index 0000000..9d5289f --- /dev/null +++ b/analyzer/tcp/http.go @@ -0,0 +1,193 @@ +package tcp + +import ( + "bytes" + "strconv" + "strings" + + "github.com/apernet/OpenGFW/analyzer" + "github.com/apernet/OpenGFW/analyzer/utils" +) + +var _ analyzer.TCPAnalyzer = (*HTTPAnalyzer)(nil) + +type HTTPAnalyzer struct{} + +func (a *HTTPAnalyzer) Name() string { + return "http" +} + +func (a *HTTPAnalyzer) Limit() int { + return 8192 +} + +func (a *HTTPAnalyzer) NewTCP(info analyzer.TCPInfo, logger analyzer.Logger) analyzer.TCPStream { + return newHTTPStream(logger) +} + +type httpStream struct { + logger analyzer.Logger + + reqBuf *utils.ByteBuffer + reqMap analyzer.PropMap + reqUpdated bool + reqLSM *utils.LinearStateMachine + reqDone bool + + respBuf *utils.ByteBuffer + respMap analyzer.PropMap + respUpdated bool + respLSM *utils.LinearStateMachine + respDone bool +} + +func newHTTPStream(logger analyzer.Logger) *httpStream { + s := &httpStream{logger: logger, reqBuf: &utils.ByteBuffer{}, respBuf: &utils.ByteBuffer{}} + s.reqLSM = utils.NewLinearStateMachine( + s.parseRequestLine, + s.parseRequestHeaders, + ) + s.respLSM = utils.NewLinearStateMachine( + s.parseResponseLine, + s.parseResponseHeaders, + ) + return s +} + +func (s *httpStream) Feed(rev, start, end bool, skip int, data []byte) (u *analyzer.PropUpdate, d bool) { + if skip != 0 { + return nil, true + } + if len(data) == 0 { + return nil, false + } + var update *analyzer.PropUpdate + var cancelled bool + if rev { + s.respBuf.Append(data) + s.respUpdated = false + cancelled, s.respDone = s.respLSM.Run() + if s.respUpdated { + update = &analyzer.PropUpdate{ + Type: analyzer.PropUpdateMerge, + M: analyzer.PropMap{"resp": s.respMap}, + } + s.respUpdated = false + } + } else { + s.reqBuf.Append(data) + s.reqUpdated = false + cancelled, s.reqDone = s.reqLSM.Run() + if s.reqUpdated { + update = &analyzer.PropUpdate{ + Type: analyzer.PropUpdateMerge, + M: analyzer.PropMap{"req": s.reqMap}, + } + s.reqUpdated = false + } + } + return update, cancelled || (s.reqDone && s.respDone) +} + +func (s *httpStream) parseRequestLine() utils.LSMAction { + // Find the end of the request line + line, ok := s.reqBuf.GetUntil([]byte("\r\n"), true, true) + if !ok { + // No end of line yet, but maybe we just need more data + return utils.LSMActionPause + } + fields := strings.Fields(string(line[:len(line)-2])) // Strip \r\n + if len(fields) != 3 { + // Invalid request line + return utils.LSMActionCancel + } + method := fields[0] + path := fields[1] + version := fields[2] + if !strings.HasPrefix(version, "HTTP/") { + // Invalid version + return utils.LSMActionCancel + } + s.reqMap = analyzer.PropMap{ + "method": method, + "path": path, + "version": version, + } + s.reqUpdated = true + return utils.LSMActionNext +} + +func (s *httpStream) parseResponseLine() utils.LSMAction { + // Find the end of the response line + line, ok := s.respBuf.GetUntil([]byte("\r\n"), true, true) + if !ok { + // No end of line yet, but maybe we just need more data + return utils.LSMActionPause + } + fields := strings.Fields(string(line[:len(line)-2])) // Strip \r\n + if len(fields) < 2 { + // Invalid response line + return utils.LSMActionCancel + } + version := fields[0] + status, _ := strconv.Atoi(fields[1]) + if !strings.HasPrefix(version, "HTTP/") || status == 0 { + // Invalid version + return utils.LSMActionCancel + } + s.respMap = analyzer.PropMap{ + "version": version, + "status": status, + } + s.respUpdated = true + return utils.LSMActionNext +} + +func (s *httpStream) parseHeaders(buf *utils.ByteBuffer) (utils.LSMAction, analyzer.PropMap) { + // Find the end of headers + headers, ok := buf.GetUntil([]byte("\r\n\r\n"), true, true) + if !ok { + // No end of headers yet, but maybe we just need more data + return utils.LSMActionPause, nil + } + headers = headers[:len(headers)-4] // Strip \r\n\r\n + headerMap := make(analyzer.PropMap) + for _, line := range bytes.Split(headers, []byte("\r\n")) { + fields := bytes.SplitN(line, []byte(":"), 2) + if len(fields) != 2 { + // Invalid header + return utils.LSMActionCancel, nil + } + key := string(bytes.TrimSpace(fields[0])) + value := string(bytes.TrimSpace(fields[1])) + // Normalize header keys to lowercase + headerMap[strings.ToLower(key)] = value + } + return utils.LSMActionNext, headerMap +} + +func (s *httpStream) parseRequestHeaders() utils.LSMAction { + action, headerMap := s.parseHeaders(s.reqBuf) + if action == utils.LSMActionNext { + s.reqMap["headers"] = headerMap + s.reqUpdated = true + } + return action +} + +func (s *httpStream) parseResponseHeaders() utils.LSMAction { + action, headerMap := s.parseHeaders(s.respBuf) + if action == utils.LSMActionNext { + s.respMap["headers"] = headerMap + s.respUpdated = true + } + return action +} + +func (s *httpStream) Close(limited bool) *analyzer.PropUpdate { + s.reqBuf.Reset() + s.respBuf.Reset() + s.reqMap = nil + s.respMap = nil + return nil +} diff --git a/analyzer/tcp/ssh.go b/analyzer/tcp/ssh.go new file mode 100644 index 0000000..a636441 --- /dev/null +++ b/analyzer/tcp/ssh.go @@ -0,0 +1,147 @@ +package tcp + +import ( + "strings" + + "github.com/apernet/OpenGFW/analyzer" + "github.com/apernet/OpenGFW/analyzer/utils" +) + +var _ analyzer.TCPAnalyzer = (*SSHAnalyzer)(nil) + +type SSHAnalyzer struct{} + +func (a *SSHAnalyzer) Name() string { + return "ssh" +} + +func (a *SSHAnalyzer) Limit() int { + return 1024 +} + +func (a *SSHAnalyzer) NewTCP(info analyzer.TCPInfo, logger analyzer.Logger) analyzer.TCPStream { + return newSSHStream(logger) +} + +type sshStream struct { + logger analyzer.Logger + + clientBuf *utils.ByteBuffer + clientMap analyzer.PropMap + clientUpdated bool + clientLSM *utils.LinearStateMachine + clientDone bool + + serverBuf *utils.ByteBuffer + serverMap analyzer.PropMap + serverUpdated bool + serverLSM *utils.LinearStateMachine + serverDone bool +} + +func newSSHStream(logger analyzer.Logger) *sshStream { + s := &sshStream{logger: logger, clientBuf: &utils.ByteBuffer{}, serverBuf: &utils.ByteBuffer{}} + s.clientLSM = utils.NewLinearStateMachine( + s.parseClientExchangeLine, + ) + s.serverLSM = utils.NewLinearStateMachine( + s.parseServerExchangeLine, + ) + return s +} + +func (s *sshStream) Feed(rev, start, end bool, skip int, data []byte) (u *analyzer.PropUpdate, done bool) { + if skip != 0 { + return nil, true + } + if len(data) == 0 { + return nil, false + } + var update *analyzer.PropUpdate + var cancelled bool + if rev { + s.serverBuf.Append(data) + s.serverUpdated = false + cancelled, s.serverDone = s.serverLSM.Run() + if s.serverUpdated { + update = &analyzer.PropUpdate{ + Type: analyzer.PropUpdateMerge, + M: analyzer.PropMap{"server": s.serverMap}, + } + s.serverUpdated = false + } + } else { + s.clientBuf.Append(data) + s.clientUpdated = false + cancelled, s.clientDone = s.clientLSM.Run() + if s.clientUpdated { + update = &analyzer.PropUpdate{ + Type: analyzer.PropUpdateMerge, + M: analyzer.PropMap{"client": s.clientMap}, + } + s.clientUpdated = false + } + } + return update, cancelled || (s.clientDone && s.serverDone) +} + +// parseExchangeLine parses the SSH Protocol Version Exchange string. +// See RFC 4253, section 4.2. +// "SSH-protoversion-softwareversion SP comments CR LF" +// The "comments" part (along with the SP) is optional. +func (s *sshStream) parseExchangeLine(buf *utils.ByteBuffer) (utils.LSMAction, analyzer.PropMap) { + // Find the end of the line + line, ok := buf.GetUntil([]byte("\r\n"), true, true) + if !ok { + // No end of line yet, but maybe we just need more data + return utils.LSMActionPause, nil + } + if !strings.HasPrefix(string(line), "SSH-") { + // Not SSH + return utils.LSMActionCancel, nil + } + fields := strings.Fields(string(line[:len(line)-2])) // Strip \r\n + if len(fields) < 1 || len(fields) > 2 { + // Invalid line + return utils.LSMActionCancel, nil + } + sshFields := strings.SplitN(fields[0], "-", 3) + if len(sshFields) != 3 { + // Invalid SSH version format + return utils.LSMActionCancel, nil + } + sMap := analyzer.PropMap{ + "protocol": sshFields[1], + "software": sshFields[2], + } + if len(fields) == 2 { + sMap["comments"] = fields[1] + } + return utils.LSMActionNext, sMap +} + +func (s *sshStream) parseClientExchangeLine() utils.LSMAction { + action, sMap := s.parseExchangeLine(s.clientBuf) + if action == utils.LSMActionNext { + s.clientMap = sMap + s.clientUpdated = true + } + return action +} + +func (s *sshStream) parseServerExchangeLine() utils.LSMAction { + action, sMap := s.parseExchangeLine(s.serverBuf) + if action == utils.LSMActionNext { + s.serverMap = sMap + s.serverUpdated = true + } + return action +} + +func (s *sshStream) Close(limited bool) *analyzer.PropUpdate { + s.clientBuf.Reset() + s.serverBuf.Reset() + s.clientMap = nil + s.serverMap = nil + return nil +} diff --git a/analyzer/tcp/tls.go b/analyzer/tcp/tls.go new file mode 100644 index 0000000..a4f62ab --- /dev/null +++ b/analyzer/tcp/tls.go @@ -0,0 +1,354 @@ +package tcp + +import ( + "github.com/apernet/OpenGFW/analyzer" + "github.com/apernet/OpenGFW/analyzer/utils" +) + +var _ analyzer.TCPAnalyzer = (*TLSAnalyzer)(nil) + +type TLSAnalyzer struct{} + +func (a *TLSAnalyzer) Name() string { + return "tls" +} + +func (a *TLSAnalyzer) Limit() int { + return 8192 +} + +func (a *TLSAnalyzer) NewTCP(info analyzer.TCPInfo, logger analyzer.Logger) analyzer.TCPStream { + return newTLSStream(logger) +} + +type tlsStream struct { + logger analyzer.Logger + + reqBuf *utils.ByteBuffer + reqMap analyzer.PropMap + reqUpdated bool + reqLSM *utils.LinearStateMachine + reqDone bool + + respBuf *utils.ByteBuffer + respMap analyzer.PropMap + respUpdated bool + respLSM *utils.LinearStateMachine + respDone bool + + clientHelloLen int + serverHelloLen int +} + +func newTLSStream(logger analyzer.Logger) *tlsStream { + s := &tlsStream{logger: logger, reqBuf: &utils.ByteBuffer{}, respBuf: &utils.ByteBuffer{}} + s.reqLSM = utils.NewLinearStateMachine( + s.tlsClientHelloSanityCheck, + s.parseClientHello, + ) + s.respLSM = utils.NewLinearStateMachine( + s.tlsServerHelloSanityCheck, + s.parseServerHello, + ) + return s +} + +func (s *tlsStream) Feed(rev, start, end bool, skip int, data []byte) (u *analyzer.PropUpdate, done bool) { + if skip != 0 { + return nil, true + } + if len(data) == 0 { + return nil, false + } + var update *analyzer.PropUpdate + var cancelled bool + if rev { + s.respBuf.Append(data) + s.respUpdated = false + cancelled, s.respDone = s.respLSM.Run() + if s.respUpdated { + update = &analyzer.PropUpdate{ + Type: analyzer.PropUpdateMerge, + M: analyzer.PropMap{"resp": s.respMap}, + } + s.respUpdated = false + } + } else { + s.reqBuf.Append(data) + s.reqUpdated = false + cancelled, s.reqDone = s.reqLSM.Run() + if s.reqUpdated { + update = &analyzer.PropUpdate{ + Type: analyzer.PropUpdateMerge, + M: analyzer.PropMap{"req": s.reqMap}, + } + s.reqUpdated = false + } + } + return update, cancelled || (s.reqDone && s.respDone) +} + +func (s *tlsStream) tlsClientHelloSanityCheck() utils.LSMAction { + data, ok := s.reqBuf.Get(9, true) + if !ok { + return utils.LSMActionPause + } + if data[0] != 0x16 || data[5] != 0x01 { + // Not a TLS handshake, or not a client hello + return utils.LSMActionCancel + } + s.clientHelloLen = int(data[6])<<16 | int(data[7])<<8 | int(data[8]) + if s.clientHelloLen < 41 { + // 2 (Protocol Version) + + // 32 (Random) + + // 1 (Session ID Length) + + // 2 (Cipher Suites Length) +_ws.col.protocol == "TLSv1.3" + // 2 (Cipher Suite) + + // 1 (Compression Methods Length) + + // 1 (Compression Method) + + // No extensions + // This should be the bare minimum for a client hello + return utils.LSMActionCancel + } + return utils.LSMActionNext +} + +func (s *tlsStream) tlsServerHelloSanityCheck() utils.LSMAction { + data, ok := s.respBuf.Get(9, true) + if !ok { + return utils.LSMActionPause + } + if data[0] != 0x16 || data[5] != 0x02 { + // Not a TLS handshake, or not a server hello + return utils.LSMActionCancel + } + s.serverHelloLen = int(data[6])<<16 | int(data[7])<<8 | int(data[8]) + if s.serverHelloLen < 38 { + // 2 (Protocol Version) + + // 32 (Random) + + // 1 (Session ID Length) + + // 2 (Cipher Suite) + + // 1 (Compression Method) + + // No extensions + // This should be the bare minimum for a server hello + return utils.LSMActionCancel + } + return utils.LSMActionNext +} + +func (s *tlsStream) parseClientHello() utils.LSMAction { + chBuf, ok := s.reqBuf.GetSubBuffer(s.clientHelloLen, true) + if !ok { + // Not a full client hello yet + return utils.LSMActionPause + } + s.reqUpdated = true + s.reqMap = make(analyzer.PropMap) + // Version, random & session ID length combined are within 35 bytes, + // so no need for bounds checking + s.reqMap["version"], _ = chBuf.GetUint16(false, true) + s.reqMap["random"], _ = chBuf.Get(32, true) + sessionIDLen, _ := chBuf.GetByte(true) + s.reqMap["session"], ok = chBuf.Get(int(sessionIDLen), true) + if !ok { + // Not enough data for session ID + return utils.LSMActionCancel + } + cipherSuitesLen, ok := chBuf.GetUint16(false, true) + if !ok { + // Not enough data for cipher suites length + return utils.LSMActionCancel + } + if cipherSuitesLen%2 != 0 { + // Cipher suites are 2 bytes each, so must be even + return utils.LSMActionCancel + } + ciphers := make([]uint16, cipherSuitesLen/2) + for i := range ciphers { + ciphers[i], ok = chBuf.GetUint16(false, true) + if !ok { + return utils.LSMActionCancel + } + } + s.reqMap["ciphers"] = ciphers + compressionMethodsLen, ok := chBuf.GetByte(true) + if !ok { + // Not enough data for compression methods length + return utils.LSMActionCancel + } + // Compression methods are 1 byte each, we just put a byte slice here + s.reqMap["compression"], ok = chBuf.Get(int(compressionMethodsLen), true) + if !ok { + // Not enough data for compression methods + return utils.LSMActionCancel + } + extsLen, ok := chBuf.GetUint16(false, true) + if !ok { + // No extensions, I guess it's possible? + return utils.LSMActionNext + } + extBuf, ok := chBuf.GetSubBuffer(int(extsLen), true) + if !ok { + // Not enough data for extensions + return utils.LSMActionCancel + } + for extBuf.Len() > 0 { + extType, ok := extBuf.GetUint16(false, true) + if !ok { + // Not enough data for extension type + return utils.LSMActionCancel + } + extLen, ok := extBuf.GetUint16(false, true) + if !ok { + // Not enough data for extension length + return utils.LSMActionCancel + } + extDataBuf, ok := extBuf.GetSubBuffer(int(extLen), true) + if !ok || !s.handleExtensions(extType, extDataBuf, s.reqMap) { + // Not enough data for extension data, or invalid extension + return utils.LSMActionCancel + } + } + return utils.LSMActionNext +} + +func (s *tlsStream) parseServerHello() utils.LSMAction { + shBuf, ok := s.respBuf.GetSubBuffer(s.serverHelloLen, true) + if !ok { + // Not a full server hello yet + return utils.LSMActionPause + } + s.respUpdated = true + s.respMap = make(analyzer.PropMap) + // Version, random & session ID length combined are within 35 bytes, + // so no need for bounds checking + s.respMap["version"], _ = shBuf.GetUint16(false, true) + s.respMap["random"], _ = shBuf.Get(32, true) + sessionIDLen, _ := shBuf.GetByte(true) + s.respMap["session"], ok = shBuf.Get(int(sessionIDLen), true) + if !ok { + // Not enough data for session ID + return utils.LSMActionCancel + } + cipherSuite, ok := shBuf.GetUint16(false, true) + if !ok { + // Not enough data for cipher suite + return utils.LSMActionCancel + } + s.respMap["cipher"] = cipherSuite + compressionMethod, ok := shBuf.GetByte(true) + if !ok { + // Not enough data for compression method + return utils.LSMActionCancel + } + s.respMap["compression"] = compressionMethod + extsLen, ok := shBuf.GetUint16(false, true) + if !ok { + // No extensions, I guess it's possible? + return utils.LSMActionNext + } + extBuf, ok := shBuf.GetSubBuffer(int(extsLen), true) + if !ok { + // Not enough data for extensions + return utils.LSMActionCancel + } + for extBuf.Len() > 0 { + extType, ok := extBuf.GetUint16(false, true) + if !ok { + // Not enough data for extension type + return utils.LSMActionCancel + } + extLen, ok := extBuf.GetUint16(false, true) + if !ok { + // Not enough data for extension length + return utils.LSMActionCancel + } + extDataBuf, ok := extBuf.GetSubBuffer(int(extLen), true) + if !ok || !s.handleExtensions(extType, extDataBuf, s.respMap) { + // Not enough data for extension data, or invalid extension + return utils.LSMActionCancel + } + } + return utils.LSMActionNext +} + +func (s *tlsStream) handleExtensions(extType uint16, extDataBuf *utils.ByteBuffer, m analyzer.PropMap) bool { + switch extType { + case 0x0000: // SNI + ok := extDataBuf.Skip(2) // Ignore list length, we only care about the first entry for now + if !ok { + // Not enough data for list length + return false + } + sniType, ok := extDataBuf.GetByte(true) + if !ok || sniType != 0 { + // Not enough data for SNI type, or not hostname + return false + } + sniLen, ok := extDataBuf.GetUint16(false, true) + if !ok { + // Not enough data for SNI length + return false + } + m["sni"], ok = extDataBuf.GetString(int(sniLen), true) + if !ok { + // Not enough data for SNI + return false + } + case 0x0010: // ALPN + ok := extDataBuf.Skip(2) // Ignore list length, as we read until the end + if !ok { + // Not enough data for list length + return false + } + var alpnList []string + for extDataBuf.Len() > 0 { + alpnLen, ok := extDataBuf.GetByte(true) + if !ok { + // Not enough data for ALPN length + return false + } + alpn, ok := extDataBuf.GetString(int(alpnLen), true) + if !ok { + // Not enough data for ALPN + return false + } + alpnList = append(alpnList, alpn) + } + m["alpn"] = alpnList + case 0x002b: // Supported Versions + if extDataBuf.Len() == 2 { + // Server only selects one version + m["supported_versions"], _ = extDataBuf.GetUint16(false, true) + } else { + // Client sends a list of versions + ok := extDataBuf.Skip(1) // Ignore list length, as we read until the end + if !ok { + // Not enough data for list length + return false + } + var versions []uint16 + for extDataBuf.Len() > 0 { + ver, ok := extDataBuf.GetUint16(false, true) + if !ok { + // Not enough data for version + return false + } + versions = append(versions, ver) + } + m["supported_versions"] = versions + } + case 0xfe0d: // ECH + // We can't parse ECH for now, just set a flag + m["ech"] = true + } + return true +} + +func (s *tlsStream) Close(limited bool) *analyzer.PropUpdate { + s.reqBuf.Reset() + s.respBuf.Reset() + s.reqMap = nil + s.respMap = nil + return nil +} diff --git a/analyzer/udp/dns.go b/analyzer/udp/dns.go new file mode 100644 index 0000000..f307a85 --- /dev/null +++ b/analyzer/udp/dns.go @@ -0,0 +1,254 @@ +package udp + +import ( + "github.com/apernet/OpenGFW/analyzer" + "github.com/apernet/OpenGFW/analyzer/utils" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" +) + +// DNSAnalyzer is for both DNS over UDP and TCP. +var ( + _ analyzer.UDPAnalyzer = (*DNSAnalyzer)(nil) + _ analyzer.TCPAnalyzer = (*DNSAnalyzer)(nil) +) + +type DNSAnalyzer struct{} + +func (a *DNSAnalyzer) Name() string { + return "dns" +} + +func (a *DNSAnalyzer) Limit() int { + // DNS is a stateless protocol, with unlimited amount + // of back-and-forth exchanges. Don't limit it here. + return 0 +} + +func (a *DNSAnalyzer) NewUDP(info analyzer.UDPInfo, logger analyzer.Logger) analyzer.UDPStream { + return &dnsUDPStream{logger: logger} +} + +func (a *DNSAnalyzer) NewTCP(info analyzer.TCPInfo, logger analyzer.Logger) analyzer.TCPStream { + s := &dnsTCPStream{logger: logger, reqBuf: &utils.ByteBuffer{}, respBuf: &utils.ByteBuffer{}} + s.reqLSM = utils.NewLinearStateMachine( + s.getReqMessageLength, + s.getReqMessage, + ) + s.respLSM = utils.NewLinearStateMachine( + s.getRespMessageLength, + s.getRespMessage, + ) + return s +} + +type dnsUDPStream struct { + logger analyzer.Logger +} + +func (s *dnsUDPStream) Feed(rev bool, data []byte) (u *analyzer.PropUpdate, done bool) { + m := parseDNSMessage(data) + if m == nil { + return nil, false + } + return &analyzer.PropUpdate{ + Type: analyzer.PropUpdateReplace, + M: m, + }, false +} + +func (s *dnsUDPStream) Close(limited bool) *analyzer.PropUpdate { + return nil +} + +type dnsTCPStream struct { + logger analyzer.Logger + + reqBuf *utils.ByteBuffer + reqMap analyzer.PropMap + reqUpdated bool + reqLSM *utils.LinearStateMachine + reqDone bool + + respBuf *utils.ByteBuffer + respMap analyzer.PropMap + respUpdated bool + respLSM *utils.LinearStateMachine + respDone bool + + reqMsgLen int + respMsgLen int +} + +func (s *dnsTCPStream) Feed(rev, start, end bool, skip int, data []byte) (u *analyzer.PropUpdate, done bool) { + if skip != 0 { + return nil, true + } + if len(data) == 0 { + return nil, false + } + var update *analyzer.PropUpdate + var cancelled bool + if rev { + s.respBuf.Append(data) + s.respUpdated = false + cancelled, s.respDone = s.respLSM.Run() + if s.respUpdated { + update = &analyzer.PropUpdate{ + Type: analyzer.PropUpdateReplace, + M: s.respMap, + } + s.respUpdated = false + } + } else { + s.reqBuf.Append(data) + s.reqUpdated = false + cancelled, s.reqDone = s.reqLSM.Run() + if s.reqUpdated { + update = &analyzer.PropUpdate{ + Type: analyzer.PropUpdateReplace, + M: s.reqMap, + } + s.reqUpdated = false + } + } + return update, cancelled || (s.reqDone && s.respDone) +} + +func (s *dnsTCPStream) Close(limited bool) *analyzer.PropUpdate { + s.reqBuf.Reset() + s.respBuf.Reset() + s.reqMap = nil + s.respMap = nil + return nil +} + +func (s *dnsTCPStream) getReqMessageLength() utils.LSMAction { + bs, ok := s.reqBuf.Get(2, true) + if !ok { + return utils.LSMActionPause + } + s.reqMsgLen = int(bs[0])<<8 | int(bs[1]) + return utils.LSMActionNext +} + +func (s *dnsTCPStream) getRespMessageLength() utils.LSMAction { + bs, ok := s.respBuf.Get(2, true) + if !ok { + return utils.LSMActionPause + } + s.respMsgLen = int(bs[0])<<8 | int(bs[1]) + return utils.LSMActionNext +} + +func (s *dnsTCPStream) getReqMessage() utils.LSMAction { + bs, ok := s.reqBuf.Get(s.reqMsgLen, true) + if !ok { + return utils.LSMActionPause + } + m := parseDNSMessage(bs) + if m == nil { + // Invalid DNS message + return utils.LSMActionCancel + } + s.reqMap = m + s.reqUpdated = true + return utils.LSMActionReset +} + +func (s *dnsTCPStream) getRespMessage() utils.LSMAction { + bs, ok := s.respBuf.Get(s.respMsgLen, true) + if !ok { + return utils.LSMActionPause + } + m := parseDNSMessage(bs) + if m == nil { + // Invalid DNS message + return utils.LSMActionCancel + } + s.respMap = m + s.respUpdated = true + return utils.LSMActionReset +} + +func parseDNSMessage(msg []byte) analyzer.PropMap { + dns := &layers.DNS{} + err := dns.DecodeFromBytes(msg, gopacket.NilDecodeFeedback) + if err != nil { + // Not a DNS packet + return nil + } + m := analyzer.PropMap{ + "id": dns.ID, + "qr": dns.QR, + "opcode": dns.OpCode, + "aa": dns.AA, + "tc": dns.TC, + "rd": dns.RD, + "ra": dns.RA, + "z": dns.Z, + "rcode": dns.ResponseCode, + } + if len(dns.Questions) > 0 { + mQuestions := make([]analyzer.PropMap, len(dns.Questions)) + for i, q := range dns.Questions { + mQuestions[i] = analyzer.PropMap{ + "name": string(q.Name), + "type": q.Type, + "class": q.Class, + } + } + m["questions"] = mQuestions + } + if len(dns.Answers) > 0 { + mAnswers := make([]analyzer.PropMap, len(dns.Answers)) + for i, rr := range dns.Answers { + mAnswers[i] = dnsRRToPropMap(rr) + } + m["answers"] = mAnswers + } + if len(dns.Authorities) > 0 { + mAuthorities := make([]analyzer.PropMap, len(dns.Authorities)) + for i, rr := range dns.Authorities { + mAuthorities[i] = dnsRRToPropMap(rr) + } + m["authorities"] = mAuthorities + } + if len(dns.Additionals) > 0 { + mAdditionals := make([]analyzer.PropMap, len(dns.Additionals)) + for i, rr := range dns.Additionals { + mAdditionals[i] = dnsRRToPropMap(rr) + } + m["additionals"] = mAdditionals + } + return m +} + +func dnsRRToPropMap(rr layers.DNSResourceRecord) analyzer.PropMap { + m := analyzer.PropMap{ + "name": string(rr.Name), + "type": rr.Type, + "class": rr.Class, + "ttl": rr.TTL, + } + switch rr.Type { + // These are not everything, but is + // all we decided to support for now. + case layers.DNSTypeA: + m["a"] = rr.IP.String() + case layers.DNSTypeAAAA: + m["aaaa"] = rr.IP.String() + case layers.DNSTypeNS: + m["ns"] = string(rr.NS) + case layers.DNSTypeCNAME: + m["cname"] = string(rr.CNAME) + case layers.DNSTypePTR: + m["ptr"] = string(rr.PTR) + case layers.DNSTypeTXT: + m["txt"] = utils.ByteSlicesToStrings(rr.TXTs) + case layers.DNSTypeMX: + m["mx"] = string(rr.MX.Name) + } + return m +} diff --git a/analyzer/utils/bytebuffer.go b/analyzer/utils/bytebuffer.go new file mode 100644 index 0000000..495e1e2 --- /dev/null +++ b/analyzer/utils/bytebuffer.go @@ -0,0 +1,99 @@ +package utils + +import "bytes" + +type ByteBuffer struct { + Buf []byte +} + +func (b *ByteBuffer) Append(data []byte) { + b.Buf = append(b.Buf, data...) +} + +func (b *ByteBuffer) Len() int { + return len(b.Buf) +} + +func (b *ByteBuffer) Index(sep []byte) int { + return bytes.Index(b.Buf, sep) +} + +func (b *ByteBuffer) Get(length int, consume bool) (data []byte, ok bool) { + if len(b.Buf) < length { + return nil, false + } + data = b.Buf[:length] + if consume { + b.Buf = b.Buf[length:] + } + return data, true +} + +func (b *ByteBuffer) GetString(length int, consume bool) (string, bool) { + data, ok := b.Get(length, consume) + if !ok { + return "", false + } + return string(data), true +} + +func (b *ByteBuffer) GetByte(consume bool) (byte, bool) { + data, ok := b.Get(1, consume) + if !ok { + return 0, false + } + return data[0], true +} + +func (b *ByteBuffer) GetUint16(littleEndian, consume bool) (uint16, bool) { + data, ok := b.Get(2, consume) + if !ok { + return 0, false + } + if littleEndian { + return uint16(data[0]) | uint16(data[1])<<8, true + } + return uint16(data[1]) | uint16(data[0])<<8, true +} + +func (b *ByteBuffer) GetUint32(littleEndian, consume bool) (uint32, bool) { + data, ok := b.Get(4, consume) + if !ok { + return 0, false + } + if littleEndian { + return uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16 | uint32(data[3])<<24, true + } + return uint32(data[3]) | uint32(data[2])<<8 | uint32(data[1])<<16 | uint32(data[0])<<24, true +} + +func (b *ByteBuffer) GetUntil(sep []byte, includeSep, consume bool) (data []byte, ok bool) { + index := b.Index(sep) + if index == -1 { + return nil, false + } + if includeSep { + index += len(sep) + } + return b.Get(index, consume) +} + +func (b *ByteBuffer) GetSubBuffer(length int, consume bool) (sub *ByteBuffer, ok bool) { + data, ok := b.Get(length, consume) + if !ok { + return nil, false + } + return &ByteBuffer{Buf: data}, true +} + +func (b *ByteBuffer) Skip(length int) bool { + if len(b.Buf) < length { + return false + } + b.Buf = b.Buf[length:] + return true +} + +func (b *ByteBuffer) Reset() { + b.Buf = nil +} diff --git a/analyzer/utils/lsm.go b/analyzer/utils/lsm.go new file mode 100644 index 0000000..3b7b0b8 --- /dev/null +++ b/analyzer/utils/lsm.go @@ -0,0 +1,50 @@ +package utils + +type LSMAction int + +const ( + LSMActionPause LSMAction = iota + LSMActionNext + LSMActionReset + LSMActionCancel +) + +type LinearStateMachine struct { + Steps []func() LSMAction + + index int + cancelled bool +} + +func NewLinearStateMachine(steps ...func() LSMAction) *LinearStateMachine { + return &LinearStateMachine{ + Steps: steps, + } +} + +// Run runs the state machine until it pauses, finishes or is cancelled. +func (lsm *LinearStateMachine) Run() (cancelled bool, done bool) { + if lsm.index >= len(lsm.Steps) { + return lsm.cancelled, true + } + for lsm.index < len(lsm.Steps) { + action := lsm.Steps[lsm.index]() + switch action { + case LSMActionPause: + return false, false + case LSMActionNext: + lsm.index++ + case LSMActionReset: + lsm.index = 0 + case LSMActionCancel: + lsm.cancelled = true + return true, true + } + } + return false, true +} + +func (lsm *LinearStateMachine) Reset() { + lsm.index = 0 + lsm.cancelled = false +} diff --git a/analyzer/utils/string.go b/analyzer/utils/string.go new file mode 100644 index 0000000..9d278fb --- /dev/null +++ b/analyzer/utils/string.go @@ -0,0 +1,9 @@ +package utils + +func ByteSlicesToStrings(bss [][]byte) []string { + ss := make([]string, len(bss)) + for i, bs := range bss { + ss[i] = string(bs) + } + return ss +} diff --git a/cmd/errors.go b/cmd/errors.go new file mode 100644 index 0000000..3d0234a --- /dev/null +++ b/cmd/errors.go @@ -0,0 +1,18 @@ +package cmd + +import ( + "fmt" +) + +type configError struct { + Field string + Err error +} + +func (e configError) Error() string { + return fmt.Sprintf("invalid config: %s: %s", e.Field, e.Err) +} + +func (e configError) Unwrap() error { + return e.Err +} diff --git a/cmd/root.go b/cmd/root.go new file mode 100644 index 0000000..d4fe10d --- /dev/null +++ b/cmd/root.go @@ -0,0 +1,381 @@ +package cmd + +import ( + "context" + "fmt" + "os" + "os/signal" + "strconv" + "strings" + + "github.com/apernet/OpenGFW/analyzer" + "github.com/apernet/OpenGFW/analyzer/tcp" + "github.com/apernet/OpenGFW/analyzer/udp" + "github.com/apernet/OpenGFW/engine" + "github.com/apernet/OpenGFW/io" + "github.com/apernet/OpenGFW/modifier" + modUDP "github.com/apernet/OpenGFW/modifier/udp" + "github.com/apernet/OpenGFW/ruleset" + + "github.com/spf13/cobra" + "github.com/spf13/viper" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +const ( + appLogo = ` +░█▀█░█▀█░█▀▀░█▀█░█▀▀░█▀▀░█░█ +░█░█░█▀▀░█▀▀░█░█░█░█░█▀▀░█▄█ +░▀▀▀░▀░░░▀▀▀░▀░▀░▀▀▀░▀░░░▀░▀ +` + appDesc = "Open source network filtering and analysis software" + appAuthors = "Aperture Internet Laboratory " + + appLogLevelEnv = "OPENGFW_LOG_LEVEL" + appLogFormatEnv = "OPENGFW_LOG_FORMAT" +) + +var logger *zap.Logger + +// Flags +var ( + cfgFile string + logLevel string + logFormat string +) + +var rootCmd = &cobra.Command{ + Use: "OpenGFW [flags] rule_file", + Short: appDesc, + Args: cobra.ExactArgs(1), + Run: runMain, +} + +var logLevelMap = map[string]zapcore.Level{ + "debug": zapcore.DebugLevel, + "info": zapcore.InfoLevel, + "warn": zapcore.WarnLevel, + "error": zapcore.ErrorLevel, +} + +var logFormatMap = map[string]zapcore.EncoderConfig{ + "console": { + TimeKey: "time", + LevelKey: "level", + NameKey: "logger", + MessageKey: "msg", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.CapitalColorLevelEncoder, + EncodeTime: zapcore.RFC3339TimeEncoder, + EncodeDuration: zapcore.SecondsDurationEncoder, + }, + "json": { + TimeKey: "time", + LevelKey: "level", + NameKey: "logger", + MessageKey: "msg", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.LowercaseLevelEncoder, + EncodeTime: zapcore.EpochMillisTimeEncoder, + EncodeDuration: zapcore.SecondsDurationEncoder, + }, +} + +// Analyzers & modifiers + +var analyzers = []analyzer.Analyzer{ + &tcp.FETAnalyzer{}, + &tcp.HTTPAnalyzer{}, + &tcp.SSHAnalyzer{}, + &tcp.TLSAnalyzer{}, + &udp.DNSAnalyzer{}, +} + +var modifiers = []modifier.Modifier{ + &modUDP.DNSModifier{}, +} + +func Execute() { + err := rootCmd.Execute() + if err != nil { + os.Exit(1) + } +} + +func init() { + initFlags() + cobra.OnInitialize(initConfig) + cobra.OnInitialize(initLogger) // initLogger must come after initConfig as it depends on config +} + +func initFlags() { + rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "config file") + rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", envOrDefaultString(appLogLevelEnv, "info"), "log level") + rootCmd.PersistentFlags().StringVarP(&logFormat, "log-format", "f", envOrDefaultString(appLogFormatEnv, "console"), "log format") +} + +func initConfig() { + if cfgFile != "" { + viper.SetConfigFile(cfgFile) + } else { + viper.SetConfigName("config") + viper.SetConfigType("yaml") + viper.SupportedExts = append([]string{"yaml", "yml"}, viper.SupportedExts...) + viper.AddConfigPath(".") + viper.AddConfigPath("$HOME/.opengfw") + viper.AddConfigPath("/etc/opengfw") + } +} + +func initLogger() { + level, ok := logLevelMap[strings.ToLower(logLevel)] + if !ok { + fmt.Printf("unsupported log level: %s\n", logLevel) + os.Exit(1) + } + enc, ok := logFormatMap[strings.ToLower(logFormat)] + if !ok { + fmt.Printf("unsupported log format: %s\n", logFormat) + os.Exit(1) + } + c := zap.Config{ + Level: zap.NewAtomicLevelAt(level), + DisableCaller: true, + DisableStacktrace: true, + Encoding: strings.ToLower(logFormat), + EncoderConfig: enc, + OutputPaths: []string{"stderr"}, + ErrorOutputPaths: []string{"stderr"}, + } + var err error + logger, err = c.Build() + if err != nil { + fmt.Printf("failed to initialize logger: %s\n", err) + os.Exit(1) + } +} + +type cliConfig struct { + IO cliConfigIO `mapstructure:"io"` + Workers cliConfigWorkers `mapstructure:"workers"` +} + +type cliConfigIO struct { + QueueSize uint32 `mapstructure:"queueSize"` + Local bool `mapstructure:"local"` +} + +type cliConfigWorkers struct { + Count int `mapstructure:"count"` + QueueSize int `mapstructure:"queueSize"` + TCPMaxBufferedPagesTotal int `mapstructure:"tcpMaxBufferedPagesTotal"` + TCPMaxBufferedPagesPerConn int `mapstructure:"tcpMaxBufferedPagesPerConn"` + UDPMaxStreams int `mapstructure:"udpMaxStreams"` +} + +func (c *cliConfig) fillLogger(config *engine.Config) error { + config.Logger = &engineLogger{} + return nil +} + +func (c *cliConfig) fillIO(config *engine.Config) error { + nfio, err := io.NewNFQueuePacketIO(io.NFQueuePacketIOConfig{ + QueueSize: c.IO.QueueSize, + Local: c.IO.Local, + }) + if err != nil { + return configError{Field: "io", Err: err} + } + config.IOs = []io.PacketIO{nfio} + return nil +} + +func (c *cliConfig) fillWorkers(config *engine.Config) error { + config.Workers = c.Workers.Count + config.WorkerQueueSize = c.Workers.QueueSize + config.WorkerTCPMaxBufferedPagesTotal = c.Workers.TCPMaxBufferedPagesTotal + config.WorkerTCPMaxBufferedPagesPerConn = c.Workers.TCPMaxBufferedPagesPerConn + config.WorkerUDPMaxStreams = c.Workers.UDPMaxStreams + return nil +} + +// Config validates the fields and returns a ready-to-use engine config. +// This does not include the ruleset. +func (c *cliConfig) Config() (*engine.Config, error) { + engineConfig := &engine.Config{} + fillers := []func(*engine.Config) error{ + c.fillLogger, + c.fillIO, + c.fillWorkers, + } + for _, f := range fillers { + if err := f(engineConfig); err != nil { + return nil, err + } + } + return engineConfig, nil +} + +func runMain(cmd *cobra.Command, args []string) { + // Config + if err := viper.ReadInConfig(); err != nil { + logger.Fatal("failed to read config", zap.Error(err)) + } + var config cliConfig + if err := viper.Unmarshal(&config); err != nil { + logger.Fatal("failed to parse config", zap.Error(err)) + } + engineConfig, err := config.Config() + if err != nil { + logger.Fatal("failed to parse config", zap.Error(err)) + } + defer func() { + // Make sure to close all IOs on exit + for _, i := range engineConfig.IOs { + _ = i.Close() + } + }() + + // Ruleset + rawRs, err := ruleset.ExprRulesFromYAML(args[0]) + if err != nil { + logger.Fatal("failed to load rules", zap.Error(err)) + } + rs, err := ruleset.CompileExprRules(rawRs, analyzers, modifiers) + if err != nil { + logger.Fatal("failed to compile rules", zap.Error(err)) + } + engineConfig.Ruleset = rs + + // Engine + en, err := engine.NewEngine(*engineConfig) + if err != nil { + logger.Fatal("failed to initialize engine", zap.Error(err)) + } + + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + sigChan := make(chan os.Signal) + signal.Notify(sigChan, os.Interrupt, os.Kill) + <-sigChan + logger.Info("shutting down gracefully...") + cancelFunc() + }() + logger.Info("engine started") + logger.Info("engine exited", zap.Error(en.Run(ctx))) +} + +type engineLogger struct{} + +func (l *engineLogger) WorkerStart(id int) { + logger.Debug("worker started", zap.Int("id", id)) +} + +func (l *engineLogger) WorkerStop(id int) { + logger.Debug("worker stopped", zap.Int("id", id)) +} + +func (l *engineLogger) TCPStreamNew(workerID int, info ruleset.StreamInfo) { + logger.Debug("new TCP stream", + zap.Int("workerID", workerID), + zap.Int64("id", info.ID), + zap.String("src", info.SrcString()), + zap.String("dst", info.DstString())) +} + +func (l *engineLogger) TCPStreamPropUpdate(info ruleset.StreamInfo, close bool) { + logger.Debug("TCP stream property update", + zap.Int64("id", info.ID), + zap.String("src", info.SrcString()), + zap.String("dst", info.DstString()), + zap.Any("props", info.Props), + zap.Bool("close", close)) +} + +func (l *engineLogger) TCPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool) { + logger.Info("TCP stream action", + zap.Int64("id", info.ID), + zap.String("src", info.SrcString()), + zap.String("dst", info.DstString()), + zap.String("action", action.String()), + zap.Bool("noMatch", noMatch)) +} + +func (l *engineLogger) UDPStreamNew(workerID int, info ruleset.StreamInfo) { + logger.Debug("new UDP stream", + zap.Int("workerID", workerID), + zap.Int64("id", info.ID), + zap.String("src", info.SrcString()), + zap.String("dst", info.DstString())) +} + +func (l *engineLogger) UDPStreamPropUpdate(info ruleset.StreamInfo, close bool) { + logger.Debug("UDP stream property update", + zap.Int64("id", info.ID), + zap.String("src", info.SrcString()), + zap.String("dst", info.DstString()), + zap.Any("props", info.Props), + zap.Bool("close", close)) +} + +func (l *engineLogger) UDPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool) { + logger.Info("UDP stream action", + zap.Int64("id", info.ID), + zap.String("src", info.SrcString()), + zap.String("dst", info.DstString()), + zap.String("action", action.String()), + zap.Bool("noMatch", noMatch)) +} + +func (l *engineLogger) MatchError(info ruleset.StreamInfo, err error) { + logger.Error("match error", + zap.Int64("id", info.ID), + zap.String("src", info.SrcString()), + zap.String("dst", info.DstString()), + zap.Error(err)) +} + +func (l *engineLogger) ModifyError(info ruleset.StreamInfo, err error) { + logger.Error("modify error", + zap.Int64("id", info.ID), + zap.String("src", info.SrcString()), + zap.String("dst", info.DstString()), + zap.Error(err)) +} + +func (l *engineLogger) AnalyzerDebugf(streamID int64, name string, format string, args ...interface{}) { + logger.Debug("analyzer debug message", + zap.Int64("id", streamID), + zap.String("name", name), + zap.String("msg", fmt.Sprintf(format, args...))) +} + +func (l *engineLogger) AnalyzerInfof(streamID int64, name string, format string, args ...interface{}) { + logger.Info("analyzer info message", + zap.Int64("id", streamID), + zap.String("name", name), + zap.String("msg", fmt.Sprintf(format, args...))) +} + +func (l *engineLogger) AnalyzerErrorf(streamID int64, name string, format string, args ...interface{}) { + logger.Error("analyzer error message", + zap.Int64("id", streamID), + zap.String("name", name), + zap.String("msg", fmt.Sprintf(format, args...))) +} + +func envOrDefaultString(key, def string) string { + if v := os.Getenv(key); v != "" { + return v + } + return def +} + +func envOrDefaultBool(key string, def bool) bool { + if v := os.Getenv(key); v != "" { + b, _ := strconv.ParseBool(v) + return b + } + return def +} diff --git a/docs/logo.png b/docs/logo.png new file mode 100644 index 0000000..36eaca0 Binary files /dev/null and b/docs/logo.png differ diff --git a/engine/engine.go b/engine/engine.go new file mode 100644 index 0000000..e8c1bfd --- /dev/null +++ b/engine/engine.go @@ -0,0 +1,119 @@ +package engine + +import ( + "context" + "runtime" + + "github.com/apernet/OpenGFW/io" + "github.com/apernet/OpenGFW/ruleset" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" +) + +var _ Engine = (*engine)(nil) + +type engine struct { + logger Logger + ioList []io.PacketIO + workers []*worker +} + +func NewEngine(config Config) (Engine, error) { + workerCount := config.Workers + if workerCount <= 0 { + workerCount = runtime.NumCPU() + } + var err error + workers := make([]*worker, workerCount) + for i := range workers { + workers[i], err = newWorker(workerConfig{ + ID: i, + ChanSize: config.WorkerQueueSize, + Logger: config.Logger, + Ruleset: config.Ruleset, + TCPMaxBufferedPagesTotal: config.WorkerTCPMaxBufferedPagesTotal, + TCPMaxBufferedPagesPerConn: config.WorkerTCPMaxBufferedPagesPerConn, + UDPMaxStreams: config.WorkerUDPMaxStreams, + }) + if err != nil { + return nil, err + } + } + return &engine{ + logger: config.Logger, + ioList: config.IOs, + workers: workers, + }, nil +} + +func (e *engine) UpdateRuleset(r ruleset.Ruleset) error { + for _, w := range e.workers { + if err := w.UpdateRuleset(r); err != nil { + return err + } + } + return nil +} + +func (e *engine) Run(ctx context.Context) error { + ioCtx, ioCancel := context.WithCancel(ctx) + defer ioCancel() // Stop workers & IOs + + // Start workers + for _, w := range e.workers { + go w.Run(ioCtx) + } + + // Register callbacks + errChan := make(chan error, len(e.ioList)) + for _, i := range e.ioList { + ioEntry := i // Make sure dispatch() uses the correct ioEntry + err := ioEntry.Register(ioCtx, func(p io.Packet, err error) bool { + if err != nil { + errChan <- err + return false + } + return e.dispatch(ioEntry, p) + }) + if err != nil { + return err + } + } + + // Block until IO errors or context is cancelled + select { + case err := <-errChan: + return err + case <-ctx.Done(): + return nil + } +} + +// dispatch dispatches a packet to a worker. +// This must be safe for concurrent use, as it may be called from multiple IOs. +func (e *engine) dispatch(ioEntry io.PacketIO, p io.Packet) bool { + data := p.Data() + ipVersion := data[0] >> 4 + var layerType gopacket.LayerType + if ipVersion == 4 { + layerType = layers.LayerTypeIPv4 + } else if ipVersion == 6 { + layerType = layers.LayerTypeIPv6 + } else { + // Unsupported network layer + _ = ioEntry.SetVerdict(p, io.VerdictAcceptStream, nil) + return true + } + // Load balance by stream ID + index := p.StreamID() % uint32(len(e.workers)) + packet := gopacket.NewPacket(data, layerType, gopacket.DecodeOptions{Lazy: true, NoCopy: true}) + e.workers[index].Feed(&workerPacket{ + StreamID: p.StreamID(), + Packet: packet, + SetVerdict: func(v io.Verdict, b []byte) error { + return ioEntry.SetVerdict(p, v, b) + }, + }) + return true +} diff --git a/engine/interface.go b/engine/interface.go new file mode 100644 index 0000000..6975834 --- /dev/null +++ b/engine/interface.go @@ -0,0 +1,50 @@ +package engine + +import ( + "context" + + "github.com/apernet/OpenGFW/io" + "github.com/apernet/OpenGFW/ruleset" +) + +// Engine is the main engine for OpenGFW. +type Engine interface { + // UpdateRuleset updates the ruleset. + UpdateRuleset(ruleset.Ruleset) error + // Run runs the engine, until an error occurs or the context is cancelled. + Run(context.Context) error +} + +// Config is the configuration for the engine. +type Config struct { + Logger Logger + IOs []io.PacketIO + Ruleset ruleset.Ruleset + + Workers int // Number of workers. Zero or negative means auto (number of CPU cores). + WorkerQueueSize int + WorkerTCPMaxBufferedPagesTotal int + WorkerTCPMaxBufferedPagesPerConn int + WorkerUDPMaxStreams int +} + +// Logger is the combined logging interface for the engine, workers and analyzers. +type Logger interface { + WorkerStart(id int) + WorkerStop(id int) + + TCPStreamNew(workerID int, info ruleset.StreamInfo) + TCPStreamPropUpdate(info ruleset.StreamInfo, close bool) + TCPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool) + + UDPStreamNew(workerID int, info ruleset.StreamInfo) + UDPStreamPropUpdate(info ruleset.StreamInfo, close bool) + UDPStreamAction(info ruleset.StreamInfo, action ruleset.Action, noMatch bool) + + MatchError(info ruleset.StreamInfo, err error) + ModifyError(info ruleset.StreamInfo, err error) + + AnalyzerDebugf(streamID int64, name string, format string, args ...interface{}) + AnalyzerInfof(streamID int64, name string, format string, args ...interface{}) + AnalyzerErrorf(streamID int64, name string, format string, args ...interface{}) +} diff --git a/engine/tcp.go b/engine/tcp.go new file mode 100644 index 0000000..067702b --- /dev/null +++ b/engine/tcp.go @@ -0,0 +1,225 @@ +package engine + +import ( + "net" + "sync" + + "github.com/apernet/OpenGFW/analyzer" + "github.com/apernet/OpenGFW/io" + "github.com/apernet/OpenGFW/ruleset" + + "github.com/bwmarrin/snowflake" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/reassembly" +) + +// tcpVerdict is a subset of io.Verdict for TCP streams. +// We don't allow modifying or dropping a single packet +// for TCP streams for now, as it doesn't make much sense. +type tcpVerdict io.Verdict + +const ( + tcpVerdictAccept = tcpVerdict(io.VerdictAccept) + tcpVerdictAcceptStream = tcpVerdict(io.VerdictAcceptStream) + tcpVerdictDropStream = tcpVerdict(io.VerdictDropStream) +) + +type tcpContext struct { + *gopacket.PacketMetadata + Verdict tcpVerdict +} + +func (ctx *tcpContext) GetCaptureInfo() gopacket.CaptureInfo { + return ctx.CaptureInfo +} + +type tcpStreamFactory struct { + WorkerID int + Logger Logger + Node *snowflake.Node + + RulesetMutex sync.RWMutex + Ruleset ruleset.Ruleset +} + +func (f *tcpStreamFactory) New(ipFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream { + id := f.Node.Generate() + ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw()) + info := ruleset.StreamInfo{ + ID: id.Int64(), + Protocol: ruleset.ProtocolTCP, + SrcIP: ipSrc, + DstIP: ipDst, + SrcPort: uint16(tcp.SrcPort), + DstPort: uint16(tcp.DstPort), + Props: make(analyzer.CombinedPropMap), + } + f.Logger.TCPStreamNew(f.WorkerID, info) + f.RulesetMutex.RLock() + rs := f.Ruleset + f.RulesetMutex.RUnlock() + ans := analyzersToTCPAnalyzers(rs.Analyzers(info)) + if len(ans) == 0 { + ctx := ac.(*tcpContext) + ctx.Verdict = tcpVerdictAcceptStream + f.Logger.TCPStreamAction(info, ruleset.ActionAllow, true) + // a tcpStream with no activeEntries is a no-op + return &tcpStream{} + } + // Create entries for each analyzer + entries := make([]*tcpStreamEntry, 0, len(ans)) + for _, a := range ans { + entries = append(entries, &tcpStreamEntry{ + Name: a.Name(), + Stream: a.NewTCP(analyzer.TCPInfo{ + SrcIP: ipSrc, + DstIP: ipDst, + SrcPort: uint16(tcp.SrcPort), + DstPort: uint16(tcp.DstPort), + }, &analyzerLogger{ + StreamID: id.Int64(), + Name: a.Name(), + Logger: f.Logger, + }), + HasLimit: a.Limit() > 0, + Quota: a.Limit(), + }) + } + return &tcpStream{ + info: info, + virgin: true, + logger: f.Logger, + ruleset: rs, + activeEntries: entries, + } +} + +func (f *tcpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error { + f.RulesetMutex.Lock() + defer f.RulesetMutex.Unlock() + f.Ruleset = r + return nil +} + +type tcpStream struct { + info ruleset.StreamInfo + virgin bool // true if no packets have been processed + logger Logger + ruleset ruleset.Ruleset + activeEntries []*tcpStreamEntry + doneEntries []*tcpStreamEntry +} + +type tcpStreamEntry struct { + Name string + Stream analyzer.TCPStream + HasLimit bool + Quota int +} + +func (s *tcpStream) Accept(tcp *layers.TCP, ci gopacket.CaptureInfo, dir reassembly.TCPFlowDirection, nextSeq reassembly.Sequence, start *bool, ac reassembly.AssemblerContext) bool { + // Only accept packets if we still have active entries + return len(s.activeEntries) > 0 +} + +func (s *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.AssemblerContext) { + dir, start, end, skip := sg.Info() + rev := dir == reassembly.TCPDirServerToClient + avail, _ := sg.Lengths() + data := sg.Fetch(avail) + updated := false + for i := len(s.activeEntries) - 1; i >= 0; i-- { + // Important: reverse order so we can remove entries + entry := s.activeEntries[i] + update, closeUpdate, done := s.feedEntry(entry, rev, start, end, skip, data) + updated = updated || processPropUpdate(s.info.Props, entry.Name, update) + updated = updated || processPropUpdate(s.info.Props, entry.Name, closeUpdate) + if done { + s.activeEntries = append(s.activeEntries[:i], s.activeEntries[i+1:]...) + s.doneEntries = append(s.doneEntries, entry) + } + } + ctx := ac.(*tcpContext) + if updated || s.virgin { + s.virgin = false + s.logger.TCPStreamPropUpdate(s.info, false) + // Match properties against ruleset + result, err := s.ruleset.Match(s.info) + if err != nil { + s.logger.MatchError(s.info, err) + } + action := result.Action + if action != ruleset.ActionMaybe && action != ruleset.ActionModify { + ctx.Verdict = actionToTCPVerdict(action) + s.logger.TCPStreamAction(s.info, action, false) + // Verdict issued, no need to process any more packets + s.closeActiveEntries() + } + } + if len(s.activeEntries) == 0 && ctx.Verdict == tcpVerdictAccept { + // All entries are done but no verdict issued, accept stream + ctx.Verdict = tcpVerdictAcceptStream + s.logger.TCPStreamAction(s.info, ruleset.ActionAllow, true) + } +} + +func (s *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool { + s.closeActiveEntries() + return true +} + +func (s *tcpStream) closeActiveEntries() { + // Signal close to all active entries & move them to doneEntries + updated := false + for _, entry := range s.activeEntries { + update := entry.Stream.Close(false) + updated = updated || processPropUpdate(s.info.Props, entry.Name, update) + } + if updated { + s.logger.TCPStreamPropUpdate(s.info, true) + } + s.doneEntries = append(s.doneEntries, s.activeEntries...) + s.activeEntries = nil +} + +func (s *tcpStream) feedEntry(entry *tcpStreamEntry, rev, start, end bool, skip int, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) { + if !entry.HasLimit { + update, done = entry.Stream.Feed(rev, start, end, skip, data) + } else { + qData := data + if len(qData) > entry.Quota { + qData = qData[:entry.Quota] + } + update, done = entry.Stream.Feed(rev, start, end, skip, qData) + entry.Quota -= len(qData) + if entry.Quota <= 0 { + // Quota exhausted, signal close & move to doneEntries + closeUpdate = entry.Stream.Close(true) + done = true + } + } + return +} + +func analyzersToTCPAnalyzers(ans []analyzer.Analyzer) []analyzer.TCPAnalyzer { + tcpAns := make([]analyzer.TCPAnalyzer, 0, len(ans)) + for _, a := range ans { + if tcpM, ok := a.(analyzer.TCPAnalyzer); ok { + tcpAns = append(tcpAns, tcpM) + } + } + return tcpAns +} + +func actionToTCPVerdict(a ruleset.Action) tcpVerdict { + switch a { + case ruleset.ActionMaybe, ruleset.ActionAllow, ruleset.ActionModify: + return tcpVerdictAcceptStream + case ruleset.ActionBlock, ruleset.ActionDrop: + return tcpVerdictDropStream + default: + // Should never happen + return tcpVerdictAcceptStream + } +} diff --git a/engine/udp.go b/engine/udp.go new file mode 100644 index 0000000..a5bf7d3 --- /dev/null +++ b/engine/udp.go @@ -0,0 +1,295 @@ +package engine + +import ( + "errors" + "net" + "sync" + + "github.com/apernet/OpenGFW/analyzer" + "github.com/apernet/OpenGFW/io" + "github.com/apernet/OpenGFW/modifier" + "github.com/apernet/OpenGFW/ruleset" + + "github.com/bwmarrin/snowflake" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + lru "github.com/hashicorp/golang-lru/v2" +) + +// udpVerdict is a subset of io.Verdict for UDP streams. +// For UDP, we support all verdicts. +type udpVerdict io.Verdict + +const ( + udpVerdictAccept = udpVerdict(io.VerdictAccept) + udpVerdictAcceptModify = udpVerdict(io.VerdictAcceptModify) + udpVerdictAcceptStream = udpVerdict(io.VerdictAcceptStream) + udpVerdictDrop = udpVerdict(io.VerdictDrop) + udpVerdictDropStream = udpVerdict(io.VerdictDropStream) +) + +var errInvalidModifier = errors.New("invalid modifier") + +type udpContext struct { + Verdict udpVerdict + Packet []byte +} + +type udpStreamFactory struct { + WorkerID int + Logger Logger + Node *snowflake.Node + + RulesetMutex sync.RWMutex + Ruleset ruleset.Ruleset +} + +func (f *udpStreamFactory) New(ipFlow, udpFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) *udpStream { + id := f.Node.Generate() + ipSrc, ipDst := net.IP(ipFlow.Src().Raw()), net.IP(ipFlow.Dst().Raw()) + info := ruleset.StreamInfo{ + ID: id.Int64(), + Protocol: ruleset.ProtocolUDP, + SrcIP: ipSrc, + DstIP: ipDst, + SrcPort: uint16(udp.SrcPort), + DstPort: uint16(udp.DstPort), + Props: make(analyzer.CombinedPropMap), + } + f.Logger.UDPStreamNew(f.WorkerID, info) + f.RulesetMutex.RLock() + rs := f.Ruleset + f.RulesetMutex.RUnlock() + ans := analyzersToUDPAnalyzers(rs.Analyzers(info)) + if len(ans) == 0 { + uc.Verdict = udpVerdictAcceptStream + f.Logger.UDPStreamAction(info, ruleset.ActionAllow, true) + // a udpStream with no activeEntries is a no-op + return &udpStream{} + } + // Create entries for each analyzer + entries := make([]*udpStreamEntry, 0, len(ans)) + for _, a := range ans { + entries = append(entries, &udpStreamEntry{ + Name: a.Name(), + Stream: a.NewUDP(analyzer.UDPInfo{ + SrcIP: ipSrc, + DstIP: ipDst, + SrcPort: uint16(udp.SrcPort), + DstPort: uint16(udp.DstPort), + }, &analyzerLogger{ + StreamID: id.Int64(), + Name: a.Name(), + Logger: f.Logger, + }), + HasLimit: a.Limit() > 0, + Quota: a.Limit(), + }) + } + return &udpStream{ + info: info, + virgin: true, + logger: f.Logger, + ruleset: rs, + activeEntries: entries, + } +} + +func (f *udpStreamFactory) UpdateRuleset(r ruleset.Ruleset) error { + f.RulesetMutex.Lock() + defer f.RulesetMutex.Unlock() + f.Ruleset = r + return nil +} + +type udpStreamManager struct { + factory *udpStreamFactory + streams *lru.Cache[uint32, *udpStreamValue] +} + +type udpStreamValue struct { + Stream *udpStream + IPFlow gopacket.Flow + UDPFlow gopacket.Flow +} + +func (v *udpStreamValue) Match(ipFlow, udpFlow gopacket.Flow) (ok, rev bool) { + fwd := v.IPFlow == ipFlow && v.UDPFlow == udpFlow + rev = v.IPFlow == ipFlow.Reverse() && v.UDPFlow == udpFlow.Reverse() + return fwd || rev, rev +} + +func newUDPStreamManager(factory *udpStreamFactory, maxStreams int) (*udpStreamManager, error) { + ss, err := lru.New[uint32, *udpStreamValue](maxStreams) + if err != nil { + return nil, err + } + return &udpStreamManager{ + factory: factory, + streams: ss, + }, nil +} + +func (m *udpStreamManager) MatchWithContext(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP, uc *udpContext) { + rev := false + value, ok := m.streams.Get(streamID) + if !ok { + // New stream + value = &udpStreamValue{ + Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc), + IPFlow: ipFlow, + UDPFlow: udp.TransportFlow(), + } + m.streams.Add(streamID, value) + } else { + // Stream ID exists, but is it really the same stream? + ok, rev = value.Match(ipFlow, udp.TransportFlow()) + if !ok { + // It's not - close the old stream & replace it with a new one + value.Stream.Close() + value = &udpStreamValue{ + Stream: m.factory.New(ipFlow, udp.TransportFlow(), udp, uc), + IPFlow: ipFlow, + UDPFlow: udp.TransportFlow(), + } + m.streams.Add(streamID, value) + } + } + if value.Stream.Accept(udp, rev, uc) { + value.Stream.Feed(udp, rev, uc) + } +} + +type udpStream struct { + info ruleset.StreamInfo + virgin bool // true if no packets have been processed + logger Logger + ruleset ruleset.Ruleset + activeEntries []*udpStreamEntry + doneEntries []*udpStreamEntry +} + +type udpStreamEntry struct { + Name string + Stream analyzer.UDPStream + HasLimit bool + Quota int +} + +func (s *udpStream) Accept(udp *layers.UDP, rev bool, uc *udpContext) bool { + // Only accept packets if we still have active entries + return len(s.activeEntries) > 0 +} + +func (s *udpStream) Feed(udp *layers.UDP, rev bool, uc *udpContext) { + updated := false + for i := len(s.activeEntries) - 1; i >= 0; i-- { + // Important: reverse order so we can remove entries + entry := s.activeEntries[i] + update, closeUpdate, done := s.feedEntry(entry, rev, udp.Payload) + updated = updated || processPropUpdate(s.info.Props, entry.Name, update) + updated = updated || processPropUpdate(s.info.Props, entry.Name, closeUpdate) + if done { + s.activeEntries = append(s.activeEntries[:i], s.activeEntries[i+1:]...) + s.doneEntries = append(s.doneEntries, entry) + } + } + if updated || s.virgin { + s.virgin = false + s.logger.UDPStreamPropUpdate(s.info, false) + // Match properties against ruleset + result, err := s.ruleset.Match(s.info) + if err != nil { + s.logger.MatchError(s.info, err) + } + action := result.Action + if action == ruleset.ActionModify { + // Call the modifier instance + udpMI, ok := result.ModInstance.(modifier.UDPModifierInstance) + if !ok { + // Not for UDP, fallback to maybe + s.logger.ModifyError(s.info, errInvalidModifier) + action = ruleset.ActionMaybe + } else { + uc.Packet, err = udpMI.Process(udp.Payload) + if err != nil { + // Modifier error, fallback to maybe + s.logger.ModifyError(s.info, err) + action = ruleset.ActionMaybe + } + } + } + if action != ruleset.ActionMaybe { + var final bool + uc.Verdict, final = actionToUDPVerdict(action) + s.logger.UDPStreamAction(s.info, action, false) + if final { + s.closeActiveEntries() + } + } + } + if len(s.activeEntries) == 0 && uc.Verdict == udpVerdictAccept { + // All entries are done but no verdict issued, accept stream + uc.Verdict = udpVerdictAcceptStream + s.logger.UDPStreamAction(s.info, ruleset.ActionAllow, true) + } +} + +func (s *udpStream) Close() { + s.closeActiveEntries() +} + +func (s *udpStream) closeActiveEntries() { + // Signal close to all active entries & move them to doneEntries + updated := false + for _, entry := range s.activeEntries { + update := entry.Stream.Close(false) + updated = updated || processPropUpdate(s.info.Props, entry.Name, update) + } + if updated { + s.logger.UDPStreamPropUpdate(s.info, true) + } + s.doneEntries = append(s.doneEntries, s.activeEntries...) + s.activeEntries = nil +} + +func (s *udpStream) feedEntry(entry *udpStreamEntry, rev bool, data []byte) (update *analyzer.PropUpdate, closeUpdate *analyzer.PropUpdate, done bool) { + update, done = entry.Stream.Feed(rev, data) + if entry.HasLimit { + entry.Quota -= len(data) + if entry.Quota <= 0 { + // Quota exhausted, signal close & move to doneEntries + closeUpdate = entry.Stream.Close(true) + done = true + } + } + return +} + +func analyzersToUDPAnalyzers(ans []analyzer.Analyzer) []analyzer.UDPAnalyzer { + udpAns := make([]analyzer.UDPAnalyzer, 0, len(ans)) + for _, a := range ans { + if udpM, ok := a.(analyzer.UDPAnalyzer); ok { + udpAns = append(udpAns, udpM) + } + } + return udpAns +} + +func actionToUDPVerdict(a ruleset.Action) (v udpVerdict, final bool) { + switch a { + case ruleset.ActionMaybe: + return udpVerdictAccept, false + case ruleset.ActionAllow: + return udpVerdictAcceptStream, true + case ruleset.ActionBlock: + return udpVerdictDropStream, true + case ruleset.ActionDrop: + return udpVerdictDrop, false + case ruleset.ActionModify: + return udpVerdictAcceptModify, false + default: + // Should never happen + return udpVerdictAccept, false + } +} diff --git a/engine/utils.go b/engine/utils.go new file mode 100644 index 0000000..41098ac --- /dev/null +++ b/engine/utils.go @@ -0,0 +1,50 @@ +package engine + +import "github.com/apernet/OpenGFW/analyzer" + +var _ analyzer.Logger = (*analyzerLogger)(nil) + +type analyzerLogger struct { + StreamID int64 + Name string + Logger Logger +} + +func (l *analyzerLogger) Debugf(format string, args ...interface{}) { + l.Logger.AnalyzerDebugf(l.StreamID, l.Name, format, args...) +} + +func (l *analyzerLogger) Infof(format string, args ...interface{}) { + l.Logger.AnalyzerInfof(l.StreamID, l.Name, format, args...) +} + +func (l *analyzerLogger) Errorf(format string, args ...interface{}) { + l.Logger.AnalyzerErrorf(l.StreamID, l.Name, format, args...) +} + +func processPropUpdate(cpm analyzer.CombinedPropMap, name string, update *analyzer.PropUpdate) (updated bool) { + if update == nil || update.Type == analyzer.PropUpdateNone { + return false + } + switch update.Type { + case analyzer.PropUpdateMerge: + m := cpm[name] + if m == nil { + m = make(analyzer.PropMap) + cpm[name] = m + } + for k, v := range update.M { + m[k] = v + } + return true + case analyzer.PropUpdateReplace: + cpm[name] = update.M + return true + case analyzer.PropUpdateDelete: + delete(cpm, name) + return true + default: + // Invalid update type, ignore for now + return false + } +} diff --git a/engine/worker.go b/engine/worker.go new file mode 100644 index 0000000..2bca8e0 --- /dev/null +++ b/engine/worker.go @@ -0,0 +1,182 @@ +package engine + +import ( + "context" + + "github.com/apernet/OpenGFW/io" + "github.com/apernet/OpenGFW/ruleset" + + "github.com/bwmarrin/snowflake" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/reassembly" +) + +const ( + defaultChanSize = 64 + defaultTCPMaxBufferedPagesTotal = 4096 + defaultTCPMaxBufferedPagesPerConnection = 64 + defaultUDPMaxStreams = 4096 +) + +type workerPacket struct { + StreamID uint32 + Packet gopacket.Packet + SetVerdict func(io.Verdict, []byte) error +} + +type worker struct { + id int + packetChan chan *workerPacket + logger Logger + + tcpStreamFactory *tcpStreamFactory + tcpStreamPool *reassembly.StreamPool + tcpAssembler *reassembly.Assembler + + udpStreamFactory *udpStreamFactory + udpStreamManager *udpStreamManager + + modSerializeBuffer gopacket.SerializeBuffer +} + +type workerConfig struct { + ID int + ChanSize int + Logger Logger + Ruleset ruleset.Ruleset + TCPMaxBufferedPagesTotal int + TCPMaxBufferedPagesPerConn int + UDPMaxStreams int +} + +func (c *workerConfig) fillDefaults() { + if c.ChanSize <= 0 { + c.ChanSize = defaultChanSize + } + if c.TCPMaxBufferedPagesTotal <= 0 { + c.TCPMaxBufferedPagesTotal = defaultTCPMaxBufferedPagesTotal + } + if c.TCPMaxBufferedPagesPerConn <= 0 { + c.TCPMaxBufferedPagesPerConn = defaultTCPMaxBufferedPagesPerConnection + } + if c.UDPMaxStreams <= 0 { + c.UDPMaxStreams = defaultUDPMaxStreams + } +} + +func newWorker(config workerConfig) (*worker, error) { + config.fillDefaults() + sfNode, err := snowflake.NewNode(int64(config.ID)) + if err != nil { + return nil, err + } + tcpSF := &tcpStreamFactory{ + WorkerID: config.ID, + Logger: config.Logger, + Node: sfNode, + Ruleset: config.Ruleset, + } + tcpStreamPool := reassembly.NewStreamPool(tcpSF) + tcpAssembler := reassembly.NewAssembler(tcpStreamPool) + tcpAssembler.MaxBufferedPagesTotal = config.TCPMaxBufferedPagesTotal + tcpAssembler.MaxBufferedPagesPerConnection = config.TCPMaxBufferedPagesPerConn + udpSF := &udpStreamFactory{ + WorkerID: config.ID, + Logger: config.Logger, + Node: sfNode, + Ruleset: config.Ruleset, + } + udpSM, err := newUDPStreamManager(udpSF, config.UDPMaxStreams) + if err != nil { + return nil, err + } + return &worker{ + id: config.ID, + packetChan: make(chan *workerPacket, config.ChanSize), + logger: config.Logger, + tcpStreamFactory: tcpSF, + tcpStreamPool: tcpStreamPool, + tcpAssembler: tcpAssembler, + udpStreamFactory: udpSF, + udpStreamManager: udpSM, + modSerializeBuffer: gopacket.NewSerializeBuffer(), + }, nil +} + +func (w *worker) Feed(p *workerPacket) { + w.packetChan <- p +} + +func (w *worker) Run(ctx context.Context) { + w.logger.WorkerStart(w.id) + defer w.logger.WorkerStop(w.id) + for { + select { + case <-ctx.Done(): + return + case wPkt := <-w.packetChan: + if wPkt == nil { + // Closed + return + } + v, b := w.handle(wPkt.StreamID, wPkt.Packet) + _ = wPkt.SetVerdict(v, b) + } + } +} + +func (w *worker) UpdateRuleset(r ruleset.Ruleset) error { + return w.tcpStreamFactory.UpdateRuleset(r) +} + +func (w *worker) handle(streamID uint32, p gopacket.Packet) (io.Verdict, []byte) { + netLayer, trLayer := p.NetworkLayer(), p.TransportLayer() + if netLayer == nil || trLayer == nil { + // Invalid packet + return io.VerdictAccept, nil + } + ipFlow := netLayer.NetworkFlow() + switch tr := trLayer.(type) { + case *layers.TCP: + return w.handleTCP(ipFlow, p.Metadata(), tr), nil + case *layers.UDP: + v, modPayload := w.handleUDP(streamID, ipFlow, tr) + if v == io.VerdictAcceptModify && modPayload != nil { + tr.Payload = modPayload + _ = tr.SetNetworkLayerForChecksum(netLayer) + _ = w.modSerializeBuffer.Clear() + err := gopacket.SerializePacket(w.modSerializeBuffer, + gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + }, p) + if err != nil { + // Just accept without modification for now + return io.VerdictAccept, nil + } + return v, w.modSerializeBuffer.Bytes() + } + return v, nil + default: + // Unsupported protocol + return io.VerdictAccept, nil + } +} + +func (w *worker) handleTCP(ipFlow gopacket.Flow, pMeta *gopacket.PacketMetadata, tcp *layers.TCP) io.Verdict { + ctx := &tcpContext{ + PacketMetadata: pMeta, + Verdict: tcpVerdictAccept, + } + w.tcpAssembler.AssembleWithContext(ipFlow, tcp, ctx) + return io.Verdict(ctx.Verdict) +} + +func (w *worker) handleUDP(streamID uint32, ipFlow gopacket.Flow, udp *layers.UDP) (io.Verdict, []byte) { + ctx := &udpContext{ + Verdict: udpVerdictAccept, + } + w.udpStreamManager.MatchWithContext(streamID, ipFlow, udp, ctx) + return io.Verdict(ctx.Verdict), ctx.Packet +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..605d503 --- /dev/null +++ b/go.mod @@ -0,0 +1,43 @@ +module github.com/apernet/OpenGFW + +go 1.20 + +require ( + github.com/bwmarrin/snowflake v0.3.0 + github.com/coreos/go-iptables v0.7.0 + github.com/expr-lang/expr v1.15.7 + github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf + github.com/google/gopacket v1.1.19 + github.com/hashicorp/golang-lru/v2 v2.0.7 + github.com/mdlayher/netlink v1.6.0 + github.com/spf13/cobra v1.8.0 + github.com/spf13/viper v1.18.2 + go.uber.org/zap v1.26.0 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/google/go-cmp v0.5.9 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/josharian/native v1.0.0 // indirect + github.com/magiconair/properties v1.8.7 // indirect + github.com/mdlayher/socket v0.1.1 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/pelletier/go-toml/v2 v2.1.0 // indirect + github.com/sagikazarmark/locafero v0.4.0 // indirect + github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.11.0 // indirect + github.com/spf13/cast v1.6.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + go.uber.org/multierr v1.11.0 // indirect + golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect + golang.org/x/net v0.19.0 // indirect + golang.org/x/sync v0.5.0 // indirect + golang.org/x/sys v0.15.0 // indirect + golang.org/x/text v0.14.0 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..21e9167 --- /dev/null +++ b/go.sum @@ -0,0 +1,120 @@ +github.com/bwmarrin/snowflake v0.3.0 h1:xm67bEhkKh6ij1790JB83OujPR5CzNe8QuQqAgISZN0= +github.com/bwmarrin/snowflake v0.3.0/go.mod h1:NdZxfVWX+oR6y2K0o6qAYv6gIOP9rjG0/E9WsDpxqwE= +github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8= +github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= +github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/expr-lang/expr v1.15.7 h1:BK0JcWUkoW6nrbLBo6xCKhz4BvH5DSOOu1Gx5lucyZo= +github.com/expr-lang/expr v1.15.7/go.mod h1:uCkhfG+x7fcZ5A5sXHKuQ07jGZRl6J0FCAaf2k4PtVQ= +github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf h1:NqGS3vTHzVENbIfd87cXZwdpO6MB2R1PjHMJLi4Z3ow= +github.com/florianl/go-nfqueue v1.3.2-0.20231218173729-f2bdeb033acf/go.mod h1:eSnAor2YCfMCVYrVNEhkLGN/r1L+J4uDjc0EUy0tfq4= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/josharian/native v1.0.0 h1:Ts/E8zCSEsG17dUqv7joXJFybuMLjQfWE04tsBODTxk= +github.com/josharian/native v1.0.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= +github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mdlayher/netlink v1.6.0 h1:rOHX5yl7qnlpiVkFWoqccueppMtXzeziFjWAjLg6sz0= +github.com/mdlayher/netlink v1.6.0/go.mod h1:0o3PlBmGst1xve7wQ7j/hwpNaFaH4qCRyWCdcZk8/vA= +github.com/mdlayher/socket v0.1.1 h1:q3uOGirUPfAV2MUoaC7BavjQ154J7+JOkTWyiV+intI= +github.com/mdlayher/socket v0.1.1/go.mod h1:mYV5YIZAfHh4dzDVzI8x8tWLWCliuX8Mon5Awbj+qDs= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4= +github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= +github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= +github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= +github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= +github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= +github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= +github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= +github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMVB+yk= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo= +go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210928044308-7d9f5e0b762b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/io/interface.go b/io/interface.go new file mode 100644 index 0000000..6f25df1 --- /dev/null +++ b/io/interface.go @@ -0,0 +1,52 @@ +package io + +import ( + "context" +) + +type Verdict int + +const ( + // VerdictAccept accepts the packet, but continues to process the stream. + VerdictAccept Verdict = iota + // VerdictAcceptModify is like VerdictAccept, but replaces the packet with a new one. + VerdictAcceptModify + // VerdictAcceptStream accepts the packet and stops processing the stream. + VerdictAcceptStream + // VerdictDrop drops the packet, but does not block the stream. + VerdictDrop + // VerdictDropStream drops the packet and blocks the stream. + VerdictDropStream +) + +// Packet represents an IP packet. +type Packet interface { + // StreamID is the ID of the stream the packet belongs to. + StreamID() uint32 + // Data is the raw packet data, starting with the IP header. + Data() []byte +} + +// PacketCallback is called for each packet received. +// Return false to "unregister" and stop receiving packets. +// It must be safe for concurrent use. +type PacketCallback func(Packet, error) bool + +type PacketIO interface { + // Register registers a callback to be called for each packet received. + // The callback should be called in one or more separate goroutines, + // and stop when the context is cancelled. + Register(context.Context, PacketCallback) error + // SetVerdict sets the verdict for a packet. + SetVerdict(Packet, Verdict, []byte) error + // Close closes the packet IO. + Close() error +} + +type ErrInvalidPacket struct { + Err error +} + +func (e *ErrInvalidPacket) Error() string { + return "invalid packet: " + e.Err.Error() +} diff --git a/io/nfqueue.go b/io/nfqueue.go new file mode 100644 index 0000000..b3102f6 --- /dev/null +++ b/io/nfqueue.go @@ -0,0 +1,225 @@ +package io + +import ( + "context" + "encoding/binary" + "errors" + "strconv" + + "github.com/coreos/go-iptables/iptables" + "github.com/florianl/go-nfqueue" + "github.com/mdlayher/netlink" +) + +const ( + nfqueueNum = 100 + nfqueueMaxPacketLen = 0xFFFF + nfqueueDefaultQueueSize = 128 + + nfqueueConnMarkAccept = 1001 + nfqueueConnMarkDrop = 1002 +) + +var iptRulesForward = []iptRule{ + {"filter", "FORWARD", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}}, + {"filter", "FORWARD", []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}}, + {"filter", "FORWARD", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}}, + {"filter", "FORWARD", []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}}, +} + +var iptRulesLocal = []iptRule{ + {"filter", "INPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}}, + {"filter", "INPUT", []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}}, + {"filter", "INPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}}, + {"filter", "INPUT", []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}}, + + {"filter", "OUTPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkAccept), "-j", "ACCEPT"}}, + {"filter", "OUTPUT", []string{"-p", "tcp", "-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "REJECT", "--reject-with", "tcp-reset"}}, + {"filter", "OUTPUT", []string{"-m", "connmark", "--mark", strconv.Itoa(nfqueueConnMarkDrop), "-j", "DROP"}}, + {"filter", "OUTPUT", []string{"-j", "NFQUEUE", "--queue-num", strconv.Itoa(nfqueueNum), "--queue-bypass"}}, +} + +var _ PacketIO = (*nfqueuePacketIO)(nil) + +var errNotNFQueuePacket = errors.New("not an NFQueue packet") + +type nfqueuePacketIO struct { + n *nfqueue.Nfqueue + local bool + ipt4 *iptables.IPTables + ipt6 *iptables.IPTables +} + +type NFQueuePacketIOConfig struct { + QueueSize uint32 + Local bool +} + +func NewNFQueuePacketIO(config NFQueuePacketIOConfig) (PacketIO, error) { + if config.QueueSize == 0 { + config.QueueSize = nfqueueDefaultQueueSize + } + ipt4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + if err != nil { + return nil, err + } + ipt6, err := iptables.NewWithProtocol(iptables.ProtocolIPv6) + if err != nil { + return nil, err + } + n, err := nfqueue.Open(&nfqueue.Config{ + NfQueue: nfqueueNum, + MaxPacketLen: nfqueueMaxPacketLen, + MaxQueueLen: config.QueueSize, + Copymode: nfqueue.NfQnlCopyPacket, + Flags: nfqueue.NfQaCfgFlagConntrack, + }) + if err != nil { + return nil, err + } + io := &nfqueuePacketIO{ + n: n, + local: config.Local, + ipt4: ipt4, + ipt6: ipt6, + } + err = io.setupIpt(config.Local, false) + if err != nil { + _ = n.Close() + return nil, err + } + return io, nil +} + +func (n *nfqueuePacketIO) Register(ctx context.Context, cb PacketCallback) error { + return n.n.RegisterWithErrorFunc(ctx, + func(a nfqueue.Attribute) int { + if a.PacketID == nil || a.Ct == nil || a.Payload == nil || len(*a.Payload) < 20 { + // Invalid packet, ignore + // 20 is the minimum possible size of an IP packet + return 0 + } + p := &nfqueuePacket{ + id: *a.PacketID, + streamID: ctIDFromCtBytes(*a.Ct), + data: *a.Payload, + } + return okBoolToInt(cb(p, nil)) + }, + func(e error) int { + return okBoolToInt(cb(nil, e)) + }) +} + +func (n *nfqueuePacketIO) SetVerdict(p Packet, v Verdict, newPacket []byte) error { + nP, ok := p.(*nfqueuePacket) + if !ok { + return &ErrInvalidPacket{Err: errNotNFQueuePacket} + } + switch v { + case VerdictAccept: + return n.n.SetVerdict(nP.id, nfqueue.NfAccept) + case VerdictAcceptModify: + return n.n.SetVerdictModPacket(nP.id, nfqueue.NfAccept, newPacket) + case VerdictAcceptStream: + return n.n.SetVerdictWithConnMark(nP.id, nfqueue.NfAccept, nfqueueConnMarkAccept) + case VerdictDrop: + return n.n.SetVerdict(nP.id, nfqueue.NfDrop) + case VerdictDropStream: + return n.n.SetVerdictWithConnMark(nP.id, nfqueue.NfDrop, nfqueueConnMarkDrop) + default: + // Invalid verdict, ignore for now + return nil + } +} + +func (n *nfqueuePacketIO) setupIpt(local, remove bool) error { + var rules []iptRule + if local { + rules = iptRulesLocal + } else { + rules = iptRulesForward + } + var err error + if remove { + err = iptsBatchDeleteIfExists([]*iptables.IPTables{n.ipt4, n.ipt6}, rules) + } else { + err = iptsBatchAppendUnique([]*iptables.IPTables{n.ipt4, n.ipt6}, rules) + } + if err != nil { + return err + } + return nil +} + +func (n *nfqueuePacketIO) Close() error { + err := n.setupIpt(n.local, true) + _ = n.n.Close() + return err +} + +var _ Packet = (*nfqueuePacket)(nil) + +type nfqueuePacket struct { + id uint32 + streamID uint32 + data []byte +} + +func (p *nfqueuePacket) StreamID() uint32 { + return p.streamID +} + +func (p *nfqueuePacket) Data() []byte { + return p.data +} + +func okBoolToInt(ok bool) int { + if ok { + return 0 + } else { + return 1 + } +} + +type iptRule struct { + Table, Chain string + RuleSpec []string +} + +func iptsBatchAppendUnique(ipts []*iptables.IPTables, rules []iptRule) error { + for _, r := range rules { + for _, ipt := range ipts { + err := ipt.AppendUnique(r.Table, r.Chain, r.RuleSpec...) + if err != nil { + return err + } + } + } + return nil +} + +func iptsBatchDeleteIfExists(ipts []*iptables.IPTables, rules []iptRule) error { + for _, r := range rules { + for _, ipt := range ipts { + err := ipt.DeleteIfExists(r.Table, r.Chain, r.RuleSpec...) + if err != nil { + return err + } + } + } + return nil +} + +func ctIDFromCtBytes(ct []byte) uint32 { + ctAttrs, err := netlink.UnmarshalAttributes(ct) + if err != nil { + return 0 + } + for _, attr := range ctAttrs { + if attr.Type == 12 { // CTA_ID + return binary.BigEndian.Uint32(attr.Data) + } + } + return 0 +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..b5b7915 --- /dev/null +++ b/main.go @@ -0,0 +1,7 @@ +package main + +import "github.com/apernet/OpenGFW/cmd" + +func main() { + cmd.Execute() +} diff --git a/modifier/interface.go b/modifier/interface.go new file mode 100644 index 0000000..0340a87 --- /dev/null +++ b/modifier/interface.go @@ -0,0 +1,32 @@ +package modifier + +type Modifier interface { + // Name returns the name of the modifier. + Name() string + // New returns a new modifier instance. + New(args map[string]interface{}) (Instance, error) +} + +type Instance interface{} + +type UDPModifierInstance interface { + Instance + // Process takes a UDP packet and returns a modified UDP packet. + Process(data []byte) ([]byte, error) +} + +type ErrInvalidPacket struct { + Err error +} + +func (e *ErrInvalidPacket) Error() string { + return "invalid packet: " + e.Err.Error() +} + +type ErrInvalidArgs struct { + Err error +} + +func (e *ErrInvalidArgs) Error() string { + return "invalid args: " + e.Err.Error() +} diff --git a/modifier/udp/dns.go b/modifier/udp/dns.go new file mode 100644 index 0000000..afab276 --- /dev/null +++ b/modifier/udp/dns.go @@ -0,0 +1,96 @@ +package udp + +import ( + "errors" + "net" + + "github.com/apernet/OpenGFW/modifier" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" +) + +var _ modifier.Modifier = (*DNSModifier)(nil) + +var ( + errInvalidIP = errors.New("invalid ip") + errNotValidDNSResponse = errors.New("not a valid dns response") + errEmptyDNSQuestion = errors.New("empty dns question") +) + +type DNSModifier struct{} + +func (m *DNSModifier) Name() string { + return "dns" +} + +func (m *DNSModifier) New(args map[string]interface{}) (modifier.Instance, error) { + i := &dnsModifierInstance{} + aStr, ok := args["a"].(string) + if ok { + a := net.ParseIP(aStr).To4() + if a == nil { + return nil, &modifier.ErrInvalidArgs{Err: errInvalidIP} + } + i.A = a + } + aaaaStr, ok := args["aaaa"].(string) + if ok { + aaaa := net.ParseIP(aaaaStr).To16() + if aaaa == nil { + return nil, &modifier.ErrInvalidArgs{Err: errInvalidIP} + } + i.AAAA = aaaa + } + return i, nil +} + +var _ modifier.UDPModifierInstance = (*dnsModifierInstance)(nil) + +type dnsModifierInstance struct { + A net.IP + AAAA net.IP +} + +func (i *dnsModifierInstance) Process(data []byte) ([]byte, error) { + dns := &layers.DNS{} + err := dns.DecodeFromBytes(data, gopacket.NilDecodeFeedback) + if err != nil { + return nil, &modifier.ErrInvalidPacket{Err: err} + } + if !dns.QR || dns.ResponseCode != layers.DNSResponseCodeNoErr { + return nil, &modifier.ErrInvalidPacket{Err: errNotValidDNSResponse} + } + if len(dns.Questions) == 0 { + return nil, &modifier.ErrInvalidPacket{Err: errEmptyDNSQuestion} + } + // In practice, most if not all DNS clients only send one question + // per packet, so we don't care about the rest for now. + q := dns.Questions[0] + switch q.Type { + case layers.DNSTypeA: + if i.A != nil { + dns.Answers = []layers.DNSResourceRecord{{ + Name: q.Name, + Type: layers.DNSTypeA, + Class: layers.DNSClassIN, + IP: i.A, + }} + } + case layers.DNSTypeAAAA: + if i.AAAA != nil { + dns.Answers = []layers.DNSResourceRecord{{ + Name: q.Name, + Type: layers.DNSTypeAAAA, + Class: layers.DNSClassIN, + IP: i.AAAA, + }} + } + } + buf := gopacket.NewSerializeBuffer() // Modifiers must be safe for concurrent use, so we can't reuse the buffer + err = gopacket.SerializeLayers(buf, gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + }, dns) + return buf.Bytes(), err +} diff --git a/ruleset/expr.go b/ruleset/expr.go new file mode 100644 index 0000000..45b7e05 --- /dev/null +++ b/ruleset/expr.go @@ -0,0 +1,219 @@ +package ruleset + +import ( + "fmt" + "os" + "reflect" + "strings" + + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/ast" + "github.com/expr-lang/expr/conf" + "github.com/expr-lang/expr/vm" + "gopkg.in/yaml.v3" + + "github.com/apernet/OpenGFW/analyzer" + "github.com/apernet/OpenGFW/modifier" +) + +// ExprRule is the external representation of an expression rule. +type ExprRule struct { + Name string `yaml:"name"` + Action string `yaml:"action"` + Modifier ModifierEntry `yaml:"modifier"` + Expr string `yaml:"expr"` +} + +type ModifierEntry struct { + Name string `yaml:"name"` + Args map[string]interface{} `yaml:"args"` +} + +func ExprRulesFromYAML(file string) ([]ExprRule, error) { + bs, err := os.ReadFile(file) + if err != nil { + return nil, err + } + var rules []ExprRule + err = yaml.Unmarshal(bs, &rules) + return rules, err +} + +// compiledExprRule is the internal, compiled representation of an expression rule. +type compiledExprRule struct { + Name string + Action Action + ModInstance modifier.Instance + Program *vm.Program + Analyzers map[string]struct{} +} + +var _ Ruleset = (*exprRuleset)(nil) + +type exprRuleset struct { + Rules []compiledExprRule + Ans []analyzer.Analyzer +} + +func (r *exprRuleset) Analyzers(info StreamInfo) []analyzer.Analyzer { + return r.Ans +} + +func (r *exprRuleset) Match(info StreamInfo) (MatchResult, error) { + env := streamInfoToExprEnv(info) + for _, rule := range r.Rules { + v, err := vm.Run(rule.Program, env) + if err != nil { + return MatchResult{ + Action: ActionMaybe, + }, fmt.Errorf("rule %q failed to run: %w", rule.Name, err) + } + if vBool, ok := v.(bool); ok && vBool { + return MatchResult{ + Action: rule.Action, + ModInstance: rule.ModInstance, + }, nil + } + } + return MatchResult{ + Action: ActionMaybe, + }, nil +} + +// CompileExprRules compiles a list of expression rules into a ruleset. +// It returns an error if any of the rules are invalid, or if any of the analyzers +// used by the rules are unknown (not provided in the analyzer list). +func CompileExprRules(rules []ExprRule, ans []analyzer.Analyzer, mods []modifier.Modifier) (Ruleset, error) { + var compiledRules []compiledExprRule + fullAnMap := analyzersToMap(ans) + fullModMap := modifiersToMap(mods) + depAnMap := make(map[string]analyzer.Analyzer) + // Compile all rules and build a map of analyzers that are used by the rules. + for _, rule := range rules { + action, ok := actionStringToAction(rule.Action) + if !ok { + return nil, fmt.Errorf("rule %q has invalid action %q", rule.Name, rule.Action) + } + visitor := &depVisitor{Analyzers: make(map[string]struct{})} + program, err := expr.Compile(rule.Expr, + func(c *conf.Config) { + c.Strict = false + c.Expect = reflect.Bool + c.Visitors = append(c.Visitors, visitor) + }, + ) + if err != nil { + return nil, fmt.Errorf("rule %q has invalid expression: %w", rule.Name, err) + } + for name := range visitor.Analyzers { + a, ok := fullAnMap[name] + if !ok && !isBuiltInAnalyzer(name) { + return nil, fmt.Errorf("rule %q uses unknown analyzer %q", rule.Name, name) + } + depAnMap[name] = a + } + cr := compiledExprRule{ + Name: rule.Name, + Action: action, + Program: program, + Analyzers: visitor.Analyzers, + } + if action == ActionModify { + mod, ok := fullModMap[rule.Modifier.Name] + if !ok { + return nil, fmt.Errorf("rule %q uses unknown modifier %q", rule.Name, rule.Modifier.Name) + } + modInst, err := mod.New(rule.Modifier.Args) + if err != nil { + return nil, fmt.Errorf("rule %q failed to create modifier instance: %w", rule.Name, err) + } + cr.ModInstance = modInst + } + compiledRules = append(compiledRules, cr) + } + // Convert the analyzer map to a list. + var depAns []analyzer.Analyzer + for _, a := range depAnMap { + depAns = append(depAns, a) + } + return &exprRuleset{ + Rules: compiledRules, + Ans: depAns, + }, nil +} + +func streamInfoToExprEnv(info StreamInfo) map[string]interface{} { + m := map[string]interface{}{ + "id": info.ID, + "proto": info.Protocol.String(), + "ip": map[string]string{ + "src": info.SrcIP.String(), + "dst": info.DstIP.String(), + }, + "port": map[string]uint16{ + "src": info.SrcPort, + "dst": info.DstPort, + }, + } + for anName, anProps := range info.Props { + if len(anProps) != 0 { + // Ignore analyzers with empty properties + m[anName] = anProps + } + } + return m +} + +func isBuiltInAnalyzer(name string) bool { + switch name { + case "id", "proto", "ip", "port": + return true + default: + return false + } +} + +func actionStringToAction(action string) (Action, bool) { + switch strings.ToLower(action) { + case "allow": + return ActionAllow, true + case "block": + return ActionBlock, true + case "drop": + return ActionDrop, true + case "modify": + return ActionModify, true + default: + return ActionMaybe, false + } +} + +// analyzersToMap converts a list of analyzers to a map of name -> analyzer. +// This is for easier lookup when compiling rules. +func analyzersToMap(ans []analyzer.Analyzer) map[string]analyzer.Analyzer { + anMap := make(map[string]analyzer.Analyzer) + for _, a := range ans { + anMap[a.Name()] = a + } + return anMap +} + +// modifiersToMap converts a list of modifiers to a map of name -> modifier. +// This is for easier lookup when compiling rules. +func modifiersToMap(mods []modifier.Modifier) map[string]modifier.Modifier { + modMap := make(map[string]modifier.Modifier) + for _, m := range mods { + modMap[m.Name()] = m + } + return modMap +} + +type depVisitor struct { + Analyzers map[string]struct{} +} + +func (v *depVisitor) Visit(node *ast.Node) { + if idNode, ok := (*node).(*ast.IdentifierNode); ok { + v.Analyzers[idNode.Value] = struct{}{} + } +} diff --git a/ruleset/interface.go b/ruleset/interface.go new file mode 100644 index 0000000..30bece7 --- /dev/null +++ b/ruleset/interface.go @@ -0,0 +1,94 @@ +package ruleset + +import ( + "net" + "strconv" + + "github.com/apernet/OpenGFW/analyzer" + "github.com/apernet/OpenGFW/modifier" +) + +type Action int + +const ( + // ActionMaybe indicates that the ruleset hasn't seen anything worth blocking based on + // current information, but that may change if volatile fields change in the future. + ActionMaybe Action = iota + // ActionAllow indicates that the stream should be allowed regardless of future changes. + ActionAllow + // ActionBlock indicates that the stream should be blocked. + ActionBlock + // ActionDrop indicates that the current packet should be dropped, + // but the stream should be allowed to continue. + // Only valid for UDP streams. Equivalent to ActionBlock for TCP streams. + ActionDrop + // ActionModify indicates that the current packet should be modified, + // and the stream should be allowed to continue. + // Only valid for UDP streams. Equivalent to ActionMaybe for TCP streams. + ActionModify +) + +func (a Action) String() string { + switch a { + case ActionMaybe: + return "maybe" + case ActionAllow: + return "allow" + case ActionBlock: + return "block" + case ActionDrop: + return "drop" + case ActionModify: + return "modify" + default: + return "unknown" + } +} + +type Protocol int + +func (p Protocol) String() string { + switch p { + case ProtocolTCP: + return "tcp" + case ProtocolUDP: + return "udp" + default: + return "unknown" + } +} + +const ( + ProtocolTCP Protocol = iota + ProtocolUDP +) + +type StreamInfo struct { + ID int64 + Protocol Protocol + SrcIP, DstIP net.IP + SrcPort, DstPort uint16 + Props analyzer.CombinedPropMap +} + +func (i StreamInfo) SrcString() string { + return net.JoinHostPort(i.SrcIP.String(), strconv.Itoa(int(i.SrcPort))) +} + +func (i StreamInfo) DstString() string { + return net.JoinHostPort(i.DstIP.String(), strconv.Itoa(int(i.DstPort))) +} + +type MatchResult struct { + Action Action + ModInstance modifier.Instance +} + +type Ruleset interface { + // Analyzers returns the list of analyzers to use for a stream. + // It must be safe for concurrent use by multiple workers. + Analyzers(StreamInfo) []analyzer.Analyzer + // Match matches a stream against the ruleset and returns the result. + // It must be safe for concurrent use by multiple workers. + Match(StreamInfo) (MatchResult, error) +}