package main import ( "context" "crypto/tls" "flag" "fmt" "log" "math/rand" "os" "os/signal" "strings" "sync" "syscall" "time" "github.com/miekg/dns" lru "github.com/hashicorp/golang-lru/v2" ) /****************************************************************** * 日志初始化 ******************************************************************/ func initLogger(verbose bool) { flags := log.Ldate | log.Ltime | log.Lmicroseconds if verbose { flags |= log.Lshortfile } log.SetFlags(flags) } /****************************************************************** * 缓存结构(支持 TTL + LRU) ******************************************************************/ type cacheEntry struct { msg *dns.Msg expireAt time.Time } var ( cache *lru.Cache[string, *cacheEntry] cacheMutex sync.RWMutex ) const ( cacheCleanupInterval = 5 * time.Minute defaultCacheSize = 10000 // 默认最大缓存条目数 ) // startCacheCleaner 定期清理过期缓存(修复:在删除前二次校验) func startCacheCleaner(ctx context.Context) { go func() { ticker := time.NewTicker(cacheCleanupInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: now := time.Now() var toDelete []string cacheMutex.RLock() for _, k := range cache.Keys() { if v, ok := cache.Peek(k); ok && now.After(v.expireAt) { toDelete = append(toDelete, k) } } cacheMutex.RUnlock() if len(toDelete) > 0 { pruned := 0 cacheMutex.Lock() for _, k := range toDelete { if v, ok := cache.Peek(k); ok && now.After(v.expireAt) { cache.Remove(k) pruned++ } } cacheMutex.Unlock() if pruned > 0 { log.Printf("[cache] cleaned %d expired entries", pruned) } } } } }() } func cacheKeyFromMsg(q dns.Question, do, cd bool) string { var b strings.Builder b.Grow(len(q.Name) + 32) b.WriteString(strings.ToLower(q.Name)) b.WriteString("|T=") b.WriteString(dns.TypeToString[q.Qtype]) b.WriteString("|C=") b.WriteString(dns.ClassToString[q.Qclass]) if do { b.WriteString("|DO") } if cd { b.WriteString("|CD") } return b.String() } func isPseudo(rr dns.RR) bool { switch rr.(type) { case *dns.OPT, *dns.TSIG: return true default: return false } } // 读取缓存(修复:Get 在写锁下;在锁外调整 TTL) func tryCacheRead(key string) (*dns.Msg, bool) { now := time.Now() cacheMutex.Lock() e, ok := cache.Get(key) // Get 会更新 LRU,必须在写锁下 if !ok { cacheMutex.Unlock() return nil, false } if now.After(e.expireAt) { cache.Remove(key) cacheMutex.Unlock() return nil, false } // 拷贝副本,在锁外改 TTL,减少临界区时间 out := e.msg.Copy() expireAt := e.expireAt cacheMutex.Unlock() remaining := uint32(expireAt.Sub(now).Seconds()) if remaining == 0 { cacheMutex.Lock() cache.Remove(key) cacheMutex.Unlock() return nil, false } for _, sec := range [][]dns.RR{out.Answer, out.Ns, out.Extra} { for _, rr := range sec { if isPseudo(rr) { continue } if rr.Header().Ttl > remaining { rr.Header().Ttl = remaining } } } return out, true } // 计算负面 TTL // hasAnswerForType 判断报文中是否存在回答“请求类型”的 RRset func hasAnswerForType(m *dns.Msg, q dns.Question) bool { for _, rr := range m.Answer { h := rr.Header() if h.Rrtype == q.Qtype && strings.EqualFold(h.Name, q.Name) { return true } } return false } // 计算负面 TTL(修复:正确识别 NODATA,包括 CNAME 等场景) func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) { // NXDOMAIN:肯定是负面 if m.Rcode != dns.RcodeNameError { // 不是 NXDOMAIN,则仅当 NOERROR 但没有“匹配 QTYPE 的答案”时才是 NODATA if m.Rcode != dns.RcodeSuccess || len(m.Question) == 0 || hasAnswerForType(m, m.Question[0]) { return 0, false } } // 按 RFC 2308,从 Authority(Ns)优先取 SOA(多数实现都只放在 Authority) var soa *dns.SOA for _, rr := range m.Ns { if s, ok := rr.(*dns.SOA); ok { soa = s break } } // 兼容性:偶尔也有人把 SOA 放 Extra(不规范,但为了兼容可以兜底看看) if soa == nil { for _, rr := range m.Extra { if s, ok := rr.(*dns.SOA); ok { soa = s break } } } if soa == nil { // 建议:无 SOA 时不做负面缓存(返回 0,false) // 如你更希望兜底,可改成:return uint32(maxTTL.Seconds()), true return 0, false } // 负面 TTL 取 min(SOA.MINIMUM, SOA 自身 TTL),再与配置上限比较 ttl := soa.Hdr.Ttl if soa.Minttl < ttl { ttl = soa.Minttl } capTTL := uint32(maxTTL.Seconds()) if capTTL > 0 && ttl > capTTL { ttl = capTTL } return ttl, ttl > 0 } func minRRsetTTL(m *dns.Msg) (uint32, bool) { minTTL := uint32(0) hasTTL := false for _, sec := range [][]dns.RR{m.Answer, m.Ns, m.Extra} { for _, rr := range sec { if isPseudo(rr) { continue } ttl := rr.Header().Ttl if !hasTTL || ttl < minTTL { minTTL = ttl hasTTL = true } } } return minTTL, hasTTL } func stripPseudoExtras(m *dns.Msg) { if len(m.Extra) == 0 { return } out := m.Extra[:0] for _, rr := range m.Extra { if isPseudo(rr) { continue } out = append(out, rr) } m.Extra = out } // 写缓存 func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) { if in == nil { return } if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError { return } var ttl uint32 var ok bool if ttl, ok = negativeTTL(in, maxTTL); !ok { minTTL, has := minRRsetTTL(in) if has { cfgTTL := uint32(maxTTL.Seconds()) if cfgTTL > 0 && minTTL > cfgTTL { minTTL = cfgTTL } ttl = minTTL } else { ttl = uint32(maxTTL.Seconds()) } } if ttl == 0 { return } expire := time.Now().Add(time.Duration(ttl) * time.Second) cp := in.Copy() stripPseudoExtras(cp) cacheMutex.Lock() cache.Add(key, &cacheEntry{msg: cp, expireAt: expire}) cacheMutex.Unlock() } /****************************************************************** * 上游查询 ******************************************************************/ var udpClient *dns.Client func shuffled(xs []string) []string { out := make([]string, len(xs)) copy(out, xs) rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] }) return out } // clampEDNSForUpstream 返回一个 msg 副本,把 EDNS UDP size 夹到给定大小 func clampEDNSForUpstream(in *dns.Msg, size uint16) *dns.Msg { m := in.Copy() o := m.IsEdns0() if o == nil { o = &dns.OPT{} o.Hdr.Name = "." o.Hdr.Rrtype = dns.TypeOPT m.Extra = append(m.Extra, o) } if size > 0 { o.SetUDPSize(size) } return m } func queryUpstreamsLimited( ctx context.Context, req *dns.Msg, upstreams []string, timeout time.Duration, maxParallel int, allowTCPFallback bool, ) *dns.Msg { if maxParallel <= 0 { maxParallel = 1 } servers := shuffled(upstreams) cctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() type result struct{ msg *dns.Msg } ch := make(chan result, len(servers)) sem := make(chan struct{}, maxParallel) execOne := func(svr string) { upReq := clampEDNSForUpstream(req, 1232) // 或做成 flag resp, _, err := udpClient.ExchangeContext(cctx, upReq, svr) if err == nil && resp != nil && resp.Truncated && allowTCPFallback { log.Printf("[upstream] UDP truncated, retry TCP: %s", svr) tcpClient := *udpClient tcpClient.Net = "tcp" resp, _, err = tcpClient.ExchangeContext(cctx, req.Copy(), svr) } if err != nil || resp == nil { if err != nil && cctx.Err() == nil { log.Printf("[upstream] %s error: %v", svr, err) } return } if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError { return } select { case ch <- result{msg: resp}: case <-cctx.Done(): } } for _, s := range servers { s := s go func() { select { case sem <- struct{}{}: defer func() { <-sem }() case <-cctx.Done(): return } execOne(s) }() } for i := 0; i < len(servers); i++ { select { case r := <-ch: if r.msg != nil { cancel() return r.msg } case <-cctx.Done(): log.Printf("[upstream] timeout after %v", timeout) return nil } } return nil } /****************************************************************** * EDNS / 响应构造 ******************************************************************/ func stripECS(m *dns.Msg) { if o := m.IsEdns0(); o != nil { var kept []dns.EDNS0 for _, e := range o.Option { if _, isECS := e.(*dns.EDNS0_SUBNET); !isECS { kept = append(kept, e) } } o.Option = kept } } func getDOFlag(m *dns.Msg) bool { if o := m.IsEdns0(); o != nil { return o.Do() } return false } func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) { if upstream == nil { dns.HandleFailed(w, req) return } out := new(dns.Msg) out.SetReply(req) out.Authoritative = false out.RecursionAvailable = upstream.RecursionAvailable out.AuthenticatedData = upstream.AuthenticatedData out.CheckingDisabled = req.CheckingDisabled out.Rcode = upstream.Rcode out.Answer = upstream.Answer out.Ns = upstream.Ns var extras []dns.RR for _, rr := range upstream.Extra { if !isPseudo(rr) { extras = append(extras, rr) } } if ro := req.IsEdns0(); ro != nil { o := new(dns.OPT) o.Hdr.Name = "." o.Hdr.Rrtype = dns.TypeOPT o.SetUDPSize(ro.UDPSize()) if ro.Do() { o.SetDo() } if uo := upstream.IsEdns0(); uo != nil { o.SetExtendedRcode(uint16(uo.ExtendedRcode())) o.SetVersion(uint8(uo.Version())) for _, opt := range uo.Option { if ede, ok := opt.(*dns.EDNS0_EDE); ok { o.Option = append(o.Option, ede) } } } extras = append(extras, o) } out.Extra = extras out.Compress = true if err := w.WriteMsg(out); err != nil { log.Printf("[write] WriteMsg error: %v", err) } } /****************************************************************** * 主处理器 ******************************************************************/ func handleDNS( upstreams []string, cacheMaxTTL, timeout time.Duration, maxParallel int, stripECSBeforeForward bool, allowTCPFallback bool, ) dns.HandlerFunc { return func(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { dns.HandleFailed(w, r) return } q := r.Question[0] log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)", dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr()) if stripECSBeforeForward { stripECS(r) } key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled) if cached, ok := tryCacheRead(key); ok { log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name) writeReply(w, r, cached) return } log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name) ctx := context.Background() resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback) if resp == nil { log.Printf("[error] all upstreams failed for %s", q.Name) dns.HandleFailed(w, r) return } cacheWrite(key, resp, cacheMaxTTL) for _, ans := range resp.Answer { log.Printf("[answer] %s", ans.String()) } writeReply(w, r, resp) } } /****************************************************************** * 主函数 ******************************************************************/ func main() { rand.Seed(time.Now().UnixNano()) certFile := flag.String("cert", "server.crt", "TLS 证书文件路径") keyFile := flag.String("key", "server.key", "TLS 私钥文件路径") addr := flag.String("addr", ":853", "DoT 监听地址") upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表") cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL") cacheSizeFlag := flag.Int("cache-size", defaultCacheSize, "LRU 缓存大小上限") timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时") maxParallel := flag.Int("max-parallel", 3, "并发上游数量") stripECSFlag := flag.Bool("strip-ecs", true, "去除 ECS") allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时 TCP 回退") verbose := flag.Bool("v", false, "verbose 日志") flag.Parse() initLogger(*verbose) var err error cache, err = lru.New[string, *cacheEntry](*cacheSizeFlag) if err != nil { log.Fatalf("[fatal] failed to init LRU cache: %v", err) } cert, err := tls.LoadX509KeyPair(*certFile, *keyFile) if err != nil { log.Fatalf("[fatal] failed to load cert: %v", err) } var upstreams []string for _, s := range strings.Split(*upstreamStr, ",") { if t := strings.TrimSpace(s); t != "" { if !strings.Contains(t, ":") { t = fmt.Sprintf("%s:53", t) } upstreams = append(upstreams, t) } } if len(upstreams) == 0 { log.Fatal("[fatal] no upstream DNS servers provided") } udpClient = &dns.Client{Net: "udp", UDPSize: 4096, SingleInflight: true} ctx, cancel := context.WithCancel(context.Background()) defer cancel() startCacheCleaner(ctx) mux := dns.NewServeMux() mux.HandleFunc(".", handleDNS(upstreams, *cacheTTLFlag, *timeoutFlag, *maxParallel, *stripECSFlag, *allowTCPFallback)) srv := &dns.Server{ Addr: *addr, Net: "tcp-tls", TLSConfig: &tls.Config{ Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12, NextProtos: []string{"dot"}, }, Handler: mux, ReadTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second, } stop := make(chan os.Signal, 1) signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) errCh := make(chan error, 1) go func() { log.Printf("🚀 starting DNS-over-TLS on %s", *addr) log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v", upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback) errCh <- srv.ListenAndServe() }() select { case sig := <-stop: log.Printf("[shutdown] caught signal: %s", sig) cancel() if err := srv.Shutdown(); err != nil { log.Printf("[shutdown] server shutdown error: %v", err) } case err := <-errCh: if err != nil { log.Fatalf("[fatal] server error: %v", err) } } log.Println("[bye] server stopped.") }