feat(blacklist): 支持 hosts 风格黑名单文件并优化热重载机制
- 支持解析 hosts 风格的黑名单文件,可识别以 IP 开头的行,并将后续字段作为域名处理 - 增加对行首及行内注释的支持(支持 `//`、`#`、`;` 符号) - 使用 atomic.Pointer 管理黑名单匹配器,提升并发安全性 - 优化黑名单热重载逻辑,使用 time.Ticker 替代 time.After 提高稳定性 - 更新相关依赖引用路径,调整 sync 包导入位置
This commit is contained in:
68
blacklist.go
68
blacklist.go
@@ -4,9 +4,11 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"log"
|
"log"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@@ -54,13 +56,35 @@ func loadBlacklistFile(path string) ([]string, error) {
|
|||||||
if line == "" {
|
if line == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// 行首注释
|
||||||
|
if strings.HasPrefix(line, "//") || strings.HasPrefix(line, "#") || strings.HasPrefix(line, ";") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// 行内注释:先 // 再 # ;
|
||||||
|
if i := strings.Index(line, "//"); i >= 0 {
|
||||||
|
line = strings.TrimSpace(line[:i])
|
||||||
|
}
|
||||||
if i := strings.IndexAny(line, "#;"); i >= 0 {
|
if i := strings.IndexAny(line, "#;"); i >= 0 {
|
||||||
line = strings.TrimSpace(line[:i])
|
line = strings.TrimSpace(line[:i])
|
||||||
}
|
}
|
||||||
if r := canonicalFQDN(line); r != "" {
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// hosts 风格:第一个字段是 IP,则其余每个字段视为域名
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
start := 0
|
||||||
|
if net.ParseIP(fields[0]) != nil {
|
||||||
|
start = 1
|
||||||
|
}
|
||||||
|
for _, tok := range fields[start:] {
|
||||||
|
if r := canonicalFQDN(tok); r != "" {
|
||||||
rules = append(rules, r)
|
rules = append(rules, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if err := sc.Err(); err != nil {
|
if err := sc.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -68,17 +92,19 @@ func loadBlacklistFile(path string) ([]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 自动重载黑名单
|
// 自动重载黑名单
|
||||||
func startBlacklistReloader(ctx context.Context, path string, interval time.Duration, current **suffixMatcher) {
|
func startBlacklistReloader(ctx context.Context, path string, interval time.Duration, holder *atomic.Pointer[suffixMatcher]) {
|
||||||
if path == "" {
|
if path == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
var lastMod time.Time
|
var lastMod time.Time
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-time.After(interval):
|
case <-ticker.C:
|
||||||
fi, err := os.Stat(path)
|
fi, err := os.Stat(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("[blacklist] reload check failed: %v", err)
|
log.Printf("[blacklist] reload check failed: %v", err)
|
||||||
@@ -91,7 +117,7 @@ func startBlacklistReloader(ctx context.Context, path string, interval time.Dura
|
|||||||
log.Printf("[blacklist] reload failed: %v", err)
|
log.Printf("[blacklist] reload failed: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
*current = newSuffixMatcher(rules)
|
holder.Store(newSuffixMatcher(rules))
|
||||||
lastMod = modTime
|
lastMod = modTime
|
||||||
log.Printf("[blacklist] reloaded %d rules (modified %s)", len(rules), modTime.Format(time.RFC3339))
|
log.Printf("[blacklist] reloaded %d rules (modified %s)", len(rules), modTime.Format(time.RFC3339))
|
||||||
}
|
}
|
||||||
@@ -155,9 +181,8 @@ func makeBlockedUpstream(rcode int, rule string) *dns.Msg {
|
|||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
func initBlacklist(ctx context.Context, listStr, filePath, rcodeStr string) (*suffixMatcher, int) {
|
func initBlacklist(ctx context.Context, listStr, filePath, rcodeStr string) (*atomic.Pointer[suffixMatcher], int) {
|
||||||
var rules []string
|
var rules []string
|
||||||
|
|
||||||
if v := strings.TrimSpace(listStr); v != "" {
|
if v := strings.TrimSpace(listStr); v != "" {
|
||||||
for _, s := range strings.Split(v, ",") {
|
for _, s := range strings.Split(v, ",") {
|
||||||
if r := canonicalFQDN(s); r != "" {
|
if r := canonicalFQDN(s); r != "" {
|
||||||
@@ -165,24 +190,19 @@ func initBlacklist(ctx context.Context, listStr, filePath, rcodeStr string) (*su
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if file := strings.TrimSpace(filePath); file != "" {
|
|
||||||
fileRules, err := loadBlacklistFile(file)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("[fatal] failed to load blacklist-file %q: %v", file, err)
|
|
||||||
}
|
|
||||||
rules = append(rules, fileRules...)
|
|
||||||
}
|
|
||||||
|
|
||||||
bl := newSuffixMatcher(rules)
|
|
||||||
blRcode := parseRcode(rcodeStr)
|
|
||||||
|
|
||||||
if filePath != "" {
|
if filePath != "" {
|
||||||
startBlacklistReloader(ctx, filePath, 30*time.Second, &bl)
|
if fs, err := loadBlacklistFile(filePath); err != nil {
|
||||||
|
log.Printf("[blacklist] load file error: %v", err)
|
||||||
|
} else {
|
||||||
|
rules = append(rules, fs...)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
log.Printf("[blacklist] loaded %d rules (file=%v, rcode=%s)",
|
var holder atomic.Pointer[suffixMatcher]
|
||||||
len(bl.rules), filePath != "", strings.ToUpper(rcodeStr))
|
holder.Store(newSuffixMatcher(rules))
|
||||||
|
blRcode := parseRcode(rcodeStr)
|
||||||
return bl, blRcode
|
if filePath != "" {
|
||||||
|
startBlacklistReloader(ctx, filePath, 30*time.Second, &holder)
|
||||||
|
}
|
||||||
|
log.Printf("[blacklist] loaded %d rules (file=%v, rcode=%s)", len(holder.Load().rules), filePath != "", strings.ToUpper(rcodeStr))
|
||||||
|
return &holder, blRcode
|
||||||
}
|
}
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -5,13 +5,13 @@ go 1.25.3
|
|||||||
require (
|
require (
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7
|
github.com/hashicorp/golang-lru/v2 v2.0.7
|
||||||
github.com/miekg/dns v1.1.68
|
github.com/miekg/dns v1.1.68
|
||||||
|
golang.org/x/sync v0.17.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/google/go-cmp v0.7.0 // indirect
|
github.com/google/go-cmp v0.7.0 // indirect
|
||||||
golang.org/x/mod v0.29.0 // indirect
|
golang.org/x/mod v0.29.0 // indirect
|
||||||
golang.org/x/net v0.46.0 // indirect
|
golang.org/x/net v0.46.0 // indirect
|
||||||
golang.org/x/sync v0.17.0 // indirect
|
|
||||||
golang.org/x/sys v0.37.0 // indirect
|
golang.org/x/sys v0.37.0 // indirect
|
||||||
golang.org/x/tools v0.38.0 // indirect
|
golang.org/x/tools v0.38.0 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
11
main.go
11
main.go
@@ -13,6 +13,7 @@ import (
|
|||||||
"os/signal"
|
"os/signal"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -564,7 +565,7 @@ func handleDNS(
|
|||||||
maxParallel int,
|
maxParallel int,
|
||||||
stripECSBeforeForward bool,
|
stripECSBeforeForward bool,
|
||||||
allowTCPFallback bool,
|
allowTCPFallback bool,
|
||||||
bl *suffixMatcher,
|
blPtr *atomic.Pointer[suffixMatcher],
|
||||||
blRcode int,
|
blRcode int,
|
||||||
) dns.HandlerFunc {
|
) dns.HandlerFunc {
|
||||||
return func(w dns.ResponseWriter, r *dns.Msg) {
|
return func(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
@@ -581,7 +582,7 @@ func handleDNS(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 黑名单拦截:命中则不查上游,直接返回
|
// 黑名单拦截:命中则不查上游,直接返回
|
||||||
if rule, ok := bl.match(q.Name); ok {
|
if rule, ok := blPtr.Load().match(q.Name); ok {
|
||||||
nameCanon := dns.CanonicalName(q.Name)
|
nameCanon := dns.CanonicalName(q.Name)
|
||||||
log.Printf("[blacklist] HIT %s rule=%s (no upstream query)", nameCanon, rule)
|
log.Printf("[blacklist] HIT %s rule=%s (no upstream query)", nameCanon, rule)
|
||||||
up := makeBlockedUpstream(blRcode, rule)
|
up := makeBlockedUpstream(blRcode, rule)
|
||||||
@@ -697,7 +698,7 @@ func main() {
|
|||||||
startCacheCleaner(ctx)
|
startCacheCleaner(ctx)
|
||||||
|
|
||||||
// 加载黑名单规则
|
// 加载黑名单规则
|
||||||
bl, blRcode := initBlacklist(ctx, *blacklistStr, *blacklistFile, *blacklistRcodeFlag)
|
blPtr, blRcode := initBlacklist(ctx, *blacklistStr, *blacklistFile, *blacklistRcodeFlag)
|
||||||
|
|
||||||
mux := dns.NewServeMux()
|
mux := dns.NewServeMux()
|
||||||
mux.HandleFunc(".", handleDNS(
|
mux.HandleFunc(".", handleDNS(
|
||||||
@@ -707,7 +708,7 @@ func main() {
|
|||||||
*maxParallel,
|
*maxParallel,
|
||||||
*stripECSFlag,
|
*stripECSFlag,
|
||||||
*allowTCPFallback,
|
*allowTCPFallback,
|
||||||
bl,
|
blPtr,
|
||||||
blRcode,
|
blRcode,
|
||||||
))
|
))
|
||||||
|
|
||||||
@@ -733,7 +734,7 @@ func main() {
|
|||||||
log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v | read_timeout=%s | write_timeout=%s | blacklist_rules=%d | blacklist_rcode=%s",
|
log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v | read_timeout=%s | write_timeout=%s | blacklist_rules=%d | blacklist_rcode=%s",
|
||||||
upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback,
|
upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback,
|
||||||
readTimeoutFlag.String(), writeTimeoutFlag.String(),
|
readTimeoutFlag.String(), writeTimeoutFlag.String(),
|
||||||
len(bl.rules), strings.ToUpper(*blacklistRcodeFlag))
|
len(blPtr.Load().rules), strings.ToUpper(*blacklistRcodeFlag))
|
||||||
errCh <- srv.ListenAndServe()
|
errCh <- srv.ListenAndServe()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user