diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..6df5a28 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,32 @@ +# ---------- 构建阶段 ---------- +FROM golang:1.25.2-alpine AS builder + +WORKDIR /app +COPY . . + +ENV CGO_ENABLED=0 GOOS=linux GOARCH=amd64 +RUN go build -o dot main.go + +# ---------- 运行阶段 ---------- +FROM alpine:3.20 + +WORKDIR /app + +# 只复制编译好的二进制,不再打包证书 +COPY --from=builder /app/dot /app/dot + +# 运行时定义可覆盖的环境变量(不在构建时生效) +ENV CERT_FILE=aixiao.me.cer +ENV KEY_FILE=aixiao.me.key + +EXPOSE 853/tcp + +# 启动命令,使用运行时传入的证书路径 +ENTRYPOINT ["sh", "-c", "./dot \ + -cert ${CERT_FILE} \ + -key ${KEY_FILE} \ + -addr :853 \ + -upstream \"119.29.29.29:53,223.5.5.5:53,114.114.114.114:53\" \ + -cache-ttl 300s \ + -timeout 3s \ + -max-parallel 3"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..7d1419a --- /dev/null +++ b/README.md @@ -0,0 +1,121 @@ +# 🧠 DNS-over-TLS Cache Proxy + +一个基于 Go 的高性能 **DNS-over-TLS (DoT)** 缓存代理服务器。 +支持多上游并发解析、智能缓存、隐私保护与优雅关闭。 +轻量、无依赖、可直接部署。 + +## ✨ 特性 + +- 🔒 **加密传输** — 完全支持 DNS-over-TLS (RFC 7858) +- ⚡ **多上游并发查询** — 类似“快乐眼球”机制,提升解析速度 +- 🧠 **TTL 智能缓存** — 支持正向与负面缓存(RFC 2308) +- 🧹 **自动清理** — 定期清除过期缓存 +- 🧩 **隐私保护** — 默认剥离 ECS (EDNS Client Subnet) +- 🪶 **轻量高效** — 单文件可执行,零外部依赖 + + +## 📦 安装 + +### 🧰 源码构建 + +```bash +git clone https://git.aixiao.me/aixiao/dot.git +cd dot +go build -o dot main.go +``` + +### 🐳 Docker 构建 + +```bash +#构建、启动 +bash build.sh build +bash build.sh run + +#清理 +bash build.sh stop +bash build.sh clean +``` + + +## 🚀 启动服务 + +```bash +./dot \ + -cert=server.crt \ + -key=server.key \ + -addr=":853" \ + -upstream="8.8.8.8:53,1.1.1.1:53" \ + -cache-ttl=120s \ + -timeout=3s \ + -max-parallel=2 \ + -strip-ecs=true \ + -tcp-fallback=true \ + -v + +``` + +输出示例: +``` +🚀 starting DNS-over-TLS on :853 +[req] A www.example.com. (id=40192 cd=false do=true from=127.0.0.1:58877) +[cache] MISS A www.example.com. +[answer] www.example.com. 300 IN A 93.184.216.34 +``` + + +## 🧩 配置参数 + +| 参数 | 默认值 | 说明 | +|------|---------|------| +| `--addr` | `:853` | 监听地址 | +| `--cert` | `server.crt` | TLS 证书路径 | +| `--key` | `server.key` | TLS 私钥路径 | +| `--upstream` | `8.8.8.8:53,1.1.1.1:53` | 上游 DNS 服务器 | +| `--cache-ttl` | `60s` | 最大缓存 TTL | +| `--timeout` | `3s` | 上游查询超时 | +| `--max-parallel` | `3` | 并发上游查询数 | +| `--strip-ecs` | `true` | 是否剥离 ECS 信息 | +| `--tcp-fallback` | `true` | 是否启用 TCP 回退 | +| `--v` | `false` | 详细日志模式 | + + +## 🧪 测试解析 + +使用 `kdig` 或 `dig` 进行测试: + +```bash +kdig @127.0.0.1 +tls-ca +tls-host=dot.local www.example.com +``` + + +## 📊 缓存机制 + +- **缓存键**:`domain|type|class|DO|CD` +- **正向缓存**:取最小 TTL 与配置上限的较小值 +- **负面缓存**:依据 SOA.MINIMUM(RFC 2308) +- **动态 TTL 续算**:返回时根据剩余时间更新 TTL +- **清理周期**:每 5 分钟清除过期项 + + +## 🔐 安全特性 + +- 默认支持 **TLS 1.2 / 1.3** +- 剥离 **EDNS Client Subnet** +- 不缓存 OPT/TSIG 伪记录 +- 独立缓存空间隔离 DO/CD 查询 + + +## 🧭 路线图 + +- [ ] 支持 DoH (DNS-over-HTTPS) +- [ ] LRU 缓存上限控制 +- [ ] 增加配置文件支持 (YAML/JSON) +- [ ] 集成 Docker Compose & CI/CD + + +## 👨‍💻 作者信息 + +**Email:** aixiao@aixiao.me +**License:** MIT +**Language:** Go 1.22+ +**Dependency:** [github.com/miekg/dns](https://github.com/miekg/dns) diff --git a/build.sh b/build.sh new file mode 100644 index 0000000..955a88a --- /dev/null +++ b/build.sh @@ -0,0 +1,96 @@ +#!/usr/bin/env bash +# +# build.sh — Build & Run Helper for DNS-over-TLS Cache Proxy +# Author: niuyuling +# Email: aixiao@aixiao.me +# License: MIT +# ----------------------------------------------------------- +# 用途:快速构建、运行和管理 dot 容器镜像。 +# 用法: +# ./build.sh build 构建镜像 +# ./build.sh run 启动容器(后台) +# ./build.sh logs 查看日志 +# ./build.sh stop 停止容器 +# ./build.sh clean 删除容器和镜像 +# ./build.sh rebuild 重新构建镜像并启动 +# ----------------------------------------------------------- + +set -e + +IMAGE_NAME="dot" +CONTAINER_NAME="dot" +TAG="latest" +PORT="853" +CERT_FILE="jinllpay.com.cer" +KEY_FILE="jinllpay.com.key" + +# ---------- 函数区 ---------- + +build() { + echo "🔨 Building Docker image: ${IMAGE_NAME}:${TAG} ..." + docker build -t "${IMAGE_NAME}:${TAG}" . + echo "✅ Build complete." +} + +run() { + echo "🚀 Starting container ${CONTAINER_NAME}..." + + # 确保旧容器不冲突 + if docker ps -a --format '{{.Names}}' | grep -w "${CONTAINER_NAME}" >/dev/null 2>&1; then + echo "⚠️ Existing container found. Removing..." + docker rm -f "${CONTAINER_NAME}" >/dev/null 2>&1 || true + fi + + docker run -d \ + --name "${CONTAINER_NAME}" \ + --memory=256m \ + --memory-swap=384m \ + --memory-reservation=128m \ + -p ${PORT}:853 \ + -e CERT_FILE="/app/${CERT_FILE}" \ + -e KEY_FILE="/app/${KEY_FILE}" \ + -v "$(pwd)/${CERT_FILE}:/app/${CERT_FILE}:ro" \ + -v "$(pwd)/${KEY_FILE}:/app/${KEY_FILE}:ro" \ + "${IMAGE_NAME}:${TAG}" + + echo "✅ Container started on port ${PORT}." +} + +logs() { + echo "📜 Showing logs..." + docker logs -f "${CONTAINER_NAME}" +} + +stop() { + echo "🛑 Stopping container..." + docker stop "${CONTAINER_NAME}" >/dev/null 2>&1 || true + docker rm "${CONTAINER_NAME}" >/dev/null 2>&1 || true + echo "✅ Container stopped and removed." +} + +clean() { + stop + echo "🧹 Removing image ${IMAGE_NAME}:${TAG}..." + docker rmi "${IMAGE_NAME}:${TAG}" >/dev/null 2>&1 || true + echo "✅ Cleanup complete." +} + +rebuild() { + clean + build + run +} + +# ---------- 主逻辑 ---------- +case "$1" in + build) build ;; + run) run ;; + logs) logs ;; + stop) stop ;; + clean) clean ;; + rebuild) rebuild ;; + *) + echo "Usage: ./build.sh [build|run|logs|stop|clean|rebuild]" + exit 1 + ;; +esac diff --git a/dot b/dot index f19db30..c6bc347 100644 Binary files a/dot and b/dot differ diff --git a/go.mod b/go.mod index 86b472f..8ee954f 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.25.2 require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/miekg/dns v1.1.68 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/prometheus/client_golang v1.23.2 // indirect diff --git a/go.sum b/go.sum index 399281d..c2e984d 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA= github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= diff --git a/main.go b/main.go index d2c714a..5ffb9a2 100644 --- a/main.go +++ b/main.go @@ -15,14 +15,13 @@ import ( "time" "github.com/miekg/dns" + + lru "github.com/hashicorp/golang-lru/v2" ) /****************************************************************** * 日志初始化 ******************************************************************/ - -// initLogger 根据 -v 开关设置日志格式;verbose=true 时附加源码文件与行号。 -// 说明:Lmicroseconds 便于排查毫秒级时序问题。 func initLogger(verbose bool) { flags := log.Ldate | log.Ltime | log.Lmicroseconds if verbose { @@ -32,24 +31,25 @@ func initLogger(verbose bool) { } /****************************************************************** - * 缓存结构 + * 缓存结构(支持 TTL + LRU) ******************************************************************/ -// cacheEntry 表示一条缓存项:保存上游响应的副本,以及过期时间。 -// 注:将完整 *dns.Msg 存入缓存,便于原样复用 Answer/Ns/Extra。 type cacheEntry struct { - msg *dns.Msg // 上游完整响应(拷贝存储) - expireAt time.Time // 过期时间 + msg *dns.Msg + expireAt time.Time } -// cache 使用 sync.Map 作为并发安全的键值存储。 -// 注:此实现为 TTL 驱动的简单缓存;如需上限/淘汰策略,可叠加 LRU。 -var cache sync.Map +var ( + cache *lru.Cache[string, *cacheEntry] + cacheMutex sync.RWMutex +) -const cacheCleanupInterval = 5 * time.Minute +const ( + cacheCleanupInterval = 5 * time.Minute + defaultCacheSize = 10000 // 默认最大缓存条目数 +) -// startCacheCleaner 定时清理过期项,避免缓存无限增长。 -// 使用 ctx 控制生命周期,与主服务一同退出。 +// startCacheCleaner 定期清理过期缓存(修复:在删除前二次校验) func startCacheCleaner(ctx context.Context) { go func() { ticker := time.NewTicker(cacheCleanupInterval) @@ -60,25 +60,35 @@ func startCacheCleaner(ctx context.Context) { return case <-ticker.C: now := time.Now() - n := 0 - cache.Range(func(k, v any) bool { - e := v.(*cacheEntry) - if now.After(e.expireAt) { - cache.Delete(k) - n++ + var toDelete []string + + cacheMutex.RLock() + for _, k := range cache.Keys() { + if v, ok := cache.Peek(k); ok && now.After(v.expireAt) { + toDelete = append(toDelete, k) + } + } + cacheMutex.RUnlock() + + if len(toDelete) > 0 { + pruned := 0 + cacheMutex.Lock() + for _, k := range toDelete { + if v, ok := cache.Peek(k); ok && now.After(v.expireAt) { + cache.Remove(k) + pruned++ + } + } + cacheMutex.Unlock() + if pruned > 0 { + log.Printf("[cache] cleaned %d expired entries", pruned) } - return true - }) - if n > 0 { - log.Printf("[cache] cleaned %d expired entries", n) } } } }() } -// 计算缓存键:name + type + class + DO + CD。 -// 说明:包含 DO/CD 可避免不同验证上下文污染同一键。 func cacheKeyFromMsg(q dns.Question, do, cd bool) string { var b strings.Builder b.Grow(len(q.Name) + 32) @@ -96,8 +106,6 @@ func cacheKeyFromMsg(q dns.Question, do, cd bool) string { return b.String() } -// 识别伪 RR(OPT/TSIG);这些记录的“TTL 字段”并非真实 TTL,不可参与 TTL 计算或改写。 -// OPT:其 TTL 字段承载扩展 RCODE 与 DO 位等标志;TSIG:签名,不应缓存或改写。 func isPseudo(rr dns.RR) bool { switch rr.(type) { case *dns.OPT, *dns.TSIG: @@ -107,66 +115,93 @@ func isPseudo(rr dns.RR) bool { } } -// tryCacheRead 尝试读取缓存并回填“剩余 TTL”;对 Answer/Ns/Extra 普通 RR 截断 TTL,跳过伪 RR。 -// 返回响应副本,保证对外不可变(防止调用方修改缓存内部对象)。 +// 读取缓存(修复:Get 在写锁下;在锁外调整 TTL) func tryCacheRead(key string) (*dns.Msg, bool) { - v, ok := cache.Load(key) - if !ok { - return nil, false - } - e := v.(*cacheEntry) now := time.Now() + + cacheMutex.Lock() + e, ok := cache.Get(key) // Get 会更新 LRU,必须在写锁下 + if !ok { + cacheMutex.Unlock() + return nil, false + } if now.After(e.expireAt) { - cache.Delete(key) + cache.Remove(key) + cacheMutex.Unlock() return nil, false } + // 拷贝副本,在锁外改 TTL,减少临界区时间 out := e.msg.Copy() - remaining := uint32(e.expireAt.Sub(now).Seconds()) + expireAt := e.expireAt + cacheMutex.Unlock() + + remaining := uint32(expireAt.Sub(now).Seconds()) if remaining == 0 { - cache.Delete(key) + cacheMutex.Lock() + cache.Remove(key) + cacheMutex.Unlock() return nil, false } - // 回填剩余 TTL(不增加,只做上限截断) - for i := range out.Answer { - if out.Answer[i].Header().Ttl > remaining { - out.Answer[i].Header().Ttl = remaining - } - } - for i := range out.Ns { - if out.Ns[i].Header().Ttl > remaining { - out.Ns[i].Header().Ttl = remaining - } - } - for i := range out.Extra { - if isPseudo(out.Extra[i]) { - continue - } - if out.Extra[i].Header().Ttl > remaining { - out.Extra[i].Header().Ttl = remaining + + for _, sec := range [][]dns.RR{out.Answer, out.Ns, out.Extra} { + for _, rr := range sec { + if isPseudo(rr) { + continue + } + if rr.Header().Ttl > remaining { + rr.Header().Ttl = remaining + } } } return out, true } -// negativeTTL 依据 RFC 2308 计算负面缓存 TTL。 -// 对 NXDOMAIN 或 NODATA(NOERROR + Answer 为空)取 min(SOA.TTL, SOA.MINIMUM),再与配置上限取 min。 -func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) { - // NXDOMAIN,或 NOERROR 但 Answer 为空(NODATA) - if m.Rcode != dns.RcodeNameError && !(m.Rcode == dns.RcodeSuccess && len(m.Answer) == 0) { - return 0, false +// 计算负面 TTL +// hasAnswerForType 判断报文中是否存在回答“请求类型”的 RRset +func hasAnswerForType(m *dns.Msg, q dns.Question) bool { + for _, rr := range m.Answer { + h := rr.Header() + if h.Rrtype == q.Qtype && strings.EqualFold(h.Name, q.Name) { + return true + } } + return false +} + +// 计算负面 TTL(修复:正确识别 NODATA,包括 CNAME 等场景) +func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) { + // NXDOMAIN:肯定是负面 + if m.Rcode != dns.RcodeNameError { + // 不是 NXDOMAIN,则仅当 NOERROR 但没有“匹配 QTYPE 的答案”时才是 NODATA + if m.Rcode != dns.RcodeSuccess || len(m.Question) == 0 || hasAnswerForType(m, m.Question[0]) { + return 0, false + } + } + + // 按 RFC 2308,从 Authority(Ns)优先取 SOA(多数实现都只放在 Authority) var soa *dns.SOA - // SOA 可能出现在 Authority(Ns) 或 Additional(Extra) - for _, rr := range append(m.Ns, m.Extra...) { + for _, rr := range m.Ns { if s, ok := rr.(*dns.SOA); ok { soa = s break } } + // 兼容性:偶尔也有人把 SOA 放 Extra(不规范,但为了兼容可以兜底看看) if soa == nil { - // 无 SOA 无法可靠计算负面 TTL;此时不缓存或由上限兜底(本实现选择不缓存) + for _, rr := range m.Extra { + if s, ok := rr.(*dns.SOA); ok { + soa = s + break + } + } + } + if soa == nil { + // 建议:无 SOA 时不做负面缓存(返回 0,false) + // 如你更希望兜底,可改成:return uint32(maxTTL.Seconds()), true return 0, false } + + // 负面 TTL 取 min(SOA.MINIMUM, SOA 自身 TTL),再与配置上限比较 ttl := soa.Hdr.Ttl if soa.Minttl < ttl { ttl = soa.Minttl @@ -178,8 +213,6 @@ func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) { return ttl, ttl > 0 } -// minRRsetTTL 获取普通(正向)响应的最小 TTL(Answer/Ns/Extra 中的普通 RR),跳过伪 RR。 -// 用于决定缓存过期时间的上限(与配置上限再取 min)。 func minRRsetTTL(m *dns.Msg) (uint32, bool) { minTTL := uint32(0) hasTTL := false @@ -198,8 +231,6 @@ func minRRsetTTL(m *dns.Msg) (uint32, bool) { return minTTL, hasTTL } -// stripPseudoExtras 从消息中剥离伪 RR(OPT/TSIG)。 -// 用途:缓存前剥离,避免将传输层细节或签名内容写入缓存。 func stripPseudoExtras(m *dns.Msg) { if len(m.Extra) == 0 { return @@ -214,58 +245,44 @@ func stripPseudoExtras(m *dns.Msg) { m.Extra = out } -// cacheWrite 写缓存:优先处理负面缓存;正面缓存取 min(应答中最小 TTL, 配置上限)。 -// 写入前统一剥离伪 RR,保证缓存与传输解耦。 +// 写缓存 func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) { if in == nil { return } - // 仅缓存 NOERROR / NXDOMAIN,其余不缓存(如 SERVFAIL/REFUSED 等) if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError { return } - // 负面缓存 - if ttl, ok := negativeTTL(in, maxTTL); ok && ttl > 0 { - expire := time.Now().Add(time.Duration(ttl) * time.Second) - cp := in.Copy() - stripPseudoExtras(cp) - cache.Store(key, &cacheEntry{msg: cp, expireAt: expire}) - return - } - // 正向缓存:minTTL 与 maxTTL 取较小 - minTTL, ok := minRRsetTTL(in) - if !ok { - // 没有 TTL 时可用上限兜底(也可选择不缓存,这里选择兜底) - if maxTTL > 0 { - expire := time.Now().Add(maxTTL) - cp := in.Copy() - stripPseudoExtras(cp) - cache.Store(key, &cacheEntry{msg: cp, expireAt: expire}) + var ttl uint32 + var ok bool + if ttl, ok = negativeTTL(in, maxTTL); !ok { + minTTL, has := minRRsetTTL(in) + if has { + cfgTTL := uint32(maxTTL.Seconds()) + if cfgTTL > 0 && minTTL > cfgTTL { + minTTL = cfgTTL + } + ttl = minTTL + } else { + ttl = uint32(maxTTL.Seconds()) } + } + if ttl == 0 { return } - cfgTTL := uint32(maxTTL.Seconds()) - finalTTL := minTTL - if cfgTTL > 0 && finalTTL > cfgTTL { - finalTTL = cfgTTL - } - if finalTTL == 0 { - return - } - expire := time.Now().Add(time.Duration(finalTTL) * time.Second) + expire := time.Now().Add(time.Duration(ttl) * time.Second) cp := in.Copy() stripPseudoExtras(cp) - cache.Store(key, &cacheEntry{msg: cp, expireAt: expire}) + cacheMutex.Lock() + cache.Add(key, &cacheEntry{msg: cp, expireAt: expire}) + cacheMutex.Unlock() } /****************************************************************** - * 上游查询(带 context 取消、并发上限、UDP→TCP 回退) + * 上游查询 ******************************************************************/ - -// 全局可复用 UDP 客户端(Net=udp,UDPSize 放大以承载更大的响应) var udpClient *dns.Client -// shuffled 打乱上游列表,避免固定顺序导致单点拥塞或偏置。 func shuffled(xs []string) []string { out := make([]string, len(xs)) copy(out, xs) @@ -273,10 +290,22 @@ func shuffled(xs []string) []string { return out } -// queryUpstreamsLimited 并发向多个上游发起查询,返回首个有效结果。 -// - timeout:整个查询窗口的上限(基于子 context) -// - maxParallel:同时在飞请求上限 -// - allowTCPFallback:若 UDP 截断(TC 位)则回退 TCP 重试 +// clampEDNSForUpstream 返回一个 msg 副本,把 EDNS UDP size 夹到给定大小 +func clampEDNSForUpstream(in *dns.Msg, size uint16) *dns.Msg { + m := in.Copy() + o := m.IsEdns0() + if o == nil { + o = &dns.OPT{} + o.Hdr.Name = "." + o.Hdr.Rrtype = dns.TypeOPT + m.Extra = append(m.Extra, o) + } + if size > 0 { + o.SetUDPSize(size) + } + return m +} + func queryUpstreamsLimited( ctx context.Context, req *dns.Msg, @@ -290,39 +319,29 @@ func queryUpstreamsLimited( } servers := shuffled(upstreams) - // 每次查询一个带超时的子 context;拿到首个有效结果后 cancel,取消其他请求。 cctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - type result struct { - msg *dns.Msg - } - ch := make(chan result, len(servers)) // 缓冲至多保存所有返回,避免阻塞 - sem := make(chan struct{}, maxParallel) // 并发信号量,限制同时在飞请求数 + type result struct{ msg *dns.Msg } + ch := make(chan result, len(servers)) + sem := make(chan struct{}, maxParallel) - // execOne 对单个上游发起查询;对 req 使用 Copy() 防止在多 goroutine 下共享同一 *dns.Msg。 execOne := func(svr string) { - // UDP 查询(带 context);使用 req.Copy() 防止并发读写 - resp, _, err := udpClient.ExchangeContext(cctx, req.Copy(), svr) + upReq := clampEDNSForUpstream(req, 1232) // 或做成 flag + resp, _, err := udpClient.ExchangeContext(cctx, upReq, svr) if err == nil && resp != nil && resp.Truncated && allowTCPFallback { - // TCP 回退(响应被截断) log.Printf("[upstream] UDP truncated, retry TCP: %s", svr) tcpClient := *udpClient tcpClient.Net = "tcp" + resp, _, err = tcpClient.ExchangeContext(cctx, req.Copy(), svr) } if err != nil || resp == nil { - // 错误仅在未被上层取消时打印,避免超时后噪声 - if err != nil { - if cctx.Err() == nil { - log.Printf("[upstream] %s error: %v", svr, err) - } - } else { - log.Printf("[upstream] %s nil response", svr) + if err != nil && cctx.Err() == nil { + log.Printf("[upstream] %s error: %v", svr, err) } return } - // 过滤差 RCODE(不参与竞速):SERVFAIL / REFUSED / FORMERR if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError { return } @@ -332,7 +351,6 @@ func queryUpstreamsLimited( } } - // “快乐眼球”式启动:生产者不阻塞,获取配额在 goroutine 内部完成;这样可以更快进入接收循环并在首个成功后取消其余。 for _, s := range servers { s := s go func() { @@ -346,7 +364,6 @@ func queryUpstreamsLimited( }() } - // 返回第一个非空结果,并 cancel 其他 goroutine。 for i := 0; i < len(servers); i++ { select { case r := <-ch: @@ -363,11 +380,8 @@ func queryUpstreamsLimited( } /****************************************************************** - * EDNS(0) / ECS 处理 + * EDNS / 响应构造 ******************************************************************/ - -// stripECS 从请求中去除 EDNS Client Subnet(ECS),减少缓存污染并保护隐私。 -// 注:ECS 会导致上游按地理/网络分区返回不同答案,不适合集中缓存。 func stripECS(m *dns.Msg) { if o := m.IsEdns0(); o != nil { var kept []dns.EDNS0 @@ -380,7 +394,6 @@ func stripECS(m *dns.Msg) { } } -// getDOFlag 读取 DO(DNSSEC OK)位,构成缓存键的一部分。 func getDOFlag(m *dns.Msg) bool { if o := m.IsEdns0(); o != nil { return o.Do() @@ -388,67 +401,47 @@ func getDOFlag(m *dns.Msg) bool { return false } -/****************************************************************** - * 响应构造:使用客户端请求头构造 reply,复制上游内容 - ******************************************************************/ - -// writeReply 根据客户端请求构造响应:复制上游 Answer/Ns/非伪 Extra, -// 并按客户端请求重建 OPT(UDPSize/DO),同时继承上游的扩展 RCODE 与 EDNS 版本; -// 可选透传上游的 EDE(Extended DNS Errors)以保留诊断信息。 func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) { if upstream == nil { dns.HandleFailed(w, req) return } - out := new(dns.Msg) out.SetReply(req) out.Authoritative = false out.RecursionAvailable = upstream.RecursionAvailable out.AuthenticatedData = upstream.AuthenticatedData - out.CheckingDisabled = req.CheckingDisabled // 反映客户端 CD 位 - out.Rcode = upstream.Rcode // 主 RCODE(低 4 位) + out.CheckingDisabled = req.CheckingDisabled + out.Rcode = upstream.Rcode out.Answer = upstream.Answer out.Ns = upstream.Ns - // 复制上游的非伪 RR 额外记录;OPT/TSIG 不透传 - extras := make([]dns.RR, 0, len(upstream.Extra)) + var extras []dns.RR for _, rr := range upstream.Extra { - if isPseudo(rr) { - continue + if !isPseudo(rr) { + extras = append(extras, rr) } - extras = append(extras, rr) } - // 基于客户端请求镜像 EDNS(UDPSize + DO) if ro := req.IsEdns0(); ro != nil { o := new(dns.OPT) o.Hdr.Name = "." o.Hdr.Rrtype = dns.TypeOPT - - // 与客户端保持一致的 UDPSize / DO 位 o.SetUDPSize(ro.UDPSize()) if ro.Do() { o.SetDo() } - - // 继承上游的扩展 RCODE 与 EDNS 版本(注意不同版本签名差异,这里显式转换) if uo := upstream.IsEdns0(); uo != nil { - // 你当前库期望 uint16,这里强转;若你的库期望 uint8,也可改成 uint8(...) o.SetExtendedRcode(uint16(uo.ExtendedRcode())) o.SetVersion(uint8(uo.Version())) - - // 可选:透传只读的 EDE 诊断信息 for _, opt := range uo.Option { if ede, ok := opt.(*dns.EDNS0_EDE); ok { o.Option = append(o.Option, ede) } } } - extras = append(extras, o) } - out.Extra = extras out.Compress = true @@ -458,11 +451,8 @@ func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) { } /****************************************************************** - * 处理器 + * 主处理器 ******************************************************************/ - -// handleDNS 为每个请求执行:日志 → (可选)剥离 ECS → 缓存命中 → 上游查询 → 写缓存 → 回写。 -// 注意:缓存键包含 DO/CD;同时通过 tryCacheRead 回填剩余 TTL。 func handleDNS( upstreams []string, cacheMaxTTL, timeout time.Duration, @@ -476,20 +466,14 @@ func handleDNS( return } q := r.Question[0] - - // 基本访问日志:类型/域名/ID/CD/DO/来源 IP:Port log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)", dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr()) - // 可选:去除 ECS(推荐) if stripECSBeforeForward { stripECS(r) } - - // 缓存键(域名小写 + QTYPE/QCLASS + DO/CD) key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled) - // 1) 缓存命中:命中即快速返回 if cached, ok := tryCacheRead(key); ok { log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name) writeReply(w, r, cached) @@ -497,7 +481,6 @@ func handleDNS( } log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name) - // 2) 上游查询(带 context 取消 & TCP 可选回退);并发向多个上游竞速 ctx := context.Background() resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback) if resp == nil { @@ -505,11 +488,7 @@ func handleDNS( dns.HandleFailed(w, r) return } - - // 3) 写缓存(负面/正面均处理;剥离伪 RR) cacheWrite(key, resp, cacheMaxTTL) - - // 4) 回写给客户端,并打印 Answer 方便调试 for _, ans := range resp.Answer { log.Printf("[answer] %s", ans.String()) } @@ -517,31 +496,38 @@ func handleDNS( } } +/****************************************************************** + * 主函数 + ******************************************************************/ func main() { - rand.Seed(time.Now().UnixNano()) // 随机源:用于上游列表打乱 + rand.Seed(time.Now().UnixNano()) - // 命令行参数:可根据部署场景调整默认值 - certFile := flag.String("cert", "server.crt", "TLS 证书文件路径 (.cer/.crt)") - keyFile := flag.String("key", "server.key", "TLS 私钥文件路径 (.key)") - addr := flag.String("addr", ":853", "DoT 监听地址(默认 :853)") - upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表(逗号分隔)") - cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL(默认 60s;实际取 min(上游最小TTL, 本值))") - timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时(默认 3s)") - maxParallel := flag.Int("max-parallel", 3, "并发查询的上游数量上限") - stripECSFlag := flag.Bool("strip-ecs", true, "转发上游前去除 EDNS Client Subnet") - allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时允许 TCP 回退") - verbose := flag.Bool("v", false, "verbose 日志(包含源码位置)") + certFile := flag.String("cert", "server.crt", "TLS 证书文件路径") + keyFile := flag.String("key", "server.key", "TLS 私钥文件路径") + addr := flag.String("addr", ":853", "DoT 监听地址") + upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表") + cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL") + cacheSizeFlag := flag.Int("cache-size", defaultCacheSize, "LRU 缓存大小上限") + timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时") + maxParallel := flag.Int("max-parallel", 3, "并发上游数量") + stripECSFlag := flag.Bool("strip-ecs", true, "去除 ECS") + allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时 TCP 回退") + verbose := flag.Bool("v", false, "verbose 日志") flag.Parse() initLogger(*verbose) - // 加载 TLS 证书/私钥;用于 DoT(RFC 7858)监听 + var err error + cache, err = lru.New[string, *cacheEntry](*cacheSizeFlag) + if err != nil { + log.Fatalf("[fatal] failed to init LRU cache: %v", err) + } + cert, err := tls.LoadX509KeyPair(*certFile, *keyFile) if err != nil { log.Fatalf("[fatal] failed to load cert: %v", err) } - // 解析上游地址:支持“host”或“host:port”,缺省端口补 53。 var upstreams []string for _, s := range strings.Split(*upstreamStr, ",") { if t := strings.TrimSpace(s); t != "" { @@ -555,65 +541,44 @@ func main() { log.Fatal("[fatal] no upstream DNS servers provided") } - // 全局 UDP 客户端(不设置 Client.Timeout,改用 ExchangeContext 控制超时) - udpClient = &dns.Client{ - Net: "udp", - UDPSize: 4096, // 放大到 4K,减小 UDP 截断概率 - } + udpClient = &dns.Client{Net: "udp", UDPSize: 4096, SingleInflight: true} - // context 用于优雅退出(SIGINT/SIGTERM 收到后取消) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - // 启动缓存清理器(后台 goroutine) startCacheCleaner(ctx) - // 注册处理器:使用 ServeMux 将所有查询交给 handleDNS mux := dns.NewServeMux() - mux.HandleFunc(".", handleDNS( - upstreams, - *cacheTTLFlag, - *timeoutFlag, - *maxParallel, - *stripECSFlag, - *allowTCPFallback, - )) + mux.HandleFunc(".", handleDNS(upstreams, *cacheTTLFlag, *timeoutFlag, *maxParallel, *stripECSFlag, *allowTCPFallback)) - // 构造 DoT(tcp-tls)服务器:显式开启 TLS1.2/1.3,增加读写超时防止慢连接。 - // NextProtos: "dot"(可选,部分客户端用于 ALPN 检测) srv := &dns.Server{ Addr: *addr, Net: "tcp-tls", TLSConfig: &tls.Config{ Certificates: []tls.Certificate{cert}, - MinVersion: tls.VersionTLS12, // 允许 1.2/1.3(默认启用 1.3) - // 不设置 CipherSuites,交由 Go 自动选择(TLS1.3 有自身套件) - NextProtos: []string{"dot"}, // 可选:显式 ALPN + MinVersion: tls.VersionTLS12, + NextProtos: []string{"dot"}, }, Handler: mux, - ReadTimeout: 10 * time.Second, // 防止慢连接 + ReadTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second, } - // 捕获信号以便优雅退出(关闭监听、结束后台协程) stop := make(chan os.Signal, 1) signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) - // 异步启动服务,错误通过 errCh 返回 errCh := make(chan error, 1) go func() { log.Printf("🚀 starting DNS-over-TLS on %s", *addr) - log.Printf(" upstreams=%v | cache_max_ttl=%s | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v", - upstreams, cacheTTLFlag.String(), timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback) + log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v", + upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback) errCh <- srv.ListenAndServe() }() - // 等待退出信号或服务器错误 select { case sig := <-stop: log.Printf("[shutdown] caught signal: %s", sig) cancel() - // miekg/dns 提供 Shutdown();部分版本无 ShutdownContext,这里用 Shutdown() if err := srv.Shutdown(); err != nil { log.Printf("[shutdown] server shutdown error: %v", err) }