package main import ( "context" "crypto/tls" "flag" "log" "math/rand" "os" "os/signal" "strings" "sync" "syscall" "time" "github.com/miekg/dns" ) /********************************* * 缓存结构与全局对象 *********************************/ type cacheEntry struct { msg *dns.Msg // 缓存的完整响应(深拷贝) expireAt time.Time // 过期时间(由动态 TTL 决定) } // 并发安全缓存 var cache sync.Map // 后台清理:每隔 N 分钟清一次 const cacheCleanupInterval = 5 * time.Minute // 启动带 context 的缓存清理器(优雅退出) func startCacheCleaner(ctx context.Context) { go func() { ticker := time.NewTicker(cacheCleanupInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: now := time.Now() n := 0 cache.Range(func(k, v any) bool { e := v.(*cacheEntry) if now.After(e.expireAt) { cache.Delete(k) n++ } return true }) if n > 0 { log.Printf("[Cache] Cleaned %d expired entries", n) } } } }() } // 生成缓存键 func cacheKey(name string, qtype uint16) string { return strings.ToLower(name) + ":" + dns.TypeToString[qtype] } // 读取缓存;命中则回填剩余 TTL 并返回拷贝 func tryCacheRead(key string) (*dns.Msg, bool) { v, ok := cache.Load(key) if !ok { return nil, false } e := v.(*cacheEntry) now := time.Now() if now.After(e.expireAt) { cache.Delete(key) return nil, false } out := e.msg.Copy() remaining := uint32(e.expireAt.Sub(now).Seconds()) // 回填剩余 TTL,避免客户端收到过期 TTL for i := range out.Answer { if out.Answer[i].Header().Ttl > remaining { out.Answer[i].Header().Ttl = remaining } } for i := range out.Ns { if out.Ns[i].Header().Ttl > remaining { out.Ns[i].Header().Ttl = remaining } } for i := range out.Extra { if out.Extra[i].Header().Ttl > remaining { out.Extra[i].Header().Ttl = remaining } } return out, true } // 写缓存:以「上游最小 TTL」与「配置上限 TTL」取较小值 func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) { if in == nil { return } // 可按需缓存 NXDOMAIN;这里允许缓存 NOERROR 与 NXDOMAIN if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError { return } // 计算报文中的最小 TTL(Answer/Ns/Extra) minTTL := uint32(0) hasTTL := false for _, sec := range [][]dns.RR{in.Answer, in.Ns, in.Extra} { for _, rr := range sec { ttl := rr.Header().Ttl if !hasTTL || ttl < minTTL { minTTL = ttl hasTTL = true } } } // 若无 TTL,可用配置上限作为兜底(也可选择不缓存) cfgTTL := uint32(maxTTL.Seconds()) var finalTTL uint32 switch { case !hasTTL && cfgTTL > 0: finalTTL = cfgTTL case hasTTL && cfgTTL > 0 && minTTL > cfgTTL: finalTTL = cfgTTL case hasTTL: finalTTL = minTTL default: return } if finalTTL == 0 { return } expire := time.Now().Add(time.Duration(finalTTL) * time.Second) cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire}) } /********************************* * 上游查询与并发控制 *********************************/ // 全局可复用的 DNS 客户端(默认 UDP) var dnsClient *dns.Client // 并发上限通过信号量限制 func queryUpstreamsLimited(r *dns.Msg, upstreams []string, timeout time.Duration, maxParallel int) *dns.Msg { if maxParallel <= 0 { maxParallel = 1 } ch := make(chan *dns.Msg, len(upstreams)) sem := make(chan struct{}, maxParallel) // 在 UDP 上查询,遇到截断再 TCP fallback queryOnce := func(server string) *dns.Msg { resp, _, err := dnsClient.Exchange(r, server) if err == nil && resp != nil && resp.Truncated { // UDP 截断,尝试 TCP log.Printf("[Info] UDP truncated, retry TCP: %s", server) tcpClient := *dnsClient tcpClient.Net = "tcp" resp, _, err = tcpClient.Exchange(r, server) } if err != nil || resp == nil { if err != nil { log.Printf("[Warn] Upstream %s failed: %v", server, err) } else { log.Printf("[Warn] Upstream %s failed: nil response", server) } return nil } // 可选:丢弃 SERVFAIL if resp.Rcode == dns.RcodeServerFailure { return nil } return resp } for _, server := range upstreams { sem <- struct{}{} go func(svr string) { defer func() { <-sem }() resp := queryOnce(svr) // 非阻塞/限时写入,防止消费者意外退出导致阻塞 select { case ch <- resp: case <-time.After(1 * time.Second): } }(server) } timer := time.NewTimer(timeout) defer timer.Stop() // 在超时内返回第一个非空的结果 for i := 0; i < len(upstreams); i++ { select { case resp := <-ch: if resp != nil { return resp } case <-timer.C: log.Printf("[Error] Upstream query timeout after %v", timeout) return nil } } return nil } /********************************* * DNS 处理逻辑 *********************************/ func shuffled(slice []string) []string { out := make([]string, len(slice)) copy(out, slice) rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] }) return out } func handleDNS(upstreams []string, cacheMaxTTL, timeout time.Duration, maxParallel int) dns.HandlerFunc { return func(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { dns.HandleFailed(w, r) return } q := r.Question[0] key := cacheKey(q.Name, q.Qtype) // 1) 缓存命中 if cached, ok := tryCacheRead(key); ok { log.Printf("[Cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name) _ = w.WriteMsg(cached) return } log.Printf("[Cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name) // 2) 随机化上游并并发查询(带 fallback) servers := shuffled(upstreams) resp := queryUpstreamsLimited(r, servers, timeout, maxParallel) if resp == nil { log.Printf("[Error] All upstreams failed for %s", q.Name) dns.HandleFailed(w, r) return } // 3) 写入缓存(动态 TTL) cacheWrite(key, resp, cacheMaxTTL) // 4) 返回结果 for _, ans := range resp.Answer { log.Printf("[Answer] %s", ans.String()) } _ = w.WriteMsg(resp) } } /********************************* * 主程序(优雅退出) *********************************/ func main() { rand.Seed(time.Now().UnixNano()) // 参数 certFile := flag.String("cert", "aixiao.me.cer", "TLS 证书文件路径 (.cer/.crt)") keyFile := flag.String("key", "aixiao.me.key", "TLS 私钥文件路径 (.key)") addr := flag.String("addr", ":853", "DoT 服务监听地址(默认 :853)") upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表(逗号分隔)") cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL(默认 60s;实际取 min(上游最小TTL, 本值))") timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时(默认 3s)") maxParallel := flag.Int("max-parallel", 3, "并发查询的上游数量上限") flag.Parse() // 证书 cert, err := tls.LoadX509KeyPair(*certFile, *keyFile) if err != nil { log.Fatalf("failed to load cert: %v", err) } // 上游列表 raw := strings.Split(*upstreamStr, ",") upstreams := make([]string, 0, len(raw)) for _, s := range raw { if t := strings.TrimSpace(s); t != "" { upstreams = append(upstreams, t) } } if len(upstreams) == 0 { log.Fatal("no upstream DNS servers provided") } // 全局 DNS 客户端(UDP,扩大 UDPSize;fallback 在查询函数中完成) dnsClient = &dns.Client{ Net: "udp", UDPSize: 4096, // 防截断 Timeout: *timeoutFlag, } // context 用于优雅退出与清理协程 ctx, cancel := context.WithCancel(context.Background()) defer cancel() // 启动缓存清理器 startCacheCleaner(ctx) // DNS 处理器 mux := dns.NewServeMux() mux.HandleFunc(".", handleDNS(upstreams, *cacheTTLFlag, *timeoutFlag, *maxParallel)) // DoT 服务器(TLS 会话缓存 + 安全套件 + TLS1.2+) srv := &dns.Server{ Addr: *addr, Net: "tcp-tls", TLSConfig: &tls.Config{ Certificates: []tls.Certificate{cert}, ClientSessionCache: tls.NewLRUClientSessionCache(256), MinVersion: tls.VersionTLS12, CipherSuites: []uint16{ tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, }, }, Handler: mux, } // 捕获退出信号,优雅关闭 stop := make(chan os.Signal, 1) signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) errCh := make(chan error, 1) go func() { log.Printf("🚀 Starting DNS-over-TLS server on %s", *addr) log.Printf("Upstreams=%v | MaxTTL=%s | Timeout=%s | MaxParallel=%d", upstreams, cacheTTLFlag.String(), timeoutFlag.String(), *maxParallel) errCh <- srv.ListenAndServe() }() select { case sig := <-stop: log.Printf("[Shutdown] Caught signal: %s", sig) cancel() // 结束清理器 // 优雅关闭服务器 // miekg/dns 提供 Shutdown();若你的版本支持 ShutdownContext,可改用带 ctx 的版本 if err := srv.Shutdown(); err != nil { log.Printf("[Shutdown] server shutdown error: %v", err) } case err := <-errCh: if err != nil { log.Fatalf("server error: %v", err) } } log.Println("[Shutdown] Bye.") }