feat(core): 改进 DNS 响应构造逻辑以支持扩展错误码和诊断信息 在 writeReply 函数中,增强对上游 EDNS 选项的处理: - 继承上游的扩展 RCODE 与 EDNS 版本号 - 可选透传 EDE(Extended DNS Errors)记录,保留更多诊断信息 - 明确设置 UDPSize、DO 位及 EDNS 版本转换逻辑 这有助于提升调试能力与协议兼容性,同时保持客户端请求的一致性。 ```
628 lines
19 KiB
Go
628 lines
19 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"flag"
|
||
"fmt"
|
||
"log"
|
||
"math/rand"
|
||
"os"
|
||
"os/signal"
|
||
"strings"
|
||
"sync"
|
||
"syscall"
|
||
"time"
|
||
|
||
"github.com/miekg/dns"
|
||
)
|
||
|
||
/******************************************************************
|
||
* 日志初始化
|
||
******************************************************************/
|
||
|
||
// initLogger 根据 -v 开关设置日志格式;verbose=true 时附加源码文件与行号。
|
||
// 说明:Lmicroseconds 便于排查毫秒级时序问题。
|
||
func initLogger(verbose bool) {
|
||
flags := log.Ldate | log.Ltime | log.Lmicroseconds
|
||
if verbose {
|
||
flags |= log.Lshortfile
|
||
}
|
||
log.SetFlags(flags)
|
||
}
|
||
|
||
/******************************************************************
|
||
* 缓存结构
|
||
******************************************************************/
|
||
|
||
// cacheEntry 表示一条缓存项:保存上游响应的副本,以及过期时间。
|
||
// 注:将完整 *dns.Msg 存入缓存,便于原样复用 Answer/Ns/Extra。
|
||
type cacheEntry struct {
|
||
msg *dns.Msg // 上游完整响应(拷贝存储)
|
||
expireAt time.Time // 过期时间
|
||
}
|
||
|
||
// cache 使用 sync.Map 作为并发安全的键值存储。
|
||
// 注:此实现为 TTL 驱动的简单缓存;如需上限/淘汰策略,可叠加 LRU。
|
||
var cache sync.Map
|
||
|
||
const cacheCleanupInterval = 5 * time.Minute
|
||
|
||
// startCacheCleaner 定时清理过期项,避免缓存无限增长。
|
||
// 使用 ctx 控制生命周期,与主服务一同退出。
|
||
func startCacheCleaner(ctx context.Context) {
|
||
go func() {
|
||
ticker := time.NewTicker(cacheCleanupInterval)
|
||
defer ticker.Stop()
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
return
|
||
case <-ticker.C:
|
||
now := time.Now()
|
||
n := 0
|
||
cache.Range(func(k, v any) bool {
|
||
e := v.(*cacheEntry)
|
||
if now.After(e.expireAt) {
|
||
cache.Delete(k)
|
||
n++
|
||
}
|
||
return true
|
||
})
|
||
if n > 0 {
|
||
log.Printf("[cache] cleaned %d expired entries", n)
|
||
}
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
|
||
// 计算缓存键:name + type + class + DO + CD。
|
||
// 说明:包含 DO/CD 可避免不同验证上下文污染同一键。
|
||
func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
|
||
var b strings.Builder
|
||
b.Grow(len(q.Name) + 32)
|
||
b.WriteString(strings.ToLower(q.Name))
|
||
b.WriteString("|T=")
|
||
b.WriteString(dns.TypeToString[q.Qtype])
|
||
b.WriteString("|C=")
|
||
b.WriteString(dns.ClassToString[q.Qclass])
|
||
if do {
|
||
b.WriteString("|DO")
|
||
}
|
||
if cd {
|
||
b.WriteString("|CD")
|
||
}
|
||
return b.String()
|
||
}
|
||
|
||
// 识别伪 RR(OPT/TSIG);这些记录的“TTL 字段”并非真实 TTL,不可参与 TTL 计算或改写。
|
||
// OPT:其 TTL 字段承载扩展 RCODE 与 DO 位等标志;TSIG:签名,不应缓存或改写。
|
||
func isPseudo(rr dns.RR) bool {
|
||
switch rr.(type) {
|
||
case *dns.OPT, *dns.TSIG:
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
// tryCacheRead 尝试读取缓存并回填“剩余 TTL”;对 Answer/Ns/Extra 普通 RR 截断 TTL,跳过伪 RR。
|
||
// 返回响应副本,保证对外不可变(防止调用方修改缓存内部对象)。
|
||
func tryCacheRead(key string) (*dns.Msg, bool) {
|
||
v, ok := cache.Load(key)
|
||
if !ok {
|
||
return nil, false
|
||
}
|
||
e := v.(*cacheEntry)
|
||
now := time.Now()
|
||
if now.After(e.expireAt) {
|
||
cache.Delete(key)
|
||
return nil, false
|
||
}
|
||
out := e.msg.Copy()
|
||
remaining := uint32(e.expireAt.Sub(now).Seconds())
|
||
if remaining == 0 {
|
||
cache.Delete(key)
|
||
return nil, false
|
||
}
|
||
// 回填剩余 TTL(不增加,只做上限截断)
|
||
for i := range out.Answer {
|
||
if out.Answer[i].Header().Ttl > remaining {
|
||
out.Answer[i].Header().Ttl = remaining
|
||
}
|
||
}
|
||
for i := range out.Ns {
|
||
if out.Ns[i].Header().Ttl > remaining {
|
||
out.Ns[i].Header().Ttl = remaining
|
||
}
|
||
}
|
||
for i := range out.Extra {
|
||
if isPseudo(out.Extra[i]) {
|
||
continue
|
||
}
|
||
if out.Extra[i].Header().Ttl > remaining {
|
||
out.Extra[i].Header().Ttl = remaining
|
||
}
|
||
}
|
||
return out, true
|
||
}
|
||
|
||
// negativeTTL 依据 RFC 2308 计算负面缓存 TTL。
|
||
// 对 NXDOMAIN 或 NODATA(NOERROR + Answer 为空)取 min(SOA.TTL, SOA.MINIMUM),再与配置上限取 min。
|
||
func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
|
||
// NXDOMAIN,或 NOERROR 但 Answer 为空(NODATA)
|
||
if m.Rcode != dns.RcodeNameError && !(m.Rcode == dns.RcodeSuccess && len(m.Answer) == 0) {
|
||
return 0, false
|
||
}
|
||
var soa *dns.SOA
|
||
// SOA 可能出现在 Authority(Ns) 或 Additional(Extra)
|
||
for _, rr := range append(m.Ns, m.Extra...) {
|
||
if s, ok := rr.(*dns.SOA); ok {
|
||
soa = s
|
||
break
|
||
}
|
||
}
|
||
if soa == nil {
|
||
// 无 SOA 无法可靠计算负面 TTL;此时不缓存或由上限兜底(本实现选择不缓存)
|
||
return 0, false
|
||
}
|
||
ttl := soa.Hdr.Ttl
|
||
if soa.Minttl < ttl {
|
||
ttl = soa.Minttl
|
||
}
|
||
capTTL := uint32(maxTTL.Seconds())
|
||
if capTTL > 0 && ttl > capTTL {
|
||
ttl = capTTL
|
||
}
|
||
return ttl, ttl > 0
|
||
}
|
||
|
||
// minRRsetTTL 获取普通(正向)响应的最小 TTL(Answer/Ns/Extra 中的普通 RR),跳过伪 RR。
|
||
// 用于决定缓存过期时间的上限(与配置上限再取 min)。
|
||
func minRRsetTTL(m *dns.Msg) (uint32, bool) {
|
||
minTTL := uint32(0)
|
||
hasTTL := false
|
||
for _, sec := range [][]dns.RR{m.Answer, m.Ns, m.Extra} {
|
||
for _, rr := range sec {
|
||
if isPseudo(rr) {
|
||
continue
|
||
}
|
||
ttl := rr.Header().Ttl
|
||
if !hasTTL || ttl < minTTL {
|
||
minTTL = ttl
|
||
hasTTL = true
|
||
}
|
||
}
|
||
}
|
||
return minTTL, hasTTL
|
||
}
|
||
|
||
// stripPseudoExtras 从消息中剥离伪 RR(OPT/TSIG)。
|
||
// 用途:缓存前剥离,避免将传输层细节或签名内容写入缓存。
|
||
func stripPseudoExtras(m *dns.Msg) {
|
||
if len(m.Extra) == 0 {
|
||
return
|
||
}
|
||
out := m.Extra[:0]
|
||
for _, rr := range m.Extra {
|
||
if isPseudo(rr) {
|
||
continue
|
||
}
|
||
out = append(out, rr)
|
||
}
|
||
m.Extra = out
|
||
}
|
||
|
||
// cacheWrite 写缓存:优先处理负面缓存;正面缓存取 min(应答中最小 TTL, 配置上限)。
|
||
// 写入前统一剥离伪 RR,保证缓存与传输解耦。
|
||
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
|
||
if in == nil {
|
||
return
|
||
}
|
||
// 仅缓存 NOERROR / NXDOMAIN,其余不缓存(如 SERVFAIL/REFUSED 等)
|
||
if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError {
|
||
return
|
||
}
|
||
// 负面缓存
|
||
if ttl, ok := negativeTTL(in, maxTTL); ok && ttl > 0 {
|
||
expire := time.Now().Add(time.Duration(ttl) * time.Second)
|
||
cp := in.Copy()
|
||
stripPseudoExtras(cp)
|
||
cache.Store(key, &cacheEntry{msg: cp, expireAt: expire})
|
||
return
|
||
}
|
||
// 正向缓存:minTTL 与 maxTTL 取较小
|
||
minTTL, ok := minRRsetTTL(in)
|
||
if !ok {
|
||
// 没有 TTL 时可用上限兜底(也可选择不缓存,这里选择兜底)
|
||
if maxTTL > 0 {
|
||
expire := time.Now().Add(maxTTL)
|
||
cp := in.Copy()
|
||
stripPseudoExtras(cp)
|
||
cache.Store(key, &cacheEntry{msg: cp, expireAt: expire})
|
||
}
|
||
return
|
||
}
|
||
cfgTTL := uint32(maxTTL.Seconds())
|
||
finalTTL := minTTL
|
||
if cfgTTL > 0 && finalTTL > cfgTTL {
|
||
finalTTL = cfgTTL
|
||
}
|
||
if finalTTL == 0 {
|
||
return
|
||
}
|
||
expire := time.Now().Add(time.Duration(finalTTL) * time.Second)
|
||
cp := in.Copy()
|
||
stripPseudoExtras(cp)
|
||
cache.Store(key, &cacheEntry{msg: cp, expireAt: expire})
|
||
}
|
||
|
||
/******************************************************************
|
||
* 上游查询(带 context 取消、并发上限、UDP→TCP 回退)
|
||
******************************************************************/
|
||
|
||
// 全局可复用 UDP 客户端(Net=udp,UDPSize 放大以承载更大的响应)
|
||
var udpClient *dns.Client
|
||
|
||
// shuffled 打乱上游列表,避免固定顺序导致单点拥塞或偏置。
|
||
func shuffled(xs []string) []string {
|
||
out := make([]string, len(xs))
|
||
copy(out, xs)
|
||
rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] })
|
||
return out
|
||
}
|
||
|
||
// queryUpstreamsLimited 并发向多个上游发起查询,返回首个有效结果。
|
||
// - timeout:整个查询窗口的上限(基于子 context)
|
||
// - maxParallel:同时在飞请求上限
|
||
// - allowTCPFallback:若 UDP 截断(TC 位)则回退 TCP 重试
|
||
func queryUpstreamsLimited(
|
||
ctx context.Context,
|
||
req *dns.Msg,
|
||
upstreams []string,
|
||
timeout time.Duration,
|
||
maxParallel int,
|
||
allowTCPFallback bool,
|
||
) *dns.Msg {
|
||
if maxParallel <= 0 {
|
||
maxParallel = 1
|
||
}
|
||
servers := shuffled(upstreams)
|
||
|
||
// 每次查询一个带超时的子 context;拿到首个有效结果后 cancel,取消其他请求。
|
||
cctx, cancel := context.WithTimeout(ctx, timeout)
|
||
defer cancel()
|
||
|
||
type result struct {
|
||
msg *dns.Msg
|
||
}
|
||
ch := make(chan result, len(servers)) // 缓冲至多保存所有返回,避免阻塞
|
||
sem := make(chan struct{}, maxParallel) // 并发信号量,限制同时在飞请求数
|
||
|
||
// execOne 对单个上游发起查询;对 req 使用 Copy() 防止在多 goroutine 下共享同一 *dns.Msg。
|
||
execOne := func(svr string) {
|
||
// UDP 查询(带 context);使用 req.Copy() 防止并发读写
|
||
resp, _, err := udpClient.ExchangeContext(cctx, req.Copy(), svr)
|
||
if err == nil && resp != nil && resp.Truncated && allowTCPFallback {
|
||
// TCP 回退(响应被截断)
|
||
log.Printf("[upstream] UDP truncated, retry TCP: %s", svr)
|
||
tcpClient := *udpClient
|
||
tcpClient.Net = "tcp"
|
||
resp, _, err = tcpClient.ExchangeContext(cctx, req.Copy(), svr)
|
||
}
|
||
if err != nil || resp == nil {
|
||
// 错误仅在未被上层取消时打印,避免超时后噪声
|
||
if err != nil {
|
||
if cctx.Err() == nil {
|
||
log.Printf("[upstream] %s error: %v", svr, err)
|
||
}
|
||
} else {
|
||
log.Printf("[upstream] %s nil response", svr)
|
||
}
|
||
return
|
||
}
|
||
// 过滤差 RCODE(不参与竞速):SERVFAIL / REFUSED / FORMERR
|
||
if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError {
|
||
return
|
||
}
|
||
select {
|
||
case ch <- result{msg: resp}:
|
||
case <-cctx.Done():
|
||
}
|
||
}
|
||
|
||
// “快乐眼球”式启动:生产者不阻塞,获取配额在 goroutine 内部完成;这样可以更快进入接收循环并在首个成功后取消其余。
|
||
for _, s := range servers {
|
||
s := s
|
||
go func() {
|
||
select {
|
||
case sem <- struct{}{}:
|
||
defer func() { <-sem }()
|
||
case <-cctx.Done():
|
||
return
|
||
}
|
||
execOne(s)
|
||
}()
|
||
}
|
||
|
||
// 返回第一个非空结果,并 cancel 其他 goroutine。
|
||
for i := 0; i < len(servers); i++ {
|
||
select {
|
||
case r := <-ch:
|
||
if r.msg != nil {
|
||
cancel()
|
||
return r.msg
|
||
}
|
||
case <-cctx.Done():
|
||
log.Printf("[upstream] timeout after %v", timeout)
|
||
return nil
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
/******************************************************************
|
||
* EDNS(0) / ECS 处理
|
||
******************************************************************/
|
||
|
||
// stripECS 从请求中去除 EDNS Client Subnet(ECS),减少缓存污染并保护隐私。
|
||
// 注:ECS 会导致上游按地理/网络分区返回不同答案,不适合集中缓存。
|
||
func stripECS(m *dns.Msg) {
|
||
if o := m.IsEdns0(); o != nil {
|
||
var kept []dns.EDNS0
|
||
for _, e := range o.Option {
|
||
if _, isECS := e.(*dns.EDNS0_SUBNET); !isECS {
|
||
kept = append(kept, e)
|
||
}
|
||
}
|
||
o.Option = kept
|
||
}
|
||
}
|
||
|
||
// getDOFlag 读取 DO(DNSSEC OK)位,构成缓存键的一部分。
|
||
func getDOFlag(m *dns.Msg) bool {
|
||
if o := m.IsEdns0(); o != nil {
|
||
return o.Do()
|
||
}
|
||
return false
|
||
}
|
||
|
||
/******************************************************************
|
||
* 响应构造:使用客户端请求头构造 reply,复制上游内容
|
||
******************************************************************/
|
||
|
||
// writeReply 根据客户端请求构造响应:复制上游 Answer/Ns/非伪 Extra,
|
||
// 并按客户端请求重建 OPT(UDPSize/DO),同时继承上游的扩展 RCODE 与 EDNS 版本;
|
||
// 可选透传上游的 EDE(Extended DNS Errors)以保留诊断信息。
|
||
func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) {
|
||
if upstream == nil {
|
||
dns.HandleFailed(w, req)
|
||
return
|
||
}
|
||
|
||
out := new(dns.Msg)
|
||
out.SetReply(req)
|
||
out.Authoritative = false
|
||
out.RecursionAvailable = upstream.RecursionAvailable
|
||
out.AuthenticatedData = upstream.AuthenticatedData
|
||
out.CheckingDisabled = req.CheckingDisabled // 反映客户端 CD 位
|
||
out.Rcode = upstream.Rcode // 主 RCODE(低 4 位)
|
||
out.Answer = upstream.Answer
|
||
out.Ns = upstream.Ns
|
||
|
||
// 复制上游的非伪 RR 额外记录;OPT/TSIG 不透传
|
||
extras := make([]dns.RR, 0, len(upstream.Extra))
|
||
for _, rr := range upstream.Extra {
|
||
if isPseudo(rr) {
|
||
continue
|
||
}
|
||
extras = append(extras, rr)
|
||
}
|
||
|
||
// 基于客户端请求镜像 EDNS(UDPSize + DO)
|
||
if ro := req.IsEdns0(); ro != nil {
|
||
o := new(dns.OPT)
|
||
o.Hdr.Name = "."
|
||
o.Hdr.Rrtype = dns.TypeOPT
|
||
|
||
// 与客户端保持一致的 UDPSize / DO 位
|
||
o.SetUDPSize(ro.UDPSize())
|
||
if ro.Do() {
|
||
o.SetDo()
|
||
}
|
||
|
||
// 继承上游的扩展 RCODE 与 EDNS 版本(注意不同版本签名差异,这里显式转换)
|
||
if uo := upstream.IsEdns0(); uo != nil {
|
||
// 你当前库期望 uint16,这里强转;若你的库期望 uint8,也可改成 uint8(...)
|
||
o.SetExtendedRcode(uint16(uo.ExtendedRcode()))
|
||
o.SetVersion(uint8(uo.Version()))
|
||
|
||
// 可选:透传只读的 EDE 诊断信息
|
||
for _, opt := range uo.Option {
|
||
if ede, ok := opt.(*dns.EDNS0_EDE); ok {
|
||
o.Option = append(o.Option, ede)
|
||
}
|
||
}
|
||
}
|
||
|
||
extras = append(extras, o)
|
||
}
|
||
|
||
out.Extra = extras
|
||
out.Compress = true
|
||
|
||
if err := w.WriteMsg(out); err != nil {
|
||
log.Printf("[write] WriteMsg error: %v", err)
|
||
}
|
||
}
|
||
|
||
/******************************************************************
|
||
* 处理器
|
||
******************************************************************/
|
||
|
||
// handleDNS 为每个请求执行:日志 → (可选)剥离 ECS → 缓存命中 → 上游查询 → 写缓存 → 回写。
|
||
// 注意:缓存键包含 DO/CD;同时通过 tryCacheRead 回填剩余 TTL。
|
||
func handleDNS(
|
||
upstreams []string,
|
||
cacheMaxTTL, timeout time.Duration,
|
||
maxParallel int,
|
||
stripECSBeforeForward bool,
|
||
allowTCPFallback bool,
|
||
) dns.HandlerFunc {
|
||
return func(w dns.ResponseWriter, r *dns.Msg) {
|
||
if len(r.Question) == 0 {
|
||
dns.HandleFailed(w, r)
|
||
return
|
||
}
|
||
q := r.Question[0]
|
||
|
||
// 基本访问日志:类型/域名/ID/CD/DO/来源 IP:Port
|
||
log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)",
|
||
dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr())
|
||
|
||
// 可选:去除 ECS(推荐)
|
||
if stripECSBeforeForward {
|
||
stripECS(r)
|
||
}
|
||
|
||
// 缓存键(域名小写 + QTYPE/QCLASS + DO/CD)
|
||
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled)
|
||
|
||
// 1) 缓存命中:命中即快速返回
|
||
if cached, ok := tryCacheRead(key); ok {
|
||
log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
|
||
writeReply(w, r, cached)
|
||
return
|
||
}
|
||
log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
|
||
|
||
// 2) 上游查询(带 context 取消 & TCP 可选回退);并发向多个上游竞速
|
||
ctx := context.Background()
|
||
resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback)
|
||
if resp == nil {
|
||
log.Printf("[error] all upstreams failed for %s", q.Name)
|
||
dns.HandleFailed(w, r)
|
||
return
|
||
}
|
||
|
||
// 3) 写缓存(负面/正面均处理;剥离伪 RR)
|
||
cacheWrite(key, resp, cacheMaxTTL)
|
||
|
||
// 4) 回写给客户端,并打印 Answer 方便调试
|
||
for _, ans := range resp.Answer {
|
||
log.Printf("[answer] %s", ans.String())
|
||
}
|
||
writeReply(w, r, resp)
|
||
}
|
||
}
|
||
|
||
func main() {
|
||
rand.Seed(time.Now().UnixNano()) // 随机源:用于上游列表打乱
|
||
|
||
// 命令行参数:可根据部署场景调整默认值
|
||
certFile := flag.String("cert", "server.crt", "TLS 证书文件路径 (.cer/.crt)")
|
||
keyFile := flag.String("key", "server.key", "TLS 私钥文件路径 (.key)")
|
||
addr := flag.String("addr", ":853", "DoT 监听地址(默认 :853)")
|
||
upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表(逗号分隔)")
|
||
cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL(默认 60s;实际取 min(上游最小TTL, 本值))")
|
||
timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时(默认 3s)")
|
||
maxParallel := flag.Int("max-parallel", 3, "并发查询的上游数量上限")
|
||
stripECSFlag := flag.Bool("strip-ecs", true, "转发上游前去除 EDNS Client Subnet")
|
||
allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时允许 TCP 回退")
|
||
verbose := flag.Bool("v", false, "verbose 日志(包含源码位置)")
|
||
flag.Parse()
|
||
|
||
initLogger(*verbose)
|
||
|
||
// 加载 TLS 证书/私钥;用于 DoT(RFC 7858)监听
|
||
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
|
||
if err != nil {
|
||
log.Fatalf("[fatal] failed to load cert: %v", err)
|
||
}
|
||
|
||
// 解析上游地址:支持“host”或“host:port”,缺省端口补 53。
|
||
var upstreams []string
|
||
for _, s := range strings.Split(*upstreamStr, ",") {
|
||
if t := strings.TrimSpace(s); t != "" {
|
||
if !strings.Contains(t, ":") {
|
||
t = fmt.Sprintf("%s:53", t)
|
||
}
|
||
upstreams = append(upstreams, t)
|
||
}
|
||
}
|
||
if len(upstreams) == 0 {
|
||
log.Fatal("[fatal] no upstream DNS servers provided")
|
||
}
|
||
|
||
// 全局 UDP 客户端(不设置 Client.Timeout,改用 ExchangeContext 控制超时)
|
||
udpClient = &dns.Client{
|
||
Net: "udp",
|
||
UDPSize: 4096, // 放大到 4K,减小 UDP 截断概率
|
||
}
|
||
|
||
// context 用于优雅退出(SIGINT/SIGTERM 收到后取消)
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
defer cancel()
|
||
|
||
// 启动缓存清理器(后台 goroutine)
|
||
startCacheCleaner(ctx)
|
||
|
||
// 注册处理器:使用 ServeMux 将所有查询交给 handleDNS
|
||
mux := dns.NewServeMux()
|
||
mux.HandleFunc(".", handleDNS(
|
||
upstreams,
|
||
*cacheTTLFlag,
|
||
*timeoutFlag,
|
||
*maxParallel,
|
||
*stripECSFlag,
|
||
*allowTCPFallback,
|
||
))
|
||
|
||
// 构造 DoT(tcp-tls)服务器:显式开启 TLS1.2/1.3,增加读写超时防止慢连接。
|
||
// NextProtos: "dot"(可选,部分客户端用于 ALPN 检测)
|
||
srv := &dns.Server{
|
||
Addr: *addr,
|
||
Net: "tcp-tls",
|
||
TLSConfig: &tls.Config{
|
||
Certificates: []tls.Certificate{cert},
|
||
MinVersion: tls.VersionTLS12, // 允许 1.2/1.3(默认启用 1.3)
|
||
// 不设置 CipherSuites,交由 Go 自动选择(TLS1.3 有自身套件)
|
||
NextProtos: []string{"dot"}, // 可选:显式 ALPN
|
||
},
|
||
Handler: mux,
|
||
ReadTimeout: 10 * time.Second, // 防止慢连接
|
||
WriteTimeout: 10 * time.Second,
|
||
}
|
||
|
||
// 捕获信号以便优雅退出(关闭监听、结束后台协程)
|
||
stop := make(chan os.Signal, 1)
|
||
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
|
||
|
||
// 异步启动服务,错误通过 errCh 返回
|
||
errCh := make(chan error, 1)
|
||
go func() {
|
||
log.Printf("🚀 starting DNS-over-TLS on %s", *addr)
|
||
log.Printf(" upstreams=%v | cache_max_ttl=%s | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v",
|
||
upstreams, cacheTTLFlag.String(), timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback)
|
||
errCh <- srv.ListenAndServe()
|
||
}()
|
||
|
||
// 等待退出信号或服务器错误
|
||
select {
|
||
case sig := <-stop:
|
||
log.Printf("[shutdown] caught signal: %s", sig)
|
||
cancel()
|
||
// miekg/dns 提供 Shutdown();部分版本无 ShutdownContext,这里用 Shutdown()
|
||
if err := srv.Shutdown(); err != nil {
|
||
log.Printf("[shutdown] server shutdown error: %v", err)
|
||
}
|
||
case err := <-errCh:
|
||
if err != nil {
|
||
log.Fatalf("[fatal] server error: %v", err)
|
||
}
|
||
}
|
||
|
||
log.Println("[bye] server stopped.")
|
||
}
|