diff --git a/blacklist.go b/blacklist.go index 7abd64a..f8b9c2c 100644 --- a/blacklist.go +++ b/blacklist.go @@ -4,9 +4,11 @@ import ( "bufio" "context" "log" + "net" "os" "sort" "strings" + "sync/atomic" "time" "github.com/miekg/dns" @@ -54,11 +56,33 @@ func loadBlacklistFile(path string) ([]string, error) { if line == "" { continue } + // 行首注释 + if strings.HasPrefix(line, "//") || strings.HasPrefix(line, "#") || strings.HasPrefix(line, ";") { + continue + } + // 行内注释:先 // 再 # ; + if i := strings.Index(line, "//"); i >= 0 { + line = strings.TrimSpace(line[:i]) + } if i := strings.IndexAny(line, "#;"); i >= 0 { line = strings.TrimSpace(line[:i]) } - if r := canonicalFQDN(line); r != "" { - rules = append(rules, r) + if line == "" { + continue + } + // hosts 风格:第一个字段是 IP,则其余每个字段视为域名 + fields := strings.Fields(line) + if len(fields) == 0 { + continue + } + start := 0 + if net.ParseIP(fields[0]) != nil { + start = 1 + } + for _, tok := range fields[start:] { + if r := canonicalFQDN(tok); r != "" { + rules = append(rules, r) + } } } if err := sc.Err(); err != nil { @@ -68,17 +92,19 @@ func loadBlacklistFile(path string) ([]string, error) { } // 自动重载黑名单 -func startBlacklistReloader(ctx context.Context, path string, interval time.Duration, current **suffixMatcher) { +func startBlacklistReloader(ctx context.Context, path string, interval time.Duration, holder *atomic.Pointer[suffixMatcher]) { if path == "" { return } go func() { var lastMod time.Time + ticker := time.NewTicker(interval) + defer ticker.Stop() for { select { case <-ctx.Done(): return - case <-time.After(interval): + case <-ticker.C: fi, err := os.Stat(path) if err != nil { log.Printf("[blacklist] reload check failed: %v", err) @@ -91,7 +117,7 @@ func startBlacklistReloader(ctx context.Context, path string, interval time.Dura log.Printf("[blacklist] reload failed: %v", err) continue } - *current = newSuffixMatcher(rules) + holder.Store(newSuffixMatcher(rules)) lastMod = modTime log.Printf("[blacklist] reloaded %d rules (modified %s)", len(rules), modTime.Format(time.RFC3339)) } @@ -155,9 +181,8 @@ func makeBlockedUpstream(rcode int, rule string) *dns.Msg { return m } -func initBlacklist(ctx context.Context, listStr, filePath, rcodeStr string) (*suffixMatcher, int) { +func initBlacklist(ctx context.Context, listStr, filePath, rcodeStr string) (*atomic.Pointer[suffixMatcher], int) { var rules []string - if v := strings.TrimSpace(listStr); v != "" { for _, s := range strings.Split(v, ",") { if r := canonicalFQDN(s); r != "" { @@ -165,24 +190,19 @@ func initBlacklist(ctx context.Context, listStr, filePath, rcodeStr string) (*su } } } - - if file := strings.TrimSpace(filePath); file != "" { - fileRules, err := loadBlacklistFile(file) - if err != nil { - log.Fatalf("[fatal] failed to load blacklist-file %q: %v", file, err) - } - rules = append(rules, fileRules...) - } - - bl := newSuffixMatcher(rules) - blRcode := parseRcode(rcodeStr) - if filePath != "" { - startBlacklistReloader(ctx, filePath, 30*time.Second, &bl) + if fs, err := loadBlacklistFile(filePath); err != nil { + log.Printf("[blacklist] load file error: %v", err) + } else { + rules = append(rules, fs...) + } } - - log.Printf("[blacklist] loaded %d rules (file=%v, rcode=%s)", - len(bl.rules), filePath != "", strings.ToUpper(rcodeStr)) - - return bl, blRcode + var holder atomic.Pointer[suffixMatcher] + holder.Store(newSuffixMatcher(rules)) + blRcode := parseRcode(rcodeStr) + if filePath != "" { + startBlacklistReloader(ctx, filePath, 30*time.Second, &holder) + } + log.Printf("[blacklist] loaded %d rules (file=%v, rcode=%s)", len(holder.Load().rules), filePath != "", strings.ToUpper(rcodeStr)) + return &holder, blRcode } diff --git a/dot b/dot index 5515c04..0304963 100644 Binary files a/dot and b/dot differ diff --git a/go.mod b/go.mod index a5bab76..da437fd 100644 --- a/go.mod +++ b/go.mod @@ -5,13 +5,13 @@ go 1.25.3 require ( github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/miekg/dns v1.1.68 + golang.org/x/sync v0.17.0 ) require ( github.com/google/go-cmp v0.7.0 // indirect golang.org/x/mod v0.29.0 // indirect golang.org/x/net v0.46.0 // indirect - golang.org/x/sync v0.17.0 // indirect golang.org/x/sys v0.37.0 // indirect golang.org/x/tools v0.38.0 // indirect ) diff --git a/main.go b/main.go index 0531226..565ad43 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,7 @@ import ( "os/signal" "strings" "sync" + "sync/atomic" "syscall" "time" @@ -564,7 +565,7 @@ func handleDNS( maxParallel int, stripECSBeforeForward bool, allowTCPFallback bool, - bl *suffixMatcher, + blPtr *atomic.Pointer[suffixMatcher], blRcode int, ) dns.HandlerFunc { return func(w dns.ResponseWriter, r *dns.Msg) { @@ -581,7 +582,7 @@ func handleDNS( } // 黑名单拦截:命中则不查上游,直接返回 - if rule, ok := bl.match(q.Name); ok { + if rule, ok := blPtr.Load().match(q.Name); ok { nameCanon := dns.CanonicalName(q.Name) log.Printf("[blacklist] HIT %s rule=%s (no upstream query)", nameCanon, rule) up := makeBlockedUpstream(blRcode, rule) @@ -697,7 +698,7 @@ func main() { startCacheCleaner(ctx) // 加载黑名单规则 - bl, blRcode := initBlacklist(ctx, *blacklistStr, *blacklistFile, *blacklistRcodeFlag) + blPtr, blRcode := initBlacklist(ctx, *blacklistStr, *blacklistFile, *blacklistRcodeFlag) mux := dns.NewServeMux() mux.HandleFunc(".", handleDNS( @@ -707,7 +708,7 @@ func main() { *maxParallel, *stripECSFlag, *allowTCPFallback, - bl, + blPtr, blRcode, )) @@ -733,7 +734,7 @@ func main() { log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v | read_timeout=%s | write_timeout=%s | blacklist_rules=%d | blacklist_rcode=%s", upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback, readTimeoutFlag.String(), writeTimeoutFlag.String(), - len(bl.rules), strings.ToUpper(*blacklistRcodeFlag)) + len(blPtr.Load().rules), strings.ToUpper(*blacklistRcodeFlag)) errCh <- srv.ListenAndServe() }()