package main import ( "context" "crypto/tls" "flag" "log" "os" "os/signal" "strings" "sync" "sync/atomic" "syscall" "time" lru "github.com/hashicorp/golang-lru/v2" "github.com/miekg/dns" "golang.org/x/sync/singleflight" ) var ( BuildDate = "unknown" cache *lru.Cache[string, *cacheEntry] inflight singleflight.Group udpClient *dns.Client tcpClient *dns.Client verbose bool // 全局日志开关 ) const ( defaultCacheSize = 20000 maxUDPSize = 1232 ) type cacheEntry struct { msg *dns.Msg expireAt time.Time ednsPresent bool ednsVersion uint8 ednsExtRcode uint16 ednsEDE []*dns.EDNS0_EDE } func initLogger(v bool) { verbose = v flags := log.Ldate | log.Ltime | log.Lmicroseconds if verbose { flags |= log.Lshortfile } log.SetFlags(flags) } func cacheKey(q dns.Question, r *dns.Msg, ecs string) string { do, cd := "0", "0" if o := r.IsEdns0(); o != nil && o.Do() { do = "1" } if r.CheckingDisabled { cd = "1" } var b strings.Builder b.Grow(len(q.Name) + 32) b.WriteString(dns.TypeToString[q.Qtype]) b.WriteByte('|') b.WriteString(do) b.WriteString(cd) b.WriteByte('|') b.WriteString(ecs) b.WriteByte('|') b.WriteString(strings.ToLower(q.Name)) return b.String() } func queryUpstreams(ctx context.Context, req *dns.Msg, upstreams []string, timeout time.Duration, parallel int) *dns.Msg { cctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() resCh := make(chan *dns.Msg, len(upstreams)) var wg sync.WaitGroup sem := make(chan struct{}, parallel) for _, svr := range upstreams { wg.Add(1) go func(s string) { defer wg.Done() select { case sem <- struct{}{}: defer func() { <-sem }() case <-cctx.Done(): return } uReq := req.Copy() if o := uReq.IsEdns0(); o == nil { uReq.SetEdns0(maxUDPSize, false) } else { o.SetUDPSize(maxUDPSize) } resp, _, err := udpClient.ExchangeContext(cctx, uReq, s) if err == nil && resp != nil && resp.Truncated { resp, _, err = tcpClient.ExchangeContext(cctx, req, s) } if err == nil && resp != nil { resCh <- resp } }(svr) } go func() { wg.Wait() close(resCh) }() for r := range resCh { if r.Rcode != dns.RcodeServerFailure && r.Rcode != dns.RcodeRefused { return r } } return nil } func handleDNS(upstreams []string, maxTTL, timeout time.Duration, parallel int, stripECS bool, blPtr *atomic.Pointer[BlacklistTrie], blRcode int) dns.HandlerFunc { return func(w dns.ResponseWriter, r *dns.Msg) { defer func() { if err := recover(); err != nil { log.Printf("[PANIC] %v", err) dns.HandleFailed(w, r) } }() if len(r.Question) == 0 { dns.HandleFailed(w, r) return } q := r.Question[0] startTime := time.Now() if trie := blPtr.Load(); trie != nil { if rule, hit := trie.Match(q.Name); hit { log.Printf("[BLOCK] %s rule=%s client=%s", q.Name, rule, w.RemoteAddr()) writeReply(w, r, makeBlockedMsg(blRcode, rule), nil) return } } ecs := "" if !stripECS { if o := r.IsEdns0(); o != nil { for _, opt := range o.Option { if s, ok := opt.(*dns.EDNS0_SUBNET); ok { ecs = s.Address.String() } } } } else { stripECSFromMsg(r) } key := cacheKey(q, r, ecs) if msg, meta, ok := tryCacheRead(key); ok { if verbose { log.Printf("[CACHE] HIT %s", q.Name) } writeReply(w, r, msg, meta) return } v, _, _ := inflight.Do(key, func() (any, error) { resp := queryUpstreams(context.Background(), r, upstreams, timeout, parallel) if resp != nil { cacheWrite(key, resp, maxTTL) } return resp, nil }) if resp, _ := v.(*dns.Msg); resp != nil { writeReply(w, r, resp, nil) if verbose { log.Printf("[QUERY] %s %s -> %s (%v)", dns.TypeToString[q.Qtype], q.Name, dns.RcodeToString[resp.Rcode], time.Since(startTime)) } } else { dns.HandleFailed(w, r) } } } func tryCacheRead(key string) (*dns.Msg, *cacheEntry, bool) { e, ok := cache.Get(key) if !ok || time.Now().After(e.expireAt) { return nil, nil, false } out := e.msg.Copy() ttl := uint32(time.Until(e.expireAt).Seconds()) for _, sec := range [][]dns.RR{out.Answer, out.Ns, out.Extra} { for _, rr := range sec { if rr.Header().Rrtype != dns.TypeOPT { rr.Header().Ttl = ttl } } } return out, e, true } func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) { if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError { return } var ttl uint32 = uint32(maxTTL.Seconds()) found := false for _, rr := range in.Answer { if rr.Header().Rrtype != dns.TypeOPT && rr.Header().Ttl < ttl { ttl = rr.Header().Ttl found = true } } if !found { for _, rr := range in.Ns { if soa, ok := rr.(*dns.SOA); ok { if soa.Minttl < ttl { ttl = soa.Minttl } found = true } } } if ttl < 10 { return } cp := in.Copy() present, ver, ext, ede := extractEDNSMeta(cp) stripPseudoExtras(cp) cache.Add(key, &cacheEntry{ msg: cp, expireAt: time.Now().Add(time.Duration(ttl) * time.Second), ednsPresent: present, ednsVersion: ver, ednsExtRcode: ext, ednsEDE: ede, }) } func writeReply(w dns.ResponseWriter, req, resp *dns.Msg, meta *cacheEntry) { out := new(dns.Msg) out.SetReply(req) out.Rcode = resp.Rcode out.Answer = resp.Answer out.Ns = resp.Ns out.Extra = resp.Extra if ro := req.IsEdns0(); ro != nil { o := new(dns.OPT) o.Hdr.Name = "." o.Hdr.Rrtype = dns.TypeOPT o.SetUDPSize(ro.UDPSize()) if uo := resp.IsEdns0(); uo != nil { o.Option = uo.Option o.SetExtendedRcode(uint16(uo.ExtendedRcode())) } else if meta != nil && meta.ednsPresent { o.SetExtendedRcode(meta.ednsExtRcode) for _, e := range meta.ednsEDE { o.Option = append(o.Option, e) } } out.Extra = append(out.Extra, o) } out.Compress = true _ = w.WriteMsg(out) } func stripECSFromMsg(m *dns.Msg) { if o := m.IsEdns0(); o != nil { newOpt := make([]dns.EDNS0, 0, len(o.Option)) for _, opt := range o.Option { if opt.Option() != dns.EDNS0SUBNET { newOpt = append(newOpt, opt) } } o.Option = newOpt } } func stripPseudoExtras(m *dns.Msg) { newExtra := make([]dns.RR, 0, len(m.Extra)) for _, rr := range m.Extra { if rr.Header().Rrtype != dns.TypeOPT && rr.Header().Rrtype != dns.TypeTSIG { newExtra = append(newExtra, rr) } } m.Extra = newExtra } func extractEDNSMeta(m *dns.Msg) (bool, uint8, uint16, []*dns.EDNS0_EDE) { if o := m.IsEdns0(); o != nil { var edes []*dns.EDNS0_EDE for _, opt := range o.Option { if e, ok := opt.(*dns.EDNS0_EDE); ok { edes = append(edes, e) } } return true, o.Version(), uint16(o.ExtendedRcode()), edes } return false, 0, 0, nil } func main() { addr := flag.String("addr", ":853", "DoT address") upStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "Upstreams") certFile := flag.String("cert", "server.crt", "TLS Cert") keyFile := flag.String("key", "server.key", "TLS Key") blFile := flag.String("blacklist-file", "", "Blacklist file") blRcodeStr := flag.String("blacklist-rcode", "REFUSED", "RCODE for blocked") v := flag.Bool("v", false, "Verbose logging") flag.Parse() initLogger(*v) cache, _ = lru.New[string, *cacheEntry](defaultCacheSize) udpClient = &dns.Client{Net: "udp", Timeout: 2 * time.Second, SingleInflight: true} tcpClient = &dns.Client{Net: "tcp", Timeout: 3 * time.Second, SingleInflight: true} cert, err := tls.LoadX509KeyPair(*certFile, *keyFile) if err != nil { log.Fatalf("TLS Error: %v", err) } ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() blPtr, blRcode := initBlacklist(ctx, *blFile, *blRcodeStr) mux := dns.NewServeMux() mux.HandleFunc(".", handleDNS(strings.Split(*upStr, ","), 1*time.Hour, 2*time.Second, 3, true, blPtr, blRcode)) server := &dns.Server{ Addr: *addr, Net: "tcp-tls", Handler: mux, TLSConfig: &tls.Config{ Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS13, NextProtos: []string{"dot"}, }, ReadTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second, } go func() { log.Printf("🚀 DoT Server started on %s (TLS 1.3)", *addr) if err := server.ListenAndServe(); err != nil { log.Printf("Server exit: %v", err) } }() <-ctx.Done() log.Println("Gracefully shutting down...") // 给 5 秒处理残余请求 shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() server.ShutdownContext(shutdownCtx) }