357 lines
9.1 KiB
Go
357 lines
9.1 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"flag"
|
||
"log"
|
||
"math/rand"
|
||
"os"
|
||
"os/signal"
|
||
"strings"
|
||
"sync"
|
||
"syscall"
|
||
"time"
|
||
|
||
"github.com/miekg/dns"
|
||
)
|
||
|
||
/*********************************
|
||
* 缓存结构与全局对象
|
||
*********************************/
|
||
|
||
type cacheEntry struct {
|
||
msg *dns.Msg // 缓存的完整响应(深拷贝)
|
||
expireAt time.Time // 过期时间(由动态 TTL 决定)
|
||
}
|
||
|
||
// 并发安全缓存
|
||
var cache sync.Map
|
||
|
||
// 后台清理:每隔 N 分钟清一次
|
||
const cacheCleanupInterval = 5 * time.Minute
|
||
|
||
// 启动带 context 的缓存清理器(优雅退出)
|
||
func startCacheCleaner(ctx context.Context) {
|
||
go func() {
|
||
ticker := time.NewTicker(cacheCleanupInterval)
|
||
defer ticker.Stop()
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
return
|
||
case <-ticker.C:
|
||
now := time.Now()
|
||
n := 0
|
||
cache.Range(func(k, v any) bool {
|
||
e := v.(*cacheEntry)
|
||
if now.After(e.expireAt) {
|
||
cache.Delete(k)
|
||
n++
|
||
}
|
||
return true
|
||
})
|
||
if n > 0 {
|
||
log.Printf("[Cache] Cleaned %d expired entries", n)
|
||
}
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
|
||
// 生成缓存键
|
||
func cacheKey(name string, qtype uint16) string {
|
||
return strings.ToLower(name) + ":" + dns.TypeToString[qtype]
|
||
}
|
||
|
||
// 读取缓存;命中则回填剩余 TTL 并返回拷贝
|
||
func tryCacheRead(key string) (*dns.Msg, bool) {
|
||
v, ok := cache.Load(key)
|
||
if !ok {
|
||
return nil, false
|
||
}
|
||
e := v.(*cacheEntry)
|
||
now := time.Now()
|
||
if now.After(e.expireAt) {
|
||
cache.Delete(key)
|
||
return nil, false
|
||
}
|
||
out := e.msg.Copy()
|
||
remaining := uint32(e.expireAt.Sub(now).Seconds())
|
||
// 回填剩余 TTL,避免客户端收到过期 TTL
|
||
for i := range out.Answer {
|
||
if out.Answer[i].Header().Ttl > remaining {
|
||
out.Answer[i].Header().Ttl = remaining
|
||
}
|
||
}
|
||
for i := range out.Ns {
|
||
if out.Ns[i].Header().Ttl > remaining {
|
||
out.Ns[i].Header().Ttl = remaining
|
||
}
|
||
}
|
||
for i := range out.Extra {
|
||
if out.Extra[i].Header().Ttl > remaining {
|
||
out.Extra[i].Header().Ttl = remaining
|
||
}
|
||
}
|
||
return out, true
|
||
}
|
||
|
||
// 写缓存:以「上游最小 TTL」与「配置上限 TTL」取较小值
|
||
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
|
||
if in == nil {
|
||
return
|
||
}
|
||
// 可按需缓存 NXDOMAIN;这里允许缓存 NOERROR 与 NXDOMAIN
|
||
if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError {
|
||
return
|
||
}
|
||
// 计算报文中的最小 TTL(Answer/Ns/Extra)
|
||
minTTL := uint32(0)
|
||
hasTTL := false
|
||
for _, sec := range [][]dns.RR{in.Answer, in.Ns, in.Extra} {
|
||
for _, rr := range sec {
|
||
ttl := rr.Header().Ttl
|
||
if !hasTTL || ttl < minTTL {
|
||
minTTL = ttl
|
||
hasTTL = true
|
||
}
|
||
}
|
||
}
|
||
// 若无 TTL,可用配置上限作为兜底(也可选择不缓存)
|
||
cfgTTL := uint32(maxTTL.Seconds())
|
||
var finalTTL uint32
|
||
switch {
|
||
case !hasTTL && cfgTTL > 0:
|
||
finalTTL = cfgTTL
|
||
case hasTTL && cfgTTL > 0 && minTTL > cfgTTL:
|
||
finalTTL = cfgTTL
|
||
case hasTTL:
|
||
finalTTL = minTTL
|
||
default:
|
||
return
|
||
}
|
||
if finalTTL == 0 {
|
||
return
|
||
}
|
||
expire := time.Now().Add(time.Duration(finalTTL) * time.Second)
|
||
cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire})
|
||
}
|
||
|
||
/*********************************
|
||
* 上游查询与并发控制
|
||
*********************************/
|
||
|
||
// 全局可复用的 DNS 客户端(默认 UDP)
|
||
var dnsClient *dns.Client
|
||
|
||
// 并发上限通过信号量限制
|
||
func queryUpstreamsLimited(r *dns.Msg, upstreams []string, timeout time.Duration, maxParallel int) *dns.Msg {
|
||
if maxParallel <= 0 {
|
||
maxParallel = 1
|
||
}
|
||
ch := make(chan *dns.Msg, len(upstreams))
|
||
sem := make(chan struct{}, maxParallel)
|
||
|
||
// 在 UDP 上查询,遇到截断再 TCP fallback
|
||
queryOnce := func(server string) *dns.Msg {
|
||
resp, _, err := dnsClient.Exchange(r, server)
|
||
if err == nil && resp != nil && resp.Truncated {
|
||
// UDP 截断,尝试 TCP
|
||
log.Printf("[Info] UDP truncated, retry TCP: %s", server)
|
||
tcpClient := *dnsClient
|
||
tcpClient.Net = "tcp"
|
||
resp, _, err = tcpClient.Exchange(r, server)
|
||
}
|
||
if err != nil || resp == nil {
|
||
if err != nil {
|
||
log.Printf("[Warn] Upstream %s failed: %v", server, err)
|
||
} else {
|
||
log.Printf("[Warn] Upstream %s failed: nil response", server)
|
||
}
|
||
return nil
|
||
}
|
||
// 可选:丢弃 SERVFAIL
|
||
if resp.Rcode == dns.RcodeServerFailure {
|
||
return nil
|
||
}
|
||
return resp
|
||
}
|
||
|
||
for _, server := range upstreams {
|
||
sem <- struct{}{}
|
||
go func(svr string) {
|
||
defer func() { <-sem }()
|
||
resp := queryOnce(svr)
|
||
// 非阻塞/限时写入,防止消费者意外退出导致阻塞
|
||
select {
|
||
case ch <- resp:
|
||
case <-time.After(1 * time.Second):
|
||
}
|
||
}(server)
|
||
}
|
||
|
||
timer := time.NewTimer(timeout)
|
||
defer timer.Stop()
|
||
|
||
// 在超时内返回第一个非空的结果
|
||
for i := 0; i < len(upstreams); i++ {
|
||
select {
|
||
case resp := <-ch:
|
||
if resp != nil {
|
||
return resp
|
||
}
|
||
case <-timer.C:
|
||
log.Printf("[Error] Upstream query timeout after %v", timeout)
|
||
return nil
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
/*********************************
|
||
* DNS 处理逻辑
|
||
*********************************/
|
||
|
||
func shuffled(slice []string) []string {
|
||
out := make([]string, len(slice))
|
||
copy(out, slice)
|
||
rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] })
|
||
return out
|
||
}
|
||
|
||
func handleDNS(upstreams []string, cacheMaxTTL, timeout time.Duration, maxParallel int) dns.HandlerFunc {
|
||
return func(w dns.ResponseWriter, r *dns.Msg) {
|
||
if len(r.Question) == 0 {
|
||
dns.HandleFailed(w, r)
|
||
return
|
||
}
|
||
q := r.Question[0]
|
||
key := cacheKey(q.Name, q.Qtype)
|
||
|
||
// 1) 缓存命中
|
||
if cached, ok := tryCacheRead(key); ok {
|
||
log.Printf("[Cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
|
||
_ = w.WriteMsg(cached)
|
||
return
|
||
}
|
||
log.Printf("[Cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
|
||
|
||
// 2) 随机化上游并并发查询(带 fallback)
|
||
servers := shuffled(upstreams)
|
||
resp := queryUpstreamsLimited(r, servers, timeout, maxParallel)
|
||
if resp == nil {
|
||
log.Printf("[Error] All upstreams failed for %s", q.Name)
|
||
dns.HandleFailed(w, r)
|
||
return
|
||
}
|
||
|
||
// 3) 写入缓存(动态 TTL)
|
||
cacheWrite(key, resp, cacheMaxTTL)
|
||
|
||
// 4) 返回结果
|
||
for _, ans := range resp.Answer {
|
||
log.Printf("[Answer] %s", ans.String())
|
||
}
|
||
_ = w.WriteMsg(resp)
|
||
}
|
||
}
|
||
|
||
/*********************************
|
||
* 主程序(优雅退出)
|
||
*********************************/
|
||
|
||
func main() {
|
||
rand.Seed(time.Now().UnixNano())
|
||
|
||
// 参数
|
||
certFile := flag.String("cert", "aixiao.me.cer", "TLS 证书文件路径 (.cer/.crt)")
|
||
keyFile := flag.String("key", "aixiao.me.key", "TLS 私钥文件路径 (.key)")
|
||
addr := flag.String("addr", ":853", "DoT 服务监听地址(默认 :853)")
|
||
upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表(逗号分隔)")
|
||
cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL(默认 60s;实际取 min(上游最小TTL, 本值))")
|
||
timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时(默认 3s)")
|
||
maxParallel := flag.Int("max-parallel", 3, "并发查询的上游数量上限")
|
||
flag.Parse()
|
||
|
||
// 证书
|
||
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
|
||
if err != nil {
|
||
log.Fatalf("failed to load cert: %v", err)
|
||
}
|
||
|
||
// 上游列表
|
||
raw := strings.Split(*upstreamStr, ",")
|
||
upstreams := make([]string, 0, len(raw))
|
||
for _, s := range raw {
|
||
if t := strings.TrimSpace(s); t != "" {
|
||
upstreams = append(upstreams, t)
|
||
}
|
||
}
|
||
if len(upstreams) == 0 {
|
||
log.Fatal("no upstream DNS servers provided")
|
||
}
|
||
|
||
// 全局 DNS 客户端(UDP,扩大 UDPSize;fallback 在查询函数中完成)
|
||
dnsClient = &dns.Client{
|
||
Net: "udp",
|
||
UDPSize: 4096, // 防截断
|
||
Timeout: *timeoutFlag,
|
||
}
|
||
|
||
// context 用于优雅退出与清理协程
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
defer cancel()
|
||
|
||
// 启动缓存清理器
|
||
startCacheCleaner(ctx)
|
||
|
||
// DNS 处理器
|
||
mux := dns.NewServeMux()
|
||
mux.HandleFunc(".", handleDNS(upstreams, *cacheTTLFlag, *timeoutFlag, *maxParallel))
|
||
|
||
// DoT 服务器(TLS 会话缓存 + 安全套件 + TLS1.2+)
|
||
srv := &dns.Server{
|
||
Addr: *addr,
|
||
Net: "tcp-tls",
|
||
TLSConfig: &tls.Config{
|
||
Certificates: []tls.Certificate{cert},
|
||
ClientSessionCache: tls.NewLRUClientSessionCache(256),
|
||
MinVersion: tls.VersionTLS12,
|
||
CipherSuites: []uint16{
|
||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||
},
|
||
},
|
||
Handler: mux,
|
||
}
|
||
|
||
// 捕获退出信号,优雅关闭
|
||
stop := make(chan os.Signal, 1)
|
||
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
|
||
|
||
errCh := make(chan error, 1)
|
||
go func() {
|
||
log.Printf("🚀 Starting DNS-over-TLS server on %s", *addr)
|
||
log.Printf("Upstreams=%v | MaxTTL=%s | Timeout=%s | MaxParallel=%d",
|
||
upstreams, cacheTTLFlag.String(), timeoutFlag.String(), *maxParallel)
|
||
errCh <- srv.ListenAndServe()
|
||
}()
|
||
|
||
select {
|
||
case sig := <-stop:
|
||
log.Printf("[Shutdown] Caught signal: %s", sig)
|
||
cancel() // 结束清理器
|
||
// 优雅关闭服务器
|
||
// miekg/dns 提供 Shutdown();若你的版本支持 ShutdownContext,可改用带 ctx 的版本
|
||
if err := srv.Shutdown(); err != nil {
|
||
log.Printf("[Shutdown] server shutdown error: %v", err)
|
||
}
|
||
case err := <-errCh:
|
||
if err != nil {
|
||
log.Fatalf("server error: %v", err)
|
||
}
|
||
}
|
||
log.Println("[Shutdown] Bye.")
|
||
}
|