feat(blacklist): 支持 hosts 风格黑名单文件并优化热重载机制

- 支持解析 hosts 风格的黑名单文件,可识别以 IP 开头的行,并将后续字段作为域名处理
- 增加对行首及行内注释的支持(支持 `//`、`#`、`;` 符号)
- 使用 atomic.Pointer 管理黑名单匹配器,提升并发安全性
- 优化黑名单热重载逻辑,使用 time.Ticker 替代 time.After 提高稳定性
- 更新相关依赖引用路径,调整 sync 包导入位置
This commit is contained in:
2025-10-17 10:46:48 +08:00
parent 224f575e68
commit 1ab273e2a8
4 changed files with 52 additions and 31 deletions

View File

@@ -4,9 +4,11 @@ import (
"bufio" "bufio"
"context" "context"
"log" "log"
"net"
"os" "os"
"sort" "sort"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -54,11 +56,33 @@ func loadBlacklistFile(path string) ([]string, error) {
if line == "" { if line == "" {
continue continue
} }
// 行首注释
if strings.HasPrefix(line, "//") || strings.HasPrefix(line, "#") || strings.HasPrefix(line, ";") {
continue
}
// 行内注释:先 // 再 # ;
if i := strings.Index(line, "//"); i >= 0 {
line = strings.TrimSpace(line[:i])
}
if i := strings.IndexAny(line, "#;"); i >= 0 { if i := strings.IndexAny(line, "#;"); i >= 0 {
line = strings.TrimSpace(line[:i]) line = strings.TrimSpace(line[:i])
} }
if r := canonicalFQDN(line); r != "" { if line == "" {
rules = append(rules, r) continue
}
// hosts 风格:第一个字段是 IP则其余每个字段视为域名
fields := strings.Fields(line)
if len(fields) == 0 {
continue
}
start := 0
if net.ParseIP(fields[0]) != nil {
start = 1
}
for _, tok := range fields[start:] {
if r := canonicalFQDN(tok); r != "" {
rules = append(rules, r)
}
} }
} }
if err := sc.Err(); err != nil { if err := sc.Err(); err != nil {
@@ -68,17 +92,19 @@ func loadBlacklistFile(path string) ([]string, error) {
} }
// 自动重载黑名单 // 自动重载黑名单
func startBlacklistReloader(ctx context.Context, path string, interval time.Duration, current **suffixMatcher) { func startBlacklistReloader(ctx context.Context, path string, interval time.Duration, holder *atomic.Pointer[suffixMatcher]) {
if path == "" { if path == "" {
return return
} }
go func() { go func() {
var lastMod time.Time var lastMod time.Time
ticker := time.NewTicker(interval)
defer ticker.Stop()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-time.After(interval): case <-ticker.C:
fi, err := os.Stat(path) fi, err := os.Stat(path)
if err != nil { if err != nil {
log.Printf("[blacklist] reload check failed: %v", err) log.Printf("[blacklist] reload check failed: %v", err)
@@ -91,7 +117,7 @@ func startBlacklistReloader(ctx context.Context, path string, interval time.Dura
log.Printf("[blacklist] reload failed: %v", err) log.Printf("[blacklist] reload failed: %v", err)
continue continue
} }
*current = newSuffixMatcher(rules) holder.Store(newSuffixMatcher(rules))
lastMod = modTime lastMod = modTime
log.Printf("[blacklist] reloaded %d rules (modified %s)", len(rules), modTime.Format(time.RFC3339)) log.Printf("[blacklist] reloaded %d rules (modified %s)", len(rules), modTime.Format(time.RFC3339))
} }
@@ -155,9 +181,8 @@ func makeBlockedUpstream(rcode int, rule string) *dns.Msg {
return m return m
} }
func initBlacklist(ctx context.Context, listStr, filePath, rcodeStr string) (*suffixMatcher, int) { func initBlacklist(ctx context.Context, listStr, filePath, rcodeStr string) (*atomic.Pointer[suffixMatcher], int) {
var rules []string var rules []string
if v := strings.TrimSpace(listStr); v != "" { if v := strings.TrimSpace(listStr); v != "" {
for _, s := range strings.Split(v, ",") { for _, s := range strings.Split(v, ",") {
if r := canonicalFQDN(s); r != "" { if r := canonicalFQDN(s); r != "" {
@@ -165,24 +190,19 @@ func initBlacklist(ctx context.Context, listStr, filePath, rcodeStr string) (*su
} }
} }
} }
if file := strings.TrimSpace(filePath); file != "" {
fileRules, err := loadBlacklistFile(file)
if err != nil {
log.Fatalf("[fatal] failed to load blacklist-file %q: %v", file, err)
}
rules = append(rules, fileRules...)
}
bl := newSuffixMatcher(rules)
blRcode := parseRcode(rcodeStr)
if filePath != "" { if filePath != "" {
startBlacklistReloader(ctx, filePath, 30*time.Second, &bl) if fs, err := loadBlacklistFile(filePath); err != nil {
log.Printf("[blacklist] load file error: %v", err)
} else {
rules = append(rules, fs...)
}
} }
var holder atomic.Pointer[suffixMatcher]
log.Printf("[blacklist] loaded %d rules (file=%v, rcode=%s)", holder.Store(newSuffixMatcher(rules))
len(bl.rules), filePath != "", strings.ToUpper(rcodeStr)) blRcode := parseRcode(rcodeStr)
if filePath != "" {
return bl, blRcode startBlacklistReloader(ctx, filePath, 30*time.Second, &holder)
}
log.Printf("[blacklist] loaded %d rules (file=%v, rcode=%s)", len(holder.Load().rules), filePath != "", strings.ToUpper(rcodeStr))
return &holder, blRcode
} }

BIN
dot

Binary file not shown.

2
go.mod
View File

@@ -5,13 +5,13 @@ go 1.25.3
require ( require (
github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/miekg/dns v1.1.68 github.com/miekg/dns v1.1.68
golang.org/x/sync v0.17.0
) )
require ( require (
github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-cmp v0.7.0 // indirect
golang.org/x/mod v0.29.0 // indirect golang.org/x/mod v0.29.0 // indirect
golang.org/x/net v0.46.0 // indirect golang.org/x/net v0.46.0 // indirect
golang.org/x/sync v0.17.0 // indirect
golang.org/x/sys v0.37.0 // indirect golang.org/x/sys v0.37.0 // indirect
golang.org/x/tools v0.38.0 // indirect golang.org/x/tools v0.38.0 // indirect
) )

11
main.go
View File

@@ -13,6 +13,7 @@ import (
"os/signal" "os/signal"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"syscall" "syscall"
"time" "time"
@@ -564,7 +565,7 @@ func handleDNS(
maxParallel int, maxParallel int,
stripECSBeforeForward bool, stripECSBeforeForward bool,
allowTCPFallback bool, allowTCPFallback bool,
bl *suffixMatcher, blPtr *atomic.Pointer[suffixMatcher],
blRcode int, blRcode int,
) dns.HandlerFunc { ) dns.HandlerFunc {
return func(w dns.ResponseWriter, r *dns.Msg) { return func(w dns.ResponseWriter, r *dns.Msg) {
@@ -581,7 +582,7 @@ func handleDNS(
} }
// 黑名单拦截:命中则不查上游,直接返回 // 黑名单拦截:命中则不查上游,直接返回
if rule, ok := bl.match(q.Name); ok { if rule, ok := blPtr.Load().match(q.Name); ok {
nameCanon := dns.CanonicalName(q.Name) nameCanon := dns.CanonicalName(q.Name)
log.Printf("[blacklist] HIT %s rule=%s (no upstream query)", nameCanon, rule) log.Printf("[blacklist] HIT %s rule=%s (no upstream query)", nameCanon, rule)
up := makeBlockedUpstream(blRcode, rule) up := makeBlockedUpstream(blRcode, rule)
@@ -697,7 +698,7 @@ func main() {
startCacheCleaner(ctx) startCacheCleaner(ctx)
// 加载黑名单规则 // 加载黑名单规则
bl, blRcode := initBlacklist(ctx, *blacklistStr, *blacklistFile, *blacklistRcodeFlag) blPtr, blRcode := initBlacklist(ctx, *blacklistStr, *blacklistFile, *blacklistRcodeFlag)
mux := dns.NewServeMux() mux := dns.NewServeMux()
mux.HandleFunc(".", handleDNS( mux.HandleFunc(".", handleDNS(
@@ -707,7 +708,7 @@ func main() {
*maxParallel, *maxParallel,
*stripECSFlag, *stripECSFlag,
*allowTCPFallback, *allowTCPFallback,
bl, blPtr,
blRcode, blRcode,
)) ))
@@ -733,7 +734,7 @@ func main() {
log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v | read_timeout=%s | write_timeout=%s | blacklist_rules=%d | blacklist_rcode=%s", log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v | read_timeout=%s | write_timeout=%s | blacklist_rules=%d | blacklist_rcode=%s",
upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback, upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback,
readTimeoutFlag.String(), writeTimeoutFlag.String(), readTimeoutFlag.String(), writeTimeoutFlag.String(),
len(bl.rules), strings.ToUpper(*blacklistRcodeFlag)) len(blPtr.Load().rules), strings.ToUpper(*blacklistRcodeFlag))
errCh <- srv.ListenAndServe() errCh <- srv.ListenAndServe()
}() }()