Files
dot/main.go
aixiao d540b302f1 ```
feat(cache): 优化缓存键生成与缓存写入逻辑

- 引入更精确的缓存键计算方式,包含 QTYPE、QCLASS、DO 和 CD 标志
- 实现负面缓存(NXDOMAIN/NODATA)支持,遵循 RFC 2308 规范
- 改进缓存清理机制,在 TTL 为 0 时主动删除过期条目
- 添加日志初始化函数,支持 verbose 模式显示源码位置
- 重构上游查询逻辑,支持 context 控制超时和 TCP 回退
- 增加 ECS(EDNS Client Subnet)剥离选项以增强隐私保护
- 调整命令行参数默认值及日志输出格式,提升可读性与调试体验
```
2025-10-14 10:28:00 +08:00

509 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"crypto/tls"
"flag"
"fmt"
"log"
"math/rand"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"github.com/miekg/dns"
)
/******************************************************************
* 日志初始化
******************************************************************/
func initLogger(verbose bool) {
flags := log.Ldate | log.Ltime | log.Lmicroseconds
if verbose {
flags |= log.Lshortfile
}
log.SetFlags(flags)
}
/******************************************************************
* 缓存结构
******************************************************************/
type cacheEntry struct {
msg *dns.Msg // 上游完整响应(拷贝存储)
expireAt time.Time // 过期时间
}
var cache sync.Map
const cacheCleanupInterval = 5 * time.Minute
func startCacheCleaner(ctx context.Context) {
go func() {
ticker := time.NewTicker(cacheCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
n := 0
cache.Range(func(k, v any) bool {
e := v.(*cacheEntry)
if now.After(e.expireAt) {
cache.Delete(k)
n++
}
return true
})
if n > 0 {
log.Printf("[cache] cleaned %d expired entries", n)
}
}
}
}()
}
// 计算缓存键name + type + class + DO + CD
func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
var b strings.Builder
b.Grow(len(q.Name) + 32)
b.WriteString(strings.ToLower(q.Name))
b.WriteString("|T=")
b.WriteString(dns.TypeToString[q.Qtype])
b.WriteString("|C=")
b.WriteString(dns.ClassToString[q.Qclass])
if do {
b.WriteString("|DO")
}
if cd {
b.WriteString("|CD")
}
return b.String()
}
// 命中缓存:回填剩余 TTL不超过记录自身 TTL
func tryCacheRead(key string) (*dns.Msg, bool) {
v, ok := cache.Load(key)
if !ok {
return nil, false
}
e := v.(*cacheEntry)
now := time.Now()
if now.After(e.expireAt) {
cache.Delete(key)
return nil, false
}
out := e.msg.Copy()
remaining := uint32(e.expireAt.Sub(now).Seconds())
if remaining == 0 {
cache.Delete(key)
return nil, false
}
for i := range out.Answer {
if out.Answer[i].Header().Ttl > remaining {
out.Answer[i].Header().Ttl = remaining
}
}
for i := range out.Ns {
if out.Ns[i].Header().Ttl > remaining {
out.Ns[i].Header().Ttl = remaining
}
}
for i := range out.Extra {
if out.Extra[i].Header().Ttl > remaining {
out.Extra[i].Header().Ttl = remaining
}
}
return out, true
}
// 负面缓存 TTLRFC 2308NXDOMAIN 或 NODATA 使用 SOA.MINIMUM 与 SOA TTL 的较小者,再与配置上限取 min
func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
// NXDOMAIN或 NOERROR 但 Answer 为空NODATA
if m.Rcode != dns.RcodeNameError && !(m.Rcode == dns.RcodeSuccess && len(m.Answer) == 0) {
return 0, false
}
var soa *dns.SOA
for _, rr := range append(m.Ns, m.Extra...) {
if s, ok := rr.(*dns.SOA); ok {
soa = s
break
}
}
if soa == nil {
return 0, false
}
ttl := soa.Hdr.Ttl
if soa.Minttl < ttl {
ttl = soa.Minttl
}
capTTL := uint32(maxTTL.Seconds())
if capTTL > 0 && ttl > capTTL {
ttl = capTTL
}
return ttl, ttl > 0
}
// 普通(正向)响应的最小 TTLAnswer/Ns/Extra
func minRRsetTTL(m *dns.Msg) (uint32, bool) {
minTTL := uint32(0)
hasTTL := false
for _, sec := range [][]dns.RR{m.Answer, m.Ns, m.Extra} {
for _, rr := range sec {
ttl := rr.Header().Ttl
if !hasTTL || ttl < minTTL {
minTTL = ttl
hasTTL = true
}
}
}
return minTTL, hasTTL
}
// 写缓存:先处理负面缓存,再处理正面缓存
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
if in == nil {
return
}
// 仅缓存 NOERROR / NXDOMAIN其余不缓存
if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError {
return
}
// 负面缓存
if ttl, ok := negativeTTL(in, maxTTL); ok && ttl > 0 {
expire := time.Now().Add(time.Duration(ttl) * time.Second)
cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire})
return
}
// 正向缓存minTTL 与 maxTTL 取较小
minTTL, ok := minRRsetTTL(in)
if !ok {
// 没有 TTL 时可用上限兜底(也可选择不缓存)
if maxTTL > 0 {
expire := time.Now().Add(maxTTL)
cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire})
}
return
}
cfgTTL := uint32(maxTTL.Seconds())
finalTTL := minTTL
if cfgTTL > 0 && finalTTL > cfgTTL {
finalTTL = cfgTTL
}
if finalTTL == 0 {
return
}
expire := time.Now().Add(time.Duration(finalTTL) * time.Second)
cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire})
}
/******************************************************************
* 上游查询(带 context 取消、并发上限、UDP→TCP 回退)
******************************************************************/
// 全局可复用 UDP 客户端
var udpClient *dns.Client
func shuffled(xs []string) []string {
out := make([]string, len(xs))
copy(out, xs)
rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] })
return out
}
func queryUpstreamsLimited(
ctx context.Context,
req *dns.Msg,
upstreams []string,
timeout time.Duration,
maxParallel int,
allowTCPFallback bool,
) *dns.Msg {
if maxParallel <= 0 {
maxParallel = 1
}
servers := shuffled(upstreams)
// 每次查询一个带超时的子 context拿到首个有效结果后取消。
cctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
type result struct {
msg *dns.Msg
}
ch := make(chan result, len(servers))
sem := make(chan struct{}, maxParallel)
execOne := func(svr string) {
defer func() { <-sem }()
// UDP 查询(带 context
resp, _, err := udpClient.ExchangeContext(cctx, req, svr)
if err == nil && resp != nil && resp.Truncated && allowTCPFallback {
// TCP 回退
log.Printf("[upstream] UDP truncated, retry TCP: %s", svr)
tcpClient := *udpClient
tcpClient.Net = "tcp"
resp, _, err = tcpClient.ExchangeContext(cctx, req, svr)
}
if err != nil || resp == nil {
if err != nil {
if cctx.Err() == nil {
log.Printf("[upstream] %s error: %v", svr, err)
}
} else {
log.Printf("[upstream] %s nil response", svr)
}
return
}
// 丢弃 SERVFAIL
if resp.Rcode == dns.RcodeServerFailure {
return
}
select {
case ch <- result{msg: resp}:
case <-cctx.Done():
}
}
for _, s := range servers {
sem <- struct{}{}
go execOne(s)
}
// 返回第一个非空结果,并 cancel 其他 goroutine
for i := 0; i < len(servers); i++ {
select {
case r := <-ch:
if r.msg != nil {
cancel()
return r.msg
}
case <-cctx.Done():
log.Printf("[upstream] timeout after %v", timeout)
return nil
}
}
return nil
}
/******************************************************************
* EDNS(0) / ECS 处理
******************************************************************/
// 去除 EDNS Client Subnet避免缓存污染与隐私泄露
func stripECS(m *dns.Msg) {
if o := m.IsEdns0(); o != nil {
var kept []dns.EDNS0
for _, e := range o.Option {
if _, isECS := e.(*dns.EDNS0_SUBNET); !isECS {
kept = append(kept, e)
}
}
o.Option = kept
}
}
// 获取 DO/EDNS
func getDOFlag(m *dns.Msg) bool {
if o := m.IsEdns0(); o != nil {
return o.Do()
}
return false
}
/******************************************************************
* 响应构造:使用客户端请求头构造 reply复制上游内容
******************************************************************/
func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) {
if upstream == nil {
dns.HandleFailed(w, req)
return
}
out := new(dns.Msg)
out.SetReply(req)
out.Authoritative = false
out.RecursionAvailable = upstream.RecursionAvailable
out.AuthenticatedData = upstream.AuthenticatedData
out.CheckingDisabled = req.CheckingDisabled // 反映客户端 CD 位
out.Rcode = upstream.Rcode
out.Answer = upstream.Answer
out.Ns = upstream.Ns
out.Extra = upstream.Extra
out.Compress = true
if err := w.WriteMsg(out); err != nil {
log.Printf("[write] WriteMsg error: %v", err)
}
}
/******************************************************************
* 处理器
******************************************************************/
func handleDNS(
upstreams []string,
cacheMaxTTL, timeout time.Duration,
maxParallel int,
stripECSBeforeForward bool,
allowTCPFallback bool,
) dns.HandlerFunc {
return func(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
dns.HandleFailed(w, r)
return
}
q := r.Question[0]
// 记录请求
log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)",
dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr())
// 可选:去除 ECS推荐
if stripECSBeforeForward {
stripECS(r)
}
// 缓存键
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled)
// 1) 缓存命中
if cached, ok := tryCacheRead(key); ok {
log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
writeReply(w, r, cached)
return
}
log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
// 2) 上游查询(带 context 取消 & TCP 可选回退)
ctx := context.Background()
resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback)
if resp == nil {
log.Printf("[error] all upstreams failed for %s", q.Name)
dns.HandleFailed(w, r)
return
}
// 3) 写缓存
cacheWrite(key, resp, cacheMaxTTL)
// 4) 回写
for _, ans := range resp.Answer {
log.Printf("[answer] %s", ans.String())
}
writeReply(w, r, resp)
}
}
/******************************************************************
* 主程序
******************************************************************/
func main() {
rand.Seed(time.Now().UnixNano())
// 参数
certFile := flag.String("cert", "server.crt", "TLS 证书文件路径 (.cer/.crt)")
keyFile := flag.String("key", "server.key", "TLS 私钥文件路径 (.key)")
addr := flag.String("addr", ":853", "DoT 监听地址(默认 :853")
upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表(逗号分隔)")
cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL默认 60s实际取 min(上游最小TTL, 本值)")
timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时(默认 3s")
maxParallel := flag.Int("max-parallel", 3, "并发查询的上游数量上限")
stripECSFlag := flag.Bool("strip-ecs", true, "转发上游前去除 EDNS Client Subnet")
allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时允许 TCP 回退")
verbose := flag.Bool("v", false, "verbose 日志(包含源码位置)")
flag.Parse()
initLogger(*verbose)
// 证书
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
if err != nil {
log.Fatalf("[fatal] failed to load cert: %v", err)
}
// 上游
var upstreams []string
for _, s := range strings.Split(*upstreamStr, ",") {
if t := strings.TrimSpace(s); t != "" {
if !strings.Contains(t, ":") {
t = fmt.Sprintf("%s:53", t)
}
upstreams = append(upstreams, t)
}
}
if len(upstreams) == 0 {
log.Fatal("[fatal] no upstream DNS servers provided")
}
// 全局 UDP 客户端UDPSize 放大;不设置 Timeout使用 ExchangeContext 控制)
udpClient = &dns.Client{
Net: "udp",
UDPSize: 4096,
}
// context 用于优雅退出
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// 启动缓存清理器
startCacheCleaner(ctx)
// mux/handler
mux := dns.NewServeMux()
mux.HandleFunc(".", handleDNS(
upstreams,
*cacheTTLFlag,
*timeoutFlag,
*maxParallel,
*stripECSFlag,
*allowTCPFallback,
))
// DoT 服务(启用 TLS1.3;不手动设 CipherSuites 以使用 Go 默认安全套件)
srv := &dns.Server{
Addr: *addr,
Net: "tcp-tls",
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12, // 允许 1.2/1.3(默认启用 1.3
// 不设置 CipherSuites交由 Go 自动选择TLS1.3 有自身套件)
},
Handler: mux,
}
// 捕获信号优雅退出
stop := make(chan os.Signal, 1)
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
errCh := make(chan error, 1)
go func() {
log.Printf("🚀 starting DNS-over-TLS on %s", *addr)
log.Printf(" upstreams=%v | cache_max_ttl=%s | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v",
upstreams, cacheTTLFlag.String(), timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback)
errCh <- srv.ListenAndServe()
}()
select {
case sig := <-stop:
log.Printf("[shutdown] caught signal: %s", sig)
cancel()
// miekg/dns 提供 Shutdown();部分版本没有 ShutdownContext这里用 Shutdown()
if err := srv.Shutdown(); err != nil {
log.Printf("[shutdown] server shutdown error: %v", err)
}
case err := <-errCh:
if err != nil {
log.Fatalf("[fatal] server error: %v", err)
}
}
log.Println("[bye] server stopped.")
}