// main.go package main import ( "context" "crypto/tls" "flag" "fmt" "log" "math/rand" "net" "os" "os/signal" "strings" "sync" "sync/atomic" "syscall" "time" lru "github.com/hashicorp/golang-lru/v2" "github.com/miekg/dns" "golang.org/x/sync/singleflight" ) var BuildDate = "unknown" // 由编译时注入 /****************************************************************** * 日志初始化 ******************************************************************/ func initLogger(verbose bool) { flags := log.Ldate | log.Ltime | log.Lmicroseconds if verbose { flags |= log.Lshortfile } log.SetFlags(flags) } /****************************************************************** * 缓存结构(支持 TTL + LRU) ******************************************************************/ type cacheEntry struct { msg *dns.Msg expireAt time.Time // EDNS metadata (to reproduce Extended RCODE / EDE on cache hits) ednsPresent bool ednsVersion uint8 ednsExtRcode uint16 ednsEDE []*dns.EDNS0_EDE } var ( cache *lru.Cache[string, *cacheEntry] cacheMutex sync.RWMutex inflight singleflight.Group ) const ( cacheCleanupInterval = 5 * time.Minute defaultCacheSize = 10000 // 默认最大缓存条目数 ) // startCacheCleaner 定期清理过期缓存(在删除前二次校验) func startCacheCleaner(ctx context.Context) { go func() { ticker := time.NewTicker(cacheCleanupInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: now := time.Now() var toDelete []string cacheMutex.RLock() for _, k := range cache.Keys() { if v, ok := cache.Peek(k); ok && now.After(v.expireAt) { toDelete = append(toDelete, k) } } cacheMutex.RUnlock() if len(toDelete) > 0 { pruned := 0 cacheMutex.Lock() for _, k := range toDelete { if v, ok := cache.Peek(k); ok && now.After(v.expireAt) { cache.Remove(k) pruned++ } } cacheMutex.Unlock() if pruned > 0 { log.Printf("[cache] cleaned %d expired entries", pruned) } } } } }() } func cacheKeyFromMsg(q dns.Question, do, cd bool) string { var b strings.Builder b.Grow(len(q.Name) + 32) // 采用规范化域名,避免尾随点/IDNA/大小写造成的重复键 b.WriteString(dns.CanonicalName(q.Name)) b.WriteString("|T=") b.WriteString(dns.TypeToString[q.Qtype]) b.WriteString("|C=") b.WriteString(dns.ClassToString[q.Qclass]) if do { b.WriteString("|DO") } if cd { b.WriteString("|CD") } return b.String() } // ecsKeyPart 在未 strip ECS 时,把 ECS 归一化后的“网络”信息并入缓存 key // 为最小改动:当 strip=true(即启用去 ECS)时直接返回空字符串 func ecsKeyPart(m *dns.Msg, strip bool) string { if strip { return "" } o := m.IsEdns0() if o == nil { return "" } for _, opt := range o.Option { s, ok := opt.(*dns.EDNS0_SUBNET) if !ok { continue } fam := s.Family pfx := int(s.SourceNetmask) addr := append(net.IP(nil), s.Address...) // 拷贝以免原切片被改 switch fam { case 1: // IPv4 ip := addr.To4() if ip != nil { mask := net.CIDRMask(pfx, 32) ip = ip.Mask(mask) return fmt.Sprintf("|ECS=%d/%d/%x", fam, s.SourceNetmask, ip) } case 2: // IPv6 ip := addr.To16() if ip != nil { mask := net.CIDRMask(pfx, 128) for i := 0; i < 16; i++ { ip[i] &= mask[i] } return fmt.Sprintf("|ECS=%d/%d/%x", fam, s.SourceNetmask, ip) } } // 回退:不做掩码 return fmt.Sprintf("|ECS=%d/%d/%x", fam, s.SourceNetmask, addr) } return "" } func isPseudo(rr dns.RR) bool { switch rr.(type) { case *dns.OPT, *dns.TSIG: return true default: return false } } // clone and extract EDNS metadata (present, version, ext-rcode, all EDEs) func cloneEDE(in *dns.EDNS0_EDE) *dns.EDNS0_EDE { if in == nil { return nil } cp := *in return &cp } func extractEDNSMeta(m *dns.Msg) (present bool, version uint8, ext uint16, ede []*dns.EDNS0_EDE) { if o := m.IsEdns0(); o != nil { present = true version = o.Version() ext = uint16(o.ExtendedRcode()) for _, opt := range o.Option { if e, ok := opt.(*dns.EDNS0_EDE); ok { ede = append(ede, cloneEDE(e)) } } } return } // 读取缓存(Get 在写锁下;在锁外调整 TTL;返回 EDNS 元数据) func tryCacheRead(key string) (*dns.Msg, *cacheEntry, bool) { now := time.Now() cacheMutex.Lock() e, ok := cache.Get(key) // Get 会更新 LRU,必须在写锁下 if !ok { cacheMutex.Unlock() return nil, nil, false } if now.After(e.expireAt) { cache.Remove(key) cacheMutex.Unlock() return nil, nil, false } // 拷贝副本,在锁外改 TTL,减少临界区时间 out := e.msg.Copy() expireAt := e.expireAt cacheMutex.Unlock() remaining := uint32(expireAt.Sub(now).Seconds()) if remaining == 0 { cacheMutex.Lock() cache.Remove(key) cacheMutex.Unlock() return nil, nil, false } for _, sec := range [][]dns.RR{out.Answer, out.Ns, out.Extra} { for _, rr := range sec { if isPseudo(rr) { continue } if rr.Header().Ttl > remaining { rr.Header().Ttl = remaining } } } return out, e, true } // hasAnswerForType 判断报文中是否存在回答“请求类型”的 RRset func hasAnswerForType(m *dns.Msg, q dns.Question) bool { for _, rr := range m.Answer { h := rr.Header() if h.Rrtype == q.Qtype && strings.EqualFold(h.Name, q.Name) { return true } } return false } // 计算负面 TTL(正确识别 NODATA,包括 CNAME 等场景) func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) { // NXDOMAIN:肯定是负面 if m.Rcode != dns.RcodeNameError { // 不是 NXDOMAIN,则仅当 NOERROR 但没有“匹配 QTYPE 的答案”时才是 NODATA if m.Rcode != dns.RcodeSuccess || len(m.Question) == 0 || hasAnswerForType(m, m.Question[0]) { return 0, false } } // 按 RFC 2308,从 Authority(Ns)优先取 SOA(多数实现都只放在 Authority) var soa *dns.SOA for _, rr := range m.Ns { if s, ok := rr.(*dns.SOA); ok { soa = s break } } // 兼容性:偶尔也有人把 SOA 放 Extra(不规范,但为了兼容可以兜底看看) if soa == nil { for _, rr := range m.Extra { if s, ok := rr.(*dns.SOA); ok { soa = s break } } } if soa == nil { // 建议:无 SOA 时不做负面缓存(返回 0,false) return 0, false } // 负面 TTL 取 min(SOA.MINIMUM, SOA 自身 TTL),再与配置上限比较 ttl := soa.Hdr.Ttl if soa.Minttl < ttl { ttl = soa.Minttl } capTTL := uint32(maxTTL.Seconds()) if capTTL > 0 && ttl > capTTL { ttl = capTTL } return ttl, ttl > 0 } func minRRsetTTL(m *dns.Msg) (uint32, bool) { minTTL := uint32(0) hasTTL := false // 优先 Answer -> Ns;若都为空,再考虑 Extra(排除伪记录) for _, sec := range [][]dns.RR{m.Answer, m.Ns} { for _, rr := range sec { if isPseudo(rr) { continue } ttl := rr.Header().Ttl if !hasTTL || ttl < minTTL { minTTL = ttl hasTTL = true } } } if !hasTTL { for _, rr := range m.Extra { if isPseudo(rr) { continue } ttl := rr.Header().Ttl if !hasTTL || ttl < minTTL { minTTL = ttl hasTTL = true } } } return minTTL, hasTTL } func stripPseudoExtras(m *dns.Msg) { if len(m.Extra) == 0 { return } out := m.Extra[:0] for _, rr := range m.Extra { if isPseudo(rr) { continue } out = append(out, rr) } m.Extra = out } // 写缓存(保存 EDNS 元数据,命中时可重建扩展 RCODE/EDE) func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) { if in == nil { return } if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError { return } var ttl uint32 var ok bool // 判断负面响应(NXDOMAIN 或 NODATA) neg, isNodata := in.Rcode == dns.RcodeNameError, false if in.Rcode == dns.RcodeSuccess && len(in.Question) > 0 && !hasAnswerForType(in, in.Question[0]) { isNodata = true } if ttl, ok = negativeTTL(in, maxTTL); !ok { if neg || isNodata { return // 负面但无 SOA → 不缓存 } minTTL, has := minRRsetTTL(in) if has { cfgTTL := uint32(maxTTL.Seconds()) if cfgTTL > 0 && minTTL > cfgTTL { minTTL = cfgTTL } ttl = minTTL } else { ttl = uint32(maxTTL.Seconds()) } } if ttl == 0 { return } expire := time.Now().Add(time.Duration(ttl) * time.Second) cp := in.Copy() // 提取 EDNS 元数据后再剥离伪记录 present, ver, ext, ede := extractEDNSMeta(cp) stripPseudoExtras(cp) cacheMutex.Lock() cache.Add(key, &cacheEntry{ msg: cp, expireAt: expire, ednsPresent: present, ednsVersion: ver, ednsExtRcode: ext, ednsEDE: ede, }) cacheMutex.Unlock() } /****************************************************************** * 上游查询 ******************************************************************/ var udpClient *dns.Client func shuffled(xs []string) []string { out := make([]string, len(xs)) copy(out, xs) rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] }) return out } // clampEDNSForUpstream 返回一个 msg 副本,把 EDNS UDP size 夹到给定大小 func clampEDNSForUpstream(in *dns.Msg, size uint16) *dns.Msg { m := in.Copy() o := m.IsEdns0() if o == nil { o = &dns.OPT{} o.Hdr.Name = "." o.Hdr.Rrtype = dns.TypeOPT m.Extra = append(m.Extra, o) } if size > 0 { o.SetUDPSize(size) } return m } func queryUpstreamsLimited( ctx context.Context, req *dns.Msg, upstreams []string, timeout time.Duration, maxParallel int, allowTCPFallback bool, ) *dns.Msg { if maxParallel <= 0 { maxParallel = 1 } servers := shuffled(upstreams) cctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() type result struct { msg *dns.Msg } ch := make(chan result, len(servers)) done := make(chan struct{}, len(servers)) sem := make(chan struct{}, maxParallel) // 单个上游执行 execOne := func(svr string) { // 并发限流(可被超时取消) select { case sem <- struct{}{}: defer func() { <-sem }() case <-cctx.Done(): // 超时/取消,直接放弃 return } defer func() { done <- struct{}{} }() // 为 UDP 上游把 EDNS UDP size 夹到 1232,降低分片风险 upReq := clampEDNSForUpstream(req, 1232) // 先走 UDP resp, _, err := udpClient.ExchangeContext(cctx, upReq, svr) // 截断且允许回退则走 TCP if err == nil && resp != nil && resp.Truncated && allowTCPFallback { log.Printf("[upstream] UDP truncated, retry TCP: %s", svr) tcpClient := *udpClient tcpClient.Net = "tcp" resp, _, err = tcpClient.ExchangeContext(cctx, req.Copy(), svr) } // 失败直接返回(但不写入 ch);只在未超时情况下打印错误 if err != nil || resp == nil { if err != nil && cctx.Err() == nil { log.Printf("[upstream] %s: %v", svr, err) } return } // 过滤不可用的错误 RCODE(避免造成“假性超时”的错觉) if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError { return } // 投递可用结果(若已经超时则丢弃) select { case ch <- result{msg: resp}: case <-cctx.Done(): } } // 并发发起 for _, s := range servers { s := s go execOne(s) } finished := 0 total := len(servers) // 聚合:首个可用响应直接返回;区分“真超时”与“无可用结果” for finished < total { select { case r := <-ch: if r.msg != nil { cancel() return r.msg } case <-done: finished++ case <-cctx.Done(): log.Printf("[upstream] timeout after %v (finished=%d/%d)", timeout, finished, total) return nil } } // 所有上游都结束,但没有一个可用 log.Printf("[upstream] no acceptable upstream response (finished=%d/%d)", finished, total) return nil } /****************************************************************** * EDNS / 响应构造 ******************************************************************/ func stripECS(m *dns.Msg) { if o := m.IsEdns0(); o != nil { var kept []dns.EDNS0 for _, e := range o.Option { if _, isECS := e.(*dns.EDNS0_SUBNET); !isECS { kept = append(kept, e) } } o.Option = kept } } func getDOFlag(m *dns.Msg) bool { if o := m.IsEdns0(); o != nil { return o.Do() } return false } func writeReply(w dns.ResponseWriter, req *dns.Msg, upstream *dns.Msg, meta *cacheEntry) { if upstream == nil { dns.HandleFailed(w, req) return } out := new(dns.Msg) out.SetReply(req) out.Authoritative = false // RA 语义修正:RA 表示“服务器是否支持递归”,与客户端 RD 无关 out.RecursionAvailable = upstream.RecursionAvailable out.AuthenticatedData = upstream.AuthenticatedData out.CheckingDisabled = req.CheckingDisabled out.Rcode = upstream.Rcode out.Answer = upstream.Answer out.Ns = upstream.Ns var extras []dns.RR for _, rr := range upstream.Extra { if !isPseudo(rr) { extras = append(extras, rr) } } if ro := req.IsEdns0(); ro != nil { o := new(dns.OPT) o.Hdr.Name = "." o.Hdr.Rrtype = dns.TypeOPT o.SetUDPSize(ro.UDPSize()) if ro.Do() { o.SetDo() } // 优先使用“请求”的 EDNS 版本;扩展 RCODE/EDE 来自上游/缓存 o.SetVersion(ro.Version()) if uo := upstream.IsEdns0(); uo != nil { o.SetExtendedRcode(uint16(uo.ExtendedRcode())) for _, opt := range uo.Option { if ede, ok := opt.(*dns.EDNS0_EDE); ok { o.Option = append(o.Option, ede) } } } else if meta != nil && meta.ednsPresent { // Upstream/cached msg has no OPT(例如缓存时被剥离),用缓存元数据重建 o.SetExtendedRcode(uint16(meta.ednsExtRcode)) for _, e := range meta.ednsEDE { o.Option = append(o.Option, e) } } extras = append(extras, o) } out.Extra = extras out.Compress = true if err := w.WriteMsg(out); err != nil { log.Printf("[write] WriteMsg error: %v", err) } } /****************************************************************** * 主处理器 ******************************************************************/ func handleDNS( upstreams []string, cacheMaxTTL, timeout time.Duration, maxParallel int, stripECSBeforeForward bool, allowTCPFallback bool, blPtr *atomic.Pointer[suffixMatcher], blRcode int, ) dns.HandlerFunc { return func(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { dns.HandleFailed(w, r) return } q := r.Question[0] log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)", dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr()) if stripECSBeforeForward { stripECS(r) } // 黑名单拦截:命中则不查上游,直接返回 if rule, ok := blPtr.Load().match(q.Name); ok { nameCanon := dns.CanonicalName(q.Name) log.Printf("[blacklist] HIT %s rule=%s (no upstream query)", nameCanon, rule) up := makeBlockedUpstream(blRcode, rule) writeReply(w, r, up, nil) return } // 缓存 key:基础 + (未 strip 时的)ECS 片段 key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled) + ecsKeyPart(r, stripECSBeforeForward) if cachedMsg, cachedMeta, ok := tryCacheRead(key); ok { log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name) writeReply(w, r, cachedMsg, cachedMeta) return } log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name) // 使用 singleflight 合并相同 key 的并发查询,避免上游雪崩 v, _, _ := inflight.Do(key, func() (any, error) { ctx := context.Background() resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback) if resp != nil { cacheWrite(key, resp, cacheMaxTTL) } return resp, nil }) resp, _ := v.(*dns.Msg) if resp == nil { log.Printf("[error] all upstreams failed for %s", q.Name) dns.HandleFailed(w, r) return } for _, ans := range resp.Answer { log.Printf("[answer] %s", ans.String()) } writeReply(w, r, resp, nil) } } /****************************************************************** * 主函数 ******************************************************************/ func main() { rand.Seed(time.Now().UnixNano()) var help bool certFile := flag.String("cert", "server.crt", "TLS 证书文件路径") keyFile := flag.String("key", "server.key", "TLS 私钥文件路径") addr := flag.String("addr", ":853", "DoT 监听地址") upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表") cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL") cacheSizeFlag := flag.Int("cache-size", defaultCacheSize, "LRU 缓存大小上限") timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时") readTimeoutFlag := flag.Duration("read-timeout", 0, "DoT 连接读超时(0=不限制)") writeTimeoutFlag := flag.Duration("write-timeout", 0, "DoT 连接写超时(0=不限制)") maxParallel := flag.Int("max-parallel", 3, "并发上游数量") stripECSFlag := flag.Bool("strip-ecs", true, "去除 ECS") allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时 TCP 回退") blacklistStr := flag.String("blacklist", "", "逗号分隔的黑名单域名(后缀匹配;支持如 *.example.com)") blacklistFile := flag.String("blacklist-file", "", "黑名单文件路径(每行一个域名;支持 # 或 ; 注释;可接受 hosts 风格)") blacklistRcodeFlag := flag.String("blacklist-rcode", "REFUSED", "命中黑名单返回的 RCODE:REFUSED|NXDOMAIN|SERVFAIL") verbose := flag.Bool("v", false, "verbose 日志") flag.BoolVar(&help, "h", false, "") flag.BoolVar(&help, "help", false, "帮助信息") flag.Parse() if help { fmt.Printf( "\t\tDNS-over-TLS (DoT)\n"+ "\tVersion 0.1\n"+ "\tE-mail: aixiao@aixiao.me\n"+ "\tBuild Date: %s\n", BuildDate) flag.Usage() fmt.Printf("\n") os.Exit(0) } initLogger(*verbose) var err error cache, err = lru.New[string, *cacheEntry](*cacheSizeFlag) if err != nil { log.Fatalf("[fatal] failed to init LRU cache: %v", err) } cert, err := tls.LoadX509KeyPair(*certFile, *keyFile) if err != nil { log.Fatalf("[fatal] failed to load cert: %v", err) } var upstreams []string for _, s := range strings.Split(*upstreamStr, ",") { if t := strings.TrimSpace(s); t != "" { if !strings.Contains(t, ":") { t = fmt.Sprintf("%s:53", t) } upstreams = append(upstreams, t) } } if len(upstreams) == 0 { log.Fatal("[fatal] no upstream DNS servers provided") } // 如需更保守的 UDP 尺寸以减少分片,可将 UDPSize 改为 1232 udpClient = &dns.Client{Net: "udp", UDPSize: 4096, SingleInflight: true} ctx, cancel := context.WithCancel(context.Background()) defer cancel() startCacheCleaner(ctx) // 加载黑名单规则 blPtr, blRcode := initBlacklist(ctx, *blacklistStr, *blacklistFile, *blacklistRcodeFlag) mux := dns.NewServeMux() mux.HandleFunc(".", handleDNS( upstreams, *cacheTTLFlag, *timeoutFlag, *maxParallel, *stripECSFlag, *allowTCPFallback, blPtr, blRcode, )) srv := &dns.Server{ Addr: *addr, Net: "tcp-tls", TLSConfig: &tls.Config{ Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS13, // TLS 1.3 NextProtos: []string{"dot"}, }, Handler: mux, ReadTimeout: *readTimeoutFlag, WriteTimeout: *writeTimeoutFlag, } stop := make(chan os.Signal, 1) signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) errCh := make(chan error, 1) go func() { log.Printf("🚀 starting DNS-over-TLS on %s", *addr) log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v | read_timeout=%s | write_timeout=%s | blacklist_rules=%d | blacklist_rcode=%s", upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback, readTimeoutFlag.String(), writeTimeoutFlag.String(), len(blPtr.Load().rules), strings.ToUpper(*blacklistRcodeFlag)) errCh <- srv.ListenAndServe() }() select { case sig := <-stop: log.Printf("[shutdown] caught signal: %s", sig) cancel() if err := srv.Shutdown(); err != nil { log.Printf("[shutdown] server shutdown error: %v", err) } case err := <-errCh: if err != nil { log.Fatalf("[fatal] server error: %v", err) } } log.Println("[bye] server stopped.") }