Files
dot/main.go
2025-10-14 09:26:29 +08:00

357 lines
9.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"crypto/tls"
"flag"
"log"
"math/rand"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"github.com/miekg/dns"
)
/*********************************
* 缓存结构与全局对象
*********************************/
type cacheEntry struct {
msg *dns.Msg // 缓存的完整响应(深拷贝)
expireAt time.Time // 过期时间(由动态 TTL 决定)
}
// 并发安全缓存
var cache sync.Map
// 后台清理:每隔 N 分钟清一次
const cacheCleanupInterval = 5 * time.Minute
// 启动带 context 的缓存清理器(优雅退出)
func startCacheCleaner(ctx context.Context) {
go func() {
ticker := time.NewTicker(cacheCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
n := 0
cache.Range(func(k, v any) bool {
e := v.(*cacheEntry)
if now.After(e.expireAt) {
cache.Delete(k)
n++
}
return true
})
if n > 0 {
log.Printf("[Cache] Cleaned %d expired entries", n)
}
}
}
}()
}
// 生成缓存键
func cacheKey(name string, qtype uint16) string {
return strings.ToLower(name) + ":" + dns.TypeToString[qtype]
}
// 读取缓存;命中则回填剩余 TTL 并返回拷贝
func tryCacheRead(key string) (*dns.Msg, bool) {
v, ok := cache.Load(key)
if !ok {
return nil, false
}
e := v.(*cacheEntry)
now := time.Now()
if now.After(e.expireAt) {
cache.Delete(key)
return nil, false
}
out := e.msg.Copy()
remaining := uint32(e.expireAt.Sub(now).Seconds())
// 回填剩余 TTL避免客户端收到过期 TTL
for i := range out.Answer {
if out.Answer[i].Header().Ttl > remaining {
out.Answer[i].Header().Ttl = remaining
}
}
for i := range out.Ns {
if out.Ns[i].Header().Ttl > remaining {
out.Ns[i].Header().Ttl = remaining
}
}
for i := range out.Extra {
if out.Extra[i].Header().Ttl > remaining {
out.Extra[i].Header().Ttl = remaining
}
}
return out, true
}
// 写缓存:以「上游最小 TTL」与「配置上限 TTL」取较小值
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
if in == nil {
return
}
// 可按需缓存 NXDOMAIN这里允许缓存 NOERROR 与 NXDOMAIN
if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError {
return
}
// 计算报文中的最小 TTLAnswer/Ns/Extra
minTTL := uint32(0)
hasTTL := false
for _, sec := range [][]dns.RR{in.Answer, in.Ns, in.Extra} {
for _, rr := range sec {
ttl := rr.Header().Ttl
if !hasTTL || ttl < minTTL {
minTTL = ttl
hasTTL = true
}
}
}
// 若无 TTL可用配置上限作为兜底也可选择不缓存
cfgTTL := uint32(maxTTL.Seconds())
var finalTTL uint32
switch {
case !hasTTL && cfgTTL > 0:
finalTTL = cfgTTL
case hasTTL && cfgTTL > 0 && minTTL > cfgTTL:
finalTTL = cfgTTL
case hasTTL:
finalTTL = minTTL
default:
return
}
if finalTTL == 0 {
return
}
expire := time.Now().Add(time.Duration(finalTTL) * time.Second)
cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire})
}
/*********************************
* 上游查询与并发控制
*********************************/
// 全局可复用的 DNS 客户端(默认 UDP
var dnsClient *dns.Client
// 并发上限通过信号量限制
func queryUpstreamsLimited(r *dns.Msg, upstreams []string, timeout time.Duration, maxParallel int) *dns.Msg {
if maxParallel <= 0 {
maxParallel = 1
}
ch := make(chan *dns.Msg, len(upstreams))
sem := make(chan struct{}, maxParallel)
// 在 UDP 上查询,遇到截断再 TCP fallback
queryOnce := func(server string) *dns.Msg {
resp, _, err := dnsClient.Exchange(r, server)
if err == nil && resp != nil && resp.Truncated {
// UDP 截断,尝试 TCP
log.Printf("[Info] UDP truncated, retry TCP: %s", server)
tcpClient := *dnsClient
tcpClient.Net = "tcp"
resp, _, err = tcpClient.Exchange(r, server)
}
if err != nil || resp == nil {
if err != nil {
log.Printf("[Warn] Upstream %s failed: %v", server, err)
} else {
log.Printf("[Warn] Upstream %s failed: nil response", server)
}
return nil
}
// 可选:丢弃 SERVFAIL
if resp.Rcode == dns.RcodeServerFailure {
return nil
}
return resp
}
for _, server := range upstreams {
sem <- struct{}{}
go func(svr string) {
defer func() { <-sem }()
resp := queryOnce(svr)
// 非阻塞/限时写入,防止消费者意外退出导致阻塞
select {
case ch <- resp:
case <-time.After(1 * time.Second):
}
}(server)
}
timer := time.NewTimer(timeout)
defer timer.Stop()
// 在超时内返回第一个非空的结果
for i := 0; i < len(upstreams); i++ {
select {
case resp := <-ch:
if resp != nil {
return resp
}
case <-timer.C:
log.Printf("[Error] Upstream query timeout after %v", timeout)
return nil
}
}
return nil
}
/*********************************
* DNS 处理逻辑
*********************************/
func shuffled(slice []string) []string {
out := make([]string, len(slice))
copy(out, slice)
rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] })
return out
}
func handleDNS(upstreams []string, cacheMaxTTL, timeout time.Duration, maxParallel int) dns.HandlerFunc {
return func(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
dns.HandleFailed(w, r)
return
}
q := r.Question[0]
key := cacheKey(q.Name, q.Qtype)
// 1) 缓存命中
if cached, ok := tryCacheRead(key); ok {
log.Printf("[Cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
_ = w.WriteMsg(cached)
return
}
log.Printf("[Cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
// 2) 随机化上游并并发查询(带 fallback
servers := shuffled(upstreams)
resp := queryUpstreamsLimited(r, servers, timeout, maxParallel)
if resp == nil {
log.Printf("[Error] All upstreams failed for %s", q.Name)
dns.HandleFailed(w, r)
return
}
// 3) 写入缓存(动态 TTL
cacheWrite(key, resp, cacheMaxTTL)
// 4) 返回结果
for _, ans := range resp.Answer {
log.Printf("[Answer] %s", ans.String())
}
_ = w.WriteMsg(resp)
}
}
/*********************************
* 主程序(优雅退出)
*********************************/
func main() {
rand.Seed(time.Now().UnixNano())
// 参数
certFile := flag.String("cert", "aixiao.me.cer", "TLS 证书文件路径 (.cer/.crt)")
keyFile := flag.String("key", "aixiao.me.key", "TLS 私钥文件路径 (.key)")
addr := flag.String("addr", ":853", "DoT 服务监听地址(默认 :853")
upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表(逗号分隔)")
cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL默认 60s实际取 min(上游最小TTL, 本值)")
timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时(默认 3s")
maxParallel := flag.Int("max-parallel", 3, "并发查询的上游数量上限")
flag.Parse()
// 证书
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
if err != nil {
log.Fatalf("failed to load cert: %v", err)
}
// 上游列表
raw := strings.Split(*upstreamStr, ",")
upstreams := make([]string, 0, len(raw))
for _, s := range raw {
if t := strings.TrimSpace(s); t != "" {
upstreams = append(upstreams, t)
}
}
if len(upstreams) == 0 {
log.Fatal("no upstream DNS servers provided")
}
// 全局 DNS 客户端UDP扩大 UDPSizefallback 在查询函数中完成)
dnsClient = &dns.Client{
Net: "udp",
UDPSize: 4096, // 防截断
Timeout: *timeoutFlag,
}
// context 用于优雅退出与清理协程
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// 启动缓存清理器
startCacheCleaner(ctx)
// DNS 处理器
mux := dns.NewServeMux()
mux.HandleFunc(".", handleDNS(upstreams, *cacheTTLFlag, *timeoutFlag, *maxParallel))
// DoT 服务器TLS 会话缓存 + 安全套件 + TLS1.2+
srv := &dns.Server{
Addr: *addr,
Net: "tcp-tls",
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
ClientSessionCache: tls.NewLRUClientSessionCache(256),
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
},
},
Handler: mux,
}
// 捕获退出信号,优雅关闭
stop := make(chan os.Signal, 1)
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
errCh := make(chan error, 1)
go func() {
log.Printf("🚀 Starting DNS-over-TLS server on %s", *addr)
log.Printf("Upstreams=%v | MaxTTL=%s | Timeout=%s | MaxParallel=%d",
upstreams, cacheTTLFlag.String(), timeoutFlag.String(), *maxParallel)
errCh <- srv.ListenAndServe()
}()
select {
case sig := <-stop:
log.Printf("[Shutdown] Caught signal: %s", sig)
cancel() // 结束清理器
// 优雅关闭服务器
// miekg/dns 提供 Shutdown();若你的版本支持 ShutdownContext可改用带 ctx 的版本
if err := srv.Shutdown(); err != nil {
log.Printf("[Shutdown] server shutdown error: %v", err)
}
case err := <-errCh:
if err != nil {
log.Fatalf("server error: %v", err)
}
}
log.Println("[Shutdown] Bye.")
}