package main import ( "bufio" "context" "log" "net" "os" "strings" "sync/atomic" "time" "github.com/miekg/dns" ) type BlacklistTrie struct { root *trieNode } type trieNode struct { children map[string]*trieNode isEnd bool rule string } func newTrie() *BlacklistTrie { return &BlacklistTrie{root: &trieNode{children: make(map[string]*trieNode)}} } func (t *BlacklistTrie) Add(domain string) { name := strings.TrimSuffix(strings.ToLower(domain), ".") parts := strings.Split(name, ".") curr := t.root for i := len(parts) - 1; i >= 0; i-- { p := parts[i] if _, ok := curr.children[p]; !ok { curr.children[p] = &trieNode{children: make(map[string]*trieNode)} } curr = curr.children[p] } curr.isEnd = true curr.rule = name } func (t *BlacklistTrie) Match(domain string) (string, bool) { name := strings.TrimSuffix(strings.ToLower(domain), ".") parts := strings.Split(name, ".") curr := t.root for i := len(parts) - 1; i >= 0; i-- { p := parts[i] next, ok := curr.children[p] if !ok { return "", false } curr = next if curr.isEnd { return curr.rule, true } } return "", false } func initBlacklist(ctx context.Context, path, rcodeStr string) (*atomic.Pointer[BlacklistTrie], int) { var holder atomic.Pointer[BlacklistTrie] var lastMod time.Time rcode := parseRcode(rcodeStr) load := func() { if path == "" { return } info, err := os.Stat(path) if err != nil { log.Printf("[blacklist] Stat error: %v", err) return } // 如果文件没动,不加载 if !info.ModTime().After(lastMod) { return } f, err := os.Open(path) if err != nil { return } defer f.Close() trie := newTrie() scanner := bufio.NewScanner(f) count := 0 for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if line == "" || line[0] == '#' || line[0] == ';' { continue } fields := strings.Fields(line) domain := fields[0] if net.ParseIP(domain) != nil && len(fields) > 1 { domain = fields[1] } trie.Add(domain) count++ } holder.Store(trie) lastMod = info.ModTime() log.Printf("[blacklist] Loaded %d rules (updated: %v)", count, lastMod.Format(time.RFC3339)) } load() if path != "" { go func() { // 缩短检查间隔,但因为有文件时间校验,所以不费 CPU ticker := time.NewTicker(5 * time.Minute) for { select { case <-ticker.C: load() case <-ctx.Done(): return } } }() } return &holder, rcode } func parseRcode(s string) int { switch strings.ToUpper(s) { case "NXDOMAIN": return dns.RcodeNameError case "SERVFAIL": return dns.RcodeServerFailure default: return dns.RcodeRefused } } func makeBlockedMsg(rcode int, rule string) *dns.Msg { m := new(dns.Msg) m.Rcode = rcode m.RecursionAvailable = true o := new(dns.OPT) o.Hdr.Name = "." o.Hdr.Rrtype = dns.TypeOPT o.Option = append(o.Option, &dns.EDNS0_EDE{InfoCode: 15, ExtraText: "Blocked by policy"}) m.Extra = append(m.Extra, o) return m }