diff --git a/dot b/dot index 73ec45e..59a11fd 100644 Binary files a/dot and b/dot differ diff --git a/main.go b/main.go index 85e5624..567a4bb 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "flag" + "fmt" "log" "math/rand" "os" @@ -16,22 +17,31 @@ import ( "github.com/miekg/dns" ) -/********************************* - * 缓存结构与全局对象 - *********************************/ +/****************************************************************** + * 日志初始化 + ******************************************************************/ -type cacheEntry struct { - msg *dns.Msg // 缓存的完整响应(深拷贝) - expireAt time.Time // 过期时间(由动态 TTL 决定) +func initLogger(verbose bool) { + flags := log.Ldate | log.Ltime | log.Lmicroseconds + if verbose { + flags |= log.Lshortfile + } + log.SetFlags(flags) +} + +/****************************************************************** + * 缓存结构 + ******************************************************************/ + +type cacheEntry struct { + msg *dns.Msg // 上游完整响应(拷贝存储) + expireAt time.Time // 过期时间 } -// 并发安全缓存 var cache sync.Map -// 后台清理:每隔 N 分钟清一次 const cacheCleanupInterval = 5 * time.Minute -// 启动带 context 的缓存清理器(优雅退出) func startCacheCleaner(ctx context.Context) { go func() { ticker := time.NewTicker(cacheCleanupInterval) @@ -52,19 +62,32 @@ func startCacheCleaner(ctx context.Context) { return true }) if n > 0 { - log.Printf("[Cache] Cleaned %d expired entries", n) + log.Printf("[cache] cleaned %d expired entries", n) } } } }() } -// 生成缓存键 -func cacheKey(name string, qtype uint16) string { - return strings.ToLower(name) + ":" + dns.TypeToString[qtype] +// 计算缓存键:name + type + class + DO + CD +func cacheKeyFromMsg(q dns.Question, do, cd bool) string { + var b strings.Builder + b.Grow(len(q.Name) + 32) + b.WriteString(strings.ToLower(q.Name)) + b.WriteString("|T=") + b.WriteString(dns.TypeToString[q.Qtype]) + b.WriteString("|C=") + b.WriteString(dns.ClassToString[q.Qclass]) + if do { + b.WriteString("|DO") + } + if cd { + b.WriteString("|CD") + } + return b.String() } -// 读取缓存;命中则回填剩余 TTL 并返回拷贝 +// 命中缓存:回填剩余 TTL(不超过记录自身 TTL) func tryCacheRead(key string) (*dns.Msg, bool) { v, ok := cache.Load(key) if !ok { @@ -78,7 +101,10 @@ func tryCacheRead(key string) (*dns.Msg, bool) { } out := e.msg.Copy() remaining := uint32(e.expireAt.Sub(now).Seconds()) - // 回填剩余 TTL,避免客户端收到过期 TTL + if remaining == 0 { + cache.Delete(key) + return nil, false + } for i := range out.Answer { if out.Answer[i].Header().Ttl > remaining { out.Answer[i].Header().Ttl = remaining @@ -97,19 +123,38 @@ func tryCacheRead(key string) (*dns.Msg, bool) { return out, true } -// 写缓存:以「上游最小 TTL」与「配置上限 TTL」取较小值 -func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) { - if in == nil { - return +// 负面缓存 TTL(RFC 2308):NXDOMAIN 或 NODATA 使用 SOA.MINIMUM 与 SOA TTL 的较小者,再与配置上限取 min +func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) { + // NXDOMAIN,或 NOERROR 但 Answer 为空(NODATA) + if m.Rcode != dns.RcodeNameError && !(m.Rcode == dns.RcodeSuccess && len(m.Answer) == 0) { + return 0, false } - // 可按需缓存 NXDOMAIN;这里允许缓存 NOERROR 与 NXDOMAIN - if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError { - return + var soa *dns.SOA + for _, rr := range append(m.Ns, m.Extra...) { + if s, ok := rr.(*dns.SOA); ok { + soa = s + break + } } - // 计算报文中的最小 TTL(Answer/Ns/Extra) + if soa == nil { + return 0, false + } + ttl := soa.Hdr.Ttl + if soa.Minttl < ttl { + ttl = soa.Minttl + } + capTTL := uint32(maxTTL.Seconds()) + if capTTL > 0 && ttl > capTTL { + ttl = capTTL + } + return ttl, ttl > 0 +} + +// 普通(正向)响应的最小 TTL(Answer/Ns/Extra) +func minRRsetTTL(m *dns.Msg) (uint32, bool) { minTTL := uint32(0) hasTTL := false - for _, sec := range [][]dns.RR{in.Answer, in.Ns, in.Extra} { + for _, sec := range [][]dns.RR{m.Answer, m.Ns, m.Extra} { for _, rr := range sec { ttl := rr.Header().Ttl if !hasTTL || ttl < minTTL { @@ -118,19 +163,39 @@ func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) { } } } - // 若无 TTL,可用配置上限作为兜底(也可选择不缓存) - cfgTTL := uint32(maxTTL.Seconds()) - var finalTTL uint32 - switch { - case !hasTTL && cfgTTL > 0: - finalTTL = cfgTTL - case hasTTL && cfgTTL > 0 && minTTL > cfgTTL: - finalTTL = cfgTTL - case hasTTL: - finalTTL = minTTL - default: + return minTTL, hasTTL +} + +// 写缓存:先处理负面缓存,再处理正面缓存 +func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) { + if in == nil { return } + // 仅缓存 NOERROR / NXDOMAIN,其余不缓存 + if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError { + return + } + // 负面缓存 + if ttl, ok := negativeTTL(in, maxTTL); ok && ttl > 0 { + expire := time.Now().Add(time.Duration(ttl) * time.Second) + cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire}) + return + } + // 正向缓存:minTTL 与 maxTTL 取较小 + minTTL, ok := minRRsetTTL(in) + if !ok { + // 没有 TTL 时可用上限兜底(也可选择不缓存) + if maxTTL > 0 { + expire := time.Now().Add(maxTTL) + cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire}) + } + return + } + cfgTTL := uint32(maxTTL.Seconds()) + finalTTL := minTTL + if cfgTTL > 0 && finalTTL > cfgTTL { + finalTTL = cfgTTL + } if finalTTL == 0 { return } @@ -138,219 +203,306 @@ func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) { cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire}) } -/********************************* - * 上游查询与并发控制 - *********************************/ +/****************************************************************** + * 上游查询(带 context 取消、并发上限、UDP→TCP 回退) + ******************************************************************/ -// 全局可复用的 DNS 客户端(默认 UDP) -var dnsClient *dns.Client +// 全局可复用 UDP 客户端 +var udpClient *dns.Client -// 并发上限通过信号量限制 -func queryUpstreamsLimited(r *dns.Msg, upstreams []string, timeout time.Duration, maxParallel int) *dns.Msg { +func shuffled(xs []string) []string { + out := make([]string, len(xs)) + copy(out, xs) + rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] }) + return out +} + +func queryUpstreamsLimited( + ctx context.Context, + req *dns.Msg, + upstreams []string, + timeout time.Duration, + maxParallel int, + allowTCPFallback bool, +) *dns.Msg { if maxParallel <= 0 { maxParallel = 1 } - ch := make(chan *dns.Msg, len(upstreams)) + servers := shuffled(upstreams) + + // 每次查询一个带超时的子 context;拿到首个有效结果后取消。 + cctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + type result struct { + msg *dns.Msg + } + ch := make(chan result, len(servers)) sem := make(chan struct{}, maxParallel) - // 在 UDP 上查询,遇到截断再 TCP fallback - queryOnce := func(server string) *dns.Msg { - resp, _, err := dnsClient.Exchange(r, server) - if err == nil && resp != nil && resp.Truncated { - // UDP 截断,尝试 TCP - log.Printf("[Info] UDP truncated, retry TCP: %s", server) - tcpClient := *dnsClient + execOne := func(svr string) { + defer func() { <-sem }() + // UDP 查询(带 context) + resp, _, err := udpClient.ExchangeContext(cctx, req, svr) + if err == nil && resp != nil && resp.Truncated && allowTCPFallback { + // TCP 回退 + log.Printf("[upstream] UDP truncated, retry TCP: %s", svr) + tcpClient := *udpClient tcpClient.Net = "tcp" - resp, _, err = tcpClient.Exchange(r, server) + resp, _, err = tcpClient.ExchangeContext(cctx, req, svr) } if err != nil || resp == nil { if err != nil { - log.Printf("[Warn] Upstream %s failed: %v", server, err) + if cctx.Err() == nil { + log.Printf("[upstream] %s error: %v", svr, err) + } } else { - log.Printf("[Warn] Upstream %s failed: nil response", server) + log.Printf("[upstream] %s nil response", svr) } - return nil + return } - // 可选:丢弃 SERVFAIL + // 丢弃 SERVFAIL if resp.Rcode == dns.RcodeServerFailure { - return nil + return } - return resp - } - - for _, server := range upstreams { - sem <- struct{}{} - go func(svr string) { - defer func() { <-sem }() - resp := queryOnce(svr) - // 非阻塞/限时写入,防止消费者意外退出导致阻塞 - select { - case ch <- resp: - case <-time.After(1 * time.Second): - } - }(server) - } - - timer := time.NewTimer(timeout) - defer timer.Stop() - - // 在超时内返回第一个非空的结果 - for i := 0; i < len(upstreams); i++ { select { - case resp := <-ch: - if resp != nil { - return resp + case ch <- result{msg: resp}: + case <-cctx.Done(): + } + } + + for _, s := range servers { + sem <- struct{}{} + go execOne(s) + } + + // 返回第一个非空结果,并 cancel 其他 goroutine + for i := 0; i < len(servers); i++ { + select { + case r := <-ch: + if r.msg != nil { + cancel() + return r.msg } - case <-timer.C: - log.Printf("[Error] Upstream query timeout after %v", timeout) + case <-cctx.Done(): + log.Printf("[upstream] timeout after %v", timeout) return nil } } return nil } -/********************************* - * DNS 处理逻辑 - *********************************/ +/****************************************************************** + * EDNS(0) / ECS 处理 + ******************************************************************/ -func shuffled(slice []string) []string { - out := make([]string, len(slice)) - copy(out, slice) - rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] }) - return out +// 去除 EDNS Client Subnet(避免缓存污染与隐私泄露) +func stripECS(m *dns.Msg) { + if o := m.IsEdns0(); o != nil { + var kept []dns.EDNS0 + for _, e := range o.Option { + if _, isECS := e.(*dns.EDNS0_SUBNET); !isECS { + kept = append(kept, e) + } + } + o.Option = kept + } } -func handleDNS(upstreams []string, cacheMaxTTL, timeout time.Duration, maxParallel int) dns.HandlerFunc { +// 获取 DO/EDNS +func getDOFlag(m *dns.Msg) bool { + if o := m.IsEdns0(); o != nil { + return o.Do() + } + return false +} + +/****************************************************************** + * 响应构造:使用客户端请求头构造 reply,复制上游内容 + ******************************************************************/ + +func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) { + if upstream == nil { + dns.HandleFailed(w, req) + return + } + out := new(dns.Msg) + out.SetReply(req) + out.Authoritative = false + out.RecursionAvailable = upstream.RecursionAvailable + out.AuthenticatedData = upstream.AuthenticatedData + out.CheckingDisabled = req.CheckingDisabled // 反映客户端 CD 位 + out.Rcode = upstream.Rcode + out.Answer = upstream.Answer + out.Ns = upstream.Ns + out.Extra = upstream.Extra + out.Compress = true + if err := w.WriteMsg(out); err != nil { + log.Printf("[write] WriteMsg error: %v", err) + } +} + +/****************************************************************** + * 处理器 + ******************************************************************/ + +func handleDNS( + upstreams []string, + cacheMaxTTL, timeout time.Duration, + maxParallel int, + stripECSBeforeForward bool, + allowTCPFallback bool, +) dns.HandlerFunc { return func(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { dns.HandleFailed(w, r) return } q := r.Question[0] - key := cacheKey(q.Name, q.Qtype) + + // 记录请求 + log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)", + dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr()) + + // 可选:去除 ECS(推荐) + if stripECSBeforeForward { + stripECS(r) + } + + // 缓存键 + key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled) // 1) 缓存命中 if cached, ok := tryCacheRead(key); ok { - log.Printf("[Cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name) - _ = w.WriteMsg(cached) + log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name) + writeReply(w, r, cached) return } - log.Printf("[Cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name) + log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name) - // 2) 随机化上游并并发查询(带 fallback) - servers := shuffled(upstreams) - resp := queryUpstreamsLimited(r, servers, timeout, maxParallel) + // 2) 上游查询(带 context 取消 & TCP 可选回退) + ctx := context.Background() + resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback) if resp == nil { - log.Printf("[Error] All upstreams failed for %s", q.Name) + log.Printf("[error] all upstreams failed for %s", q.Name) dns.HandleFailed(w, r) return } - // 3) 写入缓存(动态 TTL) + // 3) 写缓存 cacheWrite(key, resp, cacheMaxTTL) - // 4) 返回结果 + // 4) 回写 for _, ans := range resp.Answer { - log.Printf("[Answer] %s", ans.String()) + log.Printf("[answer] %s", ans.String()) } - _ = w.WriteMsg(resp) + writeReply(w, r, resp) } } -/********************************* - * 主程序(优雅退出) - *********************************/ +/****************************************************************** + * 主程序 + ******************************************************************/ func main() { rand.Seed(time.Now().UnixNano()) // 参数 - certFile := flag.String("cert", "aixiao.me.cer", "TLS 证书文件路径 (.cer/.crt)") - keyFile := flag.String("key", "aixiao.me.key", "TLS 私钥文件路径 (.key)") - addr := flag.String("addr", ":853", "DoT 服务监听地址(默认 :853)") + certFile := flag.String("cert", "server.crt", "TLS 证书文件路径 (.cer/.crt)") + keyFile := flag.String("key", "server.key", "TLS 私钥文件路径 (.key)") + addr := flag.String("addr", ":853", "DoT 监听地址(默认 :853)") upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表(逗号分隔)") cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL(默认 60s;实际取 min(上游最小TTL, 本值))") timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时(默认 3s)") maxParallel := flag.Int("max-parallel", 3, "并发查询的上游数量上限") + stripECSFlag := flag.Bool("strip-ecs", true, "转发上游前去除 EDNS Client Subnet") + allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时允许 TCP 回退") + verbose := flag.Bool("v", false, "verbose 日志(包含源码位置)") flag.Parse() + initLogger(*verbose) + // 证书 cert, err := tls.LoadX509KeyPair(*certFile, *keyFile) if err != nil { - log.Fatalf("failed to load cert: %v", err) + log.Fatalf("[fatal] failed to load cert: %v", err) } - // 上游列表 - raw := strings.Split(*upstreamStr, ",") - upstreams := make([]string, 0, len(raw)) - for _, s := range raw { + // 上游 + var upstreams []string + for _, s := range strings.Split(*upstreamStr, ",") { if t := strings.TrimSpace(s); t != "" { + if !strings.Contains(t, ":") { + t = fmt.Sprintf("%s:53", t) + } upstreams = append(upstreams, t) } } if len(upstreams) == 0 { - log.Fatal("no upstream DNS servers provided") + log.Fatal("[fatal] no upstream DNS servers provided") } - // 全局 DNS 客户端(UDP,扩大 UDPSize;fallback 在查询函数中完成) - dnsClient = &dns.Client{ + // 全局 UDP 客户端(UDPSize 放大;不设置 Timeout,使用 ExchangeContext 控制) + udpClient = &dns.Client{ Net: "udp", - UDPSize: 4096, // 防截断 - Timeout: *timeoutFlag, + UDPSize: 4096, } - // context 用于优雅退出与清理协程 + // context 用于优雅退出 ctx, cancel := context.WithCancel(context.Background()) defer cancel() // 启动缓存清理器 startCacheCleaner(ctx) - // DNS 处理器 + // mux/handler mux := dns.NewServeMux() - mux.HandleFunc(".", handleDNS(upstreams, *cacheTTLFlag, *timeoutFlag, *maxParallel)) + mux.HandleFunc(".", handleDNS( + upstreams, + *cacheTTLFlag, + *timeoutFlag, + *maxParallel, + *stripECSFlag, + *allowTCPFallback, + )) - // DoT 服务器(TLS 会话缓存 + 安全套件 + TLS1.2+) + // DoT 服务(启用 TLS1.3;不手动设 CipherSuites 以使用 Go 默认安全套件) srv := &dns.Server{ Addr: *addr, Net: "tcp-tls", TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{cert}, - ClientSessionCache: tls.NewLRUClientSessionCache(256), - MinVersion: tls.VersionTLS12, - CipherSuites: []uint16{ - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - }, + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, // 允许 1.2/1.3(默认启用 1.3) + // 不设置 CipherSuites,交由 Go 自动选择(TLS1.3 有自身套件) }, Handler: mux, } - // 捕获退出信号,优雅关闭 + // 捕获信号优雅退出 stop := make(chan os.Signal, 1) signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) errCh := make(chan error, 1) go func() { - log.Printf("🚀 Starting DNS-over-TLS server on %s", *addr) - log.Printf("Upstreams=%v | MaxTTL=%s | Timeout=%s | MaxParallel=%d", - upstreams, cacheTTLFlag.String(), timeoutFlag.String(), *maxParallel) + log.Printf("🚀 starting DNS-over-TLS on %s", *addr) + log.Printf(" upstreams=%v | cache_max_ttl=%s | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v", + upstreams, cacheTTLFlag.String(), timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback) errCh <- srv.ListenAndServe() }() select { case sig := <-stop: - log.Printf("[Shutdown] Caught signal: %s", sig) - cancel() // 结束清理器 - // 优雅关闭服务器 - // miekg/dns 提供 Shutdown();若你的版本支持 ShutdownContext,可改用带 ctx 的版本 + log.Printf("[shutdown] caught signal: %s", sig) + cancel() + // miekg/dns 提供 Shutdown();部分版本没有 ShutdownContext,这里用 Shutdown() if err := srv.Shutdown(); err != nil { - log.Printf("[Shutdown] server shutdown error: %v", err) + log.Printf("[shutdown] server shutdown error: %v", err) } case err := <-errCh: if err != nil { - log.Fatalf("server error: %v", err) + log.Fatalf("[fatal] server error: %v", err) } } - log.Println("[Shutdown] Bye.") + + log.Println("[bye] server stopped.") }