feat(cache): 实现并发安全的 DNS 响应缓存机制

新增基于 sync.Map 的缓存结构,支持 TTL 管理与定时清理。
完善缓存键生成逻辑,包含 DO/CD 标志避免上下文污染。
增强缓存读写处理:
- 自动跳过伪 RR(OPT/TSIG)防止干扰 TTL 计算
- 写入前剥离传输层细节提升通用性
- 支持负面缓存(NXDOMAIN/NODATA)并遵循 RFC 2308
- 回填剩余 TTL 并确保对外不可变

优化上游查询模块:
- 并发向多个上游竞速,支持 UDP 截断后 TCP 回退
- 过滤不良 RCODE(SERVFAIL/REFUSED/FORMERR)提升稳定性
- 使用信号量控制最大并发数,改善资源利用率
- 快速失败机制减少无效等待

其他改进:
- 完善日志记录,区分缓存命中/未命中及响应内容
- 显式构建客户端兼容的 EDNS0 选项
- 增加注释说明关键设计决策和行为边界
```
This commit is contained in:
2025-10-14 10:52:57 +08:00
parent d540b302f1
commit 05d3be286e
4 changed files with 182 additions and 52 deletions

198
main.go
View File

@@ -21,6 +21,8 @@ import (
* 日志初始化
******************************************************************/
// initLogger 根据 -v 开关设置日志格式verbose=true 时附加源码文件与行号。
// 说明Lmicroseconds 便于排查毫秒级时序问题。
func initLogger(verbose bool) {
flags := log.Ldate | log.Ltime | log.Lmicroseconds
if verbose {
@@ -33,15 +35,21 @@ func initLogger(verbose bool) {
* 缓存结构
******************************************************************/
// cacheEntry 表示一条缓存项:保存上游响应的副本,以及过期时间。
// 注:将完整 *dns.Msg 存入缓存,便于原样复用 Answer/Ns/Extra。
type cacheEntry struct {
msg *dns.Msg // 上游完整响应(拷贝存储)
expireAt time.Time // 过期时间
}
// cache 使用 sync.Map 作为并发安全的键值存储。
// 注:此实现为 TTL 驱动的简单缓存;如需上限/淘汰策略,可叠加 LRU。
var cache sync.Map
const cacheCleanupInterval = 5 * time.Minute
// startCacheCleaner 定时清理过期项,避免缓存无限增长。
// 使用 ctx 控制生命周期,与主服务一同退出。
func startCacheCleaner(ctx context.Context) {
go func() {
ticker := time.NewTicker(cacheCleanupInterval)
@@ -69,7 +77,8 @@ func startCacheCleaner(ctx context.Context) {
}()
}
// 计算缓存键name + type + class + DO + CD
// 计算缓存键name + type + class + DO + CD
// 说明:包含 DO/CD 可避免不同验证上下文污染同一键。
func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
var b strings.Builder
b.Grow(len(q.Name) + 32)
@@ -87,7 +96,19 @@ func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
return b.String()
}
// 命中缓存:回填剩余 TTL不超过记录自身 TTL
// 识别伪 RROPT/TSIG这些记录的“TTL 字段”并非真实 TTL不可参与 TTL 计算或改写。
// OPT其 TTL 字段承载扩展 RCODE 与 DO 位等标志TSIG签名不应缓存或改写。
func isPseudo(rr dns.RR) bool {
switch rr.(type) {
case *dns.OPT, *dns.TSIG:
return true
default:
return false
}
}
// tryCacheRead 尝试读取缓存并回填“剩余 TTL”对 Answer/Ns/Extra 普通 RR 截断 TTL跳过伪 RR。
// 返回响应副本,保证对外不可变(防止调用方修改缓存内部对象)。
func tryCacheRead(key string) (*dns.Msg, bool) {
v, ok := cache.Load(key)
if !ok {
@@ -105,6 +126,7 @@ func tryCacheRead(key string) (*dns.Msg, bool) {
cache.Delete(key)
return nil, false
}
// 回填剩余 TTL不增加只做上限截断
for i := range out.Answer {
if out.Answer[i].Header().Ttl > remaining {
out.Answer[i].Header().Ttl = remaining
@@ -116,6 +138,9 @@ func tryCacheRead(key string) (*dns.Msg, bool) {
}
}
for i := range out.Extra {
if isPseudo(out.Extra[i]) {
continue
}
if out.Extra[i].Header().Ttl > remaining {
out.Extra[i].Header().Ttl = remaining
}
@@ -123,13 +148,15 @@ func tryCacheRead(key string) (*dns.Msg, bool) {
return out, true
}
// 负面缓存 TTLRFC 2308NXDOMAIN 或 NODATA 使用 SOA.MINIMUM 与 SOA TTL 的较小者,再与配置上限取 min
// negativeTTL 依据 RFC 2308 计算负面缓存 TTL。
// 对 NXDOMAIN 或 NODATANOERROR + Answer 为空)取 min(SOA.TTL, SOA.MINIMUM),再与配置上限取 min。
func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
// NXDOMAIN或 NOERROR 但 Answer 为空NODATA
if m.Rcode != dns.RcodeNameError && !(m.Rcode == dns.RcodeSuccess && len(m.Answer) == 0) {
return 0, false
}
var soa *dns.SOA
// SOA 可能出现在 Authority(Ns) 或 Additional(Extra)
for _, rr := range append(m.Ns, m.Extra...) {
if s, ok := rr.(*dns.SOA); ok {
soa = s
@@ -137,6 +164,7 @@ func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
}
}
if soa == nil {
// 无 SOA 无法可靠计算负面 TTL此时不缓存或由上限兜底本实现选择不缓存
return 0, false
}
ttl := soa.Hdr.Ttl
@@ -150,12 +178,16 @@ func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
return ttl, ttl > 0
}
// 普通(正向)响应的最小 TTLAnswer/Ns/Extra
// minRRsetTTL 获取普通(正向)响应的最小 TTLAnswer/Ns/Extra 中的普通 RR跳过伪 RR。
// 用于决定缓存过期时间的上限(与配置上限再取 min
func minRRsetTTL(m *dns.Msg) (uint32, bool) {
minTTL := uint32(0)
hasTTL := false
for _, sec := range [][]dns.RR{m.Answer, m.Ns, m.Extra} {
for _, rr := range sec {
if isPseudo(rr) {
continue
}
ttl := rr.Header().Ttl
if !hasTTL || ttl < minTTL {
minTTL = ttl
@@ -166,28 +198,49 @@ func minRRsetTTL(m *dns.Msg) (uint32, bool) {
return minTTL, hasTTL
}
// 写缓存:先处理负面缓存,再处理正面缓存
// stripPseudoExtras 从消息中剥离伪 RROPT/TSIG
// 用途:缓存前剥离,避免将传输层细节或签名内容写入缓存。
func stripPseudoExtras(m *dns.Msg) {
if len(m.Extra) == 0 {
return
}
out := m.Extra[:0]
for _, rr := range m.Extra {
if isPseudo(rr) {
continue
}
out = append(out, rr)
}
m.Extra = out
}
// cacheWrite 写缓存:优先处理负面缓存;正面缓存取 min(应答中最小 TTL, 配置上限)。
// 写入前统一剥离伪 RR保证缓存与传输解耦。
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
if in == nil {
return
}
// 仅缓存 NOERROR / NXDOMAIN其余不缓存
// 仅缓存 NOERROR / NXDOMAIN其余不缓存(如 SERVFAIL/REFUSED 等)
if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError {
return
}
// 负面缓存
if ttl, ok := negativeTTL(in, maxTTL); ok && ttl > 0 {
expire := time.Now().Add(time.Duration(ttl) * time.Second)
cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire})
cp := in.Copy()
stripPseudoExtras(cp)
cache.Store(key, &cacheEntry{msg: cp, expireAt: expire})
return
}
// 正向缓存minTTL 与 maxTTL 取较小
minTTL, ok := minRRsetTTL(in)
if !ok {
// 没有 TTL 时可用上限兜底(也可选择不缓存)
// 没有 TTL 时可用上限兜底(也可选择不缓存,这里选择兜底
if maxTTL > 0 {
expire := time.Now().Add(maxTTL)
cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire})
cp := in.Copy()
stripPseudoExtras(cp)
cache.Store(key, &cacheEntry{msg: cp, expireAt: expire})
}
return
}
@@ -200,16 +253,19 @@ func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
return
}
expire := time.Now().Add(time.Duration(finalTTL) * time.Second)
cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire})
cp := in.Copy()
stripPseudoExtras(cp)
cache.Store(key, &cacheEntry{msg: cp, expireAt: expire})
}
/******************************************************************
* 上游查询(带 context 取消、并发上限、UDP→TCP 回退)
******************************************************************/
// 全局可复用 UDP 客户端
// 全局可复用 UDP 客户端Net=udpUDPSize 放大以承载更大的响应)
var udpClient *dns.Client
// shuffled 打乱上游列表,避免固定顺序导致单点拥塞或偏置。
func shuffled(xs []string) []string {
out := make([]string, len(xs))
copy(out, xs)
@@ -217,6 +273,10 @@ func shuffled(xs []string) []string {
return out
}
// queryUpstreamsLimited 并发向多个上游发起查询,返回首个有效结果。
// - timeout整个查询窗口的上限基于子 context
// - maxParallel同时在飞请求上限
// - allowTCPFallback若 UDP 截断TC 位)则回退 TCP 重试
func queryUpstreamsLimited(
ctx context.Context,
req *dns.Msg,
@@ -230,28 +290,29 @@ func queryUpstreamsLimited(
}
servers := shuffled(upstreams)
// 每次查询一个带超时的子 context拿到首个有效结果后取消
// 每次查询一个带超时的子 context拿到首个有效结果后 cancel取消其他请求
cctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
type result struct {
msg *dns.Msg
}
ch := make(chan result, len(servers))
sem := make(chan struct{}, maxParallel)
ch := make(chan result, len(servers)) // 缓冲至多保存所有返回,避免阻塞
sem := make(chan struct{}, maxParallel) // 并发信号量,限制同时在飞请求数
// execOne 对单个上游发起查询;对 req 使用 Copy() 防止在多 goroutine 下共享同一 *dns.Msg。
execOne := func(svr string) {
defer func() { <-sem }()
// UDP 查询(带 context
resp, _, err := udpClient.ExchangeContext(cctx, req, svr)
// UDP 查询(带 context使用 req.Copy() 防止并发读写
resp, _, err := udpClient.ExchangeContext(cctx, req.Copy(), svr)
if err == nil && resp != nil && resp.Truncated && allowTCPFallback {
// TCP 回退
// TCP 回退(响应被截断)
log.Printf("[upstream] UDP truncated, retry TCP: %s", svr)
tcpClient := *udpClient
tcpClient.Net = "tcp"
resp, _, err = tcpClient.ExchangeContext(cctx, req, svr)
resp, _, err = tcpClient.ExchangeContext(cctx, req.Copy(), svr)
}
if err != nil || resp == nil {
// 错误仅在未被上层取消时打印,避免超时后噪声
if err != nil {
if cctx.Err() == nil {
log.Printf("[upstream] %s error: %v", svr, err)
@@ -261,8 +322,8 @@ func queryUpstreamsLimited(
}
return
}
// 丢弃 SERVFAIL
if resp.Rcode == dns.RcodeServerFailure {
// 过滤差 RCODE不参与竞速SERVFAIL / REFUSED / FORMERR
if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError {
return
}
select {
@@ -271,12 +332,21 @@ func queryUpstreamsLimited(
}
}
// “快乐眼球”式启动:生产者不阻塞,获取配额在 goroutine 内部完成;这样可以更快进入接收循环并在首个成功后取消其余。
for _, s := range servers {
sem <- struct{}{}
go execOne(s)
s := s
go func() {
select {
case sem <- struct{}{}:
defer func() { <-sem }()
case <-cctx.Done():
return
}
execOne(s)
}()
}
// 返回第一个非空结果,并 cancel 其他 goroutine
// 返回第一个非空结果,并 cancel 其他 goroutine
for i := 0; i < len(servers); i++ {
select {
case r := <-ch:
@@ -296,7 +366,8 @@ func queryUpstreamsLimited(
* EDNS(0) / ECS 处理
******************************************************************/
// 去除 EDNS Client Subnet避免缓存污染与隐私泄露)
// stripECS 从请求中去除 EDNS Client SubnetECS减少缓存污染并保护隐私。
// 注ECS 会导致上游按地理/网络分区返回不同答案,不适合集中缓存。
func stripECS(m *dns.Msg) {
if o := m.IsEdns0(); o != nil {
var kept []dns.EDNS0
@@ -309,7 +380,7 @@ func stripECS(m *dns.Msg) {
}
}
// 取 DO/EDNS
// getDOFlag 读取 DODNSSEC OK构成缓存键的一部分。
func getDOFlag(m *dns.Msg) bool {
if o := m.IsEdns0(); o != nil {
return o.Do()
@@ -321,6 +392,8 @@ func getDOFlag(m *dns.Msg) bool {
* 响应构造:使用客户端请求头构造 reply复制上游内容
******************************************************************/
// writeReply 根据客户端请求构造响应骨架:复制上游的 Answer/Ns/非伪Extra
// 并按照**客户端请求**重建 OPTUDPSize/DO避免直接采纳上游的 OPT/TSIG。
func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) {
if upstream == nil {
dns.HandleFailed(w, req)
@@ -335,7 +408,28 @@ func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) {
out.Rcode = upstream.Rcode
out.Answer = upstream.Answer
out.Ns = upstream.Ns
out.Extra = upstream.Extra
// 复制上游的非伪 RR 额外记录OPT/TSIG 不透传
extras := make([]dns.RR, 0, len(upstream.Extra))
for _, rr := range upstream.Extra {
if isPseudo(rr) {
continue
}
extras = append(extras, rr)
}
// 基于客户端请求镜像 EDNSUDPSize + DO保持传输一致性
if ro := req.IsEdns0(); ro != nil {
o := new(dns.OPT)
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
o.SetUDPSize(ro.UDPSize())
if ro.Do() {
o.SetDo()
}
// 如需转发 NSID/COOKIE 等可在此附加;为减少复杂性与缓存污染,建议保守最小集合
extras = append(extras, o)
}
out.Extra = extras
out.Compress = true
if err := w.WriteMsg(out); err != nil {
log.Printf("[write] WriteMsg error: %v", err)
@@ -346,6 +440,8 @@ func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) {
* 处理器
******************************************************************/
// handleDNS 为每个请求执行:日志 → (可选)剥离 ECS → 缓存命中 → 上游查询 → 写缓存 → 回写。
// 注意:缓存键包含 DO/CD同时通过 tryCacheRead 回填剩余 TTL。
func handleDNS(
upstreams []string,
cacheMaxTTL, timeout time.Duration,
@@ -360,7 +456,7 @@ func handleDNS(
}
q := r.Question[0]
// 记录请求
// 基本访问日志:类型/域名/ID/CD/DO/来源 IP:Port
log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)",
dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr())
@@ -369,10 +465,10 @@ func handleDNS(
stripECS(r)
}
// 缓存键
// 缓存键(域名小写 + QTYPE/QCLASS + DO/CD
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled)
// 1) 缓存命中
// 1) 缓存命中:命中即快速返回
if cached, ok := tryCacheRead(key); ok {
log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
writeReply(w, r, cached)
@@ -380,7 +476,7 @@ func handleDNS(
}
log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
// 2) 上游查询(带 context 取消 & TCP 可选回退)
// 2) 上游查询(带 context 取消 & TCP 可选回退);并发向多个上游竞速
ctx := context.Background()
resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback)
if resp == nil {
@@ -389,10 +485,10 @@ func handleDNS(
return
}
// 3) 写缓存
// 3) 写缓存(负面/正面均处理;剥离伪 RR
cacheWrite(key, resp, cacheMaxTTL)
// 4) 回写
// 4) 回写给客户端,并打印 Answer 方便调试
for _, ans := range resp.Answer {
log.Printf("[answer] %s", ans.String())
}
@@ -400,14 +496,10 @@ func handleDNS(
}
}
/******************************************************************
* 主程序
******************************************************************/
func main() {
rand.Seed(time.Now().UnixNano())
rand.Seed(time.Now().UnixNano()) // 随机源:用于上游列表打乱
// 参数
// 命令行参数:可根据部署场景调整默认值
certFile := flag.String("cert", "server.crt", "TLS 证书文件路径 (.cer/.crt)")
keyFile := flag.String("key", "server.key", "TLS 私钥文件路径 (.key)")
addr := flag.String("addr", ":853", "DoT 监听地址(默认 :853")
@@ -422,13 +514,13 @@ func main() {
initLogger(*verbose)
// 证书
// 加载 TLS 证书/私钥;用于 DoTRFC 7858监听
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
if err != nil {
log.Fatalf("[fatal] failed to load cert: %v", err)
}
// 上游
// 解析上游地址支持“host”或“host:port”缺省端口补 53。
var upstreams []string
for _, s := range strings.Split(*upstreamStr, ",") {
if t := strings.TrimSpace(s); t != "" {
@@ -442,20 +534,20 @@ func main() {
log.Fatal("[fatal] no upstream DNS servers provided")
}
// 全局 UDP 客户端(UDPSize 放大;不设置 Timeout使用 ExchangeContext 控制)
// 全局 UDP 客户端(不设置 Client.Timeout用 ExchangeContext 控制超时
udpClient = &dns.Client{
Net: "udp",
UDPSize: 4096,
UDPSize: 4096, // 放大到 4K减小 UDP 截断概率
}
// context 用于优雅退出
// context 用于优雅退出SIGINT/SIGTERM 收到后取消)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// 启动缓存清理器
// 启动缓存清理器(后台 goroutine
startCacheCleaner(ctx)
// mux/handler
// 注册处理器:使用 ServeMux 将所有查询交给 handleDNS
mux := dns.NewServeMux()
mux.HandleFunc(".", handleDNS(
upstreams,
@@ -466,7 +558,8 @@ func main() {
*allowTCPFallback,
))
// DoT 服务(启用 TLS1.3;不手动设 CipherSuites 以使用 Go 默认安全套件)
// 构造 DoTtcp-tls服务器显式开启 TLS1.2/1.3,增加读写超时防止慢连接。
// NextProtos: "dot"(可选,部分客户端用于 ALPN 检测)
srv := &dns.Server{
Addr: *addr,
Net: "tcp-tls",
@@ -474,14 +567,18 @@ func main() {
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12, // 允许 1.2/1.3(默认启用 1.3
// 不设置 CipherSuites交由 Go 自动选择TLS1.3 有自身套件)
NextProtos: []string{"dot"}, // 可选:显式 ALPN
},
Handler: mux,
Handler: mux,
ReadTimeout: 10 * time.Second, // 防止慢连接
WriteTimeout: 10 * time.Second,
}
// 捕获信号优雅退出
// 捕获信号以便优雅退出(关闭监听、结束后台协程)
stop := make(chan os.Signal, 1)
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
// 异步启动服务,错误通过 errCh 返回
errCh := make(chan error, 1)
go func() {
log.Printf("🚀 starting DNS-over-TLS on %s", *addr)
@@ -490,11 +587,12 @@ func main() {
errCh <- srv.ListenAndServe()
}()
// 等待退出信号或服务器错误
select {
case sig := <-stop:
log.Printf("[shutdown] caught signal: %s", sig)
cancel()
// miekg/dns 提供 Shutdown();部分版本没有 ShutdownContext这里用 Shutdown()
// miekg/dns 提供 Shutdown();部分版本 ShutdownContext这里用 Shutdown()
if err := srv.Shutdown(); err != nil {
log.Printf("[shutdown] server shutdown error: %v", err)
}