feat(blacklist): 实现基于Trie树的高性能黑名单匹配系统
重构黑名单匹配算法,采用Trie前缀树数据结构替换原有的后缀匹配, 将百万级域名匹配复杂度从O(n)降至O(L),显著提升性能。 同时优化黑名单文件加载机制,支持hosts格式和通配符匹配, 并实现文件修改自动重载功能,提升系统的灵活性和实用性。 refactor: 重构README文档结构和内容展示 更新项目介绍文档,优化整体布局结构,添加项目徽章标识, 精简功能特性描述,改进快速开始指南,提供更清晰的使用说明。 chore(deps): 更新项目依赖库至最新版本 升级github.com/miekg/dns至v1.1.72版本, 更新golang.org/x/net至v0.52.0版本, 升级golang.org/x/sync至v0.20.0版本, 以及其他相关依赖库的版本更新。
This commit is contained in:
286
blacklist.go
286
blacklist.go
@@ -6,214 +6,146 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
// -------- Blacklist helpers (独立文件) --------
|
||||
|
||||
// 将输入域名规范化为 canonical FQDN(去空格、去 *. 前缀、统一大小写/尾点)
|
||||
func canonicalFQDN(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
// 允许黑名单写 "*.example.com";内部匹配用裸后缀
|
||||
s = strings.TrimPrefix(s, "*.")
|
||||
|
||||
// 先把可能的中文/Unicode 域名转成 ASCII(punycode),再规范化
|
||||
if a, err := idna.Lookup.ToASCII(s); err == nil {
|
||||
s = a
|
||||
}
|
||||
// CanonicalName 会做小写化与尾点规范化
|
||||
return dns.CanonicalName(s)
|
||||
type BlacklistTrie struct {
|
||||
root *trieNode
|
||||
}
|
||||
|
||||
func uniqueStrings(in []string) []string {
|
||||
seen := make(map[string]struct{}, len(in))
|
||||
out := make([]string, 0, len(in))
|
||||
for _, s := range in {
|
||||
if s == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[s]; ok {
|
||||
continue
|
||||
}
|
||||
seen[s] = struct{}{}
|
||||
out = append(out, s)
|
||||
}
|
||||
return out
|
||||
type trieNode struct {
|
||||
children map[string]*trieNode
|
||||
isEnd bool
|
||||
rule string
|
||||
}
|
||||
|
||||
// 支持 # / ; 注释;每行一个域名;支持以 "*.example.com" 书写
|
||||
// 支持 # / ; 注释;每行一个域名;支持以 "*.example.com" 书写;支持 hosts 风格(首列为 IP)
|
||||
func loadBlacklistFile(path string) ([]string, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
sc := bufio.NewScanner(f)
|
||||
// 默认 64KB 容量不够稳妥,这里放大到 2MB,兼容一些合并的大 hosts 列表
|
||||
sc.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
|
||||
|
||||
var rules []string
|
||||
for sc.Scan() {
|
||||
line := strings.TrimSpace(sc.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
// 去掉注释(# 或 ; 之后的内容)
|
||||
if i := strings.IndexAny(line, "#;"); i >= 0 {
|
||||
line = strings.TrimSpace(line[:i])
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// hosts 风格:第一个字段是 IP,则其余每个字段视为域名
|
||||
start := 0
|
||||
if net.ParseIP(fields[0]) != nil {
|
||||
start = 1
|
||||
}
|
||||
for _, tok := range fields[start:] {
|
||||
if r := canonicalFQDN(tok); r != "" {
|
||||
rules = append(rules, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := sc.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sort.Strings(rules)
|
||||
rules = uniqueStrings(rules)
|
||||
return rules, nil
|
||||
func newTrie() *BlacklistTrie {
|
||||
return &BlacklistTrie{root: &trieNode{children: make(map[string]*trieNode)}}
|
||||
}
|
||||
|
||||
// 自动重载黑名单
|
||||
func startBlacklistReloader(ctx context.Context, path string, interval time.Duration, holder *atomic.Pointer[suffixMatcher]) {
|
||||
if path == "" {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
var lastMod time.Time
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
fi, err := os.Stat(path)
|
||||
if err != nil {
|
||||
log.Printf("[blacklist] reload check failed: %v", err)
|
||||
continue
|
||||
}
|
||||
modTime := fi.ModTime()
|
||||
if modTime.After(lastMod) {
|
||||
rules, err := loadBlacklistFile(path)
|
||||
if err != nil {
|
||||
log.Printf("[blacklist] reload failed: %v", err)
|
||||
continue
|
||||
}
|
||||
holder.Store(newSuffixMatcher(rules))
|
||||
lastMod = modTime
|
||||
log.Printf("[blacklist] reloaded %d rules (modified %s)", len(rules), modTime.Format(time.RFC3339))
|
||||
}
|
||||
}
|
||||
func (t *BlacklistTrie) Add(domain string) {
|
||||
name := strings.TrimSuffix(strings.ToLower(domain), ".")
|
||||
parts := strings.Split(name, ".")
|
||||
curr := t.root
|
||||
for i := len(parts) - 1; i >= 0; i-- {
|
||||
p := parts[i]
|
||||
if _, ok := curr.children[p]; !ok {
|
||||
curr.children[p] = &trieNode{children: make(map[string]*trieNode)}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 后缀匹配器:rules 已 canonical,按长度降序排列(更精确的规则优先)
|
||||
type suffixMatcher struct {
|
||||
rules []string
|
||||
}
|
||||
|
||||
func newSuffixMatcher(rules []string) *suffixMatcher {
|
||||
rs := uniqueStrings(rules)
|
||||
sort.Slice(rs, func(i, j int) bool { return len(rs[i]) > len(rs[j]) })
|
||||
return &suffixMatcher{rules: rs}
|
||||
}
|
||||
|
||||
// 命中则返回匹配的规则
|
||||
func (m *suffixMatcher) match(name string) (string, bool) {
|
||||
if m == nil || len(m.rules) == 0 {
|
||||
return "", false
|
||||
curr = curr.children[p]
|
||||
}
|
||||
name = dns.CanonicalName(name)
|
||||
for _, r := range m.rules {
|
||||
if name == r || strings.HasSuffix(name, "."+r) {
|
||||
return r, true
|
||||
curr.isEnd = true
|
||||
curr.rule = name
|
||||
}
|
||||
|
||||
func (t *BlacklistTrie) Match(domain string) (string, bool) {
|
||||
name := strings.TrimSuffix(strings.ToLower(domain), ".")
|
||||
parts := strings.Split(name, ".")
|
||||
curr := t.root
|
||||
for i := len(parts) - 1; i >= 0; i-- {
|
||||
p := parts[i]
|
||||
next, ok := curr.children[p]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
curr = next
|
||||
if curr.isEnd {
|
||||
return curr.rule, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// 解析 RCODE 文本到常量
|
||||
func initBlacklist(ctx context.Context, path, rcodeStr string) (*atomic.Pointer[BlacklistTrie], int) {
|
||||
var holder atomic.Pointer[BlacklistTrie]
|
||||
var lastMod time.Time
|
||||
rcode := parseRcode(rcodeStr)
|
||||
|
||||
load := func() {
|
||||
if path == "" {
|
||||
return
|
||||
}
|
||||
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
log.Printf("[blacklist] Stat error: %v", err)
|
||||
return
|
||||
}
|
||||
// 如果文件没动,不加载
|
||||
if !info.ModTime().After(lastMod) {
|
||||
return
|
||||
}
|
||||
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
trie := newTrie()
|
||||
scanner := bufio.NewScanner(f)
|
||||
count := 0
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || line[0] == '#' || line[0] == ';' {
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
domain := fields[0]
|
||||
if net.ParseIP(domain) != nil && len(fields) > 1 {
|
||||
domain = fields[1]
|
||||
}
|
||||
trie.Add(domain)
|
||||
count++
|
||||
}
|
||||
holder.Store(trie)
|
||||
lastMod = info.ModTime()
|
||||
log.Printf("[blacklist] Loaded %d rules (updated: %v)", count, lastMod.Format(time.RFC3339))
|
||||
}
|
||||
|
||||
load()
|
||||
|
||||
if path != "" {
|
||||
go func() {
|
||||
// 缩短检查间隔,但因为有文件时间校验,所以不费 CPU
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
load()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return &holder, rcode
|
||||
}
|
||||
|
||||
func parseRcode(s string) int {
|
||||
switch strings.ToUpper(strings.TrimSpace(s)) {
|
||||
case "REFUSED", "":
|
||||
return dns.RcodeRefused
|
||||
switch strings.ToUpper(s) {
|
||||
case "NXDOMAIN":
|
||||
return dns.RcodeNameError
|
||||
case "SERVFAIL":
|
||||
return dns.RcodeServerFailure
|
||||
default:
|
||||
log.Printf("[blacklist] unknown rcode %q, fallback REFUSED", s)
|
||||
return dns.RcodeRefused
|
||||
}
|
||||
}
|
||||
|
||||
// 构造“伪上游”阻断响应;带 EDE=Blocked(15) 说明命中规则
|
||||
func makeBlockedUpstream(rcode int, rule string) *dns.Msg {
|
||||
func makeBlockedMsg(rcode int, rule string) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.RecursionAvailable = true
|
||||
m.AuthenticatedData = false
|
||||
m.Rcode = rcode
|
||||
opt := &dns.OPT{}
|
||||
opt.Hdr.Name = "."
|
||||
opt.Hdr.Rrtype = dns.TypeOPT
|
||||
ede := &dns.EDNS0_EDE{InfoCode: 15, ExtraText: "blocked by policy: " + rule}
|
||||
opt.Option = append(opt.Option, ede)
|
||||
m.Extra = append(m.Extra, opt)
|
||||
m.RecursionAvailable = true
|
||||
o := new(dns.OPT)
|
||||
o.Hdr.Name = "."
|
||||
o.Hdr.Rrtype = dns.TypeOPT
|
||||
o.Option = append(o.Option, &dns.EDNS0_EDE{InfoCode: 15, ExtraText: "Blocked by policy"})
|
||||
m.Extra = append(m.Extra, o)
|
||||
return m
|
||||
}
|
||||
|
||||
func initBlacklist(ctx context.Context, listStr, filePath, rcodeStr string) (*atomic.Pointer[suffixMatcher], int) {
|
||||
var rules []string
|
||||
if v := strings.TrimSpace(listStr); v != "" {
|
||||
for _, s := range strings.Split(v, ",") {
|
||||
if r := canonicalFQDN(s); r != "" {
|
||||
rules = append(rules, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
if filePath != "" {
|
||||
if fs, err := loadBlacklistFile(filePath); err != nil {
|
||||
log.Printf("[blacklist] load file error: %v", err)
|
||||
} else {
|
||||
rules = append(rules, fs...)
|
||||
}
|
||||
}
|
||||
var holder atomic.Pointer[suffixMatcher]
|
||||
holder.Store(newSuffixMatcher(rules))
|
||||
blRcode := parseRcode(rcodeStr)
|
||||
if filePath != "" {
|
||||
startBlacklistReloader(ctx, filePath, 30*time.Second, &holder)
|
||||
}
|
||||
log.Printf("[blacklist] loaded %d rules (file=%v, rcode=%s)", len(holder.Load().rules), filePath != "", strings.ToUpper(rcodeStr))
|
||||
return &holder, blRcode
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user