Files
dot/main.go
aixiao bcd0914b2f ```
feat(core): 改进 DNS 响应构造逻辑以支持扩展错误码和诊断信息

在 writeReply 函数中,增强对上游 EDNS 选项的处理:
- 继承上游的扩展 RCODE 与 EDNS 版本号
- 可选透传 EDE(Extended DNS Errors)记录,保留更多诊断信息
- 明确设置 UDPSize、DO 位及 EDNS 版本转换逻辑

这有助于提升调试能力与协议兼容性,同时保持客户端请求的一致性。
```
2025-10-14 13:55:01 +08:00

628 lines
19 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"crypto/tls"
"flag"
"fmt"
"log"
"math/rand"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"github.com/miekg/dns"
)
/******************************************************************
* 日志初始化
******************************************************************/
// initLogger 根据 -v 开关设置日志格式verbose=true 时附加源码文件与行号。
// 说明Lmicroseconds 便于排查毫秒级时序问题。
func initLogger(verbose bool) {
flags := log.Ldate | log.Ltime | log.Lmicroseconds
if verbose {
flags |= log.Lshortfile
}
log.SetFlags(flags)
}
/******************************************************************
* 缓存结构
******************************************************************/
// cacheEntry 表示一条缓存项:保存上游响应的副本,以及过期时间。
// 注:将完整 *dns.Msg 存入缓存,便于原样复用 Answer/Ns/Extra。
type cacheEntry struct {
msg *dns.Msg // 上游完整响应(拷贝存储)
expireAt time.Time // 过期时间
}
// cache 使用 sync.Map 作为并发安全的键值存储。
// 注:此实现为 TTL 驱动的简单缓存;如需上限/淘汰策略,可叠加 LRU。
var cache sync.Map
const cacheCleanupInterval = 5 * time.Minute
// startCacheCleaner 定时清理过期项,避免缓存无限增长。
// 使用 ctx 控制生命周期,与主服务一同退出。
func startCacheCleaner(ctx context.Context) {
go func() {
ticker := time.NewTicker(cacheCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
n := 0
cache.Range(func(k, v any) bool {
e := v.(*cacheEntry)
if now.After(e.expireAt) {
cache.Delete(k)
n++
}
return true
})
if n > 0 {
log.Printf("[cache] cleaned %d expired entries", n)
}
}
}
}()
}
// 计算缓存键name + type + class + DO + CD。
// 说明:包含 DO/CD 可避免不同验证上下文污染同一键。
func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
var b strings.Builder
b.Grow(len(q.Name) + 32)
b.WriteString(strings.ToLower(q.Name))
b.WriteString("|T=")
b.WriteString(dns.TypeToString[q.Qtype])
b.WriteString("|C=")
b.WriteString(dns.ClassToString[q.Qclass])
if do {
b.WriteString("|DO")
}
if cd {
b.WriteString("|CD")
}
return b.String()
}
// 识别伪 RROPT/TSIG这些记录的“TTL 字段”并非真实 TTL不可参与 TTL 计算或改写。
// OPT其 TTL 字段承载扩展 RCODE 与 DO 位等标志TSIG签名不应缓存或改写。
func isPseudo(rr dns.RR) bool {
switch rr.(type) {
case *dns.OPT, *dns.TSIG:
return true
default:
return false
}
}
// tryCacheRead 尝试读取缓存并回填“剩余 TTL”对 Answer/Ns/Extra 普通 RR 截断 TTL跳过伪 RR。
// 返回响应副本,保证对外不可变(防止调用方修改缓存内部对象)。
func tryCacheRead(key string) (*dns.Msg, bool) {
v, ok := cache.Load(key)
if !ok {
return nil, false
}
e := v.(*cacheEntry)
now := time.Now()
if now.After(e.expireAt) {
cache.Delete(key)
return nil, false
}
out := e.msg.Copy()
remaining := uint32(e.expireAt.Sub(now).Seconds())
if remaining == 0 {
cache.Delete(key)
return nil, false
}
// 回填剩余 TTL不增加只做上限截断
for i := range out.Answer {
if out.Answer[i].Header().Ttl > remaining {
out.Answer[i].Header().Ttl = remaining
}
}
for i := range out.Ns {
if out.Ns[i].Header().Ttl > remaining {
out.Ns[i].Header().Ttl = remaining
}
}
for i := range out.Extra {
if isPseudo(out.Extra[i]) {
continue
}
if out.Extra[i].Header().Ttl > remaining {
out.Extra[i].Header().Ttl = remaining
}
}
return out, true
}
// negativeTTL 依据 RFC 2308 计算负面缓存 TTL。
// 对 NXDOMAIN 或 NODATANOERROR + Answer 为空)取 min(SOA.TTL, SOA.MINIMUM),再与配置上限取 min。
func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
// NXDOMAIN或 NOERROR 但 Answer 为空NODATA
if m.Rcode != dns.RcodeNameError && !(m.Rcode == dns.RcodeSuccess && len(m.Answer) == 0) {
return 0, false
}
var soa *dns.SOA
// SOA 可能出现在 Authority(Ns) 或 Additional(Extra)
for _, rr := range append(m.Ns, m.Extra...) {
if s, ok := rr.(*dns.SOA); ok {
soa = s
break
}
}
if soa == nil {
// 无 SOA 无法可靠计算负面 TTL此时不缓存或由上限兜底本实现选择不缓存
return 0, false
}
ttl := soa.Hdr.Ttl
if soa.Minttl < ttl {
ttl = soa.Minttl
}
capTTL := uint32(maxTTL.Seconds())
if capTTL > 0 && ttl > capTTL {
ttl = capTTL
}
return ttl, ttl > 0
}
// minRRsetTTL 获取普通(正向)响应的最小 TTLAnswer/Ns/Extra 中的普通 RR跳过伪 RR。
// 用于决定缓存过期时间的上限(与配置上限再取 min
func minRRsetTTL(m *dns.Msg) (uint32, bool) {
minTTL := uint32(0)
hasTTL := false
for _, sec := range [][]dns.RR{m.Answer, m.Ns, m.Extra} {
for _, rr := range sec {
if isPseudo(rr) {
continue
}
ttl := rr.Header().Ttl
if !hasTTL || ttl < minTTL {
minTTL = ttl
hasTTL = true
}
}
}
return minTTL, hasTTL
}
// stripPseudoExtras 从消息中剥离伪 RROPT/TSIG
// 用途:缓存前剥离,避免将传输层细节或签名内容写入缓存。
func stripPseudoExtras(m *dns.Msg) {
if len(m.Extra) == 0 {
return
}
out := m.Extra[:0]
for _, rr := range m.Extra {
if isPseudo(rr) {
continue
}
out = append(out, rr)
}
m.Extra = out
}
// cacheWrite 写缓存:优先处理负面缓存;正面缓存取 min(应答中最小 TTL, 配置上限)。
// 写入前统一剥离伪 RR保证缓存与传输解耦。
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
if in == nil {
return
}
// 仅缓存 NOERROR / NXDOMAIN其余不缓存如 SERVFAIL/REFUSED 等)
if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError {
return
}
// 负面缓存
if ttl, ok := negativeTTL(in, maxTTL); ok && ttl > 0 {
expire := time.Now().Add(time.Duration(ttl) * time.Second)
cp := in.Copy()
stripPseudoExtras(cp)
cache.Store(key, &cacheEntry{msg: cp, expireAt: expire})
return
}
// 正向缓存minTTL 与 maxTTL 取较小
minTTL, ok := minRRsetTTL(in)
if !ok {
// 没有 TTL 时可用上限兜底(也可选择不缓存,这里选择兜底)
if maxTTL > 0 {
expire := time.Now().Add(maxTTL)
cp := in.Copy()
stripPseudoExtras(cp)
cache.Store(key, &cacheEntry{msg: cp, expireAt: expire})
}
return
}
cfgTTL := uint32(maxTTL.Seconds())
finalTTL := minTTL
if cfgTTL > 0 && finalTTL > cfgTTL {
finalTTL = cfgTTL
}
if finalTTL == 0 {
return
}
expire := time.Now().Add(time.Duration(finalTTL) * time.Second)
cp := in.Copy()
stripPseudoExtras(cp)
cache.Store(key, &cacheEntry{msg: cp, expireAt: expire})
}
/******************************************************************
* 上游查询(带 context 取消、并发上限、UDP→TCP 回退)
******************************************************************/
// 全局可复用 UDP 客户端Net=udpUDPSize 放大以承载更大的响应)
var udpClient *dns.Client
// shuffled 打乱上游列表,避免固定顺序导致单点拥塞或偏置。
func shuffled(xs []string) []string {
out := make([]string, len(xs))
copy(out, xs)
rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] })
return out
}
// queryUpstreamsLimited 并发向多个上游发起查询,返回首个有效结果。
// - timeout整个查询窗口的上限基于子 context
// - maxParallel同时在飞请求上限
// - allowTCPFallback若 UDP 截断TC 位)则回退 TCP 重试
func queryUpstreamsLimited(
ctx context.Context,
req *dns.Msg,
upstreams []string,
timeout time.Duration,
maxParallel int,
allowTCPFallback bool,
) *dns.Msg {
if maxParallel <= 0 {
maxParallel = 1
}
servers := shuffled(upstreams)
// 每次查询一个带超时的子 context拿到首个有效结果后 cancel取消其他请求。
cctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
type result struct {
msg *dns.Msg
}
ch := make(chan result, len(servers)) // 缓冲至多保存所有返回,避免阻塞
sem := make(chan struct{}, maxParallel) // 并发信号量,限制同时在飞请求数
// execOne 对单个上游发起查询;对 req 使用 Copy() 防止在多 goroutine 下共享同一 *dns.Msg。
execOne := func(svr string) {
// UDP 查询(带 context使用 req.Copy() 防止并发读写
resp, _, err := udpClient.ExchangeContext(cctx, req.Copy(), svr)
if err == nil && resp != nil && resp.Truncated && allowTCPFallback {
// TCP 回退(响应被截断)
log.Printf("[upstream] UDP truncated, retry TCP: %s", svr)
tcpClient := *udpClient
tcpClient.Net = "tcp"
resp, _, err = tcpClient.ExchangeContext(cctx, req.Copy(), svr)
}
if err != nil || resp == nil {
// 错误仅在未被上层取消时打印,避免超时后噪声
if err != nil {
if cctx.Err() == nil {
log.Printf("[upstream] %s error: %v", svr, err)
}
} else {
log.Printf("[upstream] %s nil response", svr)
}
return
}
// 过滤差 RCODE不参与竞速SERVFAIL / REFUSED / FORMERR
if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError {
return
}
select {
case ch <- result{msg: resp}:
case <-cctx.Done():
}
}
// “快乐眼球”式启动:生产者不阻塞,获取配额在 goroutine 内部完成;这样可以更快进入接收循环并在首个成功后取消其余。
for _, s := range servers {
s := s
go func() {
select {
case sem <- struct{}{}:
defer func() { <-sem }()
case <-cctx.Done():
return
}
execOne(s)
}()
}
// 返回第一个非空结果,并 cancel 其他 goroutine。
for i := 0; i < len(servers); i++ {
select {
case r := <-ch:
if r.msg != nil {
cancel()
return r.msg
}
case <-cctx.Done():
log.Printf("[upstream] timeout after %v", timeout)
return nil
}
}
return nil
}
/******************************************************************
* EDNS(0) / ECS 处理
******************************************************************/
// stripECS 从请求中去除 EDNS Client SubnetECS减少缓存污染并保护隐私。
// 注ECS 会导致上游按地理/网络分区返回不同答案,不适合集中缓存。
func stripECS(m *dns.Msg) {
if o := m.IsEdns0(); o != nil {
var kept []dns.EDNS0
for _, e := range o.Option {
if _, isECS := e.(*dns.EDNS0_SUBNET); !isECS {
kept = append(kept, e)
}
}
o.Option = kept
}
}
// getDOFlag 读取 DODNSSEC OK构成缓存键的一部分。
func getDOFlag(m *dns.Msg) bool {
if o := m.IsEdns0(); o != nil {
return o.Do()
}
return false
}
/******************************************************************
* 响应构造:使用客户端请求头构造 reply复制上游内容
******************************************************************/
// writeReply 根据客户端请求构造响应:复制上游 Answer/Ns/非伪 Extra
// 并按客户端请求重建 OPTUDPSize/DO同时继承上游的扩展 RCODE 与 EDNS 版本;
// 可选透传上游的 EDEExtended DNS Errors以保留诊断信息。
func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) {
if upstream == nil {
dns.HandleFailed(w, req)
return
}
out := new(dns.Msg)
out.SetReply(req)
out.Authoritative = false
out.RecursionAvailable = upstream.RecursionAvailable
out.AuthenticatedData = upstream.AuthenticatedData
out.CheckingDisabled = req.CheckingDisabled // 反映客户端 CD 位
out.Rcode = upstream.Rcode // 主 RCODE低 4 位)
out.Answer = upstream.Answer
out.Ns = upstream.Ns
// 复制上游的非伪 RR 额外记录OPT/TSIG 不透传
extras := make([]dns.RR, 0, len(upstream.Extra))
for _, rr := range upstream.Extra {
if isPseudo(rr) {
continue
}
extras = append(extras, rr)
}
// 基于客户端请求镜像 EDNSUDPSize + DO
if ro := req.IsEdns0(); ro != nil {
o := new(dns.OPT)
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
// 与客户端保持一致的 UDPSize / DO 位
o.SetUDPSize(ro.UDPSize())
if ro.Do() {
o.SetDo()
}
// 继承上游的扩展 RCODE 与 EDNS 版本(注意不同版本签名差异,这里显式转换)
if uo := upstream.IsEdns0(); uo != nil {
// 你当前库期望 uint16这里强转若你的库期望 uint8也可改成 uint8(...)
o.SetExtendedRcode(uint16(uo.ExtendedRcode()))
o.SetVersion(uint8(uo.Version()))
// 可选:透传只读的 EDE 诊断信息
for _, opt := range uo.Option {
if ede, ok := opt.(*dns.EDNS0_EDE); ok {
o.Option = append(o.Option, ede)
}
}
}
extras = append(extras, o)
}
out.Extra = extras
out.Compress = true
if err := w.WriteMsg(out); err != nil {
log.Printf("[write] WriteMsg error: %v", err)
}
}
/******************************************************************
* 处理器
******************************************************************/
// handleDNS 为每个请求执行:日志 → (可选)剥离 ECS → 缓存命中 → 上游查询 → 写缓存 → 回写。
// 注意:缓存键包含 DO/CD同时通过 tryCacheRead 回填剩余 TTL。
func handleDNS(
upstreams []string,
cacheMaxTTL, timeout time.Duration,
maxParallel int,
stripECSBeforeForward bool,
allowTCPFallback bool,
) dns.HandlerFunc {
return func(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
dns.HandleFailed(w, r)
return
}
q := r.Question[0]
// 基本访问日志:类型/域名/ID/CD/DO/来源 IP:Port
log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)",
dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr())
// 可选:去除 ECS推荐
if stripECSBeforeForward {
stripECS(r)
}
// 缓存键(域名小写 + QTYPE/QCLASS + DO/CD
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled)
// 1) 缓存命中:命中即快速返回
if cached, ok := tryCacheRead(key); ok {
log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
writeReply(w, r, cached)
return
}
log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
// 2) 上游查询(带 context 取消 & TCP 可选回退);并发向多个上游竞速
ctx := context.Background()
resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback)
if resp == nil {
log.Printf("[error] all upstreams failed for %s", q.Name)
dns.HandleFailed(w, r)
return
}
// 3) 写缓存(负面/正面均处理;剥离伪 RR
cacheWrite(key, resp, cacheMaxTTL)
// 4) 回写给客户端,并打印 Answer 方便调试
for _, ans := range resp.Answer {
log.Printf("[answer] %s", ans.String())
}
writeReply(w, r, resp)
}
}
func main() {
rand.Seed(time.Now().UnixNano()) // 随机源:用于上游列表打乱
// 命令行参数:可根据部署场景调整默认值
certFile := flag.String("cert", "server.crt", "TLS 证书文件路径 (.cer/.crt)")
keyFile := flag.String("key", "server.key", "TLS 私钥文件路径 (.key)")
addr := flag.String("addr", ":853", "DoT 监听地址(默认 :853")
upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表(逗号分隔)")
cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL默认 60s实际取 min(上游最小TTL, 本值)")
timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时(默认 3s")
maxParallel := flag.Int("max-parallel", 3, "并发查询的上游数量上限")
stripECSFlag := flag.Bool("strip-ecs", true, "转发上游前去除 EDNS Client Subnet")
allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时允许 TCP 回退")
verbose := flag.Bool("v", false, "verbose 日志(包含源码位置)")
flag.Parse()
initLogger(*verbose)
// 加载 TLS 证书/私钥;用于 DoTRFC 7858监听
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
if err != nil {
log.Fatalf("[fatal] failed to load cert: %v", err)
}
// 解析上游地址支持“host”或“host:port”缺省端口补 53。
var upstreams []string
for _, s := range strings.Split(*upstreamStr, ",") {
if t := strings.TrimSpace(s); t != "" {
if !strings.Contains(t, ":") {
t = fmt.Sprintf("%s:53", t)
}
upstreams = append(upstreams, t)
}
}
if len(upstreams) == 0 {
log.Fatal("[fatal] no upstream DNS servers provided")
}
// 全局 UDP 客户端(不设置 Client.Timeout改用 ExchangeContext 控制超时)
udpClient = &dns.Client{
Net: "udp",
UDPSize: 4096, // 放大到 4K减小 UDP 截断概率
}
// context 用于优雅退出SIGINT/SIGTERM 收到后取消)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// 启动缓存清理器(后台 goroutine
startCacheCleaner(ctx)
// 注册处理器:使用 ServeMux 将所有查询交给 handleDNS
mux := dns.NewServeMux()
mux.HandleFunc(".", handleDNS(
upstreams,
*cacheTTLFlag,
*timeoutFlag,
*maxParallel,
*stripECSFlag,
*allowTCPFallback,
))
// 构造 DoTtcp-tls服务器显式开启 TLS1.2/1.3,增加读写超时防止慢连接。
// NextProtos: "dot"(可选,部分客户端用于 ALPN 检测)
srv := &dns.Server{
Addr: *addr,
Net: "tcp-tls",
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12, // 允许 1.2/1.3(默认启用 1.3
// 不设置 CipherSuites交由 Go 自动选择TLS1.3 有自身套件)
NextProtos: []string{"dot"}, // 可选:显式 ALPN
},
Handler: mux,
ReadTimeout: 10 * time.Second, // 防止慢连接
WriteTimeout: 10 * time.Second,
}
// 捕获信号以便优雅退出(关闭监听、结束后台协程)
stop := make(chan os.Signal, 1)
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
// 异步启动服务,错误通过 errCh 返回
errCh := make(chan error, 1)
go func() {
log.Printf("🚀 starting DNS-over-TLS on %s", *addr)
log.Printf(" upstreams=%v | cache_max_ttl=%s | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v",
upstreams, cacheTTLFlag.String(), timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback)
errCh <- srv.ListenAndServe()
}()
// 等待退出信号或服务器错误
select {
case sig := <-stop:
log.Printf("[shutdown] caught signal: %s", sig)
cancel()
// miekg/dns 提供 Shutdown();部分版本无 ShutdownContext这里用 Shutdown()
if err := srv.Shutdown(); err != nil {
log.Printf("[shutdown] server shutdown error: %v", err)
}
case err := <-errCh:
if err != nil {
log.Fatalf("[fatal] server error: %v", err)
}
}
log.Println("[bye] server stopped.")
}