feat(cache): 引入 LRU 缓存并优化缓存清理与 TTL 处理

- 使用 github.com/hashicorp/golang-lru/v2 替代原生 sync.Map 实现 LRU 缓存
- 修复缓存读写过程中的并发安全问题,使用 RWMutex 保护共享状态
- 调整缓存键结构注释,明确支持 TTL 和 LRU 策略
- 优化负面缓存 TTL 计算逻辑,更准确识别 NODATA 场景
- 在缓存写入前统一剥离伪 RR(如 OPT、TSIG)
- 增加 cache-size 命令行参数,支持配置 LRU 缓存最大条目数
- 移除旧的缓存清理协程中不必要的全量遍历逻辑
- 更新日志输出内容,包含 cache-size 配置项
```
This commit is contained in:
2025-10-14 16:37:39 +08:00
parent bcd0914b2f
commit 916a7c8127
7 changed files with 430 additions and 213 deletions

32
Dockerfile Normal file
View File

@@ -0,0 +1,32 @@
# ---------- 构建阶段 ----------
FROM golang:1.25.2-alpine AS builder
WORKDIR /app
COPY . .
ENV CGO_ENABLED=0 GOOS=linux GOARCH=amd64
RUN go build -o dot main.go
# ---------- 运行阶段 ----------
FROM alpine:3.20
WORKDIR /app
# 只复制编译好的二进制,不再打包证书
COPY --from=builder /app/dot /app/dot
# 运行时定义可覆盖的环境变量(不在构建时生效)
ENV CERT_FILE=aixiao.me.cer
ENV KEY_FILE=aixiao.me.key
EXPOSE 853/tcp
# 启动命令,使用运行时传入的证书路径
ENTRYPOINT ["sh", "-c", "./dot \
-cert ${CERT_FILE} \
-key ${KEY_FILE} \
-addr :853 \
-upstream \"119.29.29.29:53,223.5.5.5:53,114.114.114.114:53\" \
-cache-ttl 300s \
-timeout 3s \
-max-parallel 3"]

121
README.md Normal file
View File

@@ -0,0 +1,121 @@
# 🧠 DNS-over-TLS Cache Proxy
一个基于 Go 的高性能 **DNS-over-TLS (DoT)** 缓存代理服务器。
支持多上游并发解析、智能缓存、隐私保护与优雅关闭。
轻量、无依赖、可直接部署。
## ✨ 特性
- 🔒 **加密传输** — 完全支持 DNS-over-TLS (RFC 7858)
-**多上游并发查询** — 类似“快乐眼球”机制,提升解析速度
- 🧠 **TTL 智能缓存** — 支持正向与负面缓存RFC 2308
- 🧹 **自动清理** — 定期清除过期缓存
- 🧩 **隐私保护** — 默认剥离 ECS (EDNS Client Subnet)
- 🪶 **轻量高效** — 单文件可执行,零外部依赖
## 📦 安装
### 🧰 源码构建
```bash
git clone https://git.aixiao.me/aixiao/dot.git
cd dot
go build -o dot main.go
```
### 🐳 Docker 构建
```bash
#构建、启动
bash build.sh build
bash build.sh run
#清理
bash build.sh stop
bash build.sh clean
```
## 🚀 启动服务
```bash
./dot \
-cert=server.crt \
-key=server.key \
-addr=":853" \
-upstream="8.8.8.8:53,1.1.1.1:53" \
-cache-ttl=120s \
-timeout=3s \
-max-parallel=2 \
-strip-ecs=true \
-tcp-fallback=true \
-v
```
输出示例:
```
🚀 starting DNS-over-TLS on :853
[req] A www.example.com. (id=40192 cd=false do=true from=127.0.0.1:58877)
[cache] MISS A www.example.com.
[answer] www.example.com. 300 IN A 93.184.216.34
```
## 🧩 配置参数
| 参数 | 默认值 | 说明 |
|------|---------|------|
| `--addr` | `:853` | 监听地址 |
| `--cert` | `server.crt` | TLS 证书路径 |
| `--key` | `server.key` | TLS 私钥路径 |
| `--upstream` | `8.8.8.8:53,1.1.1.1:53` | 上游 DNS 服务器 |
| `--cache-ttl` | `60s` | 最大缓存 TTL |
| `--timeout` | `3s` | 上游查询超时 |
| `--max-parallel` | `3` | 并发上游查询数 |
| `--strip-ecs` | `true` | 是否剥离 ECS 信息 |
| `--tcp-fallback` | `true` | 是否启用 TCP 回退 |
| `--v` | `false` | 详细日志模式 |
## 🧪 测试解析
使用 `kdig``dig` 进行测试:
```bash
kdig @127.0.0.1 +tls-ca +tls-host=dot.local www.example.com
```
## 📊 缓存机制
- **缓存键**`domain|type|class|DO|CD`
- **正向缓存**:取最小 TTL 与配置上限的较小值
- **负面缓存**:依据 SOA.MINIMUMRFC 2308
- **动态 TTL 续算**:返回时根据剩余时间更新 TTL
- **清理周期**:每 5 分钟清除过期项
## 🔐 安全特性
- 默认支持 **TLS 1.2 / 1.3**
- 剥离 **EDNS Client Subnet**
- 不缓存 OPT/TSIG 伪记录
- 独立缓存空间隔离 DO/CD 查询
## 🧭 路线图
- [ ] 支持 DoH (DNS-over-HTTPS)
- [ ] LRU 缓存上限控制
- [ ] 增加配置文件支持 (YAML/JSON)
- [ ] 集成 Docker Compose & CI/CD
## 👨‍💻 作者信息
**Email:** aixiao@aixiao.me
**License:** MIT
**Language:** Go 1.22+
**Dependency:** [github.com/miekg/dns](https://github.com/miekg/dns)

96
build.sh Normal file
View File

@@ -0,0 +1,96 @@
#!/usr/bin/env bash
#
# build.sh — Build & Run Helper for DNS-over-TLS Cache Proxy
# Author: niuyuling
# Email: aixiao@aixiao.me
# License: MIT
# -----------------------------------------------------------
# 用途:快速构建、运行和管理 dot 容器镜像。
# 用法:
# ./build.sh build 构建镜像
# ./build.sh run 启动容器(后台)
# ./build.sh logs 查看日志
# ./build.sh stop 停止容器
# ./build.sh clean 删除容器和镜像
# ./build.sh rebuild 重新构建镜像并启动
# -----------------------------------------------------------
set -e
IMAGE_NAME="dot"
CONTAINER_NAME="dot"
TAG="latest"
PORT="853"
CERT_FILE="jinllpay.com.cer"
KEY_FILE="jinllpay.com.key"
# ---------- 函数区 ----------
build() {
echo "🔨 Building Docker image: ${IMAGE_NAME}:${TAG} ..."
docker build -t "${IMAGE_NAME}:${TAG}" .
echo "✅ Build complete."
}
run() {
echo "🚀 Starting container ${CONTAINER_NAME}..."
# 确保旧容器不冲突
if docker ps -a --format '{{.Names}}' | grep -w "${CONTAINER_NAME}" >/dev/null 2>&1; then
echo "⚠️ Existing container found. Removing..."
docker rm -f "${CONTAINER_NAME}" >/dev/null 2>&1 || true
fi
docker run -d \
--name "${CONTAINER_NAME}" \
--memory=256m \
--memory-swap=384m \
--memory-reservation=128m \
-p ${PORT}:853 \
-e CERT_FILE="/app/${CERT_FILE}" \
-e KEY_FILE="/app/${KEY_FILE}" \
-v "$(pwd)/${CERT_FILE}:/app/${CERT_FILE}:ro" \
-v "$(pwd)/${KEY_FILE}:/app/${KEY_FILE}:ro" \
"${IMAGE_NAME}:${TAG}"
echo "✅ Container started on port ${PORT}."
}
logs() {
echo "📜 Showing logs..."
docker logs -f "${CONTAINER_NAME}"
}
stop() {
echo "🛑 Stopping container..."
docker stop "${CONTAINER_NAME}" >/dev/null 2>&1 || true
docker rm "${CONTAINER_NAME}" >/dev/null 2>&1 || true
echo "✅ Container stopped and removed."
}
clean() {
stop
echo "🧹 Removing image ${IMAGE_NAME}:${TAG}..."
docker rmi "${IMAGE_NAME}:${TAG}" >/dev/null 2>&1 || true
echo "✅ Cleanup complete."
}
rebuild() {
clean
build
run
}
# ---------- 主逻辑 ----------
case "$1" in
build) build ;;
run) run ;;
logs) logs ;;
stop) stop ;;
clean) clean ;;
rebuild) rebuild ;;
*)
echo "Usage: ./build.sh [build|run|logs|stop|clean|rebuild]"
exit 1
;;
esac

BIN
dot

Binary file not shown.

1
go.mod
View File

@@ -5,6 +5,7 @@ go 1.25.2
require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
github.com/miekg/dns v1.1.68 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/prometheus/client_golang v1.23.2 // indirect

2
go.sum
View File

@@ -2,6 +2,8 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=

391
main.go
View File

@@ -15,14 +15,13 @@ import (
"time"
"github.com/miekg/dns"
lru "github.com/hashicorp/golang-lru/v2"
)
/******************************************************************
* 日志初始化
******************************************************************/
// initLogger 根据 -v 开关设置日志格式verbose=true 时附加源码文件与行号。
// 说明Lmicroseconds 便于排查毫秒级时序问题。
func initLogger(verbose bool) {
flags := log.Ldate | log.Ltime | log.Lmicroseconds
if verbose {
@@ -32,24 +31,25 @@ func initLogger(verbose bool) {
}
/******************************************************************
* 缓存结构
* 缓存结构(支持 TTL + LRU
******************************************************************/
// cacheEntry 表示一条缓存项:保存上游响应的副本,以及过期时间。
// 注:将完整 *dns.Msg 存入缓存,便于原样复用 Answer/Ns/Extra。
type cacheEntry struct {
msg *dns.Msg // 上游完整响应(拷贝存储)
expireAt time.Time // 过期时间
msg *dns.Msg
expireAt time.Time
}
// cache 使用 sync.Map 作为并发安全的键值存储。
// 注:此实现为 TTL 驱动的简单缓存;如需上限/淘汰策略,可叠加 LRU。
var cache sync.Map
var (
cache *lru.Cache[string, *cacheEntry]
cacheMutex sync.RWMutex
)
const cacheCleanupInterval = 5 * time.Minute
const (
cacheCleanupInterval = 5 * time.Minute
defaultCacheSize = 10000 // 默认最大缓存条目数
)
// startCacheCleaner 定清理过期项,避免缓存无限增长。
// 使用 ctx 控制生命周期,与主服务一同退出。
// startCacheCleaner 定清理过期缓存(修复:在删除前二次校验)
func startCacheCleaner(ctx context.Context) {
go func() {
ticker := time.NewTicker(cacheCleanupInterval)
@@ -60,25 +60,35 @@ func startCacheCleaner(ctx context.Context) {
return
case <-ticker.C:
now := time.Now()
n := 0
cache.Range(func(k, v any) bool {
e := v.(*cacheEntry)
if now.After(e.expireAt) {
cache.Delete(k)
n++
var toDelete []string
cacheMutex.RLock()
for _, k := range cache.Keys() {
if v, ok := cache.Peek(k); ok && now.After(v.expireAt) {
toDelete = append(toDelete, k)
}
}
cacheMutex.RUnlock()
if len(toDelete) > 0 {
pruned := 0
cacheMutex.Lock()
for _, k := range toDelete {
if v, ok := cache.Peek(k); ok && now.After(v.expireAt) {
cache.Remove(k)
pruned++
}
}
cacheMutex.Unlock()
if pruned > 0 {
log.Printf("[cache] cleaned %d expired entries", pruned)
}
return true
})
if n > 0 {
log.Printf("[cache] cleaned %d expired entries", n)
}
}
}
}()
}
// 计算缓存键name + type + class + DO + CD。
// 说明:包含 DO/CD 可避免不同验证上下文污染同一键。
func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
var b strings.Builder
b.Grow(len(q.Name) + 32)
@@ -96,8 +106,6 @@ func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
return b.String()
}
// 识别伪 RROPT/TSIG这些记录的“TTL 字段”并非真实 TTL不可参与 TTL 计算或改写。
// OPT其 TTL 字段承载扩展 RCODE 与 DO 位等标志TSIG签名不应缓存或改写。
func isPseudo(rr dns.RR) bool {
switch rr.(type) {
case *dns.OPT, *dns.TSIG:
@@ -107,66 +115,93 @@ func isPseudo(rr dns.RR) bool {
}
}
// tryCacheRead 尝试读取缓存并回填“剩余 TTL”对 Answer/Ns/Extra 普通 RR 截断 TTL跳过伪 RR。
// 返回响应副本,保证对外不可变(防止调用方修改缓存内部对象)。
// 读取缓存修复Get 在写锁下;在锁外调整 TTL
func tryCacheRead(key string) (*dns.Msg, bool) {
v, ok := cache.Load(key)
if !ok {
return nil, false
}
e := v.(*cacheEntry)
now := time.Now()
cacheMutex.Lock()
e, ok := cache.Get(key) // Get 会更新 LRU必须在写锁下
if !ok {
cacheMutex.Unlock()
return nil, false
}
if now.After(e.expireAt) {
cache.Delete(key)
cache.Remove(key)
cacheMutex.Unlock()
return nil, false
}
// 拷贝副本,在锁外改 TTL减少临界区时间
out := e.msg.Copy()
remaining := uint32(e.expireAt.Sub(now).Seconds())
expireAt := e.expireAt
cacheMutex.Unlock()
remaining := uint32(expireAt.Sub(now).Seconds())
if remaining == 0 {
cache.Delete(key)
cacheMutex.Lock()
cache.Remove(key)
cacheMutex.Unlock()
return nil, false
}
// 回填剩余 TTL不增加只做上限截断
for i := range out.Answer {
if out.Answer[i].Header().Ttl > remaining {
out.Answer[i].Header().Ttl = remaining
}
}
for i := range out.Ns {
if out.Ns[i].Header().Ttl > remaining {
out.Ns[i].Header().Ttl = remaining
}
}
for i := range out.Extra {
if isPseudo(out.Extra[i]) {
continue
}
if out.Extra[i].Header().Ttl > remaining {
out.Extra[i].Header().Ttl = remaining
for _, sec := range [][]dns.RR{out.Answer, out.Ns, out.Extra} {
for _, rr := range sec {
if isPseudo(rr) {
continue
}
if rr.Header().Ttl > remaining {
rr.Header().Ttl = remaining
}
}
}
return out, true
}
// negativeTTL 依据 RFC 2308 计算负面缓存 TTL
// 对 NXDOMAIN 或 NODATANOERROR + Answer 为空)取 min(SOA.TTL, SOA.MINIMUM),再与配置上限取 min。
func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
// NXDOMAIN或 NOERROR 但 Answer 为空NODATA
if m.Rcode != dns.RcodeNameError && !(m.Rcode == dns.RcodeSuccess && len(m.Answer) == 0) {
return 0, false
// 计算负面 TTL
// hasAnswerForType 判断报文中是否存在回答“请求类型”的 RRset
func hasAnswerForType(m *dns.Msg, q dns.Question) bool {
for _, rr := range m.Answer {
h := rr.Header()
if h.Rrtype == q.Qtype && strings.EqualFold(h.Name, q.Name) {
return true
}
}
return false
}
// 计算负面 TTL修复正确识别 NODATA包括 CNAME 等场景)
func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
// NXDOMAIN肯定是负面
if m.Rcode != dns.RcodeNameError {
// 不是 NXDOMAIN则仅当 NOERROR 但没有“匹配 QTYPE 的答案”时才是 NODATA
if m.Rcode != dns.RcodeSuccess || len(m.Question) == 0 || hasAnswerForType(m, m.Question[0]) {
return 0, false
}
}
// 按 RFC 2308从 AuthorityNs优先取 SOA多数实现都只放在 Authority
var soa *dns.SOA
// SOA 可能出现在 Authority(Ns) 或 Additional(Extra)
for _, rr := range append(m.Ns, m.Extra...) {
for _, rr := range m.Ns {
if s, ok := rr.(*dns.SOA); ok {
soa = s
break
}
}
// 兼容性:偶尔也有人把 SOA 放 Extra不规范但为了兼容可以兜底看看
if soa == nil {
// 无 SOA 无法可靠计算负面 TTL此时不缓存或由上限兜底本实现选择不缓存
for _, rr := range m.Extra {
if s, ok := rr.(*dns.SOA); ok {
soa = s
break
}
}
}
if soa == nil {
// 建议:无 SOA 时不做负面缓存(返回 0,false
// 如你更希望兜底可改成return uint32(maxTTL.Seconds()), true
return 0, false
}
// 负面 TTL 取 min(SOA.MINIMUM, SOA 自身 TTL),再与配置上限比较
ttl := soa.Hdr.Ttl
if soa.Minttl < ttl {
ttl = soa.Minttl
@@ -178,8 +213,6 @@ func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
return ttl, ttl > 0
}
// minRRsetTTL 获取普通(正向)响应的最小 TTLAnswer/Ns/Extra 中的普通 RR跳过伪 RR。
// 用于决定缓存过期时间的上限(与配置上限再取 min
func minRRsetTTL(m *dns.Msg) (uint32, bool) {
minTTL := uint32(0)
hasTTL := false
@@ -198,8 +231,6 @@ func minRRsetTTL(m *dns.Msg) (uint32, bool) {
return minTTL, hasTTL
}
// stripPseudoExtras 从消息中剥离伪 RROPT/TSIG
// 用途:缓存前剥离,避免将传输层细节或签名内容写入缓存。
func stripPseudoExtras(m *dns.Msg) {
if len(m.Extra) == 0 {
return
@@ -214,58 +245,44 @@ func stripPseudoExtras(m *dns.Msg) {
m.Extra = out
}
// cacheWrite 写缓存:优先处理负面缓存;正面缓存取 min(应答中最小 TTL, 配置上限)。
// 写入前统一剥离伪 RR保证缓存与传输解耦。
// 写缓存
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
if in == nil {
return
}
// 仅缓存 NOERROR / NXDOMAIN其余不缓存如 SERVFAIL/REFUSED 等)
if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError {
return
}
// 负面缓存
if ttl, ok := negativeTTL(in, maxTTL); ok && ttl > 0 {
expire := time.Now().Add(time.Duration(ttl) * time.Second)
cp := in.Copy()
stripPseudoExtras(cp)
cache.Store(key, &cacheEntry{msg: cp, expireAt: expire})
return
}
// 正向缓存minTTL 与 maxTTL 取较小
minTTL, ok := minRRsetTTL(in)
if !ok {
// 没有 TTL 时可用上限兜底(也可选择不缓存,这里选择兜底)
if maxTTL > 0 {
expire := time.Now().Add(maxTTL)
cp := in.Copy()
stripPseudoExtras(cp)
cache.Store(key, &cacheEntry{msg: cp, expireAt: expire})
var ttl uint32
var ok bool
if ttl, ok = negativeTTL(in, maxTTL); !ok {
minTTL, has := minRRsetTTL(in)
if has {
cfgTTL := uint32(maxTTL.Seconds())
if cfgTTL > 0 && minTTL > cfgTTL {
minTTL = cfgTTL
}
ttl = minTTL
} else {
ttl = uint32(maxTTL.Seconds())
}
}
if ttl == 0 {
return
}
cfgTTL := uint32(maxTTL.Seconds())
finalTTL := minTTL
if cfgTTL > 0 && finalTTL > cfgTTL {
finalTTL = cfgTTL
}
if finalTTL == 0 {
return
}
expire := time.Now().Add(time.Duration(finalTTL) * time.Second)
expire := time.Now().Add(time.Duration(ttl) * time.Second)
cp := in.Copy()
stripPseudoExtras(cp)
cache.Store(key, &cacheEntry{msg: cp, expireAt: expire})
cacheMutex.Lock()
cache.Add(key, &cacheEntry{msg: cp, expireAt: expire})
cacheMutex.Unlock()
}
/******************************************************************
* 上游查询(带 context 取消、并发上限、UDP→TCP 回退)
* 上游查询
******************************************************************/
// 全局可复用 UDP 客户端Net=udpUDPSize 放大以承载更大的响应)
var udpClient *dns.Client
// shuffled 打乱上游列表,避免固定顺序导致单点拥塞或偏置。
func shuffled(xs []string) []string {
out := make([]string, len(xs))
copy(out, xs)
@@ -273,10 +290,22 @@ func shuffled(xs []string) []string {
return out
}
// queryUpstreamsLimited 并发向多个上游发起查询,返回首个有效结果。
// - timeout整个查询窗口的上限基于子 context
// - maxParallel同时在飞请求上限
// - allowTCPFallback若 UDP 截断TC 位)则回退 TCP 重试
// clampEDNSForUpstream 返回一个 msg 副本,把 EDNS UDP size 夹到给定大小
func clampEDNSForUpstream(in *dns.Msg, size uint16) *dns.Msg {
m := in.Copy()
o := m.IsEdns0()
if o == nil {
o = &dns.OPT{}
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
m.Extra = append(m.Extra, o)
}
if size > 0 {
o.SetUDPSize(size)
}
return m
}
func queryUpstreamsLimited(
ctx context.Context,
req *dns.Msg,
@@ -290,39 +319,29 @@ func queryUpstreamsLimited(
}
servers := shuffled(upstreams)
// 每次查询一个带超时的子 context拿到首个有效结果后 cancel取消其他请求。
cctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
type result struct {
msg *dns.Msg
}
ch := make(chan result, len(servers)) // 缓冲至多保存所有返回,避免阻塞
sem := make(chan struct{}, maxParallel) // 并发信号量,限制同时在飞请求数
type result struct{ msg *dns.Msg }
ch := make(chan result, len(servers))
sem := make(chan struct{}, maxParallel)
// execOne 对单个上游发起查询;对 req 使用 Copy() 防止在多 goroutine 下共享同一 *dns.Msg。
execOne := func(svr string) {
// UDP 查询(带 context使用 req.Copy() 防止并发读写
resp, _, err := udpClient.ExchangeContext(cctx, req.Copy(), svr)
upReq := clampEDNSForUpstream(req, 1232) // 或做成 flag
resp, _, err := udpClient.ExchangeContext(cctx, upReq, svr)
if err == nil && resp != nil && resp.Truncated && allowTCPFallback {
// TCP 回退(响应被截断)
log.Printf("[upstream] UDP truncated, retry TCP: %s", svr)
tcpClient := *udpClient
tcpClient.Net = "tcp"
resp, _, err = tcpClient.ExchangeContext(cctx, req.Copy(), svr)
}
if err != nil || resp == nil {
// 错误仅在未被上层取消时打印,避免超时后噪声
if err != nil {
if cctx.Err() == nil {
log.Printf("[upstream] %s error: %v", svr, err)
}
} else {
log.Printf("[upstream] %s nil response", svr)
if err != nil && cctx.Err() == nil {
log.Printf("[upstream] %s error: %v", svr, err)
}
return
}
// 过滤差 RCODE不参与竞速SERVFAIL / REFUSED / FORMERR
if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError {
return
}
@@ -332,7 +351,6 @@ func queryUpstreamsLimited(
}
}
// “快乐眼球”式启动:生产者不阻塞,获取配额在 goroutine 内部完成;这样可以更快进入接收循环并在首个成功后取消其余。
for _, s := range servers {
s := s
go func() {
@@ -346,7 +364,6 @@ func queryUpstreamsLimited(
}()
}
// 返回第一个非空结果,并 cancel 其他 goroutine。
for i := 0; i < len(servers); i++ {
select {
case r := <-ch:
@@ -363,11 +380,8 @@ func queryUpstreamsLimited(
}
/******************************************************************
* EDNS(0) / ECS 处理
* EDNS / 响应构造
******************************************************************/
// stripECS 从请求中去除 EDNS Client SubnetECS减少缓存污染并保护隐私。
// 注ECS 会导致上游按地理/网络分区返回不同答案,不适合集中缓存。
func stripECS(m *dns.Msg) {
if o := m.IsEdns0(); o != nil {
var kept []dns.EDNS0
@@ -380,7 +394,6 @@ func stripECS(m *dns.Msg) {
}
}
// getDOFlag 读取 DODNSSEC OK构成缓存键的一部分。
func getDOFlag(m *dns.Msg) bool {
if o := m.IsEdns0(); o != nil {
return o.Do()
@@ -388,67 +401,47 @@ func getDOFlag(m *dns.Msg) bool {
return false
}
/******************************************************************
* 响应构造:使用客户端请求头构造 reply复制上游内容
******************************************************************/
// writeReply 根据客户端请求构造响应:复制上游 Answer/Ns/非伪 Extra
// 并按客户端请求重建 OPTUDPSize/DO同时继承上游的扩展 RCODE 与 EDNS 版本;
// 可选透传上游的 EDEExtended DNS Errors以保留诊断信息。
func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) {
if upstream == nil {
dns.HandleFailed(w, req)
return
}
out := new(dns.Msg)
out.SetReply(req)
out.Authoritative = false
out.RecursionAvailable = upstream.RecursionAvailable
out.AuthenticatedData = upstream.AuthenticatedData
out.CheckingDisabled = req.CheckingDisabled // 反映客户端 CD 位
out.Rcode = upstream.Rcode // 主 RCODE低 4 位)
out.CheckingDisabled = req.CheckingDisabled
out.Rcode = upstream.Rcode
out.Answer = upstream.Answer
out.Ns = upstream.Ns
// 复制上游的非伪 RR 额外记录OPT/TSIG 不透传
extras := make([]dns.RR, 0, len(upstream.Extra))
var extras []dns.RR
for _, rr := range upstream.Extra {
if isPseudo(rr) {
continue
if !isPseudo(rr) {
extras = append(extras, rr)
}
extras = append(extras, rr)
}
// 基于客户端请求镜像 EDNSUDPSize + DO
if ro := req.IsEdns0(); ro != nil {
o := new(dns.OPT)
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
// 与客户端保持一致的 UDPSize / DO 位
o.SetUDPSize(ro.UDPSize())
if ro.Do() {
o.SetDo()
}
// 继承上游的扩展 RCODE 与 EDNS 版本(注意不同版本签名差异,这里显式转换)
if uo := upstream.IsEdns0(); uo != nil {
// 你当前库期望 uint16这里强转若你的库期望 uint8也可改成 uint8(...)
o.SetExtendedRcode(uint16(uo.ExtendedRcode()))
o.SetVersion(uint8(uo.Version()))
// 可选:透传只读的 EDE 诊断信息
for _, opt := range uo.Option {
if ede, ok := opt.(*dns.EDNS0_EDE); ok {
o.Option = append(o.Option, ede)
}
}
}
extras = append(extras, o)
}
out.Extra = extras
out.Compress = true
@@ -458,11 +451,8 @@ func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) {
}
/******************************************************************
* 处理器
* 处理器
******************************************************************/
// handleDNS 为每个请求执行:日志 → (可选)剥离 ECS → 缓存命中 → 上游查询 → 写缓存 → 回写。
// 注意:缓存键包含 DO/CD同时通过 tryCacheRead 回填剩余 TTL。
func handleDNS(
upstreams []string,
cacheMaxTTL, timeout time.Duration,
@@ -476,20 +466,14 @@ func handleDNS(
return
}
q := r.Question[0]
// 基本访问日志:类型/域名/ID/CD/DO/来源 IP:Port
log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)",
dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr())
// 可选:去除 ECS推荐
if stripECSBeforeForward {
stripECS(r)
}
// 缓存键(域名小写 + QTYPE/QCLASS + DO/CD
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled)
// 1) 缓存命中:命中即快速返回
if cached, ok := tryCacheRead(key); ok {
log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
writeReply(w, r, cached)
@@ -497,7 +481,6 @@ func handleDNS(
}
log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
// 2) 上游查询(带 context 取消 & TCP 可选回退);并发向多个上游竞速
ctx := context.Background()
resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback)
if resp == nil {
@@ -505,11 +488,7 @@ func handleDNS(
dns.HandleFailed(w, r)
return
}
// 3) 写缓存(负面/正面均处理;剥离伪 RR
cacheWrite(key, resp, cacheMaxTTL)
// 4) 回写给客户端,并打印 Answer 方便调试
for _, ans := range resp.Answer {
log.Printf("[answer] %s", ans.String())
}
@@ -517,31 +496,38 @@ func handleDNS(
}
}
/******************************************************************
* 主函数
******************************************************************/
func main() {
rand.Seed(time.Now().UnixNano()) // 随机源:用于上游列表打乱
rand.Seed(time.Now().UnixNano())
// 命令行参数:可根据部署场景调整默认值
certFile := flag.String("cert", "server.crt", "TLS 证书文件路径 (.cer/.crt)")
keyFile := flag.String("key", "server.key", "TLS 私钥文件路径 (.key)")
addr := flag.String("addr", ":853", "DoT 监听地址(默认 :853")
upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表(逗号分隔)")
cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL默认 60s实际取 min(上游最小TTL, 本值)")
timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时(默认 3s")
maxParallel := flag.Int("max-parallel", 3, "并发查询的上游数量上限")
stripECSFlag := flag.Bool("strip-ecs", true, "转发上游前去除 EDNS Client Subnet")
allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时允许 TCP 回退")
verbose := flag.Bool("v", false, "verbose 日志(包含源码位置)")
certFile := flag.String("cert", "server.crt", "TLS 证书文件路径")
keyFile := flag.String("key", "server.key", "TLS 私钥文件路径")
addr := flag.String("addr", ":853", "DoT 监听地址")
upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表")
cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL")
cacheSizeFlag := flag.Int("cache-size", defaultCacheSize, "LRU 缓存大小上限")
timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时")
maxParallel := flag.Int("max-parallel", 3, "并发上游数量")
stripECSFlag := flag.Bool("strip-ecs", true, "去除 ECS")
allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时 TCP 回退")
verbose := flag.Bool("v", false, "verbose 日志")
flag.Parse()
initLogger(*verbose)
// 加载 TLS 证书/私钥;用于 DoTRFC 7858监听
var err error
cache, err = lru.New[string, *cacheEntry](*cacheSizeFlag)
if err != nil {
log.Fatalf("[fatal] failed to init LRU cache: %v", err)
}
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
if err != nil {
log.Fatalf("[fatal] failed to load cert: %v", err)
}
// 解析上游地址支持“host”或“host:port”缺省端口补 53。
var upstreams []string
for _, s := range strings.Split(*upstreamStr, ",") {
if t := strings.TrimSpace(s); t != "" {
@@ -555,65 +541,44 @@ func main() {
log.Fatal("[fatal] no upstream DNS servers provided")
}
// 全局 UDP 客户端(不设置 Client.Timeout改用 ExchangeContext 控制超时)
udpClient = &dns.Client{
Net: "udp",
UDPSize: 4096, // 放大到 4K减小 UDP 截断概率
}
udpClient = &dns.Client{Net: "udp", UDPSize: 4096, SingleInflight: true}
// context 用于优雅退出SIGINT/SIGTERM 收到后取消)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// 启动缓存清理器(后台 goroutine
startCacheCleaner(ctx)
// 注册处理器:使用 ServeMux 将所有查询交给 handleDNS
mux := dns.NewServeMux()
mux.HandleFunc(".", handleDNS(
upstreams,
*cacheTTLFlag,
*timeoutFlag,
*maxParallel,
*stripECSFlag,
*allowTCPFallback,
))
mux.HandleFunc(".", handleDNS(upstreams, *cacheTTLFlag, *timeoutFlag, *maxParallel, *stripECSFlag, *allowTCPFallback))
// 构造 DoTtcp-tls服务器显式开启 TLS1.2/1.3,增加读写超时防止慢连接。
// NextProtos: "dot"(可选,部分客户端用于 ALPN 检测)
srv := &dns.Server{
Addr: *addr,
Net: "tcp-tls",
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12, // 允许 1.2/1.3(默认启用 1.3
// 不设置 CipherSuites交由 Go 自动选择TLS1.3 有自身套件)
NextProtos: []string{"dot"}, // 可选:显式 ALPN
MinVersion: tls.VersionTLS12,
NextProtos: []string{"dot"},
},
Handler: mux,
ReadTimeout: 10 * time.Second, // 防止慢连接
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
// 捕获信号以便优雅退出(关闭监听、结束后台协程)
stop := make(chan os.Signal, 1)
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
// 异步启动服务,错误通过 errCh 返回
errCh := make(chan error, 1)
go func() {
log.Printf("🚀 starting DNS-over-TLS on %s", *addr)
log.Printf(" upstreams=%v | cache_max_ttl=%s | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v",
upstreams, cacheTTLFlag.String(), timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback)
log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v",
upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback)
errCh <- srv.ListenAndServe()
}()
// 等待退出信号或服务器错误
select {
case sig := <-stop:
log.Printf("[shutdown] caught signal: %s", sig)
cancel()
// miekg/dns 提供 Shutdown();部分版本无 ShutdownContext这里用 Shutdown()
if err := srv.Shutdown(); err != nil {
log.Printf("[shutdown] server shutdown error: %v", err)
}