Files
dot/main.go
aixiao 916a7c8127 ```
feat(cache): 引入 LRU 缓存并优化缓存清理与 TTL 处理

- 使用 github.com/hashicorp/golang-lru/v2 替代原生 sync.Map 实现 LRU 缓存
- 修复缓存读写过程中的并发安全问题,使用 RWMutex 保护共享状态
- 调整缓存键结构注释,明确支持 TTL 和 LRU 策略
- 优化负面缓存 TTL 计算逻辑,更准确识别 NODATA 场景
- 在缓存写入前统一剥离伪 RR(如 OPT、TSIG)
- 增加 cache-size 命令行参数,支持配置 LRU 缓存最大条目数
- 移除旧的缓存清理协程中不必要的全量遍历逻辑
- 更新日志输出内容,包含 cache-size 配置项
```
2025-10-14 16:37:39 +08:00

593 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"crypto/tls"
"flag"
"fmt"
"log"
"math/rand"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"github.com/miekg/dns"
lru "github.com/hashicorp/golang-lru/v2"
)
/******************************************************************
* 日志初始化
******************************************************************/
func initLogger(verbose bool) {
flags := log.Ldate | log.Ltime | log.Lmicroseconds
if verbose {
flags |= log.Lshortfile
}
log.SetFlags(flags)
}
/******************************************************************
* 缓存结构(支持 TTL + LRU
******************************************************************/
type cacheEntry struct {
msg *dns.Msg
expireAt time.Time
}
var (
cache *lru.Cache[string, *cacheEntry]
cacheMutex sync.RWMutex
)
const (
cacheCleanupInterval = 5 * time.Minute
defaultCacheSize = 10000 // 默认最大缓存条目数
)
// startCacheCleaner 定期清理过期缓存(修复:在删除前二次校验)
func startCacheCleaner(ctx context.Context) {
go func() {
ticker := time.NewTicker(cacheCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
var toDelete []string
cacheMutex.RLock()
for _, k := range cache.Keys() {
if v, ok := cache.Peek(k); ok && now.After(v.expireAt) {
toDelete = append(toDelete, k)
}
}
cacheMutex.RUnlock()
if len(toDelete) > 0 {
pruned := 0
cacheMutex.Lock()
for _, k := range toDelete {
if v, ok := cache.Peek(k); ok && now.After(v.expireAt) {
cache.Remove(k)
pruned++
}
}
cacheMutex.Unlock()
if pruned > 0 {
log.Printf("[cache] cleaned %d expired entries", pruned)
}
}
}
}
}()
}
func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
var b strings.Builder
b.Grow(len(q.Name) + 32)
b.WriteString(strings.ToLower(q.Name))
b.WriteString("|T=")
b.WriteString(dns.TypeToString[q.Qtype])
b.WriteString("|C=")
b.WriteString(dns.ClassToString[q.Qclass])
if do {
b.WriteString("|DO")
}
if cd {
b.WriteString("|CD")
}
return b.String()
}
func isPseudo(rr dns.RR) bool {
switch rr.(type) {
case *dns.OPT, *dns.TSIG:
return true
default:
return false
}
}
// 读取缓存修复Get 在写锁下;在锁外调整 TTL
func tryCacheRead(key string) (*dns.Msg, bool) {
now := time.Now()
cacheMutex.Lock()
e, ok := cache.Get(key) // Get 会更新 LRU必须在写锁下
if !ok {
cacheMutex.Unlock()
return nil, false
}
if now.After(e.expireAt) {
cache.Remove(key)
cacheMutex.Unlock()
return nil, false
}
// 拷贝副本,在锁外改 TTL减少临界区时间
out := e.msg.Copy()
expireAt := e.expireAt
cacheMutex.Unlock()
remaining := uint32(expireAt.Sub(now).Seconds())
if remaining == 0 {
cacheMutex.Lock()
cache.Remove(key)
cacheMutex.Unlock()
return nil, false
}
for _, sec := range [][]dns.RR{out.Answer, out.Ns, out.Extra} {
for _, rr := range sec {
if isPseudo(rr) {
continue
}
if rr.Header().Ttl > remaining {
rr.Header().Ttl = remaining
}
}
}
return out, true
}
// 计算负面 TTL
// hasAnswerForType 判断报文中是否存在回答“请求类型”的 RRset
func hasAnswerForType(m *dns.Msg, q dns.Question) bool {
for _, rr := range m.Answer {
h := rr.Header()
if h.Rrtype == q.Qtype && strings.EqualFold(h.Name, q.Name) {
return true
}
}
return false
}
// 计算负面 TTL修复正确识别 NODATA包括 CNAME 等场景)
func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
// NXDOMAIN肯定是负面
if m.Rcode != dns.RcodeNameError {
// 不是 NXDOMAIN则仅当 NOERROR 但没有“匹配 QTYPE 的答案”时才是 NODATA
if m.Rcode != dns.RcodeSuccess || len(m.Question) == 0 || hasAnswerForType(m, m.Question[0]) {
return 0, false
}
}
// 按 RFC 2308从 AuthorityNs优先取 SOA多数实现都只放在 Authority
var soa *dns.SOA
for _, rr := range m.Ns {
if s, ok := rr.(*dns.SOA); ok {
soa = s
break
}
}
// 兼容性:偶尔也有人把 SOA 放 Extra不规范但为了兼容可以兜底看看
if soa == nil {
for _, rr := range m.Extra {
if s, ok := rr.(*dns.SOA); ok {
soa = s
break
}
}
}
if soa == nil {
// 建议:无 SOA 时不做负面缓存(返回 0,false
// 如你更希望兜底可改成return uint32(maxTTL.Seconds()), true
return 0, false
}
// 负面 TTL 取 min(SOA.MINIMUM, SOA 自身 TTL),再与配置上限比较
ttl := soa.Hdr.Ttl
if soa.Minttl < ttl {
ttl = soa.Minttl
}
capTTL := uint32(maxTTL.Seconds())
if capTTL > 0 && ttl > capTTL {
ttl = capTTL
}
return ttl, ttl > 0
}
func minRRsetTTL(m *dns.Msg) (uint32, bool) {
minTTL := uint32(0)
hasTTL := false
for _, sec := range [][]dns.RR{m.Answer, m.Ns, m.Extra} {
for _, rr := range sec {
if isPseudo(rr) {
continue
}
ttl := rr.Header().Ttl
if !hasTTL || ttl < minTTL {
minTTL = ttl
hasTTL = true
}
}
}
return minTTL, hasTTL
}
func stripPseudoExtras(m *dns.Msg) {
if len(m.Extra) == 0 {
return
}
out := m.Extra[:0]
for _, rr := range m.Extra {
if isPseudo(rr) {
continue
}
out = append(out, rr)
}
m.Extra = out
}
// 写缓存
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
if in == nil {
return
}
if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError {
return
}
var ttl uint32
var ok bool
if ttl, ok = negativeTTL(in, maxTTL); !ok {
minTTL, has := minRRsetTTL(in)
if has {
cfgTTL := uint32(maxTTL.Seconds())
if cfgTTL > 0 && minTTL > cfgTTL {
minTTL = cfgTTL
}
ttl = minTTL
} else {
ttl = uint32(maxTTL.Seconds())
}
}
if ttl == 0 {
return
}
expire := time.Now().Add(time.Duration(ttl) * time.Second)
cp := in.Copy()
stripPseudoExtras(cp)
cacheMutex.Lock()
cache.Add(key, &cacheEntry{msg: cp, expireAt: expire})
cacheMutex.Unlock()
}
/******************************************************************
* 上游查询
******************************************************************/
var udpClient *dns.Client
func shuffled(xs []string) []string {
out := make([]string, len(xs))
copy(out, xs)
rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] })
return out
}
// clampEDNSForUpstream 返回一个 msg 副本,把 EDNS UDP size 夹到给定大小
func clampEDNSForUpstream(in *dns.Msg, size uint16) *dns.Msg {
m := in.Copy()
o := m.IsEdns0()
if o == nil {
o = &dns.OPT{}
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
m.Extra = append(m.Extra, o)
}
if size > 0 {
o.SetUDPSize(size)
}
return m
}
func queryUpstreamsLimited(
ctx context.Context,
req *dns.Msg,
upstreams []string,
timeout time.Duration,
maxParallel int,
allowTCPFallback bool,
) *dns.Msg {
if maxParallel <= 0 {
maxParallel = 1
}
servers := shuffled(upstreams)
cctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
type result struct{ msg *dns.Msg }
ch := make(chan result, len(servers))
sem := make(chan struct{}, maxParallel)
execOne := func(svr string) {
upReq := clampEDNSForUpstream(req, 1232) // 或做成 flag
resp, _, err := udpClient.ExchangeContext(cctx, upReq, svr)
if err == nil && resp != nil && resp.Truncated && allowTCPFallback {
log.Printf("[upstream] UDP truncated, retry TCP: %s", svr)
tcpClient := *udpClient
tcpClient.Net = "tcp"
resp, _, err = tcpClient.ExchangeContext(cctx, req.Copy(), svr)
}
if err != nil || resp == nil {
if err != nil && cctx.Err() == nil {
log.Printf("[upstream] %s error: %v", svr, err)
}
return
}
if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError {
return
}
select {
case ch <- result{msg: resp}:
case <-cctx.Done():
}
}
for _, s := range servers {
s := s
go func() {
select {
case sem <- struct{}{}:
defer func() { <-sem }()
case <-cctx.Done():
return
}
execOne(s)
}()
}
for i := 0; i < len(servers); i++ {
select {
case r := <-ch:
if r.msg != nil {
cancel()
return r.msg
}
case <-cctx.Done():
log.Printf("[upstream] timeout after %v", timeout)
return nil
}
}
return nil
}
/******************************************************************
* EDNS / 响应构造
******************************************************************/
func stripECS(m *dns.Msg) {
if o := m.IsEdns0(); o != nil {
var kept []dns.EDNS0
for _, e := range o.Option {
if _, isECS := e.(*dns.EDNS0_SUBNET); !isECS {
kept = append(kept, e)
}
}
o.Option = kept
}
}
func getDOFlag(m *dns.Msg) bool {
if o := m.IsEdns0(); o != nil {
return o.Do()
}
return false
}
func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) {
if upstream == nil {
dns.HandleFailed(w, req)
return
}
out := new(dns.Msg)
out.SetReply(req)
out.Authoritative = false
out.RecursionAvailable = upstream.RecursionAvailable
out.AuthenticatedData = upstream.AuthenticatedData
out.CheckingDisabled = req.CheckingDisabled
out.Rcode = upstream.Rcode
out.Answer = upstream.Answer
out.Ns = upstream.Ns
var extras []dns.RR
for _, rr := range upstream.Extra {
if !isPseudo(rr) {
extras = append(extras, rr)
}
}
if ro := req.IsEdns0(); ro != nil {
o := new(dns.OPT)
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
o.SetUDPSize(ro.UDPSize())
if ro.Do() {
o.SetDo()
}
if uo := upstream.IsEdns0(); uo != nil {
o.SetExtendedRcode(uint16(uo.ExtendedRcode()))
o.SetVersion(uint8(uo.Version()))
for _, opt := range uo.Option {
if ede, ok := opt.(*dns.EDNS0_EDE); ok {
o.Option = append(o.Option, ede)
}
}
}
extras = append(extras, o)
}
out.Extra = extras
out.Compress = true
if err := w.WriteMsg(out); err != nil {
log.Printf("[write] WriteMsg error: %v", err)
}
}
/******************************************************************
* 主处理器
******************************************************************/
func handleDNS(
upstreams []string,
cacheMaxTTL, timeout time.Duration,
maxParallel int,
stripECSBeforeForward bool,
allowTCPFallback bool,
) dns.HandlerFunc {
return func(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
dns.HandleFailed(w, r)
return
}
q := r.Question[0]
log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)",
dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr())
if stripECSBeforeForward {
stripECS(r)
}
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled)
if cached, ok := tryCacheRead(key); ok {
log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
writeReply(w, r, cached)
return
}
log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
ctx := context.Background()
resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback)
if resp == nil {
log.Printf("[error] all upstreams failed for %s", q.Name)
dns.HandleFailed(w, r)
return
}
cacheWrite(key, resp, cacheMaxTTL)
for _, ans := range resp.Answer {
log.Printf("[answer] %s", ans.String())
}
writeReply(w, r, resp)
}
}
/******************************************************************
* 主函数
******************************************************************/
func main() {
rand.Seed(time.Now().UnixNano())
certFile := flag.String("cert", "server.crt", "TLS 证书文件路径")
keyFile := flag.String("key", "server.key", "TLS 私钥文件路径")
addr := flag.String("addr", ":853", "DoT 监听地址")
upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表")
cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL")
cacheSizeFlag := flag.Int("cache-size", defaultCacheSize, "LRU 缓存大小上限")
timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时")
maxParallel := flag.Int("max-parallel", 3, "并发上游数量")
stripECSFlag := flag.Bool("strip-ecs", true, "去除 ECS")
allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时 TCP 回退")
verbose := flag.Bool("v", false, "verbose 日志")
flag.Parse()
initLogger(*verbose)
var err error
cache, err = lru.New[string, *cacheEntry](*cacheSizeFlag)
if err != nil {
log.Fatalf("[fatal] failed to init LRU cache: %v", err)
}
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
if err != nil {
log.Fatalf("[fatal] failed to load cert: %v", err)
}
var upstreams []string
for _, s := range strings.Split(*upstreamStr, ",") {
if t := strings.TrimSpace(s); t != "" {
if !strings.Contains(t, ":") {
t = fmt.Sprintf("%s:53", t)
}
upstreams = append(upstreams, t)
}
}
if len(upstreams) == 0 {
log.Fatal("[fatal] no upstream DNS servers provided")
}
udpClient = &dns.Client{Net: "udp", UDPSize: 4096, SingleInflight: true}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
startCacheCleaner(ctx)
mux := dns.NewServeMux()
mux.HandleFunc(".", handleDNS(upstreams, *cacheTTLFlag, *timeoutFlag, *maxParallel, *stripECSFlag, *allowTCPFallback))
srv := &dns.Server{
Addr: *addr,
Net: "tcp-tls",
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
NextProtos: []string{"dot"},
},
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
stop := make(chan os.Signal, 1)
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
errCh := make(chan error, 1)
go func() {
log.Printf("🚀 starting DNS-over-TLS on %s", *addr)
log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v",
upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback)
errCh <- srv.ListenAndServe()
}()
select {
case sig := <-stop:
log.Printf("[shutdown] caught signal: %s", sig)
cancel()
if err := srv.Shutdown(); err != nil {
log.Printf("[shutdown] server shutdown error: %v", err)
}
case err := <-errCh:
if err != nil {
log.Fatalf("[fatal] server error: %v", err)
}
}
log.Println("[bye] server stopped.")
}