package main import ( "context" "crypto/tls" "flag" "fmt" "log" "math/rand" "os" "os/signal" "strings" "sync" "syscall" "time" "github.com/miekg/dns" ) /****************************************************************** * 日志初始化 ******************************************************************/ // initLogger 根据 -v 开关设置日志格式;verbose=true 时附加源码文件与行号。 // 说明:Lmicroseconds 便于排查毫秒级时序问题。 func initLogger(verbose bool) { flags := log.Ldate | log.Ltime | log.Lmicroseconds if verbose { flags |= log.Lshortfile } log.SetFlags(flags) } /****************************************************************** * 缓存结构 ******************************************************************/ // cacheEntry 表示一条缓存项:保存上游响应的副本,以及过期时间。 // 注:将完整 *dns.Msg 存入缓存,便于原样复用 Answer/Ns/Extra。 type cacheEntry struct { msg *dns.Msg // 上游完整响应(拷贝存储) expireAt time.Time // 过期时间 } // cache 使用 sync.Map 作为并发安全的键值存储。 // 注:此实现为 TTL 驱动的简单缓存;如需上限/淘汰策略,可叠加 LRU。 var cache sync.Map const cacheCleanupInterval = 5 * time.Minute // startCacheCleaner 定时清理过期项,避免缓存无限增长。 // 使用 ctx 控制生命周期,与主服务一同退出。 func startCacheCleaner(ctx context.Context) { go func() { ticker := time.NewTicker(cacheCleanupInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: now := time.Now() n := 0 cache.Range(func(k, v any) bool { e := v.(*cacheEntry) if now.After(e.expireAt) { cache.Delete(k) n++ } return true }) if n > 0 { log.Printf("[cache] cleaned %d expired entries", n) } } } }() } // 计算缓存键:name + type + class + DO + CD。 // 说明:包含 DO/CD 可避免不同验证上下文污染同一键。 func cacheKeyFromMsg(q dns.Question, do, cd bool) string { var b strings.Builder b.Grow(len(q.Name) + 32) b.WriteString(strings.ToLower(q.Name)) b.WriteString("|T=") b.WriteString(dns.TypeToString[q.Qtype]) b.WriteString("|C=") b.WriteString(dns.ClassToString[q.Qclass]) if do { b.WriteString("|DO") } if cd { b.WriteString("|CD") } return b.String() } // 识别伪 RR(OPT/TSIG);这些记录的“TTL 字段”并非真实 TTL,不可参与 TTL 计算或改写。 // OPT:其 TTL 字段承载扩展 RCODE 与 DO 位等标志;TSIG:签名,不应缓存或改写。 func isPseudo(rr dns.RR) bool { switch rr.(type) { case *dns.OPT, *dns.TSIG: return true default: return false } } // tryCacheRead 尝试读取缓存并回填“剩余 TTL”;对 Answer/Ns/Extra 普通 RR 截断 TTL,跳过伪 RR。 // 返回响应副本,保证对外不可变(防止调用方修改缓存内部对象)。 func tryCacheRead(key string) (*dns.Msg, bool) { v, ok := cache.Load(key) if !ok { return nil, false } e := v.(*cacheEntry) now := time.Now() if now.After(e.expireAt) { cache.Delete(key) return nil, false } out := e.msg.Copy() remaining := uint32(e.expireAt.Sub(now).Seconds()) if remaining == 0 { cache.Delete(key) return nil, false } // 回填剩余 TTL(不增加,只做上限截断) for i := range out.Answer { if out.Answer[i].Header().Ttl > remaining { out.Answer[i].Header().Ttl = remaining } } for i := range out.Ns { if out.Ns[i].Header().Ttl > remaining { out.Ns[i].Header().Ttl = remaining } } for i := range out.Extra { if isPseudo(out.Extra[i]) { continue } if out.Extra[i].Header().Ttl > remaining { out.Extra[i].Header().Ttl = remaining } } return out, true } // negativeTTL 依据 RFC 2308 计算负面缓存 TTL。 // 对 NXDOMAIN 或 NODATA(NOERROR + Answer 为空)取 min(SOA.TTL, SOA.MINIMUM),再与配置上限取 min。 func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) { // NXDOMAIN,或 NOERROR 但 Answer 为空(NODATA) if m.Rcode != dns.RcodeNameError && !(m.Rcode == dns.RcodeSuccess && len(m.Answer) == 0) { return 0, false } var soa *dns.SOA // SOA 可能出现在 Authority(Ns) 或 Additional(Extra) for _, rr := range append(m.Ns, m.Extra...) { if s, ok := rr.(*dns.SOA); ok { soa = s break } } if soa == nil { // 无 SOA 无法可靠计算负面 TTL;此时不缓存或由上限兜底(本实现选择不缓存) return 0, false } ttl := soa.Hdr.Ttl if soa.Minttl < ttl { ttl = soa.Minttl } capTTL := uint32(maxTTL.Seconds()) if capTTL > 0 && ttl > capTTL { ttl = capTTL } return ttl, ttl > 0 } // minRRsetTTL 获取普通(正向)响应的最小 TTL(Answer/Ns/Extra 中的普通 RR),跳过伪 RR。 // 用于决定缓存过期时间的上限(与配置上限再取 min)。 func minRRsetTTL(m *dns.Msg) (uint32, bool) { minTTL := uint32(0) hasTTL := false for _, sec := range [][]dns.RR{m.Answer, m.Ns, m.Extra} { for _, rr := range sec { if isPseudo(rr) { continue } ttl := rr.Header().Ttl if !hasTTL || ttl < minTTL { minTTL = ttl hasTTL = true } } } return minTTL, hasTTL } // stripPseudoExtras 从消息中剥离伪 RR(OPT/TSIG)。 // 用途:缓存前剥离,避免将传输层细节或签名内容写入缓存。 func stripPseudoExtras(m *dns.Msg) { if len(m.Extra) == 0 { return } out := m.Extra[:0] for _, rr := range m.Extra { if isPseudo(rr) { continue } out = append(out, rr) } m.Extra = out } // cacheWrite 写缓存:优先处理负面缓存;正面缓存取 min(应答中最小 TTL, 配置上限)。 // 写入前统一剥离伪 RR,保证缓存与传输解耦。 func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) { if in == nil { return } // 仅缓存 NOERROR / NXDOMAIN,其余不缓存(如 SERVFAIL/REFUSED 等) if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError { return } // 负面缓存 if ttl, ok := negativeTTL(in, maxTTL); ok && ttl > 0 { expire := time.Now().Add(time.Duration(ttl) * time.Second) cp := in.Copy() stripPseudoExtras(cp) cache.Store(key, &cacheEntry{msg: cp, expireAt: expire}) return } // 正向缓存:minTTL 与 maxTTL 取较小 minTTL, ok := minRRsetTTL(in) if !ok { // 没有 TTL 时可用上限兜底(也可选择不缓存,这里选择兜底) if maxTTL > 0 { expire := time.Now().Add(maxTTL) cp := in.Copy() stripPseudoExtras(cp) cache.Store(key, &cacheEntry{msg: cp, expireAt: expire}) } return } cfgTTL := uint32(maxTTL.Seconds()) finalTTL := minTTL if cfgTTL > 0 && finalTTL > cfgTTL { finalTTL = cfgTTL } if finalTTL == 0 { return } expire := time.Now().Add(time.Duration(finalTTL) * time.Second) cp := in.Copy() stripPseudoExtras(cp) cache.Store(key, &cacheEntry{msg: cp, expireAt: expire}) } /****************************************************************** * 上游查询(带 context 取消、并发上限、UDP→TCP 回退) ******************************************************************/ // 全局可复用 UDP 客户端(Net=udp,UDPSize 放大以承载更大的响应) var udpClient *dns.Client // shuffled 打乱上游列表,避免固定顺序导致单点拥塞或偏置。 func shuffled(xs []string) []string { out := make([]string, len(xs)) copy(out, xs) rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] }) return out } // queryUpstreamsLimited 并发向多个上游发起查询,返回首个有效结果。 // - timeout:整个查询窗口的上限(基于子 context) // - maxParallel:同时在飞请求上限 // - allowTCPFallback:若 UDP 截断(TC 位)则回退 TCP 重试 func queryUpstreamsLimited( ctx context.Context, req *dns.Msg, upstreams []string, timeout time.Duration, maxParallel int, allowTCPFallback bool, ) *dns.Msg { if maxParallel <= 0 { maxParallel = 1 } servers := shuffled(upstreams) // 每次查询一个带超时的子 context;拿到首个有效结果后 cancel,取消其他请求。 cctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() type result struct { msg *dns.Msg } ch := make(chan result, len(servers)) // 缓冲至多保存所有返回,避免阻塞 sem := make(chan struct{}, maxParallel) // 并发信号量,限制同时在飞请求数 // execOne 对单个上游发起查询;对 req 使用 Copy() 防止在多 goroutine 下共享同一 *dns.Msg。 execOne := func(svr string) { // UDP 查询(带 context);使用 req.Copy() 防止并发读写 resp, _, err := udpClient.ExchangeContext(cctx, req.Copy(), svr) if err == nil && resp != nil && resp.Truncated && allowTCPFallback { // TCP 回退(响应被截断) log.Printf("[upstream] UDP truncated, retry TCP: %s", svr) tcpClient := *udpClient tcpClient.Net = "tcp" resp, _, err = tcpClient.ExchangeContext(cctx, req.Copy(), svr) } if err != nil || resp == nil { // 错误仅在未被上层取消时打印,避免超时后噪声 if err != nil { if cctx.Err() == nil { log.Printf("[upstream] %s error: %v", svr, err) } } else { log.Printf("[upstream] %s nil response", svr) } return } // 过滤差 RCODE(不参与竞速):SERVFAIL / REFUSED / FORMERR if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError { return } select { case ch <- result{msg: resp}: case <-cctx.Done(): } } // “快乐眼球”式启动:生产者不阻塞,获取配额在 goroutine 内部完成;这样可以更快进入接收循环并在首个成功后取消其余。 for _, s := range servers { s := s go func() { select { case sem <- struct{}{}: defer func() { <-sem }() case <-cctx.Done(): return } execOne(s) }() } // 返回第一个非空结果,并 cancel 其他 goroutine。 for i := 0; i < len(servers); i++ { select { case r := <-ch: if r.msg != nil { cancel() return r.msg } case <-cctx.Done(): log.Printf("[upstream] timeout after %v", timeout) return nil } } return nil } /****************************************************************** * EDNS(0) / ECS 处理 ******************************************************************/ // stripECS 从请求中去除 EDNS Client Subnet(ECS),减少缓存污染并保护隐私。 // 注:ECS 会导致上游按地理/网络分区返回不同答案,不适合集中缓存。 func stripECS(m *dns.Msg) { if o := m.IsEdns0(); o != nil { var kept []dns.EDNS0 for _, e := range o.Option { if _, isECS := e.(*dns.EDNS0_SUBNET); !isECS { kept = append(kept, e) } } o.Option = kept } } // getDOFlag 读取 DO(DNSSEC OK)位,构成缓存键的一部分。 func getDOFlag(m *dns.Msg) bool { if o := m.IsEdns0(); o != nil { return o.Do() } return false } /****************************************************************** * 响应构造:使用客户端请求头构造 reply,复制上游内容 ******************************************************************/ // writeReply 根据客户端请求构造响应:复制上游 Answer/Ns/非伪 Extra, // 并按客户端请求重建 OPT(UDPSize/DO),同时继承上游的扩展 RCODE 与 EDNS 版本; // 可选透传上游的 EDE(Extended DNS Errors)以保留诊断信息。 func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) { if upstream == nil { dns.HandleFailed(w, req) return } out := new(dns.Msg) out.SetReply(req) out.Authoritative = false out.RecursionAvailable = upstream.RecursionAvailable out.AuthenticatedData = upstream.AuthenticatedData out.CheckingDisabled = req.CheckingDisabled // 反映客户端 CD 位 out.Rcode = upstream.Rcode // 主 RCODE(低 4 位) out.Answer = upstream.Answer out.Ns = upstream.Ns // 复制上游的非伪 RR 额外记录;OPT/TSIG 不透传 extras := make([]dns.RR, 0, len(upstream.Extra)) for _, rr := range upstream.Extra { if isPseudo(rr) { continue } extras = append(extras, rr) } // 基于客户端请求镜像 EDNS(UDPSize + DO) if ro := req.IsEdns0(); ro != nil { o := new(dns.OPT) o.Hdr.Name = "." o.Hdr.Rrtype = dns.TypeOPT // 与客户端保持一致的 UDPSize / DO 位 o.SetUDPSize(ro.UDPSize()) if ro.Do() { o.SetDo() } // 继承上游的扩展 RCODE 与 EDNS 版本(注意不同版本签名差异,这里显式转换) if uo := upstream.IsEdns0(); uo != nil { // 你当前库期望 uint16,这里强转;若你的库期望 uint8,也可改成 uint8(...) o.SetExtendedRcode(uint16(uo.ExtendedRcode())) o.SetVersion(uint8(uo.Version())) // 可选:透传只读的 EDE 诊断信息 for _, opt := range uo.Option { if ede, ok := opt.(*dns.EDNS0_EDE); ok { o.Option = append(o.Option, ede) } } } extras = append(extras, o) } out.Extra = extras out.Compress = true if err := w.WriteMsg(out); err != nil { log.Printf("[write] WriteMsg error: %v", err) } } /****************************************************************** * 处理器 ******************************************************************/ // handleDNS 为每个请求执行:日志 → (可选)剥离 ECS → 缓存命中 → 上游查询 → 写缓存 → 回写。 // 注意:缓存键包含 DO/CD;同时通过 tryCacheRead 回填剩余 TTL。 func handleDNS( upstreams []string, cacheMaxTTL, timeout time.Duration, maxParallel int, stripECSBeforeForward bool, allowTCPFallback bool, ) dns.HandlerFunc { return func(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { dns.HandleFailed(w, r) return } q := r.Question[0] // 基本访问日志:类型/域名/ID/CD/DO/来源 IP:Port log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)", dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr()) // 可选:去除 ECS(推荐) if stripECSBeforeForward { stripECS(r) } // 缓存键(域名小写 + QTYPE/QCLASS + DO/CD) key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled) // 1) 缓存命中:命中即快速返回 if cached, ok := tryCacheRead(key); ok { log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name) writeReply(w, r, cached) return } log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name) // 2) 上游查询(带 context 取消 & TCP 可选回退);并发向多个上游竞速 ctx := context.Background() resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback) if resp == nil { log.Printf("[error] all upstreams failed for %s", q.Name) dns.HandleFailed(w, r) return } // 3) 写缓存(负面/正面均处理;剥离伪 RR) cacheWrite(key, resp, cacheMaxTTL) // 4) 回写给客户端,并打印 Answer 方便调试 for _, ans := range resp.Answer { log.Printf("[answer] %s", ans.String()) } writeReply(w, r, resp) } } func main() { rand.Seed(time.Now().UnixNano()) // 随机源:用于上游列表打乱 // 命令行参数:可根据部署场景调整默认值 certFile := flag.String("cert", "server.crt", "TLS 证书文件路径 (.cer/.crt)") keyFile := flag.String("key", "server.key", "TLS 私钥文件路径 (.key)") addr := flag.String("addr", ":853", "DoT 监听地址(默认 :853)") upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表(逗号分隔)") cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL(默认 60s;实际取 min(上游最小TTL, 本值))") timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时(默认 3s)") maxParallel := flag.Int("max-parallel", 3, "并发查询的上游数量上限") stripECSFlag := flag.Bool("strip-ecs", true, "转发上游前去除 EDNS Client Subnet") allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时允许 TCP 回退") verbose := flag.Bool("v", false, "verbose 日志(包含源码位置)") flag.Parse() initLogger(*verbose) // 加载 TLS 证书/私钥;用于 DoT(RFC 7858)监听 cert, err := tls.LoadX509KeyPair(*certFile, *keyFile) if err != nil { log.Fatalf("[fatal] failed to load cert: %v", err) } // 解析上游地址:支持“host”或“host:port”,缺省端口补 53。 var upstreams []string for _, s := range strings.Split(*upstreamStr, ",") { if t := strings.TrimSpace(s); t != "" { if !strings.Contains(t, ":") { t = fmt.Sprintf("%s:53", t) } upstreams = append(upstreams, t) } } if len(upstreams) == 0 { log.Fatal("[fatal] no upstream DNS servers provided") } // 全局 UDP 客户端(不设置 Client.Timeout,改用 ExchangeContext 控制超时) udpClient = &dns.Client{ Net: "udp", UDPSize: 4096, // 放大到 4K,减小 UDP 截断概率 } // context 用于优雅退出(SIGINT/SIGTERM 收到后取消) ctx, cancel := context.WithCancel(context.Background()) defer cancel() // 启动缓存清理器(后台 goroutine) startCacheCleaner(ctx) // 注册处理器:使用 ServeMux 将所有查询交给 handleDNS mux := dns.NewServeMux() mux.HandleFunc(".", handleDNS( upstreams, *cacheTTLFlag, *timeoutFlag, *maxParallel, *stripECSFlag, *allowTCPFallback, )) // 构造 DoT(tcp-tls)服务器:显式开启 TLS1.2/1.3,增加读写超时防止慢连接。 // NextProtos: "dot"(可选,部分客户端用于 ALPN 检测) srv := &dns.Server{ Addr: *addr, Net: "tcp-tls", TLSConfig: &tls.Config{ Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12, // 允许 1.2/1.3(默认启用 1.3) // 不设置 CipherSuites,交由 Go 自动选择(TLS1.3 有自身套件) NextProtos: []string{"dot"}, // 可选:显式 ALPN }, Handler: mux, ReadTimeout: 10 * time.Second, // 防止慢连接 WriteTimeout: 10 * time.Second, } // 捕获信号以便优雅退出(关闭监听、结束后台协程) stop := make(chan os.Signal, 1) signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) // 异步启动服务,错误通过 errCh 返回 errCh := make(chan error, 1) go func() { log.Printf("🚀 starting DNS-over-TLS on %s", *addr) log.Printf(" upstreams=%v | cache_max_ttl=%s | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v", upstreams, cacheTTLFlag.String(), timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback) errCh <- srv.ListenAndServe() }() // 等待退出信号或服务器错误 select { case sig := <-stop: log.Printf("[shutdown] caught signal: %s", sig) cancel() // miekg/dns 提供 Shutdown();部分版本无 ShutdownContext,这里用 Shutdown() if err := srv.Shutdown(); err != nil { log.Printf("[shutdown] server shutdown error: %v", err) } case err := <-errCh: if err != nil { log.Fatalf("[fatal] server error: %v", err) } } log.Println("[bye] server stopped.") }