package main import ( "bufio" "context" "log" "net" "os" "sort" "strings" "sync/atomic" "time" "github.com/miekg/dns" ) // -------- Blacklist helpers (独立文件) -------- // 将输入域名规范化为 canonical FQDN(去空格、去 *. 前缀、统一大小写/尾点) func canonicalFQDN(s string) string { s = strings.TrimSpace(s) if s == "" { return "" } s = strings.TrimPrefix(s, "*.") return dns.CanonicalName(s) } func uniqueStrings(in []string) []string { seen := make(map[string]struct{}, len(in)) out := make([]string, 0, len(in)) for _, s := range in { if s == "" { continue } if _, ok := seen[s]; ok { continue } seen[s] = struct{}{} out = append(out, s) } return out } // 支持 # / ; 注释;每行一个域名;支持以 "*.example.com" 书写 func loadBlacklistFile(path string) ([]string, error) { f, err := os.Open(path) if err != nil { return nil, err } defer f.Close() sc := bufio.NewScanner(f) var rules []string for sc.Scan() { line := strings.TrimSpace(sc.Text()) if line == "" { continue } // 行首注释 if strings.HasPrefix(line, "//") || strings.HasPrefix(line, "#") || strings.HasPrefix(line, ";") { continue } // 行内注释:先 // 再 # ; if i := strings.Index(line, "//"); i >= 0 { line = strings.TrimSpace(line[:i]) } if i := strings.IndexAny(line, "#;"); i >= 0 { line = strings.TrimSpace(line[:i]) } if line == "" { continue } // hosts 风格:第一个字段是 IP,则其余每个字段视为域名 fields := strings.Fields(line) if len(fields) == 0 { continue } start := 0 if net.ParseIP(fields[0]) != nil { start = 1 } for _, tok := range fields[start:] { if r := canonicalFQDN(tok); r != "" { rules = append(rules, r) } } } if err := sc.Err(); err != nil { return nil, err } return uniqueStrings(rules), nil } // 自动重载黑名单 func startBlacklistReloader(ctx context.Context, path string, interval time.Duration, holder *atomic.Pointer[suffixMatcher]) { if path == "" { return } go func() { var lastMod time.Time ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: fi, err := os.Stat(path) if err != nil { log.Printf("[blacklist] reload check failed: %v", err) continue } modTime := fi.ModTime() if modTime.After(lastMod) { rules, err := loadBlacklistFile(path) if err != nil { log.Printf("[blacklist] reload failed: %v", err) continue } holder.Store(newSuffixMatcher(rules)) lastMod = modTime log.Printf("[blacklist] reloaded %d rules (modified %s)", len(rules), modTime.Format(time.RFC3339)) } } } }() } // 后缀匹配器:rules 已 canonical,按长度降序排列(更精确的规则优先) type suffixMatcher struct { rules []string } func newSuffixMatcher(rules []string) *suffixMatcher { rs := uniqueStrings(rules) sort.Slice(rs, func(i, j int) bool { return len(rs[i]) > len(rs[j]) }) return &suffixMatcher{rules: rs} } // 命中则返回匹配的规则 func (m *suffixMatcher) match(name string) (string, bool) { if m == nil || len(m.rules) == 0 { return "", false } name = dns.CanonicalName(name) for _, r := range m.rules { if name == r || strings.HasSuffix(name, "."+r) { return r, true } } return "", false } // 解析 RCODE 文本到常量 func parseRcode(s string) int { switch strings.ToUpper(strings.TrimSpace(s)) { case "REFUSED", "": return dns.RcodeRefused case "NXDOMAIN": return dns.RcodeNameError case "SERVFAIL": return dns.RcodeServerFailure default: log.Printf("[blacklist] unknown rcode %q, fallback REFUSED", s) return dns.RcodeRefused } } // 构造“伪上游”阻断响应;带 EDE=Blocked(15) 说明命中规则 func makeBlockedUpstream(rcode int, rule string) *dns.Msg { m := new(dns.Msg) m.RecursionAvailable = true m.AuthenticatedData = false m.Rcode = rcode opt := &dns.OPT{} opt.Hdr.Name = "." opt.Hdr.Rrtype = dns.TypeOPT ede := &dns.EDNS0_EDE{InfoCode: 15, ExtraText: "blocked by policy: " + rule} opt.Option = append(opt.Option, ede) m.Extra = append(m.Extra, opt) return m } func initBlacklist(ctx context.Context, listStr, filePath, rcodeStr string) (*atomic.Pointer[suffixMatcher], int) { var rules []string if v := strings.TrimSpace(listStr); v != "" { for _, s := range strings.Split(v, ",") { if r := canonicalFQDN(s); r != "" { rules = append(rules, r) } } } if filePath != "" { if fs, err := loadBlacklistFile(filePath); err != nil { log.Printf("[blacklist] load file error: %v", err) } else { rules = append(rules, fs...) } } var holder atomic.Pointer[suffixMatcher] holder.Store(newSuffixMatcher(rules)) blRcode := parseRcode(rcodeStr) if filePath != "" { startBlacklistReloader(ctx, filePath, 30*time.Second, &holder) } log.Printf("[blacklist] loaded %d rules (file=%v, rcode=%s)", len(holder.Load().rules), filePath != "", strings.ToUpper(rcodeStr)) return &holder, blRcode }