package main import ( "context" "crypto/tls" "flag" "fmt" "log" "math/rand" "os" "os/signal" "strings" "sync" "syscall" "time" "github.com/miekg/dns" ) /****************************************************************** * 日志初始化 ******************************************************************/ func initLogger(verbose bool) { flags := log.Ldate | log.Ltime | log.Lmicroseconds if verbose { flags |= log.Lshortfile } log.SetFlags(flags) } /****************************************************************** * 缓存结构 ******************************************************************/ type cacheEntry struct { msg *dns.Msg // 上游完整响应(拷贝存储) expireAt time.Time // 过期时间 } var cache sync.Map const cacheCleanupInterval = 5 * time.Minute func startCacheCleaner(ctx context.Context) { go func() { ticker := time.NewTicker(cacheCleanupInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: now := time.Now() n := 0 cache.Range(func(k, v any) bool { e := v.(*cacheEntry) if now.After(e.expireAt) { cache.Delete(k) n++ } return true }) if n > 0 { log.Printf("[cache] cleaned %d expired entries", n) } } } }() } // 计算缓存键:name + type + class + DO + CD func cacheKeyFromMsg(q dns.Question, do, cd bool) string { var b strings.Builder b.Grow(len(q.Name) + 32) b.WriteString(strings.ToLower(q.Name)) b.WriteString("|T=") b.WriteString(dns.TypeToString[q.Qtype]) b.WriteString("|C=") b.WriteString(dns.ClassToString[q.Qclass]) if do { b.WriteString("|DO") } if cd { b.WriteString("|CD") } return b.String() } // 命中缓存:回填剩余 TTL(不超过记录自身 TTL) func tryCacheRead(key string) (*dns.Msg, bool) { v, ok := cache.Load(key) if !ok { return nil, false } e := v.(*cacheEntry) now := time.Now() if now.After(e.expireAt) { cache.Delete(key) return nil, false } out := e.msg.Copy() remaining := uint32(e.expireAt.Sub(now).Seconds()) if remaining == 0 { cache.Delete(key) return nil, false } for i := range out.Answer { if out.Answer[i].Header().Ttl > remaining { out.Answer[i].Header().Ttl = remaining } } for i := range out.Ns { if out.Ns[i].Header().Ttl > remaining { out.Ns[i].Header().Ttl = remaining } } for i := range out.Extra { if out.Extra[i].Header().Ttl > remaining { out.Extra[i].Header().Ttl = remaining } } return out, true } // 负面缓存 TTL(RFC 2308):NXDOMAIN 或 NODATA 使用 SOA.MINIMUM 与 SOA TTL 的较小者,再与配置上限取 min func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) { // NXDOMAIN,或 NOERROR 但 Answer 为空(NODATA) if m.Rcode != dns.RcodeNameError && !(m.Rcode == dns.RcodeSuccess && len(m.Answer) == 0) { return 0, false } var soa *dns.SOA for _, rr := range append(m.Ns, m.Extra...) { if s, ok := rr.(*dns.SOA); ok { soa = s break } } if soa == nil { return 0, false } ttl := soa.Hdr.Ttl if soa.Minttl < ttl { ttl = soa.Minttl } capTTL := uint32(maxTTL.Seconds()) if capTTL > 0 && ttl > capTTL { ttl = capTTL } return ttl, ttl > 0 } // 普通(正向)响应的最小 TTL(Answer/Ns/Extra) func minRRsetTTL(m *dns.Msg) (uint32, bool) { minTTL := uint32(0) hasTTL := false for _, sec := range [][]dns.RR{m.Answer, m.Ns, m.Extra} { for _, rr := range sec { ttl := rr.Header().Ttl if !hasTTL || ttl < minTTL { minTTL = ttl hasTTL = true } } } return minTTL, hasTTL } // 写缓存:先处理负面缓存,再处理正面缓存 func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) { if in == nil { return } // 仅缓存 NOERROR / NXDOMAIN,其余不缓存 if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError { return } // 负面缓存 if ttl, ok := negativeTTL(in, maxTTL); ok && ttl > 0 { expire := time.Now().Add(time.Duration(ttl) * time.Second) cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire}) return } // 正向缓存:minTTL 与 maxTTL 取较小 minTTL, ok := minRRsetTTL(in) if !ok { // 没有 TTL 时可用上限兜底(也可选择不缓存) if maxTTL > 0 { expire := time.Now().Add(maxTTL) cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire}) } return } cfgTTL := uint32(maxTTL.Seconds()) finalTTL := minTTL if cfgTTL > 0 && finalTTL > cfgTTL { finalTTL = cfgTTL } if finalTTL == 0 { return } expire := time.Now().Add(time.Duration(finalTTL) * time.Second) cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire}) } /****************************************************************** * 上游查询(带 context 取消、并发上限、UDP→TCP 回退) ******************************************************************/ // 全局可复用 UDP 客户端 var udpClient *dns.Client func shuffled(xs []string) []string { out := make([]string, len(xs)) copy(out, xs) rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] }) return out } func queryUpstreamsLimited( ctx context.Context, req *dns.Msg, upstreams []string, timeout time.Duration, maxParallel int, allowTCPFallback bool, ) *dns.Msg { if maxParallel <= 0 { maxParallel = 1 } servers := shuffled(upstreams) // 每次查询一个带超时的子 context;拿到首个有效结果后取消。 cctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() type result struct { msg *dns.Msg } ch := make(chan result, len(servers)) sem := make(chan struct{}, maxParallel) execOne := func(svr string) { defer func() { <-sem }() // UDP 查询(带 context) resp, _, err := udpClient.ExchangeContext(cctx, req, svr) if err == nil && resp != nil && resp.Truncated && allowTCPFallback { // TCP 回退 log.Printf("[upstream] UDP truncated, retry TCP: %s", svr) tcpClient := *udpClient tcpClient.Net = "tcp" resp, _, err = tcpClient.ExchangeContext(cctx, req, svr) } if err != nil || resp == nil { if err != nil { if cctx.Err() == nil { log.Printf("[upstream] %s error: %v", svr, err) } } else { log.Printf("[upstream] %s nil response", svr) } return } // 丢弃 SERVFAIL if resp.Rcode == dns.RcodeServerFailure { return } select { case ch <- result{msg: resp}: case <-cctx.Done(): } } for _, s := range servers { sem <- struct{}{} go execOne(s) } // 返回第一个非空结果,并 cancel 其他 goroutine for i := 0; i < len(servers); i++ { select { case r := <-ch: if r.msg != nil { cancel() return r.msg } case <-cctx.Done(): log.Printf("[upstream] timeout after %v", timeout) return nil } } return nil } /****************************************************************** * EDNS(0) / ECS 处理 ******************************************************************/ // 去除 EDNS Client Subnet(避免缓存污染与隐私泄露) func stripECS(m *dns.Msg) { if o := m.IsEdns0(); o != nil { var kept []dns.EDNS0 for _, e := range o.Option { if _, isECS := e.(*dns.EDNS0_SUBNET); !isECS { kept = append(kept, e) } } o.Option = kept } } // 获取 DO/EDNS func getDOFlag(m *dns.Msg) bool { if o := m.IsEdns0(); o != nil { return o.Do() } return false } /****************************************************************** * 响应构造:使用客户端请求头构造 reply,复制上游内容 ******************************************************************/ func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) { if upstream == nil { dns.HandleFailed(w, req) return } out := new(dns.Msg) out.SetReply(req) out.Authoritative = false out.RecursionAvailable = upstream.RecursionAvailable out.AuthenticatedData = upstream.AuthenticatedData out.CheckingDisabled = req.CheckingDisabled // 反映客户端 CD 位 out.Rcode = upstream.Rcode out.Answer = upstream.Answer out.Ns = upstream.Ns out.Extra = upstream.Extra out.Compress = true if err := w.WriteMsg(out); err != nil { log.Printf("[write] WriteMsg error: %v", err) } } /****************************************************************** * 处理器 ******************************************************************/ func handleDNS( upstreams []string, cacheMaxTTL, timeout time.Duration, maxParallel int, stripECSBeforeForward bool, allowTCPFallback bool, ) dns.HandlerFunc { return func(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { dns.HandleFailed(w, r) return } q := r.Question[0] // 记录请求 log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)", dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr()) // 可选:去除 ECS(推荐) if stripECSBeforeForward { stripECS(r) } // 缓存键 key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled) // 1) 缓存命中 if cached, ok := tryCacheRead(key); ok { log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name) writeReply(w, r, cached) return } log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name) // 2) 上游查询(带 context 取消 & TCP 可选回退) ctx := context.Background() resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback) if resp == nil { log.Printf("[error] all upstreams failed for %s", q.Name) dns.HandleFailed(w, r) return } // 3) 写缓存 cacheWrite(key, resp, cacheMaxTTL) // 4) 回写 for _, ans := range resp.Answer { log.Printf("[answer] %s", ans.String()) } writeReply(w, r, resp) } } /****************************************************************** * 主程序 ******************************************************************/ func main() { rand.Seed(time.Now().UnixNano()) // 参数 certFile := flag.String("cert", "server.crt", "TLS 证书文件路径 (.cer/.crt)") keyFile := flag.String("key", "server.key", "TLS 私钥文件路径 (.key)") addr := flag.String("addr", ":853", "DoT 监听地址(默认 :853)") upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表(逗号分隔)") cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL(默认 60s;实际取 min(上游最小TTL, 本值))") timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时(默认 3s)") maxParallel := flag.Int("max-parallel", 3, "并发查询的上游数量上限") stripECSFlag := flag.Bool("strip-ecs", true, "转发上游前去除 EDNS Client Subnet") allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时允许 TCP 回退") verbose := flag.Bool("v", false, "verbose 日志(包含源码位置)") flag.Parse() initLogger(*verbose) // 证书 cert, err := tls.LoadX509KeyPair(*certFile, *keyFile) if err != nil { log.Fatalf("[fatal] failed to load cert: %v", err) } // 上游 var upstreams []string for _, s := range strings.Split(*upstreamStr, ",") { if t := strings.TrimSpace(s); t != "" { if !strings.Contains(t, ":") { t = fmt.Sprintf("%s:53", t) } upstreams = append(upstreams, t) } } if len(upstreams) == 0 { log.Fatal("[fatal] no upstream DNS servers provided") } // 全局 UDP 客户端(UDPSize 放大;不设置 Timeout,使用 ExchangeContext 控制) udpClient = &dns.Client{ Net: "udp", UDPSize: 4096, } // context 用于优雅退出 ctx, cancel := context.WithCancel(context.Background()) defer cancel() // 启动缓存清理器 startCacheCleaner(ctx) // mux/handler mux := dns.NewServeMux() mux.HandleFunc(".", handleDNS( upstreams, *cacheTTLFlag, *timeoutFlag, *maxParallel, *stripECSFlag, *allowTCPFallback, )) // DoT 服务(启用 TLS1.3;不手动设 CipherSuites 以使用 Go 默认安全套件) srv := &dns.Server{ Addr: *addr, Net: "tcp-tls", TLSConfig: &tls.Config{ Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12, // 允许 1.2/1.3(默认启用 1.3) // 不设置 CipherSuites,交由 Go 自动选择(TLS1.3 有自身套件) }, Handler: mux, } // 捕获信号优雅退出 stop := make(chan os.Signal, 1) signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) errCh := make(chan error, 1) go func() { log.Printf("🚀 starting DNS-over-TLS on %s", *addr) log.Printf(" upstreams=%v | cache_max_ttl=%s | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v", upstreams, cacheTTLFlag.String(), timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback) errCh <- srv.ListenAndServe() }() select { case sig := <-stop: log.Printf("[shutdown] caught signal: %s", sig) cancel() // miekg/dns 提供 Shutdown();部分版本没有 ShutdownContext,这里用 Shutdown() if err := srv.Shutdown(); err != nil { log.Printf("[shutdown] server shutdown error: %v", err) } case err := <-errCh: if err != nil { log.Fatalf("[fatal] server error: %v", err) } } log.Println("[bye] server stopped.") }