commit 0840090e5b32e8aec021a2a1f38f71db74a29f6c Author: aixiao Date: Tue Oct 14 09:26:29 2025 +0800 init diff --git a/dot b/dot new file mode 100644 index 0000000..73ec45e Binary files /dev/null and b/dot differ diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..2977add --- /dev/null +++ b/go.mod @@ -0,0 +1,12 @@ +module dot + +go 1.25.2 + +require ( + github.com/miekg/dns v1.1.68 // indirect + golang.org/x/mod v0.24.0 // indirect + golang.org/x/net v0.40.0 // indirect + golang.org/x/sync v0.14.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/tools v0.33.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..3b3e82e --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA= +github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps= +golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= +golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= +golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= +golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= +golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= +golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= diff --git a/main.go b/main.go new file mode 100644 index 0000000..85e5624 --- /dev/null +++ b/main.go @@ -0,0 +1,356 @@ +package main + +import ( + "context" + "crypto/tls" + "flag" + "log" + "math/rand" + "os" + "os/signal" + "strings" + "sync" + "syscall" + "time" + + "github.com/miekg/dns" +) + +/********************************* + * 缓存结构与全局对象 + *********************************/ + +type cacheEntry struct { + msg *dns.Msg // 缓存的完整响应(深拷贝) + expireAt time.Time // 过期时间(由动态 TTL 决定) +} + +// 并发安全缓存 +var cache sync.Map + +// 后台清理:每隔 N 分钟清一次 +const cacheCleanupInterval = 5 * time.Minute + +// 启动带 context 的缓存清理器(优雅退出) +func startCacheCleaner(ctx context.Context) { + go func() { + ticker := time.NewTicker(cacheCleanupInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + now := time.Now() + n := 0 + cache.Range(func(k, v any) bool { + e := v.(*cacheEntry) + if now.After(e.expireAt) { + cache.Delete(k) + n++ + } + return true + }) + if n > 0 { + log.Printf("[Cache] Cleaned %d expired entries", n) + } + } + } + }() +} + +// 生成缓存键 +func cacheKey(name string, qtype uint16) string { + return strings.ToLower(name) + ":" + dns.TypeToString[qtype] +} + +// 读取缓存;命中则回填剩余 TTL 并返回拷贝 +func tryCacheRead(key string) (*dns.Msg, bool) { + v, ok := cache.Load(key) + if !ok { + return nil, false + } + e := v.(*cacheEntry) + now := time.Now() + if now.After(e.expireAt) { + cache.Delete(key) + return nil, false + } + out := e.msg.Copy() + remaining := uint32(e.expireAt.Sub(now).Seconds()) + // 回填剩余 TTL,避免客户端收到过期 TTL + for i := range out.Answer { + if out.Answer[i].Header().Ttl > remaining { + out.Answer[i].Header().Ttl = remaining + } + } + for i := range out.Ns { + if out.Ns[i].Header().Ttl > remaining { + out.Ns[i].Header().Ttl = remaining + } + } + for i := range out.Extra { + if out.Extra[i].Header().Ttl > remaining { + out.Extra[i].Header().Ttl = remaining + } + } + return out, true +} + +// 写缓存:以「上游最小 TTL」与「配置上限 TTL」取较小值 +func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) { + if in == nil { + return + } + // 可按需缓存 NXDOMAIN;这里允许缓存 NOERROR 与 NXDOMAIN + if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError { + return + } + // 计算报文中的最小 TTL(Answer/Ns/Extra) + minTTL := uint32(0) + hasTTL := false + for _, sec := range [][]dns.RR{in.Answer, in.Ns, in.Extra} { + for _, rr := range sec { + ttl := rr.Header().Ttl + if !hasTTL || ttl < minTTL { + minTTL = ttl + hasTTL = true + } + } + } + // 若无 TTL,可用配置上限作为兜底(也可选择不缓存) + cfgTTL := uint32(maxTTL.Seconds()) + var finalTTL uint32 + switch { + case !hasTTL && cfgTTL > 0: + finalTTL = cfgTTL + case hasTTL && cfgTTL > 0 && minTTL > cfgTTL: + finalTTL = cfgTTL + case hasTTL: + finalTTL = minTTL + default: + return + } + if finalTTL == 0 { + return + } + expire := time.Now().Add(time.Duration(finalTTL) * time.Second) + cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire}) +} + +/********************************* + * 上游查询与并发控制 + *********************************/ + +// 全局可复用的 DNS 客户端(默认 UDP) +var dnsClient *dns.Client + +// 并发上限通过信号量限制 +func queryUpstreamsLimited(r *dns.Msg, upstreams []string, timeout time.Duration, maxParallel int) *dns.Msg { + if maxParallel <= 0 { + maxParallel = 1 + } + ch := make(chan *dns.Msg, len(upstreams)) + sem := make(chan struct{}, maxParallel) + + // 在 UDP 上查询,遇到截断再 TCP fallback + queryOnce := func(server string) *dns.Msg { + resp, _, err := dnsClient.Exchange(r, server) + if err == nil && resp != nil && resp.Truncated { + // UDP 截断,尝试 TCP + log.Printf("[Info] UDP truncated, retry TCP: %s", server) + tcpClient := *dnsClient + tcpClient.Net = "tcp" + resp, _, err = tcpClient.Exchange(r, server) + } + if err != nil || resp == nil { + if err != nil { + log.Printf("[Warn] Upstream %s failed: %v", server, err) + } else { + log.Printf("[Warn] Upstream %s failed: nil response", server) + } + return nil + } + // 可选:丢弃 SERVFAIL + if resp.Rcode == dns.RcodeServerFailure { + return nil + } + return resp + } + + for _, server := range upstreams { + sem <- struct{}{} + go func(svr string) { + defer func() { <-sem }() + resp := queryOnce(svr) + // 非阻塞/限时写入,防止消费者意外退出导致阻塞 + select { + case ch <- resp: + case <-time.After(1 * time.Second): + } + }(server) + } + + timer := time.NewTimer(timeout) + defer timer.Stop() + + // 在超时内返回第一个非空的结果 + for i := 0; i < len(upstreams); i++ { + select { + case resp := <-ch: + if resp != nil { + return resp + } + case <-timer.C: + log.Printf("[Error] Upstream query timeout after %v", timeout) + return nil + } + } + return nil +} + +/********************************* + * DNS 处理逻辑 + *********************************/ + +func shuffled(slice []string) []string { + out := make([]string, len(slice)) + copy(out, slice) + rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] }) + return out +} + +func handleDNS(upstreams []string, cacheMaxTTL, timeout time.Duration, maxParallel int) dns.HandlerFunc { + return func(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 0 { + dns.HandleFailed(w, r) + return + } + q := r.Question[0] + key := cacheKey(q.Name, q.Qtype) + + // 1) 缓存命中 + if cached, ok := tryCacheRead(key); ok { + log.Printf("[Cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name) + _ = w.WriteMsg(cached) + return + } + log.Printf("[Cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name) + + // 2) 随机化上游并并发查询(带 fallback) + servers := shuffled(upstreams) + resp := queryUpstreamsLimited(r, servers, timeout, maxParallel) + if resp == nil { + log.Printf("[Error] All upstreams failed for %s", q.Name) + dns.HandleFailed(w, r) + return + } + + // 3) 写入缓存(动态 TTL) + cacheWrite(key, resp, cacheMaxTTL) + + // 4) 返回结果 + for _, ans := range resp.Answer { + log.Printf("[Answer] %s", ans.String()) + } + _ = w.WriteMsg(resp) + } +} + +/********************************* + * 主程序(优雅退出) + *********************************/ + +func main() { + rand.Seed(time.Now().UnixNano()) + + // 参数 + certFile := flag.String("cert", "aixiao.me.cer", "TLS 证书文件路径 (.cer/.crt)") + keyFile := flag.String("key", "aixiao.me.key", "TLS 私钥文件路径 (.key)") + addr := flag.String("addr", ":853", "DoT 服务监听地址(默认 :853)") + upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表(逗号分隔)") + cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL(默认 60s;实际取 min(上游最小TTL, 本值))") + timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时(默认 3s)") + maxParallel := flag.Int("max-parallel", 3, "并发查询的上游数量上限") + flag.Parse() + + // 证书 + cert, err := tls.LoadX509KeyPair(*certFile, *keyFile) + if err != nil { + log.Fatalf("failed to load cert: %v", err) + } + + // 上游列表 + raw := strings.Split(*upstreamStr, ",") + upstreams := make([]string, 0, len(raw)) + for _, s := range raw { + if t := strings.TrimSpace(s); t != "" { + upstreams = append(upstreams, t) + } + } + if len(upstreams) == 0 { + log.Fatal("no upstream DNS servers provided") + } + + // 全局 DNS 客户端(UDP,扩大 UDPSize;fallback 在查询函数中完成) + dnsClient = &dns.Client{ + Net: "udp", + UDPSize: 4096, // 防截断 + Timeout: *timeoutFlag, + } + + // context 用于优雅退出与清理协程 + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // 启动缓存清理器 + startCacheCleaner(ctx) + + // DNS 处理器 + mux := dns.NewServeMux() + mux.HandleFunc(".", handleDNS(upstreams, *cacheTTLFlag, *timeoutFlag, *maxParallel)) + + // DoT 服务器(TLS 会话缓存 + 安全套件 + TLS1.2+) + srv := &dns.Server{ + Addr: *addr, + Net: "tcp-tls", + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{cert}, + ClientSessionCache: tls.NewLRUClientSessionCache(256), + MinVersion: tls.VersionTLS12, + CipherSuites: []uint16{ + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + }, + }, + Handler: mux, + } + + // 捕获退出信号,优雅关闭 + stop := make(chan os.Signal, 1) + signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM) + + errCh := make(chan error, 1) + go func() { + log.Printf("🚀 Starting DNS-over-TLS server on %s", *addr) + log.Printf("Upstreams=%v | MaxTTL=%s | Timeout=%s | MaxParallel=%d", + upstreams, cacheTTLFlag.String(), timeoutFlag.String(), *maxParallel) + errCh <- srv.ListenAndServe() + }() + + select { + case sig := <-stop: + log.Printf("[Shutdown] Caught signal: %s", sig) + cancel() // 结束清理器 + // 优雅关闭服务器 + // miekg/dns 提供 Shutdown();若你的版本支持 ShutdownContext,可改用带 ctx 的版本 + if err := srv.Shutdown(); err != nil { + log.Printf("[Shutdown] server shutdown error: %v", err) + } + case err := <-errCh: + if err != nil { + log.Fatalf("server error: %v", err) + } + } + log.Println("[Shutdown] Bye.") +}