diff --git a/dot b/dot index 59a11fd..b1c1148 100644 Binary files a/dot and b/dot differ diff --git a/go.mod b/go.mod index 2977add..86b472f 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,19 @@ module dot go 1.25.2 require ( + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/miekg/dns v1.1.68 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/prometheus/client_golang v1.23.2 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.66.1 // indirect + github.com/prometheus/procfs v0.16.1 // indirect + go.yaml.in/yaml/v2 v2.4.2 // indirect golang.org/x/mod v0.24.0 // indirect - golang.org/x/net v0.40.0 // indirect + golang.org/x/net v0.43.0 // indirect golang.org/x/sync v0.14.0 // indirect - golang.org/x/sys v0.33.0 // indirect + golang.org/x/sys v0.35.0 // indirect golang.org/x/tools v0.33.0 // indirect + google.golang.org/protobuf v1.36.8 // indirect ) diff --git a/go.sum b/go.sum index 3b3e82e..399281d 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,35 @@ +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA= github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= +github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= +github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= +github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= +github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= +go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= +go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= +google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= +google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/main.go b/main.go index 567a4bb..b21fcf5 100644 --- a/main.go +++ b/main.go @@ -21,6 +21,8 @@ import ( * 日志初始化 ******************************************************************/ +// initLogger 根据 -v 开关设置日志格式;verbose=true 时附加源码文件与行号。 +// 说明:Lmicroseconds 便于排查毫秒级时序问题。 func initLogger(verbose bool) { flags := log.Ldate | log.Ltime | log.Lmicroseconds if verbose { @@ -33,15 +35,21 @@ func initLogger(verbose bool) { * 缓存结构 ******************************************************************/ +// cacheEntry 表示一条缓存项:保存上游响应的副本,以及过期时间。 +// 注:将完整 *dns.Msg 存入缓存,便于原样复用 Answer/Ns/Extra。 type cacheEntry struct { msg *dns.Msg // 上游完整响应(拷贝存储) expireAt time.Time // 过期时间 } +// cache 使用 sync.Map 作为并发安全的键值存储。 +// 注:此实现为 TTL 驱动的简单缓存;如需上限/淘汰策略,可叠加 LRU。 var cache sync.Map const cacheCleanupInterval = 5 * time.Minute +// startCacheCleaner 定时清理过期项,避免缓存无限增长。 +// 使用 ctx 控制生命周期,与主服务一同退出。 func startCacheCleaner(ctx context.Context) { go func() { ticker := time.NewTicker(cacheCleanupInterval) @@ -69,7 +77,8 @@ func startCacheCleaner(ctx context.Context) { }() } -// 计算缓存键:name + type + class + DO + CD +// 计算缓存键:name + type + class + DO + CD。 +// 说明:包含 DO/CD 可避免不同验证上下文污染同一键。 func cacheKeyFromMsg(q dns.Question, do, cd bool) string { var b strings.Builder b.Grow(len(q.Name) + 32) @@ -87,7 +96,19 @@ func cacheKeyFromMsg(q dns.Question, do, cd bool) string { return b.String() } -// 命中缓存:回填剩余 TTL(不超过记录自身 TTL) +// 识别伪 RR(OPT/TSIG);这些记录的“TTL 字段”并非真实 TTL,不可参与 TTL 计算或改写。 +// OPT:其 TTL 字段承载扩展 RCODE 与 DO 位等标志;TSIG:签名,不应缓存或改写。 +func isPseudo(rr dns.RR) bool { + switch rr.(type) { + case *dns.OPT, *dns.TSIG: + return true + default: + return false + } +} + +// tryCacheRead 尝试读取缓存并回填“剩余 TTL”;对 Answer/Ns/Extra 普通 RR 截断 TTL,跳过伪 RR。 +// 返回响应副本,保证对外不可变(防止调用方修改缓存内部对象)。 func tryCacheRead(key string) (*dns.Msg, bool) { v, ok := cache.Load(key) if !ok { @@ -105,6 +126,7 @@ func tryCacheRead(key string) (*dns.Msg, bool) { cache.Delete(key) return nil, false } + // 回填剩余 TTL(不增加,只做上限截断) for i := range out.Answer { if out.Answer[i].Header().Ttl > remaining { out.Answer[i].Header().Ttl = remaining @@ -116,6 +138,9 @@ func tryCacheRead(key string) (*dns.Msg, bool) { } } for i := range out.Extra { + if isPseudo(out.Extra[i]) { + continue + } if out.Extra[i].Header().Ttl > remaining { out.Extra[i].Header().Ttl = remaining } @@ -123,13 +148,15 @@ func tryCacheRead(key string) (*dns.Msg, bool) { return out, true } -// 负面缓存 TTL(RFC 2308):NXDOMAIN 或 NODATA 使用 SOA.MINIMUM 与 SOA TTL 的较小者,再与配置上限取 min +// negativeTTL 依据 RFC 2308 计算负面缓存 TTL。 +// 对 NXDOMAIN 或 NODATA(NOERROR + Answer 为空)取 min(SOA.TTL, SOA.MINIMUM),再与配置上限取 min。 func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) { // NXDOMAIN,或 NOERROR 但 Answer 为空(NODATA) if m.Rcode != dns.RcodeNameError && !(m.Rcode == dns.RcodeSuccess && len(m.Answer) == 0) { return 0, false } var soa *dns.SOA + // SOA 可能出现在 Authority(Ns) 或 Additional(Extra) for _, rr := range append(m.Ns, m.Extra...) { if s, ok := rr.(*dns.SOA); ok { soa = s @@ -137,6 +164,7 @@ func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) { } } if soa == nil { + // 无 SOA 无法可靠计算负面 TTL;此时不缓存或由上限兜底(本实现选择不缓存) return 0, false } ttl := soa.Hdr.Ttl @@ -150,12 +178,16 @@ func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) { return ttl, ttl > 0 } -// 普通(正向)响应的最小 TTL(Answer/Ns/Extra) +// minRRsetTTL 获取普通(正向)响应的最小 TTL(Answer/Ns/Extra 中的普通 RR),跳过伪 RR。 +// 用于决定缓存过期时间的上限(与配置上限再取 min)。 func minRRsetTTL(m *dns.Msg) (uint32, bool) { minTTL := uint32(0) hasTTL := false for _, sec := range [][]dns.RR{m.Answer, m.Ns, m.Extra} { for _, rr := range sec { + if isPseudo(rr) { + continue + } ttl := rr.Header().Ttl if !hasTTL || ttl < minTTL { minTTL = ttl @@ -166,28 +198,49 @@ func minRRsetTTL(m *dns.Msg) (uint32, bool) { return minTTL, hasTTL } -// 写缓存:先处理负面缓存,再处理正面缓存 +// stripPseudoExtras 从消息中剥离伪 RR(OPT/TSIG)。 +// 用途:缓存前剥离,避免将传输层细节或签名内容写入缓存。 +func stripPseudoExtras(m *dns.Msg) { + if len(m.Extra) == 0 { + return + } + out := m.Extra[:0] + for _, rr := range m.Extra { + if isPseudo(rr) { + continue + } + out = append(out, rr) + } + m.Extra = out +} + +// cacheWrite 写缓存:优先处理负面缓存;正面缓存取 min(应答中最小 TTL, 配置上限)。 +// 写入前统一剥离伪 RR,保证缓存与传输解耦。 func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) { if in == nil { return } - // 仅缓存 NOERROR / NXDOMAIN,其余不缓存 + // 仅缓存 NOERROR / NXDOMAIN,其余不缓存(如 SERVFAIL/REFUSED 等) if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError { return } // 负面缓存 if ttl, ok := negativeTTL(in, maxTTL); ok && ttl > 0 { expire := time.Now().Add(time.Duration(ttl) * time.Second) - cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire}) + cp := in.Copy() + stripPseudoExtras(cp) + cache.Store(key, &cacheEntry{msg: cp, expireAt: expire}) return } // 正向缓存:minTTL 与 maxTTL 取较小 minTTL, ok := minRRsetTTL(in) if !ok { - // 没有 TTL 时可用上限兜底(也可选择不缓存) + // 没有 TTL 时可用上限兜底(也可选择不缓存,这里选择兜底) if maxTTL > 0 { expire := time.Now().Add(maxTTL) - cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire}) + cp := in.Copy() + stripPseudoExtras(cp) + cache.Store(key, &cacheEntry{msg: cp, expireAt: expire}) } return } @@ -200,16 +253,19 @@ func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) { return } expire := time.Now().Add(time.Duration(finalTTL) * time.Second) - cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire}) + cp := in.Copy() + stripPseudoExtras(cp) + cache.Store(key, &cacheEntry{msg: cp, expireAt: expire}) } /****************************************************************** * 上游查询(带 context 取消、并发上限、UDP→TCP 回退) ******************************************************************/ -// 全局可复用 UDP 客户端 +// 全局可复用 UDP 客户端(Net=udp,UDPSize 放大以承载更大的响应) var udpClient *dns.Client +// shuffled 打乱上游列表,避免固定顺序导致单点拥塞或偏置。 func shuffled(xs []string) []string { out := make([]string, len(xs)) copy(out, xs) @@ -217,6 +273,10 @@ func shuffled(xs []string) []string { return out } +// queryUpstreamsLimited 并发向多个上游发起查询,返回首个有效结果。 +// - timeout:整个查询窗口的上限(基于子 context) +// - maxParallel:同时在飞请求上限 +// - allowTCPFallback:若 UDP 截断(TC 位)则回退 TCP 重试 func queryUpstreamsLimited( ctx context.Context, req *dns.Msg, @@ -230,28 +290,29 @@ func queryUpstreamsLimited( } servers := shuffled(upstreams) - // 每次查询一个带超时的子 context;拿到首个有效结果后取消。 + // 每次查询一个带超时的子 context;拿到首个有效结果后 cancel,取消其他请求。 cctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() type result struct { msg *dns.Msg } - ch := make(chan result, len(servers)) - sem := make(chan struct{}, maxParallel) + ch := make(chan result, len(servers)) // 缓冲至多保存所有返回,避免阻塞 + sem := make(chan struct{}, maxParallel) // 并发信号量,限制同时在飞请求数 + // execOne 对单个上游发起查询;对 req 使用 Copy() 防止在多 goroutine 下共享同一 *dns.Msg。 execOne := func(svr string) { - defer func() { <-sem }() - // UDP 查询(带 context) - resp, _, err := udpClient.ExchangeContext(cctx, req, svr) + // UDP 查询(带 context);使用 req.Copy() 防止并发读写 + resp, _, err := udpClient.ExchangeContext(cctx, req.Copy(), svr) if err == nil && resp != nil && resp.Truncated && allowTCPFallback { - // TCP 回退 + // TCP 回退(响应被截断) log.Printf("[upstream] UDP truncated, retry TCP: %s", svr) tcpClient := *udpClient tcpClient.Net = "tcp" - resp, _, err = tcpClient.ExchangeContext(cctx, req, svr) + resp, _, err = tcpClient.ExchangeContext(cctx, req.Copy(), svr) } if err != nil || resp == nil { + // 错误仅在未被上层取消时打印,避免超时后噪声 if err != nil { if cctx.Err() == nil { log.Printf("[upstream] %s error: %v", svr, err) @@ -261,8 +322,8 @@ func queryUpstreamsLimited( } return } - // 丢弃 SERVFAIL - if resp.Rcode == dns.RcodeServerFailure { + // 过滤差 RCODE(不参与竞速):SERVFAIL / REFUSED / FORMERR + if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError { return } select { @@ -271,12 +332,21 @@ func queryUpstreamsLimited( } } + // “快乐眼球”式启动:生产者不阻塞,获取配额在 goroutine 内部完成;这样可以更快进入接收循环并在首个成功后取消其余。 for _, s := range servers { - sem <- struct{}{} - go execOne(s) + s := s + go func() { + select { + case sem <- struct{}{}: + defer func() { <-sem }() + case <-cctx.Done(): + return + } + execOne(s) + }() } - // 返回第一个非空结果,并 cancel 其他 goroutine + // 返回第一个非空结果,并 cancel 其他 goroutine。 for i := 0; i < len(servers); i++ { select { case r := <-ch: @@ -296,7 +366,8 @@ func queryUpstreamsLimited( * EDNS(0) / ECS 处理 ******************************************************************/ -// 去除 EDNS Client Subnet(避免缓存污染与隐私泄露) +// stripECS 从请求中去除 EDNS Client Subnet(ECS),减少缓存污染并保护隐私。 +// 注:ECS 会导致上游按地理/网络分区返回不同答案,不适合集中缓存。 func stripECS(m *dns.Msg) { if o := m.IsEdns0(); o != nil { var kept []dns.EDNS0 @@ -309,7 +380,7 @@ func stripECS(m *dns.Msg) { } } -// 获取 DO/EDNS +// getDOFlag 读取 DO(DNSSEC OK)位,构成缓存键的一部分。 func getDOFlag(m *dns.Msg) bool { if o := m.IsEdns0(); o != nil { return o.Do() @@ -321,6 +392,8 @@ func getDOFlag(m *dns.Msg) bool { * 响应构造:使用客户端请求头构造 reply,复制上游内容 ******************************************************************/ +// writeReply 根据客户端请求构造响应骨架:复制上游的 Answer/Ns/(非伪)Extra, +// 并按照**客户端请求**重建 OPT(UDPSize/DO),避免直接采纳上游的 OPT/TSIG。 func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) { if upstream == nil { dns.HandleFailed(w, req) @@ -335,7 +408,28 @@ func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) { out.Rcode = upstream.Rcode out.Answer = upstream.Answer out.Ns = upstream.Ns - out.Extra = upstream.Extra + + // 复制上游的非伪 RR 额外记录;OPT/TSIG 不透传 + extras := make([]dns.RR, 0, len(upstream.Extra)) + for _, rr := range upstream.Extra { + if isPseudo(rr) { + continue + } + extras = append(extras, rr) + } + // 基于客户端请求镜像 EDNS(UDPSize + DO),保持传输一致性 + if ro := req.IsEdns0(); ro != nil { + o := new(dns.OPT) + o.Hdr.Name = "." + o.Hdr.Rrtype = dns.TypeOPT + o.SetUDPSize(ro.UDPSize()) + if ro.Do() { + o.SetDo() + } + // 如需转发 NSID/COOKIE 等可在此附加;为减少复杂性与缓存污染,建议保守最小集合 + extras = append(extras, o) + } + out.Extra = extras out.Compress = true if err := w.WriteMsg(out); err != nil { log.Printf("[write] WriteMsg error: %v", err) @@ -346,6 +440,8 @@ func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) { * 处理器 ******************************************************************/ +// handleDNS 为每个请求执行:日志 → (可选)剥离 ECS → 缓存命中 → 上游查询 → 写缓存 → 回写。 +// 注意:缓存键包含 DO/CD;同时通过 tryCacheRead 回填剩余 TTL。 func handleDNS( upstreams []string, cacheMaxTTL, timeout time.Duration, @@ -360,7 +456,7 @@ func handleDNS( } q := r.Question[0] - // 记录请求 + // 基本访问日志:类型/域名/ID/CD/DO/来源 IP:Port log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)", dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr()) @@ -369,10 +465,10 @@ func handleDNS( stripECS(r) } - // 缓存键 + // 缓存键(域名小写 + QTYPE/QCLASS + DO/CD) key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled) - // 1) 缓存命中 + // 1) 缓存命中:命中即快速返回 if cached, ok := tryCacheRead(key); ok { log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name) writeReply(w, r, cached) @@ -380,7 +476,7 @@ func handleDNS( } log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name) - // 2) 上游查询(带 context 取消 & TCP 可选回退) + // 2) 上游查询(带 context 取消 & TCP 可选回退);并发向多个上游竞速 ctx := context.Background() resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback) if resp == nil { @@ -389,10 +485,10 @@ func handleDNS( return } - // 3) 写缓存 + // 3) 写缓存(负面/正面均处理;剥离伪 RR) cacheWrite(key, resp, cacheMaxTTL) - // 4) 回写 + // 4) 回写给客户端,并打印 Answer 方便调试 for _, ans := range resp.Answer { log.Printf("[answer] %s", ans.String()) } @@ -400,14 +496,10 @@ func handleDNS( } } -/****************************************************************** - * 主程序 - ******************************************************************/ - func main() { - rand.Seed(time.Now().UnixNano()) + rand.Seed(time.Now().UnixNano()) // 随机源:用于上游列表打乱 - // 参数 + // 命令行参数:可根据部署场景调整默认值 certFile := flag.String("cert", "server.crt", "TLS 证书文件路径 (.cer/.crt)") keyFile := flag.String("key", "server.key", "TLS 私钥文件路径 (.key)") addr := flag.String("addr", ":853", "DoT 监听地址(默认 :853)") @@ -422,13 +514,13 @@ func main() { initLogger(*verbose) - // 证书 + // 加载 TLS 证书/私钥;用于 DoT(RFC 7858)监听 cert, err := tls.LoadX509KeyPair(*certFile, *keyFile) if err != nil { log.Fatalf("[fatal] failed to load cert: %v", err) } - // 上游 + // 解析上游地址:支持“host”或“host:port”,缺省端口补 53。 var upstreams []string for _, s := range strings.Split(*upstreamStr, ",") { if t := strings.TrimSpace(s); t != "" { @@ -442,20 +534,20 @@ func main() { log.Fatal("[fatal] no upstream DNS servers provided") } - // 全局 UDP 客户端(UDPSize 放大;不设置 Timeout,使用 ExchangeContext 控制) + // 全局 UDP 客户端(不设置 Client.Timeout,改用 ExchangeContext 控制超时) udpClient = &dns.Client{ Net: "udp", - UDPSize: 4096, + UDPSize: 4096, // 放大到 4K,减小 UDP 截断概率 } - // context 用于优雅退出 + // context 用于优雅退出(SIGINT/SIGTERM 收到后取消) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - // 启动缓存清理器 + // 启动缓存清理器(后台 goroutine) startCacheCleaner(ctx) - // mux/handler + // 注册处理器:使用 ServeMux 将所有查询交给 handleDNS mux := dns.NewServeMux() mux.HandleFunc(".", handleDNS( upstreams, @@ -466,7 +558,8 @@ func main() { *allowTCPFallback, )) - // DoT 服务(启用 TLS1.3;不手动设 CipherSuites 以使用 Go 默认安全套件) + // 构造 DoT(tcp-tls)服务器:显式开启 TLS1.2/1.3,增加读写超时防止慢连接。 + // NextProtos: "dot"(可选,部分客户端用于 ALPN 检测) srv := &dns.Server{ Addr: *addr, Net: "tcp-tls", @@ -474,14 +567,18 @@ func main() { Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12, // 允许 1.2/1.3(默认启用 1.3) // 不设置 CipherSuites,交由 Go 自动选择(TLS1.3 有自身套件) + NextProtos: []string{"dot"}, // 可选:显式 ALPN }, - Handler: mux, + Handler: mux, + ReadTimeout: 10 * time.Second, // 防止慢连接 + WriteTimeout: 10 * time.Second, } - // 捕获信号优雅退出 + // 捕获信号以便优雅退出(关闭监听、结束后台协程) stop := make(chan os.Signal, 1) signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) + // 异步启动服务,错误通过 errCh 返回 errCh := make(chan error, 1) go func() { log.Printf("🚀 starting DNS-over-TLS on %s", *addr) @@ -490,11 +587,12 @@ func main() { errCh <- srv.ListenAndServe() }() + // 等待退出信号或服务器错误 select { case sig := <-stop: log.Printf("[shutdown] caught signal: %s", sig) cancel() - // miekg/dns 提供 Shutdown();部分版本没有 ShutdownContext,这里用 Shutdown() + // miekg/dns 提供 Shutdown();部分版本无 ShutdownContext,这里用 Shutdown() if err := srv.Shutdown(); err != nil { log.Printf("[shutdown] server shutdown error: %v", err) }