Files
dot/main.go
aixiao 05d3be286e ```
feat(cache): 实现并发安全的 DNS 响应缓存机制

新增基于 sync.Map 的缓存结构,支持 TTL 管理与定时清理。
完善缓存键生成逻辑,包含 DO/CD 标志避免上下文污染。
增强缓存读写处理:
- 自动跳过伪 RR(OPT/TSIG)防止干扰 TTL 计算
- 写入前剥离传输层细节提升通用性
- 支持负面缓存(NXDOMAIN/NODATA)并遵循 RFC 2308
- 回填剩余 TTL 并确保对外不可变

优化上游查询模块:
- 并发向多个上游竞速,支持 UDP 截断后 TCP 回退
- 过滤不良 RCODE(SERVFAIL/REFUSED/FORMERR)提升稳定性
- 使用信号量控制最大并发数,改善资源利用率
- 快速失败机制减少无效等待

其他改进:
- 完善日志记录,区分缓存命中/未命中及响应内容
- 显式构建客户端兼容的 EDNS0 选项
- 增加注释说明关键设计决策和行为边界
```
2025-10-14 10:52:57 +08:00

607 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"crypto/tls"
"flag"
"fmt"
"log"
"math/rand"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"github.com/miekg/dns"
)
/******************************************************************
* 日志初始化
******************************************************************/
// initLogger 根据 -v 开关设置日志格式verbose=true 时附加源码文件与行号。
// 说明Lmicroseconds 便于排查毫秒级时序问题。
func initLogger(verbose bool) {
flags := log.Ldate | log.Ltime | log.Lmicroseconds
if verbose {
flags |= log.Lshortfile
}
log.SetFlags(flags)
}
/******************************************************************
* 缓存结构
******************************************************************/
// cacheEntry 表示一条缓存项:保存上游响应的副本,以及过期时间。
// 注:将完整 *dns.Msg 存入缓存,便于原样复用 Answer/Ns/Extra。
type cacheEntry struct {
msg *dns.Msg // 上游完整响应(拷贝存储)
expireAt time.Time // 过期时间
}
// cache 使用 sync.Map 作为并发安全的键值存储。
// 注:此实现为 TTL 驱动的简单缓存;如需上限/淘汰策略,可叠加 LRU。
var cache sync.Map
const cacheCleanupInterval = 5 * time.Minute
// startCacheCleaner 定时清理过期项,避免缓存无限增长。
// 使用 ctx 控制生命周期,与主服务一同退出。
func startCacheCleaner(ctx context.Context) {
go func() {
ticker := time.NewTicker(cacheCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
n := 0
cache.Range(func(k, v any) bool {
e := v.(*cacheEntry)
if now.After(e.expireAt) {
cache.Delete(k)
n++
}
return true
})
if n > 0 {
log.Printf("[cache] cleaned %d expired entries", n)
}
}
}
}()
}
// 计算缓存键name + type + class + DO + CD。
// 说明:包含 DO/CD 可避免不同验证上下文污染同一键。
func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
var b strings.Builder
b.Grow(len(q.Name) + 32)
b.WriteString(strings.ToLower(q.Name))
b.WriteString("|T=")
b.WriteString(dns.TypeToString[q.Qtype])
b.WriteString("|C=")
b.WriteString(dns.ClassToString[q.Qclass])
if do {
b.WriteString("|DO")
}
if cd {
b.WriteString("|CD")
}
return b.String()
}
// 识别伪 RROPT/TSIG这些记录的“TTL 字段”并非真实 TTL不可参与 TTL 计算或改写。
// OPT其 TTL 字段承载扩展 RCODE 与 DO 位等标志TSIG签名不应缓存或改写。
func isPseudo(rr dns.RR) bool {
switch rr.(type) {
case *dns.OPT, *dns.TSIG:
return true
default:
return false
}
}
// tryCacheRead 尝试读取缓存并回填“剩余 TTL”对 Answer/Ns/Extra 普通 RR 截断 TTL跳过伪 RR。
// 返回响应副本,保证对外不可变(防止调用方修改缓存内部对象)。
func tryCacheRead(key string) (*dns.Msg, bool) {
v, ok := cache.Load(key)
if !ok {
return nil, false
}
e := v.(*cacheEntry)
now := time.Now()
if now.After(e.expireAt) {
cache.Delete(key)
return nil, false
}
out := e.msg.Copy()
remaining := uint32(e.expireAt.Sub(now).Seconds())
if remaining == 0 {
cache.Delete(key)
return nil, false
}
// 回填剩余 TTL不增加只做上限截断
for i := range out.Answer {
if out.Answer[i].Header().Ttl > remaining {
out.Answer[i].Header().Ttl = remaining
}
}
for i := range out.Ns {
if out.Ns[i].Header().Ttl > remaining {
out.Ns[i].Header().Ttl = remaining
}
}
for i := range out.Extra {
if isPseudo(out.Extra[i]) {
continue
}
if out.Extra[i].Header().Ttl > remaining {
out.Extra[i].Header().Ttl = remaining
}
}
return out, true
}
// negativeTTL 依据 RFC 2308 计算负面缓存 TTL。
// 对 NXDOMAIN 或 NODATANOERROR + Answer 为空)取 min(SOA.TTL, SOA.MINIMUM),再与配置上限取 min。
func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
// NXDOMAIN或 NOERROR 但 Answer 为空NODATA
if m.Rcode != dns.RcodeNameError && !(m.Rcode == dns.RcodeSuccess && len(m.Answer) == 0) {
return 0, false
}
var soa *dns.SOA
// SOA 可能出现在 Authority(Ns) 或 Additional(Extra)
for _, rr := range append(m.Ns, m.Extra...) {
if s, ok := rr.(*dns.SOA); ok {
soa = s
break
}
}
if soa == nil {
// 无 SOA 无法可靠计算负面 TTL此时不缓存或由上限兜底本实现选择不缓存
return 0, false
}
ttl := soa.Hdr.Ttl
if soa.Minttl < ttl {
ttl = soa.Minttl
}
capTTL := uint32(maxTTL.Seconds())
if capTTL > 0 && ttl > capTTL {
ttl = capTTL
}
return ttl, ttl > 0
}
// minRRsetTTL 获取普通(正向)响应的最小 TTLAnswer/Ns/Extra 中的普通 RR跳过伪 RR。
// 用于决定缓存过期时间的上限(与配置上限再取 min
func minRRsetTTL(m *dns.Msg) (uint32, bool) {
minTTL := uint32(0)
hasTTL := false
for _, sec := range [][]dns.RR{m.Answer, m.Ns, m.Extra} {
for _, rr := range sec {
if isPseudo(rr) {
continue
}
ttl := rr.Header().Ttl
if !hasTTL || ttl < minTTL {
minTTL = ttl
hasTTL = true
}
}
}
return minTTL, hasTTL
}
// stripPseudoExtras 从消息中剥离伪 RROPT/TSIG
// 用途:缓存前剥离,避免将传输层细节或签名内容写入缓存。
func stripPseudoExtras(m *dns.Msg) {
if len(m.Extra) == 0 {
return
}
out := m.Extra[:0]
for _, rr := range m.Extra {
if isPseudo(rr) {
continue
}
out = append(out, rr)
}
m.Extra = out
}
// cacheWrite 写缓存:优先处理负面缓存;正面缓存取 min(应答中最小 TTL, 配置上限)。
// 写入前统一剥离伪 RR保证缓存与传输解耦。
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
if in == nil {
return
}
// 仅缓存 NOERROR / NXDOMAIN其余不缓存如 SERVFAIL/REFUSED 等)
if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError {
return
}
// 负面缓存
if ttl, ok := negativeTTL(in, maxTTL); ok && ttl > 0 {
expire := time.Now().Add(time.Duration(ttl) * time.Second)
cp := in.Copy()
stripPseudoExtras(cp)
cache.Store(key, &cacheEntry{msg: cp, expireAt: expire})
return
}
// 正向缓存minTTL 与 maxTTL 取较小
minTTL, ok := minRRsetTTL(in)
if !ok {
// 没有 TTL 时可用上限兜底(也可选择不缓存,这里选择兜底)
if maxTTL > 0 {
expire := time.Now().Add(maxTTL)
cp := in.Copy()
stripPseudoExtras(cp)
cache.Store(key, &cacheEntry{msg: cp, expireAt: expire})
}
return
}
cfgTTL := uint32(maxTTL.Seconds())
finalTTL := minTTL
if cfgTTL > 0 && finalTTL > cfgTTL {
finalTTL = cfgTTL
}
if finalTTL == 0 {
return
}
expire := time.Now().Add(time.Duration(finalTTL) * time.Second)
cp := in.Copy()
stripPseudoExtras(cp)
cache.Store(key, &cacheEntry{msg: cp, expireAt: expire})
}
/******************************************************************
* 上游查询(带 context 取消、并发上限、UDP→TCP 回退)
******************************************************************/
// 全局可复用 UDP 客户端Net=udpUDPSize 放大以承载更大的响应)
var udpClient *dns.Client
// shuffled 打乱上游列表,避免固定顺序导致单点拥塞或偏置。
func shuffled(xs []string) []string {
out := make([]string, len(xs))
copy(out, xs)
rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] })
return out
}
// queryUpstreamsLimited 并发向多个上游发起查询,返回首个有效结果。
// - timeout整个查询窗口的上限基于子 context
// - maxParallel同时在飞请求上限
// - allowTCPFallback若 UDP 截断TC 位)则回退 TCP 重试
func queryUpstreamsLimited(
ctx context.Context,
req *dns.Msg,
upstreams []string,
timeout time.Duration,
maxParallel int,
allowTCPFallback bool,
) *dns.Msg {
if maxParallel <= 0 {
maxParallel = 1
}
servers := shuffled(upstreams)
// 每次查询一个带超时的子 context拿到首个有效结果后 cancel取消其他请求。
cctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
type result struct {
msg *dns.Msg
}
ch := make(chan result, len(servers)) // 缓冲至多保存所有返回,避免阻塞
sem := make(chan struct{}, maxParallel) // 并发信号量,限制同时在飞请求数
// execOne 对单个上游发起查询;对 req 使用 Copy() 防止在多 goroutine 下共享同一 *dns.Msg。
execOne := func(svr string) {
// UDP 查询(带 context使用 req.Copy() 防止并发读写
resp, _, err := udpClient.ExchangeContext(cctx, req.Copy(), svr)
if err == nil && resp != nil && resp.Truncated && allowTCPFallback {
// TCP 回退(响应被截断)
log.Printf("[upstream] UDP truncated, retry TCP: %s", svr)
tcpClient := *udpClient
tcpClient.Net = "tcp"
resp, _, err = tcpClient.ExchangeContext(cctx, req.Copy(), svr)
}
if err != nil || resp == nil {
// 错误仅在未被上层取消时打印,避免超时后噪声
if err != nil {
if cctx.Err() == nil {
log.Printf("[upstream] %s error: %v", svr, err)
}
} else {
log.Printf("[upstream] %s nil response", svr)
}
return
}
// 过滤差 RCODE不参与竞速SERVFAIL / REFUSED / FORMERR
if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError {
return
}
select {
case ch <- result{msg: resp}:
case <-cctx.Done():
}
}
// “快乐眼球”式启动:生产者不阻塞,获取配额在 goroutine 内部完成;这样可以更快进入接收循环并在首个成功后取消其余。
for _, s := range servers {
s := s
go func() {
select {
case sem <- struct{}{}:
defer func() { <-sem }()
case <-cctx.Done():
return
}
execOne(s)
}()
}
// 返回第一个非空结果,并 cancel 其他 goroutine。
for i := 0; i < len(servers); i++ {
select {
case r := <-ch:
if r.msg != nil {
cancel()
return r.msg
}
case <-cctx.Done():
log.Printf("[upstream] timeout after %v", timeout)
return nil
}
}
return nil
}
/******************************************************************
* EDNS(0) / ECS 处理
******************************************************************/
// stripECS 从请求中去除 EDNS Client SubnetECS减少缓存污染并保护隐私。
// 注ECS 会导致上游按地理/网络分区返回不同答案,不适合集中缓存。
func stripECS(m *dns.Msg) {
if o := m.IsEdns0(); o != nil {
var kept []dns.EDNS0
for _, e := range o.Option {
if _, isECS := e.(*dns.EDNS0_SUBNET); !isECS {
kept = append(kept, e)
}
}
o.Option = kept
}
}
// getDOFlag 读取 DODNSSEC OK构成缓存键的一部分。
func getDOFlag(m *dns.Msg) bool {
if o := m.IsEdns0(); o != nil {
return o.Do()
}
return false
}
/******************************************************************
* 响应构造:使用客户端请求头构造 reply复制上游内容
******************************************************************/
// writeReply 根据客户端请求构造响应骨架:复制上游的 Answer/Ns/非伪Extra
// 并按照**客户端请求**重建 OPTUDPSize/DO避免直接采纳上游的 OPT/TSIG。
func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) {
if upstream == nil {
dns.HandleFailed(w, req)
return
}
out := new(dns.Msg)
out.SetReply(req)
out.Authoritative = false
out.RecursionAvailable = upstream.RecursionAvailable
out.AuthenticatedData = upstream.AuthenticatedData
out.CheckingDisabled = req.CheckingDisabled // 反映客户端 CD 位
out.Rcode = upstream.Rcode
out.Answer = upstream.Answer
out.Ns = upstream.Ns
// 复制上游的非伪 RR 额外记录OPT/TSIG 不透传
extras := make([]dns.RR, 0, len(upstream.Extra))
for _, rr := range upstream.Extra {
if isPseudo(rr) {
continue
}
extras = append(extras, rr)
}
// 基于客户端请求镜像 EDNSUDPSize + DO保持传输一致性
if ro := req.IsEdns0(); ro != nil {
o := new(dns.OPT)
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
o.SetUDPSize(ro.UDPSize())
if ro.Do() {
o.SetDo()
}
// 如需转发 NSID/COOKIE 等可在此附加;为减少复杂性与缓存污染,建议保守最小集合
extras = append(extras, o)
}
out.Extra = extras
out.Compress = true
if err := w.WriteMsg(out); err != nil {
log.Printf("[write] WriteMsg error: %v", err)
}
}
/******************************************************************
* 处理器
******************************************************************/
// handleDNS 为每个请求执行:日志 → (可选)剥离 ECS → 缓存命中 → 上游查询 → 写缓存 → 回写。
// 注意:缓存键包含 DO/CD同时通过 tryCacheRead 回填剩余 TTL。
func handleDNS(
upstreams []string,
cacheMaxTTL, timeout time.Duration,
maxParallel int,
stripECSBeforeForward bool,
allowTCPFallback bool,
) dns.HandlerFunc {
return func(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
dns.HandleFailed(w, r)
return
}
q := r.Question[0]
// 基本访问日志:类型/域名/ID/CD/DO/来源 IP:Port
log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)",
dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr())
// 可选:去除 ECS推荐
if stripECSBeforeForward {
stripECS(r)
}
// 缓存键(域名小写 + QTYPE/QCLASS + DO/CD
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled)
// 1) 缓存命中:命中即快速返回
if cached, ok := tryCacheRead(key); ok {
log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
writeReply(w, r, cached)
return
}
log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
// 2) 上游查询(带 context 取消 & TCP 可选回退);并发向多个上游竞速
ctx := context.Background()
resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback)
if resp == nil {
log.Printf("[error] all upstreams failed for %s", q.Name)
dns.HandleFailed(w, r)
return
}
// 3) 写缓存(负面/正面均处理;剥离伪 RR
cacheWrite(key, resp, cacheMaxTTL)
// 4) 回写给客户端,并打印 Answer 方便调试
for _, ans := range resp.Answer {
log.Printf("[answer] %s", ans.String())
}
writeReply(w, r, resp)
}
}
func main() {
rand.Seed(time.Now().UnixNano()) // 随机源:用于上游列表打乱
// 命令行参数:可根据部署场景调整默认值
certFile := flag.String("cert", "server.crt", "TLS 证书文件路径 (.cer/.crt)")
keyFile := flag.String("key", "server.key", "TLS 私钥文件路径 (.key)")
addr := flag.String("addr", ":853", "DoT 监听地址(默认 :853")
upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表(逗号分隔)")
cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL默认 60s实际取 min(上游最小TTL, 本值)")
timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时(默认 3s")
maxParallel := flag.Int("max-parallel", 3, "并发查询的上游数量上限")
stripECSFlag := flag.Bool("strip-ecs", true, "转发上游前去除 EDNS Client Subnet")
allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时允许 TCP 回退")
verbose := flag.Bool("v", false, "verbose 日志(包含源码位置)")
flag.Parse()
initLogger(*verbose)
// 加载 TLS 证书/私钥;用于 DoTRFC 7858监听
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
if err != nil {
log.Fatalf("[fatal] failed to load cert: %v", err)
}
// 解析上游地址支持“host”或“host:port”缺省端口补 53。
var upstreams []string
for _, s := range strings.Split(*upstreamStr, ",") {
if t := strings.TrimSpace(s); t != "" {
if !strings.Contains(t, ":") {
t = fmt.Sprintf("%s:53", t)
}
upstreams = append(upstreams, t)
}
}
if len(upstreams) == 0 {
log.Fatal("[fatal] no upstream DNS servers provided")
}
// 全局 UDP 客户端(不设置 Client.Timeout改用 ExchangeContext 控制超时)
udpClient = &dns.Client{
Net: "udp",
UDPSize: 4096, // 放大到 4K减小 UDP 截断概率
}
// context 用于优雅退出SIGINT/SIGTERM 收到后取消)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// 启动缓存清理器(后台 goroutine
startCacheCleaner(ctx)
// 注册处理器:使用 ServeMux 将所有查询交给 handleDNS
mux := dns.NewServeMux()
mux.HandleFunc(".", handleDNS(
upstreams,
*cacheTTLFlag,
*timeoutFlag,
*maxParallel,
*stripECSFlag,
*allowTCPFallback,
))
// 构造 DoTtcp-tls服务器显式开启 TLS1.2/1.3,增加读写超时防止慢连接。
// NextProtos: "dot"(可选,部分客户端用于 ALPN 检测)
srv := &dns.Server{
Addr: *addr,
Net: "tcp-tls",
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12, // 允许 1.2/1.3(默认启用 1.3
// 不设置 CipherSuites交由 Go 自动选择TLS1.3 有自身套件)
NextProtos: []string{"dot"}, // 可选:显式 ALPN
},
Handler: mux,
ReadTimeout: 10 * time.Second, // 防止慢连接
WriteTimeout: 10 * time.Second,
}
// 捕获信号以便优雅退出(关闭监听、结束后台协程)
stop := make(chan os.Signal, 1)
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
// 异步启动服务,错误通过 errCh 返回
errCh := make(chan error, 1)
go func() {
log.Printf("🚀 starting DNS-over-TLS on %s", *addr)
log.Printf(" upstreams=%v | cache_max_ttl=%s | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v",
upstreams, cacheTTLFlag.String(), timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback)
errCh <- srv.ListenAndServe()
}()
// 等待退出信号或服务器错误
select {
case sig := <-stop:
log.Printf("[shutdown] caught signal: %s", sig)
cancel()
// miekg/dns 提供 Shutdown();部分版本无 ShutdownContext这里用 Shutdown()
if err := srv.Shutdown(); err != nil {
log.Printf("[shutdown] server shutdown error: %v", err)
}
case err := <-errCh:
if err != nil {
log.Fatalf("[fatal] server error: %v", err)
}
}
log.Println("[bye] server stopped.")
}