feat(dns): 引入 ECS 缓存键支持与并发查询合并

新增对 EDNS Client Subnet (ECS) 的缓存键支持,当未启用 strip-ecs 时,
将归一化后的 ECS 网络信息加入缓存 key。同时引入 singleflight 合并相同
查询请求,防止缓存击穿和上游负载突增。此外,修正了 DoT 服务的 TLS 版本
至 1.3,并开放读写超时配置选项。

- 添加 ecsKeyPart 函数用于生成 ECS 相关的缓存键片段
- 使用 singleflight.Group 避免重复并发查询
- 升级默认 TLS 最小版本到 VersionTLS13
- 增加 `-read-timeout` 和 `-write-timeout` 参数控制 DoT 连接超时
- 调整日志输出以包含新的超时参数
- 优化缓存写入逻辑,在 singleflight 回调中执行
This commit is contained in:
2025-10-17 09:43:50 +08:00
parent 20d0ddc18e
commit 224f575e68
2 changed files with 79 additions and 20 deletions

BIN
dot

Binary file not shown.

95
main.go
View File

@@ -1,3 +1,4 @@
// main.go
package main package main
import ( import (
@@ -7,6 +8,7 @@ import (
"fmt" "fmt"
"log" "log"
"math/rand" "math/rand"
"net"
"os" "os"
"os/signal" "os/signal"
"strings" "strings"
@@ -14,9 +16,9 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/miekg/dns"
lru "github.com/hashicorp/golang-lru/v2" lru "github.com/hashicorp/golang-lru/v2"
"github.com/miekg/dns"
"golang.org/x/sync/singleflight"
) )
var BuildDate = "unknown" // 由编译时注入 var BuildDate = "unknown" // 由编译时注入
@@ -49,6 +51,7 @@ type cacheEntry struct {
var ( var (
cache *lru.Cache[string, *cacheEntry] cache *lru.Cache[string, *cacheEntry]
cacheMutex sync.RWMutex cacheMutex sync.RWMutex
inflight singleflight.Group
) )
const ( const (
@@ -56,7 +59,7 @@ const (
defaultCacheSize = 10000 // 默认最大缓存条目数 defaultCacheSize = 10000 // 默认最大缓存条目数
) )
// startCacheCleaner 定期清理过期缓存(修复:在删除前二次校验) // startCacheCleaner 定期清理过期缓存(在删除前二次校验)
func startCacheCleaner(ctx context.Context) { func startCacheCleaner(ctx context.Context) {
go func() { go func() {
ticker := time.NewTicker(cacheCleanupInterval) ticker := time.NewTicker(cacheCleanupInterval)
@@ -114,6 +117,48 @@ func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
return b.String() return b.String()
} }
// ecsKeyPart 在未 strip ECS 时,把 ECS 归一化后的“网络”信息并入缓存 key
// 为最小改动:当 strip=true即启用去 ECS时直接返回空字符串
func ecsKeyPart(m *dns.Msg, strip bool) string {
if strip {
return ""
}
o := m.IsEdns0()
if o == nil {
return ""
}
for _, opt := range o.Option {
s, ok := opt.(*dns.EDNS0_SUBNET)
if !ok {
continue
}
fam := s.Family
pfx := int(s.SourceNetmask)
addr := append(net.IP(nil), s.Address...) // 拷贝以免原切片被改
switch fam {
case 1: // IPv4
ip := addr.To4()
if ip != nil {
mask := net.CIDRMask(pfx, 32)
ip = ip.Mask(mask)
return fmt.Sprintf("|ECS=%d/%d/%x", fam, s.SourceNetmask, ip)
}
case 2: // IPv6
ip := addr.To16()
if ip != nil {
mask := net.CIDRMask(pfx, 128)
for i := 0; i < 16; i++ {
ip[i] &= mask[i]
}
return fmt.Sprintf("|ECS=%d/%d/%x", fam, s.SourceNetmask, ip)
}
}
// 回退:不做掩码
return fmt.Sprintf("|ECS=%d/%d/%x", fam, s.SourceNetmask, addr)
}
return ""
}
func isPseudo(rr dns.RR) bool { func isPseudo(rr dns.RR) bool {
switch rr.(type) { switch rr.(type) {
case *dns.OPT, *dns.TSIG: case *dns.OPT, *dns.TSIG:
@@ -146,7 +191,7 @@ func extractEDNSMeta(m *dns.Msg) (present bool, version uint8, ext uint16, ede [
return return
} }
// 读取缓存(修复:Get 在写锁下;在锁外调整 TTL返回 EDNS 元数据) // 读取缓存Get 在写锁下;在锁外调整 TTL返回 EDNS 元数据)
func tryCacheRead(key string) (*dns.Msg, *cacheEntry, bool) { func tryCacheRead(key string) (*dns.Msg, *cacheEntry, bool) {
now := time.Now() now := time.Now()
@@ -187,7 +232,6 @@ func tryCacheRead(key string) (*dns.Msg, *cacheEntry, bool) {
return out, e, true return out, e, true
} }
// 计算负面 TTL
// hasAnswerForType 判断报文中是否存在回答“请求类型”的 RRset // hasAnswerForType 判断报文中是否存在回答“请求类型”的 RRset
func hasAnswerForType(m *dns.Msg, q dns.Question) bool { func hasAnswerForType(m *dns.Msg, q dns.Question) bool {
for _, rr := range m.Answer { for _, rr := range m.Answer {
@@ -199,7 +243,7 @@ func hasAnswerForType(m *dns.Msg, q dns.Question) bool {
return false return false
} }
// 计算负面 TTL修复:正确识别 NODATA包括 CNAME 等场景) // 计算负面 TTL正确识别 NODATA包括 CNAME 等场景)
func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) { func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
// NXDOMAIN肯定是负面 // NXDOMAIN肯定是负面
if m.Rcode != dns.RcodeNameError { if m.Rcode != dns.RcodeNameError {
@@ -379,7 +423,7 @@ func queryUpstreamsLimited(
sem := make(chan struct{}, maxParallel) sem := make(chan struct{}, maxParallel)
execOne := func(svr string) { execOne := func(svr string) {
upReq := clampEDNSForUpstream(req, 1232) // 或做成 flag upReq := clampEDNSForUpstream(req, 1232) // 采用 1232 降低分片风险
resp, _, err := udpClient.ExchangeContext(cctx, upReq, svr) resp, _, err := udpClient.ExchangeContext(cctx, upReq, svr)
if err == nil && resp != nil && resp.Truncated && allowTCPFallback { if err == nil && resp != nil && resp.Truncated && allowTCPFallback {
log.Printf("[upstream] UDP truncated, retry TCP: %s", svr) log.Printf("[upstream] UDP truncated, retry TCP: %s", svr)
@@ -394,6 +438,7 @@ func queryUpstreamsLimited(
} }
return return
} }
// 丢弃对客户端无意义/不可靠的错误
if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError { if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError {
return return
} }
@@ -461,6 +506,7 @@ func writeReply(w dns.ResponseWriter, req *dns.Msg, upstream *dns.Msg, meta *cac
out := new(dns.Msg) out := new(dns.Msg)
out.SetReply(req) out.SetReply(req)
out.Authoritative = false out.Authoritative = false
// RA 语义修正RA 表示“服务器是否支持递归”,与客户端 RD 无关
out.RecursionAvailable = upstream.RecursionAvailable out.RecursionAvailable = upstream.RecursionAvailable
out.AuthenticatedData = upstream.AuthenticatedData out.AuthenticatedData = upstream.AuthenticatedData
out.CheckingDisabled = req.CheckingDisabled out.CheckingDisabled = req.CheckingDisabled
@@ -483,9 +529,10 @@ func writeReply(w dns.ResponseWriter, req *dns.Msg, upstream *dns.Msg, meta *cac
if ro.Do() { if ro.Do() {
o.SetDo() o.SetDo()
} }
// 优先使用“请求”的 EDNS 版本;扩展 RCODE/EDE 来自上游/缓存
o.SetVersion(ro.Version())
if uo := upstream.IsEdns0(); uo != nil { if uo := upstream.IsEdns0(); uo != nil {
o.SetExtendedRcode(uint16(uo.ExtendedRcode())) o.SetExtendedRcode(uint16(uo.ExtendedRcode()))
o.SetVersion(uint8(uo.Version()))
for _, opt := range uo.Option { for _, opt := range uo.Option {
if ede, ok := opt.(*dns.EDNS0_EDE); ok { if ede, ok := opt.(*dns.EDNS0_EDE); ok {
o.Option = append(o.Option, ede) o.Option = append(o.Option, ede)
@@ -493,8 +540,7 @@ func writeReply(w dns.ResponseWriter, req *dns.Msg, upstream *dns.Msg, meta *cac
} }
} else if meta != nil && meta.ednsPresent { } else if meta != nil && meta.ednsPresent {
// Upstream/cached msg has no OPT例如缓存时被剥离用缓存元数据重建 // Upstream/cached msg has no OPT例如缓存时被剥离用缓存元数据重建
o.SetExtendedRcode(meta.ednsExtRcode) o.SetExtendedRcode(uint16(meta.ednsExtRcode))
o.SetVersion(meta.ednsVersion)
for _, e := range meta.ednsEDE { for _, e := range meta.ednsEDE {
o.Option = append(o.Option, e) o.Option = append(o.Option, e)
} }
@@ -543,7 +589,8 @@ func handleDNS(
return return
} }
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled) // 缓存 key基础 + (未 strip 时的ECS 片段
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled) + ecsKeyPart(r, stripECSBeforeForward)
if cachedMsg, cachedMeta, ok := tryCacheRead(key); ok { if cachedMsg, cachedMeta, ok := tryCacheRead(key); ok {
log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name) log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
@@ -552,14 +599,21 @@ func handleDNS(
} }
log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name) log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
// 使用 singleflight 合并相同 key 的并发查询,避免上游雪崩
v, _, _ := inflight.Do(key, func() (any, error) {
ctx := context.Background() ctx := context.Background()
resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback) resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback)
if resp != nil {
cacheWrite(key, resp, cacheMaxTTL)
}
return resp, nil
})
resp, _ := v.(*dns.Msg)
if resp == nil { if resp == nil {
log.Printf("[error] all upstreams failed for %s", q.Name) log.Printf("[error] all upstreams failed for %s", q.Name)
dns.HandleFailed(w, r) dns.HandleFailed(w, r)
return return
} }
cacheWrite(key, resp, cacheMaxTTL)
for _, ans := range resp.Answer { for _, ans := range resp.Answer {
log.Printf("[answer] %s", ans.String()) log.Printf("[answer] %s", ans.String())
} }
@@ -581,11 +635,13 @@ func main() {
cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL") cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL")
cacheSizeFlag := flag.Int("cache-size", defaultCacheSize, "LRU 缓存大小上限") cacheSizeFlag := flag.Int("cache-size", defaultCacheSize, "LRU 缓存大小上限")
timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时") timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时")
readTimeoutFlag := flag.Duration("read-timeout", 0, "DoT 连接读超时0=不限制)")
writeTimeoutFlag := flag.Duration("write-timeout", 0, "DoT 连接写超时0=不限制)")
maxParallel := flag.Int("max-parallel", 3, "并发上游数量") maxParallel := flag.Int("max-parallel", 3, "并发上游数量")
stripECSFlag := flag.Bool("strip-ecs", true, "去除 ECS") stripECSFlag := flag.Bool("strip-ecs", true, "去除 ECS")
allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时 TCP 回退") allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时 TCP 回退")
blacklistStr := flag.String("blacklist", "", "逗号分隔的黑名单域名(后缀匹配;支持如 *.example.com") blacklistStr := flag.String("blacklist", "", "逗号分隔的黑名单域名(后缀匹配;支持如 *.example.com")
blacklistFile := flag.String("blacklist-file", "", "黑名单文件路径(每行一个域名;支持 # 或 ; 注释;后缀匹配") blacklistFile := flag.String("blacklist-file", "", "黑名单文件路径(每行一个域名;支持 # 或 ; 注释;可接受 hosts 风格")
blacklistRcodeFlag := flag.String("blacklist-rcode", "REFUSED", "命中黑名单返回的 RCODEREFUSED|NXDOMAIN|SERVFAIL") blacklistRcodeFlag := flag.String("blacklist-rcode", "REFUSED", "命中黑名单返回的 RCODEREFUSED|NXDOMAIN|SERVFAIL")
verbose := flag.Bool("v", false, "verbose 日志") verbose := flag.Bool("v", false, "verbose 日志")
@@ -632,6 +688,7 @@ func main() {
log.Fatal("[fatal] no upstream DNS servers provided") log.Fatal("[fatal] no upstream DNS servers provided")
} }
// 如需更保守的 UDP 尺寸以减少分片,可将 UDPSize 改为 1232
udpClient = &dns.Client{Net: "udp", UDPSize: 4096, SingleInflight: true} udpClient = &dns.Client{Net: "udp", UDPSize: 4096, SingleInflight: true}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@@ -659,12 +716,12 @@ func main() {
Net: "tcp-tls", Net: "tcp-tls",
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS13, // TLS 1.3
NextProtos: []string{"dot"}, NextProtos: []string{"dot"},
}, },
Handler: mux, Handler: mux,
ReadTimeout: 10 * time.Second, ReadTimeout: *readTimeoutFlag,
WriteTimeout: 10 * time.Second, WriteTimeout: *writeTimeoutFlag,
} }
stop := make(chan os.Signal, 1) stop := make(chan os.Signal, 1)
@@ -673,8 +730,10 @@ func main() {
errCh := make(chan error, 1) errCh := make(chan error, 1)
go func() { go func() {
log.Printf("🚀 starting DNS-over-TLS on %s", *addr) log.Printf("🚀 starting DNS-over-TLS on %s", *addr)
log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v | blacklist_rules=%d | blacklist_rcode=%s", log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v | read_timeout=%s | write_timeout=%s | blacklist_rules=%d | blacklist_rcode=%s",
upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback, len(bl.rules), strings.ToUpper(*blacklistRcodeFlag)) upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback,
readTimeoutFlag.String(), writeTimeoutFlag.String(),
len(bl.rules), strings.ToUpper(*blacklistRcodeFlag))
errCh <- srv.ListenAndServe() errCh <- srv.ListenAndServe()
}() }()