Files
dot/main.go
aixiao f5bf77927d feat(blacklist): 实现基于Trie树的高性能黑名单匹配系统
重构黑名单匹配算法,采用Trie前缀树数据结构替换原有的后缀匹配,
将百万级域名匹配复杂度从O(n)降至O(L),显著提升性能。

同时优化黑名单文件加载机制,支持hosts格式和通配符匹配,
并实现文件修改自动重载功能,提升系统的灵活性和实用性。

refactor: 重构README文档结构和内容展示

更新项目介绍文档,优化整体布局结构,添加项目徽章标识,
精简功能特性描述,改进快速开始指南,提供更清晰的使用说明。

chore(deps): 更新项目依赖库至最新版本

升级github.com/miekg/dns至v1.1.72版本,
更新golang.org/x/net至v0.52.0版本,
升级golang.org/x/sync至v0.20.0版本,
以及其他相关依赖库的版本更新。
2026-03-20 14:39:29 +08:00

370 lines
8.4 KiB
Go

package main
import (
"context"
"crypto/tls"
"flag"
"log"
"os"
"os/signal"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/miekg/dns"
"golang.org/x/sync/singleflight"
)
var (
BuildDate = "unknown"
cache *lru.Cache[string, *cacheEntry]
inflight singleflight.Group
udpClient *dns.Client
tcpClient *dns.Client
verbose bool // 全局日志开关
)
const (
defaultCacheSize = 20000
maxUDPSize = 1232
)
type cacheEntry struct {
msg *dns.Msg
expireAt time.Time
ednsPresent bool
ednsVersion uint8
ednsExtRcode uint16
ednsEDE []*dns.EDNS0_EDE
}
func initLogger(v bool) {
verbose = v
flags := log.Ldate | log.Ltime | log.Lmicroseconds
if verbose {
flags |= log.Lshortfile
}
log.SetFlags(flags)
}
func cacheKey(q dns.Question, r *dns.Msg, ecs string) string {
do, cd := "0", "0"
if o := r.IsEdns0(); o != nil && o.Do() {
do = "1"
}
if r.CheckingDisabled {
cd = "1"
}
var b strings.Builder
b.Grow(len(q.Name) + 32)
b.WriteString(dns.TypeToString[q.Qtype])
b.WriteByte('|')
b.WriteString(do)
b.WriteString(cd)
b.WriteByte('|')
b.WriteString(ecs)
b.WriteByte('|')
b.WriteString(strings.ToLower(q.Name))
return b.String()
}
func queryUpstreams(ctx context.Context, req *dns.Msg, upstreams []string, timeout time.Duration, parallel int) *dns.Msg {
cctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
resCh := make(chan *dns.Msg, len(upstreams))
var wg sync.WaitGroup
sem := make(chan struct{}, parallel)
for _, svr := range upstreams {
wg.Add(1)
go func(s string) {
defer wg.Done()
select {
case sem <- struct{}{}:
defer func() { <-sem }()
case <-cctx.Done():
return
}
uReq := req.Copy()
if o := uReq.IsEdns0(); o == nil {
uReq.SetEdns0(maxUDPSize, false)
} else {
o.SetUDPSize(maxUDPSize)
}
resp, _, err := udpClient.ExchangeContext(cctx, uReq, s)
if err == nil && resp != nil && resp.Truncated {
resp, _, err = tcpClient.ExchangeContext(cctx, req, s)
}
if err == nil && resp != nil {
resCh <- resp
}
}(svr)
}
go func() {
wg.Wait()
close(resCh)
}()
for r := range resCh {
if r.Rcode != dns.RcodeServerFailure && r.Rcode != dns.RcodeRefused {
return r
}
}
return nil
}
func handleDNS(upstreams []string, maxTTL, timeout time.Duration, parallel int, stripECS bool, blPtr *atomic.Pointer[BlacklistTrie], blRcode int) dns.HandlerFunc {
return func(w dns.ResponseWriter, r *dns.Msg) {
defer func() {
if err := recover(); err != nil {
log.Printf("[PANIC] %v", err)
dns.HandleFailed(w, r)
}
}()
if len(r.Question) == 0 {
dns.HandleFailed(w, r)
return
}
q := r.Question[0]
startTime := time.Now()
if trie := blPtr.Load(); trie != nil {
if rule, hit := trie.Match(q.Name); hit {
log.Printf("[BLOCK] %s rule=%s client=%s", q.Name, rule, w.RemoteAddr())
writeReply(w, r, makeBlockedMsg(blRcode, rule), nil)
return
}
}
ecs := ""
if !stripECS {
if o := r.IsEdns0(); o != nil {
for _, opt := range o.Option {
if s, ok := opt.(*dns.EDNS0_SUBNET); ok {
ecs = s.Address.String()
}
}
}
} else {
stripECSFromMsg(r)
}
key := cacheKey(q, r, ecs)
if msg, meta, ok := tryCacheRead(key); ok {
if verbose {
log.Printf("[CACHE] HIT %s", q.Name)
}
writeReply(w, r, msg, meta)
return
}
v, _, _ := inflight.Do(key, func() (any, error) {
resp := queryUpstreams(context.Background(), r, upstreams, timeout, parallel)
if resp != nil {
cacheWrite(key, resp, maxTTL)
}
return resp, nil
})
if resp, _ := v.(*dns.Msg); resp != nil {
writeReply(w, r, resp, nil)
if verbose {
log.Printf("[QUERY] %s %s -> %s (%v)", dns.TypeToString[q.Qtype], q.Name, dns.RcodeToString[resp.Rcode], time.Since(startTime))
}
} else {
dns.HandleFailed(w, r)
}
}
}
func tryCacheRead(key string) (*dns.Msg, *cacheEntry, bool) {
e, ok := cache.Get(key)
if !ok || time.Now().After(e.expireAt) {
return nil, nil, false
}
out := e.msg.Copy()
ttl := uint32(time.Until(e.expireAt).Seconds())
for _, sec := range [][]dns.RR{out.Answer, out.Ns, out.Extra} {
for _, rr := range sec {
if rr.Header().Rrtype != dns.TypeOPT {
rr.Header().Ttl = ttl
}
}
}
return out, e, true
}
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError {
return
}
var ttl uint32 = uint32(maxTTL.Seconds())
found := false
for _, rr := range in.Answer {
if rr.Header().Rrtype != dns.TypeOPT && rr.Header().Ttl < ttl {
ttl = rr.Header().Ttl
found = true
}
}
if !found {
for _, rr := range in.Ns {
if soa, ok := rr.(*dns.SOA); ok {
if soa.Minttl < ttl {
ttl = soa.Minttl
}
found = true
}
}
}
if ttl < 10 {
return
}
cp := in.Copy()
present, ver, ext, ede := extractEDNSMeta(cp)
stripPseudoExtras(cp)
cache.Add(key, &cacheEntry{
msg: cp,
expireAt: time.Now().Add(time.Duration(ttl) * time.Second),
ednsPresent: present,
ednsVersion: ver,
ednsExtRcode: ext,
ednsEDE: ede,
})
}
func writeReply(w dns.ResponseWriter, req, resp *dns.Msg, meta *cacheEntry) {
out := new(dns.Msg)
out.SetReply(req)
out.Rcode = resp.Rcode
out.Answer = resp.Answer
out.Ns = resp.Ns
out.Extra = resp.Extra
if ro := req.IsEdns0(); ro != nil {
o := new(dns.OPT)
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
o.SetUDPSize(ro.UDPSize())
if uo := resp.IsEdns0(); uo != nil {
o.Option = uo.Option
o.SetExtendedRcode(uint16(uo.ExtendedRcode()))
} else if meta != nil && meta.ednsPresent {
o.SetExtendedRcode(meta.ednsExtRcode)
for _, e := range meta.ednsEDE {
o.Option = append(o.Option, e)
}
}
out.Extra = append(out.Extra, o)
}
out.Compress = true
_ = w.WriteMsg(out)
}
func stripECSFromMsg(m *dns.Msg) {
if o := m.IsEdns0(); o != nil {
newOpt := make([]dns.EDNS0, 0, len(o.Option))
for _, opt := range o.Option {
if opt.Option() != dns.EDNS0SUBNET {
newOpt = append(newOpt, opt)
}
}
o.Option = newOpt
}
}
func stripPseudoExtras(m *dns.Msg) {
newExtra := make([]dns.RR, 0, len(m.Extra))
for _, rr := range m.Extra {
if rr.Header().Rrtype != dns.TypeOPT && rr.Header().Rrtype != dns.TypeTSIG {
newExtra = append(newExtra, rr)
}
}
m.Extra = newExtra
}
func extractEDNSMeta(m *dns.Msg) (bool, uint8, uint16, []*dns.EDNS0_EDE) {
if o := m.IsEdns0(); o != nil {
var edes []*dns.EDNS0_EDE
for _, opt := range o.Option {
if e, ok := opt.(*dns.EDNS0_EDE); ok {
edes = append(edes, e)
}
}
return true, o.Version(), uint16(o.ExtendedRcode()), edes
}
return false, 0, 0, nil
}
func main() {
addr := flag.String("addr", ":853", "DoT address")
upStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "Upstreams")
certFile := flag.String("cert", "server.crt", "TLS Cert")
keyFile := flag.String("key", "server.key", "TLS Key")
blFile := flag.String("blacklist-file", "", "Blacklist file")
blRcodeStr := flag.String("blacklist-rcode", "REFUSED", "RCODE for blocked")
v := flag.Bool("v", false, "Verbose logging")
flag.Parse()
initLogger(*v)
cache, _ = lru.New[string, *cacheEntry](defaultCacheSize)
udpClient = &dns.Client{Net: "udp", Timeout: 2 * time.Second, SingleInflight: true}
tcpClient = &dns.Client{Net: "tcp", Timeout: 3 * time.Second, SingleInflight: true}
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
if err != nil {
log.Fatalf("TLS Error: %v", err)
}
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
blPtr, blRcode := initBlacklist(ctx, *blFile, *blRcodeStr)
mux := dns.NewServeMux()
mux.HandleFunc(".", handleDNS(strings.Split(*upStr, ","), 1*time.Hour, 2*time.Second, 3, true, blPtr, blRcode))
server := &dns.Server{
Addr: *addr,
Net: "tcp-tls",
Handler: mux,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS13,
NextProtos: []string{"dot"},
},
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
go func() {
log.Printf("🚀 DoT Server started on %s (TLS 1.3)", *addr)
if err := server.ListenAndServe(); err != nil {
log.Printf("Server exit: %v", err)
}
}()
<-ctx.Done()
log.Println("Gracefully shutting down...")
// 给 5 秒处理残余请求
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
server.ShutdownContext(shutdownCtx)
}