重构黑名单匹配算法,采用Trie前缀树数据结构替换原有的后缀匹配, 将百万级域名匹配复杂度从O(n)降至O(L),显著提升性能。 同时优化黑名单文件加载机制,支持hosts格式和通配符匹配, 并实现文件修改自动重载功能,提升系统的灵活性和实用性。 refactor: 重构README文档结构和内容展示 更新项目介绍文档,优化整体布局结构,添加项目徽章标识, 精简功能特性描述,改进快速开始指南,提供更清晰的使用说明。 chore(deps): 更新项目依赖库至最新版本 升级github.com/miekg/dns至v1.1.72版本, 更新golang.org/x/net至v0.52.0版本, 升级golang.org/x/sync至v0.20.0版本, 以及其他相关依赖库的版本更新。
370 lines
8.4 KiB
Go
370 lines
8.4 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"flag"
|
|
"log"
|
|
"os"
|
|
"os/signal"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"syscall"
|
|
"time"
|
|
|
|
lru "github.com/hashicorp/golang-lru/v2"
|
|
"github.com/miekg/dns"
|
|
"golang.org/x/sync/singleflight"
|
|
)
|
|
|
|
var (
|
|
BuildDate = "unknown"
|
|
cache *lru.Cache[string, *cacheEntry]
|
|
inflight singleflight.Group
|
|
udpClient *dns.Client
|
|
tcpClient *dns.Client
|
|
verbose bool // 全局日志开关
|
|
)
|
|
|
|
const (
|
|
defaultCacheSize = 20000
|
|
maxUDPSize = 1232
|
|
)
|
|
|
|
type cacheEntry struct {
|
|
msg *dns.Msg
|
|
expireAt time.Time
|
|
ednsPresent bool
|
|
ednsVersion uint8
|
|
ednsExtRcode uint16
|
|
ednsEDE []*dns.EDNS0_EDE
|
|
}
|
|
|
|
func initLogger(v bool) {
|
|
verbose = v
|
|
flags := log.Ldate | log.Ltime | log.Lmicroseconds
|
|
if verbose {
|
|
flags |= log.Lshortfile
|
|
}
|
|
log.SetFlags(flags)
|
|
}
|
|
|
|
func cacheKey(q dns.Question, r *dns.Msg, ecs string) string {
|
|
do, cd := "0", "0"
|
|
if o := r.IsEdns0(); o != nil && o.Do() {
|
|
do = "1"
|
|
}
|
|
if r.CheckingDisabled {
|
|
cd = "1"
|
|
}
|
|
|
|
var b strings.Builder
|
|
b.Grow(len(q.Name) + 32)
|
|
b.WriteString(dns.TypeToString[q.Qtype])
|
|
b.WriteByte('|')
|
|
b.WriteString(do)
|
|
b.WriteString(cd)
|
|
b.WriteByte('|')
|
|
b.WriteString(ecs)
|
|
b.WriteByte('|')
|
|
b.WriteString(strings.ToLower(q.Name))
|
|
return b.String()
|
|
}
|
|
|
|
func queryUpstreams(ctx context.Context, req *dns.Msg, upstreams []string, timeout time.Duration, parallel int) *dns.Msg {
|
|
cctx, cancel := context.WithTimeout(ctx, timeout)
|
|
defer cancel()
|
|
|
|
resCh := make(chan *dns.Msg, len(upstreams))
|
|
var wg sync.WaitGroup
|
|
sem := make(chan struct{}, parallel)
|
|
|
|
for _, svr := range upstreams {
|
|
wg.Add(1)
|
|
go func(s string) {
|
|
defer wg.Done()
|
|
select {
|
|
case sem <- struct{}{}:
|
|
defer func() { <-sem }()
|
|
case <-cctx.Done():
|
|
return
|
|
}
|
|
|
|
uReq := req.Copy()
|
|
if o := uReq.IsEdns0(); o == nil {
|
|
uReq.SetEdns0(maxUDPSize, false)
|
|
} else {
|
|
o.SetUDPSize(maxUDPSize)
|
|
}
|
|
|
|
resp, _, err := udpClient.ExchangeContext(cctx, uReq, s)
|
|
if err == nil && resp != nil && resp.Truncated {
|
|
resp, _, err = tcpClient.ExchangeContext(cctx, req, s)
|
|
}
|
|
|
|
if err == nil && resp != nil {
|
|
resCh <- resp
|
|
}
|
|
}(svr)
|
|
}
|
|
|
|
go func() {
|
|
wg.Wait()
|
|
close(resCh)
|
|
}()
|
|
|
|
for r := range resCh {
|
|
if r.Rcode != dns.RcodeServerFailure && r.Rcode != dns.RcodeRefused {
|
|
return r
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func handleDNS(upstreams []string, maxTTL, timeout time.Duration, parallel int, stripECS bool, blPtr *atomic.Pointer[BlacklistTrie], blRcode int) dns.HandlerFunc {
|
|
return func(w dns.ResponseWriter, r *dns.Msg) {
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
log.Printf("[PANIC] %v", err)
|
|
dns.HandleFailed(w, r)
|
|
}
|
|
}()
|
|
|
|
if len(r.Question) == 0 {
|
|
dns.HandleFailed(w, r)
|
|
return
|
|
}
|
|
|
|
q := r.Question[0]
|
|
startTime := time.Now()
|
|
|
|
if trie := blPtr.Load(); trie != nil {
|
|
if rule, hit := trie.Match(q.Name); hit {
|
|
log.Printf("[BLOCK] %s rule=%s client=%s", q.Name, rule, w.RemoteAddr())
|
|
writeReply(w, r, makeBlockedMsg(blRcode, rule), nil)
|
|
return
|
|
}
|
|
}
|
|
|
|
ecs := ""
|
|
if !stripECS {
|
|
if o := r.IsEdns0(); o != nil {
|
|
for _, opt := range o.Option {
|
|
if s, ok := opt.(*dns.EDNS0_SUBNET); ok {
|
|
ecs = s.Address.String()
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
stripECSFromMsg(r)
|
|
}
|
|
|
|
key := cacheKey(q, r, ecs)
|
|
|
|
if msg, meta, ok := tryCacheRead(key); ok {
|
|
if verbose {
|
|
log.Printf("[CACHE] HIT %s", q.Name)
|
|
}
|
|
writeReply(w, r, msg, meta)
|
|
return
|
|
}
|
|
|
|
v, _, _ := inflight.Do(key, func() (any, error) {
|
|
resp := queryUpstreams(context.Background(), r, upstreams, timeout, parallel)
|
|
if resp != nil {
|
|
cacheWrite(key, resp, maxTTL)
|
|
}
|
|
return resp, nil
|
|
})
|
|
|
|
if resp, _ := v.(*dns.Msg); resp != nil {
|
|
writeReply(w, r, resp, nil)
|
|
if verbose {
|
|
log.Printf("[QUERY] %s %s -> %s (%v)", dns.TypeToString[q.Qtype], q.Name, dns.RcodeToString[resp.Rcode], time.Since(startTime))
|
|
}
|
|
} else {
|
|
dns.HandleFailed(w, r)
|
|
}
|
|
}
|
|
}
|
|
|
|
func tryCacheRead(key string) (*dns.Msg, *cacheEntry, bool) {
|
|
e, ok := cache.Get(key)
|
|
if !ok || time.Now().After(e.expireAt) {
|
|
return nil, nil, false
|
|
}
|
|
out := e.msg.Copy()
|
|
ttl := uint32(time.Until(e.expireAt).Seconds())
|
|
for _, sec := range [][]dns.RR{out.Answer, out.Ns, out.Extra} {
|
|
for _, rr := range sec {
|
|
if rr.Header().Rrtype != dns.TypeOPT {
|
|
rr.Header().Ttl = ttl
|
|
}
|
|
}
|
|
}
|
|
return out, e, true
|
|
}
|
|
|
|
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
|
|
if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError {
|
|
return
|
|
}
|
|
|
|
var ttl uint32 = uint32(maxTTL.Seconds())
|
|
found := false
|
|
for _, rr := range in.Answer {
|
|
if rr.Header().Rrtype != dns.TypeOPT && rr.Header().Ttl < ttl {
|
|
ttl = rr.Header().Ttl
|
|
found = true
|
|
}
|
|
}
|
|
if !found {
|
|
for _, rr := range in.Ns {
|
|
if soa, ok := rr.(*dns.SOA); ok {
|
|
if soa.Minttl < ttl {
|
|
ttl = soa.Minttl
|
|
}
|
|
found = true
|
|
}
|
|
}
|
|
}
|
|
if ttl < 10 {
|
|
return
|
|
}
|
|
|
|
cp := in.Copy()
|
|
present, ver, ext, ede := extractEDNSMeta(cp)
|
|
stripPseudoExtras(cp)
|
|
|
|
cache.Add(key, &cacheEntry{
|
|
msg: cp,
|
|
expireAt: time.Now().Add(time.Duration(ttl) * time.Second),
|
|
ednsPresent: present,
|
|
ednsVersion: ver,
|
|
ednsExtRcode: ext,
|
|
ednsEDE: ede,
|
|
})
|
|
}
|
|
|
|
func writeReply(w dns.ResponseWriter, req, resp *dns.Msg, meta *cacheEntry) {
|
|
out := new(dns.Msg)
|
|
out.SetReply(req)
|
|
out.Rcode = resp.Rcode
|
|
out.Answer = resp.Answer
|
|
out.Ns = resp.Ns
|
|
out.Extra = resp.Extra
|
|
|
|
if ro := req.IsEdns0(); ro != nil {
|
|
o := new(dns.OPT)
|
|
o.Hdr.Name = "."
|
|
o.Hdr.Rrtype = dns.TypeOPT
|
|
o.SetUDPSize(ro.UDPSize())
|
|
if uo := resp.IsEdns0(); uo != nil {
|
|
o.Option = uo.Option
|
|
o.SetExtendedRcode(uint16(uo.ExtendedRcode()))
|
|
} else if meta != nil && meta.ednsPresent {
|
|
o.SetExtendedRcode(meta.ednsExtRcode)
|
|
for _, e := range meta.ednsEDE {
|
|
o.Option = append(o.Option, e)
|
|
}
|
|
}
|
|
out.Extra = append(out.Extra, o)
|
|
}
|
|
out.Compress = true
|
|
_ = w.WriteMsg(out)
|
|
}
|
|
|
|
func stripECSFromMsg(m *dns.Msg) {
|
|
if o := m.IsEdns0(); o != nil {
|
|
newOpt := make([]dns.EDNS0, 0, len(o.Option))
|
|
for _, opt := range o.Option {
|
|
if opt.Option() != dns.EDNS0SUBNET {
|
|
newOpt = append(newOpt, opt)
|
|
}
|
|
}
|
|
o.Option = newOpt
|
|
}
|
|
}
|
|
|
|
func stripPseudoExtras(m *dns.Msg) {
|
|
newExtra := make([]dns.RR, 0, len(m.Extra))
|
|
for _, rr := range m.Extra {
|
|
if rr.Header().Rrtype != dns.TypeOPT && rr.Header().Rrtype != dns.TypeTSIG {
|
|
newExtra = append(newExtra, rr)
|
|
}
|
|
}
|
|
m.Extra = newExtra
|
|
}
|
|
|
|
func extractEDNSMeta(m *dns.Msg) (bool, uint8, uint16, []*dns.EDNS0_EDE) {
|
|
if o := m.IsEdns0(); o != nil {
|
|
var edes []*dns.EDNS0_EDE
|
|
for _, opt := range o.Option {
|
|
if e, ok := opt.(*dns.EDNS0_EDE); ok {
|
|
edes = append(edes, e)
|
|
}
|
|
}
|
|
return true, o.Version(), uint16(o.ExtendedRcode()), edes
|
|
}
|
|
return false, 0, 0, nil
|
|
}
|
|
|
|
func main() {
|
|
addr := flag.String("addr", ":853", "DoT address")
|
|
upStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "Upstreams")
|
|
certFile := flag.String("cert", "server.crt", "TLS Cert")
|
|
keyFile := flag.String("key", "server.key", "TLS Key")
|
|
blFile := flag.String("blacklist-file", "", "Blacklist file")
|
|
blRcodeStr := flag.String("blacklist-rcode", "REFUSED", "RCODE for blocked")
|
|
v := flag.Bool("v", false, "Verbose logging")
|
|
flag.Parse()
|
|
|
|
initLogger(*v)
|
|
|
|
cache, _ = lru.New[string, *cacheEntry](defaultCacheSize)
|
|
udpClient = &dns.Client{Net: "udp", Timeout: 2 * time.Second, SingleInflight: true}
|
|
tcpClient = &dns.Client{Net: "tcp", Timeout: 3 * time.Second, SingleInflight: true}
|
|
|
|
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
|
|
if err != nil {
|
|
log.Fatalf("TLS Error: %v", err)
|
|
}
|
|
|
|
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
|
defer stop()
|
|
|
|
blPtr, blRcode := initBlacklist(ctx, *blFile, *blRcodeStr)
|
|
|
|
mux := dns.NewServeMux()
|
|
mux.HandleFunc(".", handleDNS(strings.Split(*upStr, ","), 1*time.Hour, 2*time.Second, 3, true, blPtr, blRcode))
|
|
|
|
server := &dns.Server{
|
|
Addr: *addr,
|
|
Net: "tcp-tls",
|
|
Handler: mux,
|
|
TLSConfig: &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
MinVersion: tls.VersionTLS13,
|
|
NextProtos: []string{"dot"},
|
|
},
|
|
ReadTimeout: 10 * time.Second,
|
|
WriteTimeout: 10 * time.Second,
|
|
}
|
|
|
|
go func() {
|
|
log.Printf("🚀 DoT Server started on %s (TLS 1.3)", *addr)
|
|
if err := server.ListenAndServe(); err != nil {
|
|
log.Printf("Server exit: %v", err)
|
|
}
|
|
}()
|
|
|
|
<-ctx.Done()
|
|
log.Println("Gracefully shutting down...")
|
|
|
|
// 给 5 秒处理残余请求
|
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
server.ShutdownContext(shutdownCtx)
|
|
}
|