diff --git a/README.md b/README.md index 3587c82..49b6353 100644 --- a/README.md +++ b/README.md @@ -1,50 +1,47 @@ -# DNS-over-TLS Cache Proxy +# 🚀 Go-DoT: 高性能 DNS-over-TLS 缓存代理 -一个用 **Go** 编写的高性能 **DNS-over-TLS (DoT)** 缓存代理服务, -专为隐私保护与性能优化而设计。支持多上游并发解析、智能缓存、ECS 剥离和优雅关闭。 +[![Go Version](https://img.shields.io/badge/Go-1.22+-00ADD8?style=flat&logo=go)](https://golang.org) +[![License](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE) +[![Security](https://img.shields.io/badge/TLS-1.3_Only-green.svg)](https://en.wikipedia.org/wiki/Transport_Layer_Security#TLS_1.3) -## 🚀 特性概览 +**Go-DoT** 是一个用 Go 语言编写的生产级 **DNS-over-TLS (DoT)** 缓存代理服务器,专注于**隐私保护、极致性能与高可用性**。 + +--- + +## ✨ 特性概览 | 功能 | 描述 | |------|------| -| 🔒 **加密传输** | 完全支持 DNS-over-TLS(RFC 7858),支持 TLS 1.2/1.3 | -| ⚡ **多上游并发解析** | “快乐眼球”机制并行查询多个上游 DNS,取最快响应 | -| 🧠 **智能缓存系统** | 支持正向与负面缓存(RFC 2308),动态 TTL 调整 | -| 🧹 **自动清理机制** | 定期清理过期缓存项(默认每 5 分钟) | -| 🧩 **隐私保护** | 默认剥离 ECS (EDNS Client Subnet),防止地理泄露 | -| 🧱 **黑名单过滤** | 支持域名黑名单(后缀匹配、文件或命令行加载) | -| 🪶 **轻量高效** | 单一可执行文件,无外部依赖,易于容器化部署 | +| 🔒 加密传输 | 支持 RFC 7858,默认强制 TLS 1.3 | +| ⚡ 高效匹配 | 基于 Trie(前缀树)实现百万级黑名单匹配,复杂度 O(L) | +| 🧠 请求合并 | 使用 Singleflight 防止缓存击穿 | +| 🚀 并发查询 | 并行请求上游 DNS,返回最快响应 | +| 📦 智能缓存 | 支持正向缓存 + NXDOMAIN 负缓存(LRU) | +| 🧹 优雅退出 | 支持 SIGTERM 信号 | +| 🧩 隐私保护 | 默认移除 ECS(EDNS Client Subnet) | + +--- ## 🏗️ 快速开始 -### 从源码构建 +### 1. 编译构建 + +需要 Go 1.22 或更高版本: ```bash git clone https://git.aixiao.me/aixiao/dot.git cd dot -bash build.sh bin +go mod tidy +go build -ldflags="-s -w -X 'main.BuildDate=$(date)'" -o dot ``` -### 使用 Docker 构建与运行 +### 2. 生成 TLS 证书(测试用) ```bash -# 构建镜像 -bash build.sh build - -# 运行容器 -bash build.sh run - -# 查看日志 -bash build.sh logs - -# 停止与清理 -bash build.sh stop -bash build.sh clean +openssl req -x509 -newkey rsa:4096 -keyout server.key -out server.crt -days 365 -nodes -subj "/CN=dot.local" ``` -> 🧩 默认镜像名为 `dot:latest`,监听 `853` 端口,可通过 `PORT` 环境变量修改。 - -## ⚙️ 启动示例 +### 3. 启动服务 ```bash ./dot \ @@ -52,119 +49,79 @@ bash build.sh clean -key aixiao.me.key \ -addr :853 \ -upstream 119.29.29.29:53,223.5.5.5:53,114.114.114.114:53 \ - -cache-ttl 300s \ - -timeout 3s \ - -max-parallel 3 \ - -blacklist-file blacklist.txt - + -blacklist-file blacklist.txt \ + -v ``` -启动日志示例: - -```bash -🚀 starting DNS-over-TLS on :853 -[req] A www.example.com. (id=40192 cd=false do=true from=127.0.0.1:58877) -[cache] MISS A www.example.com. -[answer] www.example.com. 300 IN A 93.184.216.34 -``` +--- ## ⚙️ 参数说明 | 参数 | 默认值 | 描述 | -|------|---------|------| -| `--addr` | `:853` | 监听地址(支持 IPv4/IPv6) | -| `--cert` | `server.crt` | TLS 证书路径 | -| `--key` | `server.key` | TLS 私钥路径 | -| `--upstream` | `8.8.8.8:53,1.1.1.1:53` | 上游 DNS 服务器列表 | -| `--cache-ttl` | `60s` | 缓存最大 TTL(正向/负面均适用) | -| `--timeout` | `3s` | 上游查询超时 | -| `--max-parallel` | `3` | 最大并发上游查询数 | -| `--strip-ecs` | `true` | 是否剥离 ECS 信息 | -| `--tcp-fallback` | `true` | UDP 截断时是否自动 TCP 回退 | -| `--blacklist` | 空 | 逗号分隔的黑名单域名或通配后缀 | -| `--blacklist-file` | 空 | 黑名单文件路径(每行一个规则) | -| `--blacklist-rcode` | `REFUSED` | 黑名单命中返回码:`REFUSED` / `NXDOMAIN` / `SERVFAIL` | -| `--cache-size` | `10000` | LRU 缓存最大条目数 | -| `--v` | `false` | 启用详细日志模式 | +|------|--------|------| +| -addr | :853 | DoT 监听地址 | +| -cert | server.crt | TLS 证书路径 | +| -key | server.key | TLS 私钥路径 | +| -upstream | 8.8.8.8:53,1.1.1.1:53 | 上游 DNS | +| -blacklist-file | 空 | 黑名单文件路径 | +| -blacklist-rcode | REFUSED | 命中黑名单返回码:REFUSED / NXDOMAIN / SERVFAIL | +| -v | false | 启用详细日志模式 | -## 🔍 缓存机制详解 - -**缓存键格式:** - -```sh -domain|type|class|DO|CD -``` - -**缓存策略:** - -- ✅ **正向缓存**:取最小 TTL 与配置上限的较小值 -- 🚫 **负面缓存**:遵循 RFC 2308,从 SOA.MINIMUM 计算 TTL -- 🧭 **动态 TTL 调整**:返回时按剩余时间递减 TTL -- 🧹 **自动清理**:每 5 分钟扫描并删除过期条目 -- 🔒 **隔离逻辑**:DO/CD 不同查询独立缓存空间 +--- ## 🧱 黑名单功能 -支持两种配置方式: +支持 hosts + 通配符格式: -1. 命令行参数: +```text +# 注释 +ad.doubleclick.net +*.tracking.com +127.0.0.1 malicious-site.io +``` - ```bash - ./dot -blacklist="*.ads.com,*.tracking.net" - ``` +命中后直接返回指定 RCODE,并附带 EDE 说明。 -2. 文件加载(每行一个域名或后缀): +--- - ```sh - # blacklist.txt - *.ads.com - *.malware.net - ``` +## 🔍 缓存机制 - 启动命令: +- 缓存键: Type|DO|CD|ECS|Domain 组合键,确保不同请求策略的结果物理隔离。 - ```bash - ./dot -blacklist-file=blacklist.txt -blacklist-rcode=NXDOMAIN - ``` +- TTL 策略: + - ✅ 正向缓存: 取记录中最小 TTL 与配置上限的较小值。 + - 🚫 负面缓存: 遵循 RFC 2308,自动计算 SOA 的最小 TTL 进行缓存。 -黑名单命中后不再上游查询,直接返回指定 RCODE。 +--- -## 🧩 架构与运行原理 +## 🧩 工作流程 ```mermaid flowchart TD - A[Client (DoT Request)] --> B[DNS-over-TLS Server] - B --> C[Cache Lookup] - C -- HIT --> D[Return Cached Response] - C -- MISS --> E[Upstream Resolver Pool] - E -->|Fastest Response| F[DNS Response] - F --> G[Cache Write] - G --> H[Return to Client] + A[客户端请求] --> B{黑名单匹配} + B -- 命中 --> C[返回拦截] + B -- 未命中 --> D{缓存命中} + D -- 命中 --> E[返回缓存] + D -- 未命中 --> F[Singleflight] + F --> G[并发查询上游] + G --> H[写入缓存] + H --> I[返回结果] ``` -## 🧰 开发与维护 +--- -- 语言:**Go 1.22+** -- 依赖: - - [`github.com/miekg/dns`](https://github.com/miekg/dns) - - [`github.com/hashicorp/golang-lru/v2`](https://pkg.go.dev/github.com/hashicorp/golang-lru/v2) -- 推荐编译参数: +## 🛡️ 生产建议 - ```bash - go build -ldflags="-s -w" -o dot main.go - ``` +- 提高文件描述符限制:ulimit -n 65535 +- 使用合法 CA 证书(Android/iOS 必须) +- 关闭 -v 日志提升性能 +- 配置 ≥3 个上游 DNS -## 🧪 测试方法 +--- -使用 `kdig` 或 `dig` 测试解析: +## 👨‍💻 作者 -```bash -kdig @127.0.0.1 +tls-ca +tls-host=dot.local www.example.com -``` - -## 👨‍💻 作者信息 - -**Author:** niuyuling -**Email:** [aixiao@aixiao.me](mailto:aixiao@aixiao.me) -**License:** MIT -**Repository:** [git.aixiao.me/aixiao/dot](https://git.aixiao.me/aixiao/dot) +- Author: niuyuling +- Email: aixiao@aixiao.me +- License: MIT +- Repository: https://git.aixiao.me/aixiao/dot diff --git a/blacklist.go b/blacklist.go index 774a123..77e7694 100644 --- a/blacklist.go +++ b/blacklist.go @@ -6,214 +6,146 @@ import ( "log" "net" "os" - "sort" "strings" "sync/atomic" "time" "github.com/miekg/dns" - "golang.org/x/net/idna" ) -// -------- Blacklist helpers (独立文件) -------- - -// 将输入域名规范化为 canonical FQDN(去空格、去 *. 前缀、统一大小写/尾点) -func canonicalFQDN(s string) string { - s = strings.TrimSpace(s) - if s == "" { - return "" - } - // 允许黑名单写 "*.example.com";内部匹配用裸后缀 - s = strings.TrimPrefix(s, "*.") - - // 先把可能的中文/Unicode 域名转成 ASCII(punycode),再规范化 - if a, err := idna.Lookup.ToASCII(s); err == nil { - s = a - } - // CanonicalName 会做小写化与尾点规范化 - return dns.CanonicalName(s) +type BlacklistTrie struct { + root *trieNode } -func uniqueStrings(in []string) []string { - seen := make(map[string]struct{}, len(in)) - out := make([]string, 0, len(in)) - for _, s := range in { - if s == "" { - continue - } - if _, ok := seen[s]; ok { - continue - } - seen[s] = struct{}{} - out = append(out, s) - } - return out +type trieNode struct { + children map[string]*trieNode + isEnd bool + rule string } -// 支持 # / ; 注释;每行一个域名;支持以 "*.example.com" 书写 -// 支持 # / ; 注释;每行一个域名;支持以 "*.example.com" 书写;支持 hosts 风格(首列为 IP) -func loadBlacklistFile(path string) ([]string, error) { - f, err := os.Open(path) - if err != nil { - return nil, err - } - defer f.Close() - - sc := bufio.NewScanner(f) - // 默认 64KB 容量不够稳妥,这里放大到 2MB,兼容一些合并的大 hosts 列表 - sc.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) - - var rules []string - for sc.Scan() { - line := strings.TrimSpace(sc.Text()) - if line == "" { - continue - } - // 去掉注释(# 或 ; 之后的内容) - if i := strings.IndexAny(line, "#;"); i >= 0 { - line = strings.TrimSpace(line[:i]) - if line == "" { - continue - } - } - - fields := strings.Fields(line) - if len(fields) == 0 { - continue - } - - // hosts 风格:第一个字段是 IP,则其余每个字段视为域名 - start := 0 - if net.ParseIP(fields[0]) != nil { - start = 1 - } - for _, tok := range fields[start:] { - if r := canonicalFQDN(tok); r != "" { - rules = append(rules, r) - } - } - } - if err := sc.Err(); err != nil { - return nil, err - } - - sort.Strings(rules) - rules = uniqueStrings(rules) - return rules, nil +func newTrie() *BlacklistTrie { + return &BlacklistTrie{root: &trieNode{children: make(map[string]*trieNode)}} } -// 自动重载黑名单 -func startBlacklistReloader(ctx context.Context, path string, interval time.Duration, holder *atomic.Pointer[suffixMatcher]) { - if path == "" { - return - } - go func() { - var lastMod time.Time - ticker := time.NewTicker(interval) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - fi, err := os.Stat(path) - if err != nil { - log.Printf("[blacklist] reload check failed: %v", err) - continue - } - modTime := fi.ModTime() - if modTime.After(lastMod) { - rules, err := loadBlacklistFile(path) - if err != nil { - log.Printf("[blacklist] reload failed: %v", err) - continue - } - holder.Store(newSuffixMatcher(rules)) - lastMod = modTime - log.Printf("[blacklist] reloaded %d rules (modified %s)", len(rules), modTime.Format(time.RFC3339)) - } - } +func (t *BlacklistTrie) Add(domain string) { + name := strings.TrimSuffix(strings.ToLower(domain), ".") + parts := strings.Split(name, ".") + curr := t.root + for i := len(parts) - 1; i >= 0; i-- { + p := parts[i] + if _, ok := curr.children[p]; !ok { + curr.children[p] = &trieNode{children: make(map[string]*trieNode)} } - }() -} - -// 后缀匹配器:rules 已 canonical,按长度降序排列(更精确的规则优先) -type suffixMatcher struct { - rules []string -} - -func newSuffixMatcher(rules []string) *suffixMatcher { - rs := uniqueStrings(rules) - sort.Slice(rs, func(i, j int) bool { return len(rs[i]) > len(rs[j]) }) - return &suffixMatcher{rules: rs} -} - -// 命中则返回匹配的规则 -func (m *suffixMatcher) match(name string) (string, bool) { - if m == nil || len(m.rules) == 0 { - return "", false + curr = curr.children[p] } - name = dns.CanonicalName(name) - for _, r := range m.rules { - if name == r || strings.HasSuffix(name, "."+r) { - return r, true + curr.isEnd = true + curr.rule = name +} + +func (t *BlacklistTrie) Match(domain string) (string, bool) { + name := strings.TrimSuffix(strings.ToLower(domain), ".") + parts := strings.Split(name, ".") + curr := t.root + for i := len(parts) - 1; i >= 0; i-- { + p := parts[i] + next, ok := curr.children[p] + if !ok { + return "", false + } + curr = next + if curr.isEnd { + return curr.rule, true } } return "", false } -// 解析 RCODE 文本到常量 +func initBlacklist(ctx context.Context, path, rcodeStr string) (*atomic.Pointer[BlacklistTrie], int) { + var holder atomic.Pointer[BlacklistTrie] + var lastMod time.Time + rcode := parseRcode(rcodeStr) + + load := func() { + if path == "" { + return + } + + info, err := os.Stat(path) + if err != nil { + log.Printf("[blacklist] Stat error: %v", err) + return + } + // 如果文件没动,不加载 + if !info.ModTime().After(lastMod) { + return + } + + f, err := os.Open(path) + if err != nil { + return + } + defer f.Close() + + trie := newTrie() + scanner := bufio.NewScanner(f) + count := 0 + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || line[0] == '#' || line[0] == ';' { + continue + } + fields := strings.Fields(line) + domain := fields[0] + if net.ParseIP(domain) != nil && len(fields) > 1 { + domain = fields[1] + } + trie.Add(domain) + count++ + } + holder.Store(trie) + lastMod = info.ModTime() + log.Printf("[blacklist] Loaded %d rules (updated: %v)", count, lastMod.Format(time.RFC3339)) + } + + load() + + if path != "" { + go func() { + // 缩短检查间隔,但因为有文件时间校验,所以不费 CPU + ticker := time.NewTicker(5 * time.Minute) + for { + select { + case <-ticker.C: + load() + case <-ctx.Done(): + return + } + } + }() + } + + return &holder, rcode +} + func parseRcode(s string) int { - switch strings.ToUpper(strings.TrimSpace(s)) { - case "REFUSED", "": - return dns.RcodeRefused + switch strings.ToUpper(s) { case "NXDOMAIN": return dns.RcodeNameError case "SERVFAIL": return dns.RcodeServerFailure default: - log.Printf("[blacklist] unknown rcode %q, fallback REFUSED", s) return dns.RcodeRefused } } -// 构造“伪上游”阻断响应;带 EDE=Blocked(15) 说明命中规则 -func makeBlockedUpstream(rcode int, rule string) *dns.Msg { +func makeBlockedMsg(rcode int, rule string) *dns.Msg { m := new(dns.Msg) - m.RecursionAvailable = true - m.AuthenticatedData = false m.Rcode = rcode - opt := &dns.OPT{} - opt.Hdr.Name = "." - opt.Hdr.Rrtype = dns.TypeOPT - ede := &dns.EDNS0_EDE{InfoCode: 15, ExtraText: "blocked by policy: " + rule} - opt.Option = append(opt.Option, ede) - m.Extra = append(m.Extra, opt) + m.RecursionAvailable = true + o := new(dns.OPT) + o.Hdr.Name = "." + o.Hdr.Rrtype = dns.TypeOPT + o.Option = append(o.Option, &dns.EDNS0_EDE{InfoCode: 15, ExtraText: "Blocked by policy"}) + m.Extra = append(m.Extra, o) return m } - -func initBlacklist(ctx context.Context, listStr, filePath, rcodeStr string) (*atomic.Pointer[suffixMatcher], int) { - var rules []string - if v := strings.TrimSpace(listStr); v != "" { - for _, s := range strings.Split(v, ",") { - if r := canonicalFQDN(s); r != "" { - rules = append(rules, r) - } - } - } - if filePath != "" { - if fs, err := loadBlacklistFile(filePath); err != nil { - log.Printf("[blacklist] load file error: %v", err) - } else { - rules = append(rules, fs...) - } - } - var holder atomic.Pointer[suffixMatcher] - holder.Store(newSuffixMatcher(rules)) - blRcode := parseRcode(rcodeStr) - if filePath != "" { - startBlacklistReloader(ctx, filePath, 30*time.Second, &holder) - } - log.Printf("[blacklist] loaded %d rules (file=%v, rcode=%s)", len(holder.Load().rules), filePath != "", strings.ToUpper(rcodeStr)) - return &holder, blRcode -} diff --git a/dot b/dot index 1669f36..993c8e5 100644 Binary files a/dot and b/dot differ diff --git a/go.mod b/go.mod index 540f672..8176e86 100644 --- a/go.mod +++ b/go.mod @@ -4,15 +4,14 @@ go 1.25.3 require ( github.com/hashicorp/golang-lru/v2 v2.0.7 - github.com/miekg/dns v1.1.68 - golang.org/x/net v0.46.0 - golang.org/x/sync v0.17.0 + github.com/miekg/dns v1.1.72 + golang.org/x/net v0.52.0 // indirect + golang.org/x/sync v0.20.0 ) require ( github.com/google/go-cmp v0.7.0 // indirect - golang.org/x/mod v0.29.0 // indirect - golang.org/x/sys v0.37.0 // indirect - golang.org/x/text v0.30.0 // indirect - golang.org/x/tools v0.38.0 // indirect + golang.org/x/mod v0.34.0 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/tools v0.42.0 // indirect ) diff --git a/go.sum b/go.sum index 3ca3654..28d077c 100644 --- a/go.sum +++ b/go.sum @@ -4,15 +4,43 @@ github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA= github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps= +github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI= +github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs= golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA= golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w= +golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= +golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/mod v0.34.0 h1:xIHgNUUnW6sYkcM5Jleh05DvLOtwc6RitGHbDk4akRI= +golang.org/x/mod v0.34.0/go.mod h1:ykgH52iCZe79kzLLMhyCUzhMci+nQj+0XkbXpNYtVjY= golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= +golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= +golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= +golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= diff --git a/main.go b/main.go index dd22914..27cf4a5 100644 --- a/main.go +++ b/main.go @@ -1,14 +1,10 @@ -// main.go package main import ( "context" "crypto/tls" "flag" - "fmt" "log" - "math/rand" - "net" "os" "os/signal" "strings" @@ -22,12 +18,31 @@ import ( "golang.org/x/sync/singleflight" ) -var BuildDate = "unknown" // 由编译时注入 +var ( + BuildDate = "unknown" + cache *lru.Cache[string, *cacheEntry] + inflight singleflight.Group + udpClient *dns.Client + tcpClient *dns.Client + verbose bool // 全局日志开关 +) -/****************************************************************** - * 日志初始化 - ******************************************************************/ -func initLogger(verbose bool) { +const ( + defaultCacheSize = 20000 + maxUDPSize = 1232 +) + +type cacheEntry struct { + msg *dns.Msg + expireAt time.Time + ednsPresent bool + ednsVersion uint8 + ednsExtRcode uint16 + ednsEDE []*dns.EDNS0_EDE +} + +func initLogger(v bool) { + verbose = v flags := log.Ldate | log.Ltime | log.Lmicroseconds if verbose { flags |= log.Lshortfile @@ -35,756 +50,320 @@ func initLogger(verbose bool) { log.SetFlags(flags) } -/****************************************************************** - * 缓存结构(支持 TTL + LRU) - ******************************************************************/ +func cacheKey(q dns.Question, r *dns.Msg, ecs string) string { + do, cd := "0", "0" + if o := r.IsEdns0(); o != nil && o.Do() { + do = "1" + } + if r.CheckingDisabled { + cd = "1" + } -type cacheEntry struct { - msg *dns.Msg - expireAt time.Time - // EDNS metadata (to reproduce Extended RCODE / EDE on cache hits) - ednsPresent bool - ednsVersion uint8 - ednsExtRcode uint16 - ednsEDE []*dns.EDNS0_EDE -} - -var ( - cache *lru.Cache[string, *cacheEntry] - cacheMutex sync.RWMutex - inflight singleflight.Group -) - -const ( - cacheCleanupInterval = 5 * time.Minute - defaultCacheSize = 10000 // 默认最大缓存条目数 -) - -// startCacheCleaner 定期清理过期缓存(在删除前二次校验) -func startCacheCleaner(ctx context.Context) { - go func() { - ticker := time.NewTicker(cacheCleanupInterval) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - now := time.Now() - var toDelete []string - - cacheMutex.RLock() - for _, k := range cache.Keys() { - if v, ok := cache.Peek(k); ok && now.After(v.expireAt) { - toDelete = append(toDelete, k) - } - } - cacheMutex.RUnlock() - - if len(toDelete) > 0 { - pruned := 0 - cacheMutex.Lock() - for _, k := range toDelete { - if v, ok := cache.Peek(k); ok && now.After(v.expireAt) { - cache.Remove(k) - pruned++ - } - } - cacheMutex.Unlock() - if pruned > 0 { - log.Printf("[cache] cleaned %d expired entries", pruned) - } - } - } - } - }() -} - -func cacheKeyFromMsg(q dns.Question, do, cd bool) string { var b strings.Builder b.Grow(len(q.Name) + 32) - // 采用规范化域名,避免尾随点/IDNA/大小写造成的重复键 - b.WriteString(dns.CanonicalName(q.Name)) - b.WriteString("|T=") b.WriteString(dns.TypeToString[q.Qtype]) - b.WriteString("|C=") - b.WriteString(dns.ClassToString[q.Qclass]) - if do { - b.WriteString("|DO") - } - if cd { - b.WriteString("|CD") - } + b.WriteByte('|') + b.WriteString(do) + b.WriteString(cd) + b.WriteByte('|') + b.WriteString(ecs) + b.WriteByte('|') + b.WriteString(strings.ToLower(q.Name)) return b.String() } -// ecsKeyPart 在未 strip ECS 时,把 ECS 归一化后的“网络”信息并入缓存 key -// 为最小改动:当 strip=true(即启用去 ECS)时直接返回空字符串 -func ecsKeyPart(m *dns.Msg, strip bool) string { - if strip { - return "" - } - o := m.IsEdns0() - if o == nil { - return "" - } - for _, opt := range o.Option { - s, ok := opt.(*dns.EDNS0_SUBNET) - if !ok { - continue - } - fam := s.Family - pfx := int(s.SourceNetmask) - addr := append(net.IP(nil), s.Address...) // 拷贝以免原切片被改 - switch fam { - case 1: // IPv4 - ip := addr.To4() - if ip != nil { - mask := net.CIDRMask(pfx, 32) - ip = ip.Mask(mask) - return fmt.Sprintf("|ECS=%d/%d/%x", fam, s.SourceNetmask, ip) +func queryUpstreams(ctx context.Context, req *dns.Msg, upstreams []string, timeout time.Duration, parallel int) *dns.Msg { + cctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + resCh := make(chan *dns.Msg, len(upstreams)) + var wg sync.WaitGroup + sem := make(chan struct{}, parallel) + + for _, svr := range upstreams { + wg.Add(1) + go func(s string) { + defer wg.Done() + select { + case sem <- struct{}{}: + defer func() { <-sem }() + case <-cctx.Done(): + return } - case 2: // IPv6 - ip := addr.To16() - if ip != nil { - mask := net.CIDRMask(pfx, 128) - for i := 0; i < 16; i++ { - ip[i] &= mask[i] + + uReq := req.Copy() + if o := uReq.IsEdns0(); o == nil { + uReq.SetEdns0(maxUDPSize, false) + } else { + o.SetUDPSize(maxUDPSize) + } + + resp, _, err := udpClient.ExchangeContext(cctx, uReq, s) + if err == nil && resp != nil && resp.Truncated { + resp, _, err = tcpClient.ExchangeContext(cctx, req, s) + } + + if err == nil && resp != nil { + resCh <- resp + } + }(svr) + } + + go func() { + wg.Wait() + close(resCh) + }() + + for r := range resCh { + if r.Rcode != dns.RcodeServerFailure && r.Rcode != dns.RcodeRefused { + return r + } + } + return nil +} + +func handleDNS(upstreams []string, maxTTL, timeout time.Duration, parallel int, stripECS bool, blPtr *atomic.Pointer[BlacklistTrie], blRcode int) dns.HandlerFunc { + return func(w dns.ResponseWriter, r *dns.Msg) { + defer func() { + if err := recover(); err != nil { + log.Printf("[PANIC] %v", err) + dns.HandleFailed(w, r) + } + }() + + if len(r.Question) == 0 { + dns.HandleFailed(w, r) + return + } + + q := r.Question[0] + startTime := time.Now() + + if trie := blPtr.Load(); trie != nil { + if rule, hit := trie.Match(q.Name); hit { + log.Printf("[BLOCK] %s rule=%s client=%s", q.Name, rule, w.RemoteAddr()) + writeReply(w, r, makeBlockedMsg(blRcode, rule), nil) + return + } + } + + ecs := "" + if !stripECS { + if o := r.IsEdns0(); o != nil { + for _, opt := range o.Option { + if s, ok := opt.(*dns.EDNS0_SUBNET); ok { + ecs = s.Address.String() + } } - return fmt.Sprintf("|ECS=%d/%d/%x", fam, s.SourceNetmask, ip) } + } else { + stripECSFromMsg(r) } - // 回退:不做掩码 - return fmt.Sprintf("|ECS=%d/%d/%x", fam, s.SourceNetmask, addr) - } - return "" -} -func isPseudo(rr dns.RR) bool { - switch rr.(type) { - case *dns.OPT, *dns.TSIG: - return true - default: - return false - } -} + key := cacheKey(q, r, ecs) -// clone and extract EDNS metadata (present, version, ext-rcode, all EDEs) -func cloneEDE(in *dns.EDNS0_EDE) *dns.EDNS0_EDE { - if in == nil { - return nil - } - cp := *in - return &cp -} - -func extractEDNSMeta(m *dns.Msg) (present bool, version uint8, ext uint16, ede []*dns.EDNS0_EDE) { - if o := m.IsEdns0(); o != nil { - present = true - version = o.Version() - ext = uint16(o.ExtendedRcode()) - for _, opt := range o.Option { - if e, ok := opt.(*dns.EDNS0_EDE); ok { - ede = append(ede, cloneEDE(e)) + if msg, meta, ok := tryCacheRead(key); ok { + if verbose { + log.Printf("[CACHE] HIT %s", q.Name) } + writeReply(w, r, msg, meta) + return + } + + v, _, _ := inflight.Do(key, func() (any, error) { + resp := queryUpstreams(context.Background(), r, upstreams, timeout, parallel) + if resp != nil { + cacheWrite(key, resp, maxTTL) + } + return resp, nil + }) + + if resp, _ := v.(*dns.Msg); resp != nil { + writeReply(w, r, resp, nil) + if verbose { + log.Printf("[QUERY] %s %s -> %s (%v)", dns.TypeToString[q.Qtype], q.Name, dns.RcodeToString[resp.Rcode], time.Since(startTime)) + } + } else { + dns.HandleFailed(w, r) } } - return } -// 读取缓存(Get 在写锁下;在锁外调整 TTL;返回 EDNS 元数据) func tryCacheRead(key string) (*dns.Msg, *cacheEntry, bool) { - now := time.Now() - - cacheMutex.Lock() - e, ok := cache.Get(key) // Get 会更新 LRU,必须在写锁下 - if !ok { - cacheMutex.Unlock() + e, ok := cache.Get(key) + if !ok || time.Now().After(e.expireAt) { return nil, nil, false } - if now.After(e.expireAt) { - cache.Remove(key) - cacheMutex.Unlock() - return nil, nil, false - } - // 拷贝副本,在锁外改 TTL,减少临界区时间 out := e.msg.Copy() - expireAt := e.expireAt - cacheMutex.Unlock() - - remaining := uint32(expireAt.Sub(now).Seconds()) - if remaining == 0 { - cacheMutex.Lock() - cache.Remove(key) - cacheMutex.Unlock() - return nil, nil, false - } - + ttl := uint32(time.Until(e.expireAt).Seconds()) for _, sec := range [][]dns.RR{out.Answer, out.Ns, out.Extra} { for _, rr := range sec { - if isPseudo(rr) { - continue - } - if rr.Header().Ttl > remaining { - rr.Header().Ttl = remaining + if rr.Header().Rrtype != dns.TypeOPT { + rr.Header().Ttl = ttl } } } return out, e, true } -// hasAnswerForType 判断报文中是否存在回答“请求类型”的 RRset -func hasAnswerForType(m *dns.Msg, q dns.Question) bool { - for _, rr := range m.Answer { - h := rr.Header() - if h.Rrtype == q.Qtype && strings.EqualFold(h.Name, q.Name) { - return true - } - } - return false -} - -// 计算负面 TTL(正确识别 NODATA,包括 CNAME 等场景) -func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) { - // NXDOMAIN:肯定是负面 - if m.Rcode != dns.RcodeNameError { - // 不是 NXDOMAIN,则仅当 NOERROR 但没有“匹配 QTYPE 的答案”时才是 NODATA - if m.Rcode != dns.RcodeSuccess || len(m.Question) == 0 || hasAnswerForType(m, m.Question[0]) { - return 0, false - } - } - - // 按 RFC 2308,从 Authority(Ns)优先取 SOA(多数实现都只放在 Authority) - var soa *dns.SOA - for _, rr := range m.Ns { - if s, ok := rr.(*dns.SOA); ok { - soa = s - break - } - } - // 兼容性:偶尔也有人把 SOA 放 Extra(不规范,但为了兼容可以兜底看看) - if soa == nil { - for _, rr := range m.Extra { - if s, ok := rr.(*dns.SOA); ok { - soa = s - break - } - } - } - if soa == nil { - // 建议:无 SOA 时不做负面缓存(返回 0,false) - return 0, false - } - - // 负面 TTL 取 min(SOA.MINIMUM, SOA 自身 TTL),再与配置上限比较 - ttl := soa.Hdr.Ttl - if soa.Minttl < ttl { - ttl = soa.Minttl - } - capTTL := uint32(maxTTL.Seconds()) - if capTTL > 0 && ttl > capTTL { - ttl = capTTL - } - return ttl, ttl > 0 -} - -func minRRsetTTL(m *dns.Msg) (uint32, bool) { - minTTL := uint32(0) - hasTTL := false - // 优先 Answer -> Ns;若都为空,再考虑 Extra(排除伪记录) - for _, sec := range [][]dns.RR{m.Answer, m.Ns} { - for _, rr := range sec { - if isPseudo(rr) { - continue - } - ttl := rr.Header().Ttl - if !hasTTL || ttl < minTTL { - minTTL = ttl - hasTTL = true - } - } - } - if !hasTTL { - for _, rr := range m.Extra { - if isPseudo(rr) { - continue - } - ttl := rr.Header().Ttl - if !hasTTL || ttl < minTTL { - minTTL = ttl - hasTTL = true - } - } - } - return minTTL, hasTTL -} - -func stripPseudoExtras(m *dns.Msg) { - if len(m.Extra) == 0 { - return - } - out := m.Extra[:0] - for _, rr := range m.Extra { - if isPseudo(rr) { - continue - } - out = append(out, rr) - } - m.Extra = out -} - -// 写缓存(保存 EDNS 元数据,命中时可重建扩展 RCODE/EDE) func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) { - if in == nil { - return - } if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError { return } - var ttl uint32 - var ok bool - // 判断负面响应(NXDOMAIN 或 NODATA) - neg, isNodata := in.Rcode == dns.RcodeNameError, false - if in.Rcode == dns.RcodeSuccess && len(in.Question) > 0 && !hasAnswerForType(in, in.Question[0]) { - isNodata = true - } - - if ttl, ok = negativeTTL(in, maxTTL); !ok { - if neg || isNodata { - return // 负面但无 SOA → 不缓存 + var ttl uint32 = uint32(maxTTL.Seconds()) + found := false + for _, rr := range in.Answer { + if rr.Header().Rrtype != dns.TypeOPT && rr.Header().Ttl < ttl { + ttl = rr.Header().Ttl + found = true } - - minTTL, has := minRRsetTTL(in) - if has { - cfgTTL := uint32(maxTTL.Seconds()) - if cfgTTL > 0 && minTTL > cfgTTL { - minTTL = cfgTTL + } + if !found { + for _, rr := range in.Ns { + if soa, ok := rr.(*dns.SOA); ok { + if soa.Minttl < ttl { + ttl = soa.Minttl + } + found = true } - ttl = minTTL - } else { - ttl = uint32(maxTTL.Seconds()) } } - if ttl == 0 { + if ttl < 10 { return } - expire := time.Now().Add(time.Duration(ttl) * time.Second) + cp := in.Copy() - // 提取 EDNS 元数据后再剥离伪记录 present, ver, ext, ede := extractEDNSMeta(cp) stripPseudoExtras(cp) - cacheMutex.Lock() + cache.Add(key, &cacheEntry{ msg: cp, - expireAt: expire, + expireAt: time.Now().Add(time.Duration(ttl) * time.Second), ednsPresent: present, ednsVersion: ver, ednsExtRcode: ext, ednsEDE: ede, }) - cacheMutex.Unlock() } -/****************************************************************** - * 上游查询 - ******************************************************************/ -var udpClient *dns.Client - -func shuffled(xs []string) []string { - out := make([]string, len(xs)) - copy(out, xs) - rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] }) - return out -} - -// clampEDNSForUpstream 返回一个 msg 副本,把 EDNS UDP size 夹到给定大小 -func clampEDNSForUpstream(in *dns.Msg, size uint16) *dns.Msg { - m := in.Copy() - o := m.IsEdns0() - if o == nil { - o = &dns.OPT{} - o.Hdr.Name = "." - o.Hdr.Rrtype = dns.TypeOPT - m.Extra = append(m.Extra, o) - } - if size > 0 { - o.SetUDPSize(size) - } - return m -} - -func queryUpstreamsLimited( - ctx context.Context, - req *dns.Msg, - upstreams []string, - timeout time.Duration, - maxParallel int, - allowTCPFallback bool, -) *dns.Msg { - if maxParallel <= 0 { - maxParallel = 1 - } - servers := shuffled(upstreams) - - cctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - type result struct { - msg *dns.Msg - } - ch := make(chan result, len(servers)) - done := make(chan struct{}, len(servers)) - sem := make(chan struct{}, maxParallel) - - // 单个上游执行 - execOne := func(svr string) { - // 并发限流(可被超时取消) - select { - case sem <- struct{}{}: - defer func() { <-sem }() - case <-cctx.Done(): - // 超时/取消,直接放弃 - return - } - defer func() { done <- struct{}{} }() - - // 为 UDP 上游把 EDNS UDP size 夹到 1232,降低分片风险 - upReq := clampEDNSForUpstream(req, 1232) - - // 先走 UDP - resp, _, err := udpClient.ExchangeContext(cctx, upReq, svr) - // 截断且允许回退则走 TCP - if err == nil && resp != nil && resp.Truncated && allowTCPFallback { - log.Printf("[upstream] UDP truncated, retry TCP: %s", svr) - tcpClient := *udpClient - tcpClient.Net = "tcp" - resp, _, err = tcpClient.ExchangeContext(cctx, req.Copy(), svr) - } - // 失败直接返回(但不写入 ch);只在未超时情况下打印错误 - if err != nil || resp == nil { - if err != nil && cctx.Err() == nil { - log.Printf("[upstream] %s: %v", svr, err) - } - return - } - // 过滤不可用的错误 RCODE(避免造成“假性超时”的错觉) - if resp.Rcode == dns.RcodeServerFailure || - resp.Rcode == dns.RcodeRefused || - resp.Rcode == dns.RcodeFormatError { - return - } - - // 投递可用结果(若已经超时则丢弃) - select { - case ch <- result{msg: resp}: - case <-cctx.Done(): - } - } - - // 并发发起 - for _, s := range servers { - s := s - go execOne(s) - } - - finished := 0 - total := len(servers) - - // 聚合:首个可用响应直接返回;区分“真超时”与“无可用结果” - for finished < total { - select { - case r := <-ch: - if r.msg != nil { - cancel() - return r.msg - } - case <-done: - finished++ - case <-cctx.Done(): - log.Printf("[upstream] timeout after %v (finished=%d/%d)", timeout, finished, total) - return nil - } - } - - // 所有上游都结束,但没有一个可用 - log.Printf("[upstream] no acceptable upstream response (finished=%d/%d)", finished, total) - return nil -} - -/****************************************************************** - * EDNS / 响应构造 - ******************************************************************/ -func stripECS(m *dns.Msg) { - if o := m.IsEdns0(); o != nil { - var kept []dns.EDNS0 - for _, e := range o.Option { - if _, isECS := e.(*dns.EDNS0_SUBNET); !isECS { - kept = append(kept, e) - } - } - o.Option = kept - } -} - -func getDOFlag(m *dns.Msg) bool { - if o := m.IsEdns0(); o != nil { - return o.Do() - } - return false -} - -func writeReply(w dns.ResponseWriter, req *dns.Msg, upstream *dns.Msg, meta *cacheEntry) { - if upstream == nil { - dns.HandleFailed(w, req) - return - } +func writeReply(w dns.ResponseWriter, req, resp *dns.Msg, meta *cacheEntry) { out := new(dns.Msg) out.SetReply(req) - out.Authoritative = false - // RA 语义修正:RA 表示“服务器是否支持递归”,与客户端 RD 无关 - out.RecursionAvailable = upstream.RecursionAvailable - out.AuthenticatedData = upstream.AuthenticatedData - out.CheckingDisabled = req.CheckingDisabled - out.Rcode = upstream.Rcode - out.Answer = upstream.Answer - out.Ns = upstream.Ns - - var extras []dns.RR - for _, rr := range upstream.Extra { - if !isPseudo(rr) { - extras = append(extras, rr) - } - } + out.Rcode = resp.Rcode + out.Answer = resp.Answer + out.Ns = resp.Ns + out.Extra = resp.Extra if ro := req.IsEdns0(); ro != nil { o := new(dns.OPT) o.Hdr.Name = "." o.Hdr.Rrtype = dns.TypeOPT o.SetUDPSize(ro.UDPSize()) - if ro.Do() { - o.SetDo() - } - // 优先使用“请求”的 EDNS 版本;扩展 RCODE/EDE 来自上游/缓存 - o.SetVersion(ro.Version()) - if uo := upstream.IsEdns0(); uo != nil { + if uo := resp.IsEdns0(); uo != nil { + o.Option = uo.Option o.SetExtendedRcode(uint16(uo.ExtendedRcode())) - for _, opt := range uo.Option { - if ede, ok := opt.(*dns.EDNS0_EDE); ok { - o.Option = append(o.Option, ede) - } - } } else if meta != nil && meta.ednsPresent { - // Upstream/cached msg has no OPT(例如缓存时被剥离),用缓存元数据重建 - o.SetExtendedRcode(uint16(meta.ednsExtRcode)) + o.SetExtendedRcode(meta.ednsExtRcode) for _, e := range meta.ednsEDE { o.Option = append(o.Option, e) } } - extras = append(extras, o) + out.Extra = append(out.Extra, o) } - out.Extra = extras out.Compress = true - - if err := w.WriteMsg(out); err != nil { - log.Printf("[write] WriteMsg error: %v", err) - } + _ = w.WriteMsg(out) } -/****************************************************************** - * 主处理器 - ******************************************************************/ -func handleDNS( - upstreams []string, - cacheMaxTTL, timeout time.Duration, - maxParallel int, - stripECSBeforeForward bool, - allowTCPFallback bool, - blPtr *atomic.Pointer[suffixMatcher], - blRcode int, -) dns.HandlerFunc { - return func(w dns.ResponseWriter, r *dns.Msg) { - if len(r.Question) == 0 { - dns.HandleFailed(w, r) - return - } - q := r.Question[0] - log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)", - dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr()) - - if stripECSBeforeForward { - stripECS(r) - } - - // 黑名单拦截:命中则不查上游,直接返回 - if rule, ok := blPtr.Load().match(q.Name); ok { - nameCanon := dns.CanonicalName(q.Name) - log.Printf("[blacklist] HIT %s rule=%s (no upstream query)", nameCanon, rule) - up := makeBlockedUpstream(blRcode, rule) - writeReply(w, r, up, nil) - return - } - - // 缓存 key:基础 + (未 strip 时的)ECS 片段 - key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled) + ecsKeyPart(r, stripECSBeforeForward) - - if cachedMsg, cachedMeta, ok := tryCacheRead(key); ok { - log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name) - writeReply(w, r, cachedMsg, cachedMeta) - return - } - log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name) - - // 使用 singleflight 合并相同 key 的并发查询,避免上游雪崩 - v, _, _ := inflight.Do(key, func() (any, error) { - ctx := context.Background() - resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback) - if resp != nil { - cacheWrite(key, resp, cacheMaxTTL) +func stripECSFromMsg(m *dns.Msg) { + if o := m.IsEdns0(); o != nil { + newOpt := make([]dns.EDNS0, 0, len(o.Option)) + for _, opt := range o.Option { + if opt.Option() != dns.EDNS0SUBNET { + newOpt = append(newOpt, opt) } - return resp, nil - }) - resp, _ := v.(*dns.Msg) - if resp == nil { - log.Printf("[error] all upstreams failed for %s", q.Name) - dns.HandleFailed(w, r) - return } - for _, ans := range resp.Answer { - log.Printf("[answer] %s", ans.String()) - } - writeReply(w, r, resp, nil) + o.Option = newOpt } } -/****************************************************************** - * 主函数 - ******************************************************************/ +func stripPseudoExtras(m *dns.Msg) { + newExtra := make([]dns.RR, 0, len(m.Extra)) + for _, rr := range m.Extra { + if rr.Header().Rrtype != dns.TypeOPT && rr.Header().Rrtype != dns.TypeTSIG { + newExtra = append(newExtra, rr) + } + } + m.Extra = newExtra +} + +func extractEDNSMeta(m *dns.Msg) (bool, uint8, uint16, []*dns.EDNS0_EDE) { + if o := m.IsEdns0(); o != nil { + var edes []*dns.EDNS0_EDE + for _, opt := range o.Option { + if e, ok := opt.(*dns.EDNS0_EDE); ok { + edes = append(edes, e) + } + } + return true, o.Version(), uint16(o.ExtendedRcode()), edes + } + return false, 0, 0, nil +} + func main() { - rand.Seed(time.Now().UnixNano()) - - var help bool - certFile := flag.String("cert", "server.crt", "TLS 证书文件路径") - keyFile := flag.String("key", "server.key", "TLS 私钥文件路径") - addr := flag.String("addr", ":853", "DoT 监听地址") - upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表") - cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL") - cacheSizeFlag := flag.Int("cache-size", defaultCacheSize, "LRU 缓存大小上限") - timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时") - readTimeoutFlag := flag.Duration("read-timeout", 0, "DoT 连接读超时(0=不限制)") - writeTimeoutFlag := flag.Duration("write-timeout", 0, "DoT 连接写超时(0=不限制)") - maxParallel := flag.Int("max-parallel", 3, "并发上游数量") - stripECSFlag := flag.Bool("strip-ecs", true, "去除 ECS") - allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时 TCP 回退") - blacklistStr := flag.String("blacklist", "", "逗号分隔的黑名单域名(后缀匹配;支持如 *.example.com)") - blacklistFile := flag.String("blacklist-file", "", "黑名单文件路径(每行一个域名;支持 # 或 ; 注释;可接受 hosts 风格)") - blacklistRcodeFlag := flag.String("blacklist-rcode", "REFUSED", "命中黑名单返回的 RCODE:REFUSED|NXDOMAIN|SERVFAIL") - verbose := flag.Bool("v", false, "verbose 日志") - - flag.BoolVar(&help, "h", false, "") - flag.BoolVar(&help, "help", false, "帮助信息") - + addr := flag.String("addr", ":853", "DoT address") + upStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "Upstreams") + certFile := flag.String("cert", "server.crt", "TLS Cert") + keyFile := flag.String("key", "server.key", "TLS Key") + blFile := flag.String("blacklist-file", "", "Blacklist file") + blRcodeStr := flag.String("blacklist-rcode", "REFUSED", "RCODE for blocked") + v := flag.Bool("v", false, "Verbose logging") flag.Parse() - if help { - fmt.Printf( - "\t\tDNS-over-TLS (DoT)\n"+ - "\tVersion 0.1\n"+ - "\tE-mail: aixiao@aixiao.me\n"+ - "\tBuild Date: %s\n", BuildDate) + initLogger(*v) - flag.Usage() - fmt.Printf("\n") - os.Exit(0) - } - - initLogger(*verbose) - - var err error - cache, err = lru.New[string, *cacheEntry](*cacheSizeFlag) - if err != nil { - log.Fatalf("[fatal] failed to init LRU cache: %v", err) - } + cache, _ = lru.New[string, *cacheEntry](defaultCacheSize) + udpClient = &dns.Client{Net: "udp", Timeout: 2 * time.Second, SingleInflight: true} + tcpClient = &dns.Client{Net: "tcp", Timeout: 3 * time.Second, SingleInflight: true} cert, err := tls.LoadX509KeyPair(*certFile, *keyFile) if err != nil { - log.Fatalf("[fatal] failed to load cert: %v", err) + log.Fatalf("TLS Error: %v", err) } - var upstreams []string - for _, s := range strings.Split(*upstreamStr, ",") { - if t := strings.TrimSpace(s); t != "" { - if !strings.Contains(t, ":") { - t = fmt.Sprintf("%s:53", t) - } - upstreams = append(upstreams, t) - } - } - if len(upstreams) == 0 { - log.Fatal("[fatal] no upstream DNS servers provided") - } + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() - // 如需更保守的 UDP 尺寸以减少分片,可将 UDPSize 改为 1232 - udpClient = &dns.Client{Net: "udp", UDPSize: 4096, SingleInflight: true} - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - startCacheCleaner(ctx) - - // 加载黑名单规则 - blPtr, blRcode := initBlacklist(ctx, *blacklistStr, *blacklistFile, *blacklistRcodeFlag) + blPtr, blRcode := initBlacklist(ctx, *blFile, *blRcodeStr) mux := dns.NewServeMux() - mux.HandleFunc(".", handleDNS( - upstreams, - *cacheTTLFlag, - *timeoutFlag, - *maxParallel, - *stripECSFlag, - *allowTCPFallback, - blPtr, - blRcode, - )) + mux.HandleFunc(".", handleDNS(strings.Split(*upStr, ","), 1*time.Hour, 2*time.Second, 3, true, blPtr, blRcode)) - srv := &dns.Server{ - Addr: *addr, - Net: "tcp-tls", + server := &dns.Server{ + Addr: *addr, + Net: "tcp-tls", + Handler: mux, TLSConfig: &tls.Config{ Certificates: []tls.Certificate{cert}, - MinVersion: tls.VersionTLS13, // TLS 1.3 + MinVersion: tls.VersionTLS13, NextProtos: []string{"dot"}, }, - Handler: mux, - ReadTimeout: *readTimeoutFlag, - WriteTimeout: *writeTimeoutFlag, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, } - stop := make(chan os.Signal, 1) - signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) - - errCh := make(chan error, 1) go func() { - log.Printf("🚀 starting DNS-over-TLS on %s", *addr) - log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v | read_timeout=%s | write_timeout=%s | blacklist_rules=%d | blacklist_rcode=%s", - upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback, - readTimeoutFlag.String(), writeTimeoutFlag.String(), - len(blPtr.Load().rules), strings.ToUpper(*blacklistRcodeFlag)) - errCh <- srv.ListenAndServe() + log.Printf("🚀 DoT Server started on %s (TLS 1.3)", *addr) + if err := server.ListenAndServe(); err != nil { + log.Printf("Server exit: %v", err) + } }() - select { - case sig := <-stop: - log.Printf("[shutdown] caught signal: %s", sig) - cancel() - if err := srv.Shutdown(); err != nil { - log.Printf("[shutdown] server shutdown error: %v", err) - } - case err := <-errCh: - if err != nil { - log.Fatalf("[fatal] server error: %v", err) - } - } + <-ctx.Done() + log.Println("Gracefully shutting down...") - log.Println("[bye] server stopped.") + // 给 5 秒处理残余请求 + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + server.ShutdownContext(shutdownCtx) }