feat(dns): 引入 ECS 缓存键支持与并发查询合并
新增对 EDNS Client Subnet (ECS) 的缓存键支持,当未启用 strip-ecs 时, 将归一化后的 ECS 网络信息加入缓存 key。同时引入 singleflight 合并相同 查询请求,防止缓存击穿和上游负载突增。此外,修正了 DoT 服务的 TLS 版本 至 1.3,并开放读写超时配置选项。 - 添加 ecsKeyPart 函数用于生成 ECS 相关的缓存键片段 - 使用 singleflight.Group 避免重复并发查询 - 升级默认 TLS 最小版本到 VersionTLS13 - 增加 `-read-timeout` 和 `-write-timeout` 参数控制 DoT 连接超时 - 调整日志输出以包含新的超时参数 - 优化缓存写入逻辑,在 singleflight 回调中执行
This commit is contained in:
99
main.go
99
main.go
@@ -1,3 +1,4 @@
|
||||
// main.go
|
||||
package main
|
||||
|
||||
import (
|
||||
@@ -7,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
@@ -14,9 +16,9 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru/v2"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
var BuildDate = "unknown" // 由编译时注入
|
||||
@@ -49,6 +51,7 @@ type cacheEntry struct {
|
||||
var (
|
||||
cache *lru.Cache[string, *cacheEntry]
|
||||
cacheMutex sync.RWMutex
|
||||
inflight singleflight.Group
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -56,7 +59,7 @@ const (
|
||||
defaultCacheSize = 10000 // 默认最大缓存条目数
|
||||
)
|
||||
|
||||
// startCacheCleaner 定期清理过期缓存(修复:在删除前二次校验)
|
||||
// startCacheCleaner 定期清理过期缓存(在删除前二次校验)
|
||||
func startCacheCleaner(ctx context.Context) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(cacheCleanupInterval)
|
||||
@@ -114,6 +117,48 @@ func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// ecsKeyPart 在未 strip ECS 时,把 ECS 归一化后的“网络”信息并入缓存 key
|
||||
// 为最小改动:当 strip=true(即启用去 ECS)时直接返回空字符串
|
||||
func ecsKeyPart(m *dns.Msg, strip bool) string {
|
||||
if strip {
|
||||
return ""
|
||||
}
|
||||
o := m.IsEdns0()
|
||||
if o == nil {
|
||||
return ""
|
||||
}
|
||||
for _, opt := range o.Option {
|
||||
s, ok := opt.(*dns.EDNS0_SUBNET)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
fam := s.Family
|
||||
pfx := int(s.SourceNetmask)
|
||||
addr := append(net.IP(nil), s.Address...) // 拷贝以免原切片被改
|
||||
switch fam {
|
||||
case 1: // IPv4
|
||||
ip := addr.To4()
|
||||
if ip != nil {
|
||||
mask := net.CIDRMask(pfx, 32)
|
||||
ip = ip.Mask(mask)
|
||||
return fmt.Sprintf("|ECS=%d/%d/%x", fam, s.SourceNetmask, ip)
|
||||
}
|
||||
case 2: // IPv6
|
||||
ip := addr.To16()
|
||||
if ip != nil {
|
||||
mask := net.CIDRMask(pfx, 128)
|
||||
for i := 0; i < 16; i++ {
|
||||
ip[i] &= mask[i]
|
||||
}
|
||||
return fmt.Sprintf("|ECS=%d/%d/%x", fam, s.SourceNetmask, ip)
|
||||
}
|
||||
}
|
||||
// 回退:不做掩码
|
||||
return fmt.Sprintf("|ECS=%d/%d/%x", fam, s.SourceNetmask, addr)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func isPseudo(rr dns.RR) bool {
|
||||
switch rr.(type) {
|
||||
case *dns.OPT, *dns.TSIG:
|
||||
@@ -146,7 +191,7 @@ func extractEDNSMeta(m *dns.Msg) (present bool, version uint8, ext uint16, ede [
|
||||
return
|
||||
}
|
||||
|
||||
// 读取缓存(修复:Get 在写锁下;在锁外调整 TTL;返回 EDNS 元数据)
|
||||
// 读取缓存(Get 在写锁下;在锁外调整 TTL;返回 EDNS 元数据)
|
||||
func tryCacheRead(key string) (*dns.Msg, *cacheEntry, bool) {
|
||||
now := time.Now()
|
||||
|
||||
@@ -187,7 +232,6 @@ func tryCacheRead(key string) (*dns.Msg, *cacheEntry, bool) {
|
||||
return out, e, true
|
||||
}
|
||||
|
||||
// 计算负面 TTL
|
||||
// hasAnswerForType 判断报文中是否存在回答“请求类型”的 RRset
|
||||
func hasAnswerForType(m *dns.Msg, q dns.Question) bool {
|
||||
for _, rr := range m.Answer {
|
||||
@@ -199,7 +243,7 @@ func hasAnswerForType(m *dns.Msg, q dns.Question) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// 计算负面 TTL(修复:正确识别 NODATA,包括 CNAME 等场景)
|
||||
// 计算负面 TTL(正确识别 NODATA,包括 CNAME 等场景)
|
||||
func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
|
||||
// NXDOMAIN:肯定是负面
|
||||
if m.Rcode != dns.RcodeNameError {
|
||||
@@ -379,7 +423,7 @@ func queryUpstreamsLimited(
|
||||
sem := make(chan struct{}, maxParallel)
|
||||
|
||||
execOne := func(svr string) {
|
||||
upReq := clampEDNSForUpstream(req, 1232) // 或做成 flag
|
||||
upReq := clampEDNSForUpstream(req, 1232) // 采用 1232 降低分片风险
|
||||
resp, _, err := udpClient.ExchangeContext(cctx, upReq, svr)
|
||||
if err == nil && resp != nil && resp.Truncated && allowTCPFallback {
|
||||
log.Printf("[upstream] UDP truncated, retry TCP: %s", svr)
|
||||
@@ -394,6 +438,7 @@ func queryUpstreamsLimited(
|
||||
}
|
||||
return
|
||||
}
|
||||
// 丢弃对客户端无意义/不可靠的错误
|
||||
if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError {
|
||||
return
|
||||
}
|
||||
@@ -461,6 +506,7 @@ func writeReply(w dns.ResponseWriter, req *dns.Msg, upstream *dns.Msg, meta *cac
|
||||
out := new(dns.Msg)
|
||||
out.SetReply(req)
|
||||
out.Authoritative = false
|
||||
// RA 语义修正:RA 表示“服务器是否支持递归”,与客户端 RD 无关
|
||||
out.RecursionAvailable = upstream.RecursionAvailable
|
||||
out.AuthenticatedData = upstream.AuthenticatedData
|
||||
out.CheckingDisabled = req.CheckingDisabled
|
||||
@@ -483,9 +529,10 @@ func writeReply(w dns.ResponseWriter, req *dns.Msg, upstream *dns.Msg, meta *cac
|
||||
if ro.Do() {
|
||||
o.SetDo()
|
||||
}
|
||||
// 优先使用“请求”的 EDNS 版本;扩展 RCODE/EDE 来自上游/缓存
|
||||
o.SetVersion(ro.Version())
|
||||
if uo := upstream.IsEdns0(); uo != nil {
|
||||
o.SetExtendedRcode(uint16(uo.ExtendedRcode()))
|
||||
o.SetVersion(uint8(uo.Version()))
|
||||
for _, opt := range uo.Option {
|
||||
if ede, ok := opt.(*dns.EDNS0_EDE); ok {
|
||||
o.Option = append(o.Option, ede)
|
||||
@@ -493,8 +540,7 @@ func writeReply(w dns.ResponseWriter, req *dns.Msg, upstream *dns.Msg, meta *cac
|
||||
}
|
||||
} else if meta != nil && meta.ednsPresent {
|
||||
// Upstream/cached msg has no OPT(例如缓存时被剥离),用缓存元数据重建
|
||||
o.SetExtendedRcode(meta.ednsExtRcode)
|
||||
o.SetVersion(meta.ednsVersion)
|
||||
o.SetExtendedRcode(uint16(meta.ednsExtRcode))
|
||||
for _, e := range meta.ednsEDE {
|
||||
o.Option = append(o.Option, e)
|
||||
}
|
||||
@@ -543,7 +589,8 @@ func handleDNS(
|
||||
return
|
||||
}
|
||||
|
||||
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled)
|
||||
// 缓存 key:基础 + (未 strip 时的)ECS 片段
|
||||
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled) + ecsKeyPart(r, stripECSBeforeForward)
|
||||
|
||||
if cachedMsg, cachedMeta, ok := tryCacheRead(key); ok {
|
||||
log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
|
||||
@@ -552,14 +599,21 @@ func handleDNS(
|
||||
}
|
||||
log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
|
||||
|
||||
ctx := context.Background()
|
||||
resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback)
|
||||
// 使用 singleflight 合并相同 key 的并发查询,避免上游雪崩
|
||||
v, _, _ := inflight.Do(key, func() (any, error) {
|
||||
ctx := context.Background()
|
||||
resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback)
|
||||
if resp != nil {
|
||||
cacheWrite(key, resp, cacheMaxTTL)
|
||||
}
|
||||
return resp, nil
|
||||
})
|
||||
resp, _ := v.(*dns.Msg)
|
||||
if resp == nil {
|
||||
log.Printf("[error] all upstreams failed for %s", q.Name)
|
||||
dns.HandleFailed(w, r)
|
||||
return
|
||||
}
|
||||
cacheWrite(key, resp, cacheMaxTTL)
|
||||
for _, ans := range resp.Answer {
|
||||
log.Printf("[answer] %s", ans.String())
|
||||
}
|
||||
@@ -581,11 +635,13 @@ func main() {
|
||||
cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL")
|
||||
cacheSizeFlag := flag.Int("cache-size", defaultCacheSize, "LRU 缓存大小上限")
|
||||
timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时")
|
||||
readTimeoutFlag := flag.Duration("read-timeout", 0, "DoT 连接读超时(0=不限制)")
|
||||
writeTimeoutFlag := flag.Duration("write-timeout", 0, "DoT 连接写超时(0=不限制)")
|
||||
maxParallel := flag.Int("max-parallel", 3, "并发上游数量")
|
||||
stripECSFlag := flag.Bool("strip-ecs", true, "去除 ECS")
|
||||
allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时 TCP 回退")
|
||||
blacklistStr := flag.String("blacklist", "", "逗号分隔的黑名单域名(后缀匹配;支持如 *.example.com)")
|
||||
blacklistFile := flag.String("blacklist-file", "", "黑名单文件路径(每行一个域名;支持 # 或 ; 注释;后缀匹配)")
|
||||
blacklistFile := flag.String("blacklist-file", "", "黑名单文件路径(每行一个域名;支持 # 或 ; 注释;可接受 hosts 风格)")
|
||||
blacklistRcodeFlag := flag.String("blacklist-rcode", "REFUSED", "命中黑名单返回的 RCODE:REFUSED|NXDOMAIN|SERVFAIL")
|
||||
verbose := flag.Bool("v", false, "verbose 日志")
|
||||
|
||||
@@ -632,6 +688,7 @@ func main() {
|
||||
log.Fatal("[fatal] no upstream DNS servers provided")
|
||||
}
|
||||
|
||||
// 如需更保守的 UDP 尺寸以减少分片,可将 UDPSize 改为 1232
|
||||
udpClient = &dns.Client{Net: "udp", UDPSize: 4096, SingleInflight: true}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -659,12 +716,12 @@ func main() {
|
||||
Net: "tcp-tls",
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
MinVersion: tls.VersionTLS13, // TLS 1.3
|
||||
NextProtos: []string{"dot"},
|
||||
},
|
||||
Handler: mux,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
ReadTimeout: *readTimeoutFlag,
|
||||
WriteTimeout: *writeTimeoutFlag,
|
||||
}
|
||||
|
||||
stop := make(chan os.Signal, 1)
|
||||
@@ -673,8 +730,10 @@ func main() {
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
log.Printf("🚀 starting DNS-over-TLS on %s", *addr)
|
||||
log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v | blacklist_rules=%d | blacklist_rcode=%s",
|
||||
upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback, len(bl.rules), strings.ToUpper(*blacklistRcodeFlag))
|
||||
log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v | read_timeout=%s | write_timeout=%s | blacklist_rules=%d | blacklist_rcode=%s",
|
||||
upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback,
|
||||
readTimeoutFlag.String(), writeTimeoutFlag.String(),
|
||||
len(bl.rules), strings.ToUpper(*blacklistRcodeFlag))
|
||||
errCh <- srv.ListenAndServe()
|
||||
}()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user