diff --git a/dot b/dot index 726a612..5515c04 100644 Binary files a/dot and b/dot differ diff --git a/main.go b/main.go index 45c4334..0531226 100644 --- a/main.go +++ b/main.go @@ -1,3 +1,4 @@ +// main.go package main import ( @@ -7,6 +8,7 @@ import ( "fmt" "log" "math/rand" + "net" "os" "os/signal" "strings" @@ -14,9 +16,9 @@ import ( "syscall" "time" - "github.com/miekg/dns" - lru "github.com/hashicorp/golang-lru/v2" + "github.com/miekg/dns" + "golang.org/x/sync/singleflight" ) var BuildDate = "unknown" // 由编译时注入 @@ -49,6 +51,7 @@ type cacheEntry struct { var ( cache *lru.Cache[string, *cacheEntry] cacheMutex sync.RWMutex + inflight singleflight.Group ) const ( @@ -56,7 +59,7 @@ const ( defaultCacheSize = 10000 // 默认最大缓存条目数 ) -// startCacheCleaner 定期清理过期缓存(修复:在删除前二次校验) +// startCacheCleaner 定期清理过期缓存(在删除前二次校验) func startCacheCleaner(ctx context.Context) { go func() { ticker := time.NewTicker(cacheCleanupInterval) @@ -114,6 +117,48 @@ func cacheKeyFromMsg(q dns.Question, do, cd bool) string { return b.String() } +// ecsKeyPart 在未 strip ECS 时,把 ECS 归一化后的“网络”信息并入缓存 key +// 为最小改动:当 strip=true(即启用去 ECS)时直接返回空字符串 +func ecsKeyPart(m *dns.Msg, strip bool) string { + if strip { + return "" + } + o := m.IsEdns0() + if o == nil { + return "" + } + for _, opt := range o.Option { + s, ok := opt.(*dns.EDNS0_SUBNET) + if !ok { + continue + } + fam := s.Family + pfx := int(s.SourceNetmask) + addr := append(net.IP(nil), s.Address...) // 拷贝以免原切片被改 + switch fam { + case 1: // IPv4 + ip := addr.To4() + if ip != nil { + mask := net.CIDRMask(pfx, 32) + ip = ip.Mask(mask) + return fmt.Sprintf("|ECS=%d/%d/%x", fam, s.SourceNetmask, ip) + } + case 2: // IPv6 + ip := addr.To16() + if ip != nil { + mask := net.CIDRMask(pfx, 128) + for i := 0; i < 16; i++ { + ip[i] &= mask[i] + } + return fmt.Sprintf("|ECS=%d/%d/%x", fam, s.SourceNetmask, ip) + } + } + // 回退:不做掩码 + return fmt.Sprintf("|ECS=%d/%d/%x", fam, s.SourceNetmask, addr) + } + return "" +} + func isPseudo(rr dns.RR) bool { switch rr.(type) { case *dns.OPT, *dns.TSIG: @@ -146,7 +191,7 @@ func extractEDNSMeta(m *dns.Msg) (present bool, version uint8, ext uint16, ede [ return } -// 读取缓存(修复:Get 在写锁下;在锁外调整 TTL;返回 EDNS 元数据) +// 读取缓存(Get 在写锁下;在锁外调整 TTL;返回 EDNS 元数据) func tryCacheRead(key string) (*dns.Msg, *cacheEntry, bool) { now := time.Now() @@ -187,7 +232,6 @@ func tryCacheRead(key string) (*dns.Msg, *cacheEntry, bool) { return out, e, true } -// 计算负面 TTL // hasAnswerForType 判断报文中是否存在回答“请求类型”的 RRset func hasAnswerForType(m *dns.Msg, q dns.Question) bool { for _, rr := range m.Answer { @@ -199,7 +243,7 @@ func hasAnswerForType(m *dns.Msg, q dns.Question) bool { return false } -// 计算负面 TTL(修复:正确识别 NODATA,包括 CNAME 等场景) +// 计算负面 TTL(正确识别 NODATA,包括 CNAME 等场景) func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) { // NXDOMAIN:肯定是负面 if m.Rcode != dns.RcodeNameError { @@ -379,7 +423,7 @@ func queryUpstreamsLimited( sem := make(chan struct{}, maxParallel) execOne := func(svr string) { - upReq := clampEDNSForUpstream(req, 1232) // 或做成 flag + upReq := clampEDNSForUpstream(req, 1232) // 采用 1232 降低分片风险 resp, _, err := udpClient.ExchangeContext(cctx, upReq, svr) if err == nil && resp != nil && resp.Truncated && allowTCPFallback { log.Printf("[upstream] UDP truncated, retry TCP: %s", svr) @@ -394,6 +438,7 @@ func queryUpstreamsLimited( } return } + // 丢弃对客户端无意义/不可靠的错误 if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError { return } @@ -461,6 +506,7 @@ func writeReply(w dns.ResponseWriter, req *dns.Msg, upstream *dns.Msg, meta *cac out := new(dns.Msg) out.SetReply(req) out.Authoritative = false + // RA 语义修正:RA 表示“服务器是否支持递归”,与客户端 RD 无关 out.RecursionAvailable = upstream.RecursionAvailable out.AuthenticatedData = upstream.AuthenticatedData out.CheckingDisabled = req.CheckingDisabled @@ -483,9 +529,10 @@ func writeReply(w dns.ResponseWriter, req *dns.Msg, upstream *dns.Msg, meta *cac if ro.Do() { o.SetDo() } + // 优先使用“请求”的 EDNS 版本;扩展 RCODE/EDE 来自上游/缓存 + o.SetVersion(ro.Version()) if uo := upstream.IsEdns0(); uo != nil { o.SetExtendedRcode(uint16(uo.ExtendedRcode())) - o.SetVersion(uint8(uo.Version())) for _, opt := range uo.Option { if ede, ok := opt.(*dns.EDNS0_EDE); ok { o.Option = append(o.Option, ede) @@ -493,8 +540,7 @@ func writeReply(w dns.ResponseWriter, req *dns.Msg, upstream *dns.Msg, meta *cac } } else if meta != nil && meta.ednsPresent { // Upstream/cached msg has no OPT(例如缓存时被剥离),用缓存元数据重建 - o.SetExtendedRcode(meta.ednsExtRcode) - o.SetVersion(meta.ednsVersion) + o.SetExtendedRcode(uint16(meta.ednsExtRcode)) for _, e := range meta.ednsEDE { o.Option = append(o.Option, e) } @@ -543,7 +589,8 @@ func handleDNS( return } - key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled) + // 缓存 key:基础 + (未 strip 时的)ECS 片段 + key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled) + ecsKeyPart(r, stripECSBeforeForward) if cachedMsg, cachedMeta, ok := tryCacheRead(key); ok { log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name) @@ -552,14 +599,21 @@ func handleDNS( } log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name) - ctx := context.Background() - resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback) + // 使用 singleflight 合并相同 key 的并发查询,避免上游雪崩 + v, _, _ := inflight.Do(key, func() (any, error) { + ctx := context.Background() + resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback) + if resp != nil { + cacheWrite(key, resp, cacheMaxTTL) + } + return resp, nil + }) + resp, _ := v.(*dns.Msg) if resp == nil { log.Printf("[error] all upstreams failed for %s", q.Name) dns.HandleFailed(w, r) return } - cacheWrite(key, resp, cacheMaxTTL) for _, ans := range resp.Answer { log.Printf("[answer] %s", ans.String()) } @@ -581,11 +635,13 @@ func main() { cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL") cacheSizeFlag := flag.Int("cache-size", defaultCacheSize, "LRU 缓存大小上限") timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时") + readTimeoutFlag := flag.Duration("read-timeout", 0, "DoT 连接读超时(0=不限制)") + writeTimeoutFlag := flag.Duration("write-timeout", 0, "DoT 连接写超时(0=不限制)") maxParallel := flag.Int("max-parallel", 3, "并发上游数量") stripECSFlag := flag.Bool("strip-ecs", true, "去除 ECS") allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时 TCP 回退") blacklistStr := flag.String("blacklist", "", "逗号分隔的黑名单域名(后缀匹配;支持如 *.example.com)") - blacklistFile := flag.String("blacklist-file", "", "黑名单文件路径(每行一个域名;支持 # 或 ; 注释;后缀匹配)") + blacklistFile := flag.String("blacklist-file", "", "黑名单文件路径(每行一个域名;支持 # 或 ; 注释;可接受 hosts 风格)") blacklistRcodeFlag := flag.String("blacklist-rcode", "REFUSED", "命中黑名单返回的 RCODE:REFUSED|NXDOMAIN|SERVFAIL") verbose := flag.Bool("v", false, "verbose 日志") @@ -632,6 +688,7 @@ func main() { log.Fatal("[fatal] no upstream DNS servers provided") } + // 如需更保守的 UDP 尺寸以减少分片,可将 UDPSize 改为 1232 udpClient = &dns.Client{Net: "udp", UDPSize: 4096, SingleInflight: true} ctx, cancel := context.WithCancel(context.Background()) @@ -659,12 +716,12 @@ func main() { Net: "tcp-tls", TLSConfig: &tls.Config{ Certificates: []tls.Certificate{cert}, - MinVersion: tls.VersionTLS12, + MinVersion: tls.VersionTLS13, // TLS 1.3 NextProtos: []string{"dot"}, }, Handler: mux, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, + ReadTimeout: *readTimeoutFlag, + WriteTimeout: *writeTimeoutFlag, } stop := make(chan os.Signal, 1) @@ -673,8 +730,10 @@ func main() { errCh := make(chan error, 1) go func() { log.Printf("🚀 starting DNS-over-TLS on %s", *addr) - log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v | blacklist_rules=%d | blacklist_rcode=%s", - upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback, len(bl.rules), strings.ToUpper(*blacklistRcodeFlag)) + log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v | read_timeout=%s | write_timeout=%s | blacklist_rules=%d | blacklist_rcode=%s", + upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback, + readTimeoutFlag.String(), writeTimeoutFlag.String(), + len(bl.rules), strings.ToUpper(*blacklistRcodeFlag)) errCh <- srv.ListenAndServe() }()