diff --git a/blacklist.go b/blacklist.go index f8b9c2c..774a123 100644 --- a/blacklist.go +++ b/blacklist.go @@ -12,6 +12,7 @@ import ( "time" "github.com/miekg/dns" + "golang.org/x/net/idna" ) // -------- Blacklist helpers (独立文件) -------- @@ -22,7 +23,14 @@ func canonicalFQDN(s string) string { if s == "" { return "" } + // 允许黑名单写 "*.example.com";内部匹配用裸后缀 s = strings.TrimPrefix(s, "*.") + + // 先把可能的中文/Unicode 域名转成 ASCII(punycode),再规范化 + if a, err := idna.Lookup.ToASCII(s); err == nil { + s = a + } + // CanonicalName 会做小写化与尾点规范化 return dns.CanonicalName(s) } @@ -43,38 +51,38 @@ func uniqueStrings(in []string) []string { } // 支持 # / ; 注释;每行一个域名;支持以 "*.example.com" 书写 +// 支持 # / ; 注释;每行一个域名;支持以 "*.example.com" 书写;支持 hosts 风格(首列为 IP) func loadBlacklistFile(path string) ([]string, error) { f, err := os.Open(path) if err != nil { return nil, err } defer f.Close() + sc := bufio.NewScanner(f) + // 默认 64KB 容量不够稳妥,这里放大到 2MB,兼容一些合并的大 hosts 列表 + sc.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) + var rules []string for sc.Scan() { line := strings.TrimSpace(sc.Text()) if line == "" { continue } - // 行首注释 - if strings.HasPrefix(line, "//") || strings.HasPrefix(line, "#") || strings.HasPrefix(line, ";") { - continue - } - // 行内注释:先 // 再 # ; - if i := strings.Index(line, "//"); i >= 0 { - line = strings.TrimSpace(line[:i]) - } + // 去掉注释(# 或 ; 之后的内容) if i := strings.IndexAny(line, "#;"); i >= 0 { line = strings.TrimSpace(line[:i]) + if line == "" { + continue + } } - if line == "" { - continue - } - // hosts 风格:第一个字段是 IP,则其余每个字段视为域名 + fields := strings.Fields(line) if len(fields) == 0 { continue } + + // hosts 风格:第一个字段是 IP,则其余每个字段视为域名 start := 0 if net.ParseIP(fields[0]) != nil { start = 1 @@ -88,7 +96,10 @@ func loadBlacklistFile(path string) ([]string, error) { if err := sc.Err(); err != nil { return nil, err } - return uniqueStrings(rules), nil + + sort.Strings(rules) + rules = uniqueStrings(rules) + return rules, nil } // 自动重载黑名单 diff --git a/dot b/dot index 0304963..1669f36 100644 Binary files a/dot and b/dot differ diff --git a/go.mod b/go.mod index da437fd..540f672 100644 --- a/go.mod +++ b/go.mod @@ -5,13 +5,14 @@ go 1.25.3 require ( github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/miekg/dns v1.1.68 + golang.org/x/net v0.46.0 golang.org/x/sync v0.17.0 ) require ( github.com/google/go-cmp v0.7.0 // indirect golang.org/x/mod v0.29.0 // indirect - golang.org/x/net v0.46.0 // indirect golang.org/x/sys v0.37.0 // indirect + golang.org/x/text v0.30.0 // indirect golang.org/x/tools v0.38.0 // indirect ) diff --git a/go.sum b/go.sum index 40b76e0..3ca3654 100644 --- a/go.sum +++ b/go.sum @@ -12,5 +12,7 @@ golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= +golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= diff --git a/main.go b/main.go index 565ad43..dd22914 100644 --- a/main.go +++ b/main.go @@ -343,7 +343,18 @@ func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) { } var ttl uint32 var ok bool + + // 判断负面响应(NXDOMAIN 或 NODATA) + neg, isNodata := in.Rcode == dns.RcodeNameError, false + if in.Rcode == dns.RcodeSuccess && len(in.Question) > 0 && !hasAnswerForType(in, in.Question[0]) { + isNodata = true + } + if ttl, ok = negativeTTL(in, maxTTL); !ok { + if neg || isNodata { + return // 负面但无 SOA → 不缓存 + } + minTTL, has := minRRsetTTL(in) if has { cfgTTL := uint32(maxTTL.Seconds()) @@ -419,61 +430,85 @@ func queryUpstreamsLimited( cctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - type result struct{ msg *dns.Msg } + type result struct { + msg *dns.Msg + } ch := make(chan result, len(servers)) + done := make(chan struct{}, len(servers)) sem := make(chan struct{}, maxParallel) + // 单个上游执行 execOne := func(svr string) { - upReq := clampEDNSForUpstream(req, 1232) // 采用 1232 降低分片风险 + // 并发限流(可被超时取消) + select { + case sem <- struct{}{}: + defer func() { <-sem }() + case <-cctx.Done(): + // 超时/取消,直接放弃 + return + } + defer func() { done <- struct{}{} }() + + // 为 UDP 上游把 EDNS UDP size 夹到 1232,降低分片风险 + upReq := clampEDNSForUpstream(req, 1232) + + // 先走 UDP resp, _, err := udpClient.ExchangeContext(cctx, upReq, svr) + // 截断且允许回退则走 TCP if err == nil && resp != nil && resp.Truncated && allowTCPFallback { log.Printf("[upstream] UDP truncated, retry TCP: %s", svr) tcpClient := *udpClient tcpClient.Net = "tcp" - resp, _, err = tcpClient.ExchangeContext(cctx, req.Copy(), svr) } + // 失败直接返回(但不写入 ch);只在未超时情况下打印错误 if err != nil || resp == nil { if err != nil && cctx.Err() == nil { - log.Printf("[upstream] %s error: %v", svr, err) + log.Printf("[upstream] %s: %v", svr, err) } return } - // 丢弃对客户端无意义/不可靠的错误 - if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError { + // 过滤不可用的错误 RCODE(避免造成“假性超时”的错觉) + if resp.Rcode == dns.RcodeServerFailure || + resp.Rcode == dns.RcodeRefused || + resp.Rcode == dns.RcodeFormatError { return } + + // 投递可用结果(若已经超时则丢弃) select { case ch <- result{msg: resp}: case <-cctx.Done(): } } + // 并发发起 for _, s := range servers { s := s - go func() { - select { - case sem <- struct{}{}: - defer func() { <-sem }() - case <-cctx.Done(): - return - } - execOne(s) - }() + go execOne(s) } - for i := 0; i < len(servers); i++ { + finished := 0 + total := len(servers) + + // 聚合:首个可用响应直接返回;区分“真超时”与“无可用结果” + for finished < total { select { case r := <-ch: if r.msg != nil { cancel() return r.msg } + case <-done: + finished++ case <-cctx.Done(): - log.Printf("[upstream] timeout after %v", timeout) + log.Printf("[upstream] timeout after %v (finished=%d/%d)", timeout, finished, total) return nil } } + + // 所有上游都结束,但没有一个可用 + log.Printf("[upstream] no acceptable upstream response (finished=%d/%d)", finished, total) return nil }