feat(cache): 引入 LRU 缓存并优化缓存清理与 TTL 处理 - 使用 github.com/hashicorp/golang-lru/v2 替代原生 sync.Map 实现 LRU 缓存 - 修复缓存读写过程中的并发安全问题,使用 RWMutex 保护共享状态 - 调整缓存键结构注释,明确支持 TTL 和 LRU 策略 - 优化负面缓存 TTL 计算逻辑,更准确识别 NODATA 场景 - 在缓存写入前统一剥离伪 RR(如 OPT、TSIG) - 增加 cache-size 命令行参数,支持配置 LRU 缓存最大条目数 - 移除旧的缓存清理协程中不必要的全量遍历逻辑 - 更新日志输出内容,包含 cache-size 配置项 ```
593 lines
14 KiB
Go
593 lines
14 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"flag"
|
||
"fmt"
|
||
"log"
|
||
"math/rand"
|
||
"os"
|
||
"os/signal"
|
||
"strings"
|
||
"sync"
|
||
"syscall"
|
||
"time"
|
||
|
||
"github.com/miekg/dns"
|
||
|
||
lru "github.com/hashicorp/golang-lru/v2"
|
||
)
|
||
|
||
/******************************************************************
|
||
* 日志初始化
|
||
******************************************************************/
|
||
func initLogger(verbose bool) {
|
||
flags := log.Ldate | log.Ltime | log.Lmicroseconds
|
||
if verbose {
|
||
flags |= log.Lshortfile
|
||
}
|
||
log.SetFlags(flags)
|
||
}
|
||
|
||
/******************************************************************
|
||
* 缓存结构(支持 TTL + LRU)
|
||
******************************************************************/
|
||
|
||
type cacheEntry struct {
|
||
msg *dns.Msg
|
||
expireAt time.Time
|
||
}
|
||
|
||
var (
|
||
cache *lru.Cache[string, *cacheEntry]
|
||
cacheMutex sync.RWMutex
|
||
)
|
||
|
||
const (
|
||
cacheCleanupInterval = 5 * time.Minute
|
||
defaultCacheSize = 10000 // 默认最大缓存条目数
|
||
)
|
||
|
||
// startCacheCleaner 定期清理过期缓存(修复:在删除前二次校验)
|
||
func startCacheCleaner(ctx context.Context) {
|
||
go func() {
|
||
ticker := time.NewTicker(cacheCleanupInterval)
|
||
defer ticker.Stop()
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
return
|
||
case <-ticker.C:
|
||
now := time.Now()
|
||
var toDelete []string
|
||
|
||
cacheMutex.RLock()
|
||
for _, k := range cache.Keys() {
|
||
if v, ok := cache.Peek(k); ok && now.After(v.expireAt) {
|
||
toDelete = append(toDelete, k)
|
||
}
|
||
}
|
||
cacheMutex.RUnlock()
|
||
|
||
if len(toDelete) > 0 {
|
||
pruned := 0
|
||
cacheMutex.Lock()
|
||
for _, k := range toDelete {
|
||
if v, ok := cache.Peek(k); ok && now.After(v.expireAt) {
|
||
cache.Remove(k)
|
||
pruned++
|
||
}
|
||
}
|
||
cacheMutex.Unlock()
|
||
if pruned > 0 {
|
||
log.Printf("[cache] cleaned %d expired entries", pruned)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
|
||
func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
|
||
var b strings.Builder
|
||
b.Grow(len(q.Name) + 32)
|
||
b.WriteString(strings.ToLower(q.Name))
|
||
b.WriteString("|T=")
|
||
b.WriteString(dns.TypeToString[q.Qtype])
|
||
b.WriteString("|C=")
|
||
b.WriteString(dns.ClassToString[q.Qclass])
|
||
if do {
|
||
b.WriteString("|DO")
|
||
}
|
||
if cd {
|
||
b.WriteString("|CD")
|
||
}
|
||
return b.String()
|
||
}
|
||
|
||
func isPseudo(rr dns.RR) bool {
|
||
switch rr.(type) {
|
||
case *dns.OPT, *dns.TSIG:
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
// 读取缓存(修复:Get 在写锁下;在锁外调整 TTL)
|
||
func tryCacheRead(key string) (*dns.Msg, bool) {
|
||
now := time.Now()
|
||
|
||
cacheMutex.Lock()
|
||
e, ok := cache.Get(key) // Get 会更新 LRU,必须在写锁下
|
||
if !ok {
|
||
cacheMutex.Unlock()
|
||
return nil, false
|
||
}
|
||
if now.After(e.expireAt) {
|
||
cache.Remove(key)
|
||
cacheMutex.Unlock()
|
||
return nil, false
|
||
}
|
||
// 拷贝副本,在锁外改 TTL,减少临界区时间
|
||
out := e.msg.Copy()
|
||
expireAt := e.expireAt
|
||
cacheMutex.Unlock()
|
||
|
||
remaining := uint32(expireAt.Sub(now).Seconds())
|
||
if remaining == 0 {
|
||
cacheMutex.Lock()
|
||
cache.Remove(key)
|
||
cacheMutex.Unlock()
|
||
return nil, false
|
||
}
|
||
|
||
for _, sec := range [][]dns.RR{out.Answer, out.Ns, out.Extra} {
|
||
for _, rr := range sec {
|
||
if isPseudo(rr) {
|
||
continue
|
||
}
|
||
if rr.Header().Ttl > remaining {
|
||
rr.Header().Ttl = remaining
|
||
}
|
||
}
|
||
}
|
||
return out, true
|
||
}
|
||
|
||
// 计算负面 TTL
|
||
// hasAnswerForType 判断报文中是否存在回答“请求类型”的 RRset
|
||
func hasAnswerForType(m *dns.Msg, q dns.Question) bool {
|
||
for _, rr := range m.Answer {
|
||
h := rr.Header()
|
||
if h.Rrtype == q.Qtype && strings.EqualFold(h.Name, q.Name) {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// 计算负面 TTL(修复:正确识别 NODATA,包括 CNAME 等场景)
|
||
func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
|
||
// NXDOMAIN:肯定是负面
|
||
if m.Rcode != dns.RcodeNameError {
|
||
// 不是 NXDOMAIN,则仅当 NOERROR 但没有“匹配 QTYPE 的答案”时才是 NODATA
|
||
if m.Rcode != dns.RcodeSuccess || len(m.Question) == 0 || hasAnswerForType(m, m.Question[0]) {
|
||
return 0, false
|
||
}
|
||
}
|
||
|
||
// 按 RFC 2308,从 Authority(Ns)优先取 SOA(多数实现都只放在 Authority)
|
||
var soa *dns.SOA
|
||
for _, rr := range m.Ns {
|
||
if s, ok := rr.(*dns.SOA); ok {
|
||
soa = s
|
||
break
|
||
}
|
||
}
|
||
// 兼容性:偶尔也有人把 SOA 放 Extra(不规范,但为了兼容可以兜底看看)
|
||
if soa == nil {
|
||
for _, rr := range m.Extra {
|
||
if s, ok := rr.(*dns.SOA); ok {
|
||
soa = s
|
||
break
|
||
}
|
||
}
|
||
}
|
||
if soa == nil {
|
||
// 建议:无 SOA 时不做负面缓存(返回 0,false)
|
||
// 如你更希望兜底,可改成:return uint32(maxTTL.Seconds()), true
|
||
return 0, false
|
||
}
|
||
|
||
// 负面 TTL 取 min(SOA.MINIMUM, SOA 自身 TTL),再与配置上限比较
|
||
ttl := soa.Hdr.Ttl
|
||
if soa.Minttl < ttl {
|
||
ttl = soa.Minttl
|
||
}
|
||
capTTL := uint32(maxTTL.Seconds())
|
||
if capTTL > 0 && ttl > capTTL {
|
||
ttl = capTTL
|
||
}
|
||
return ttl, ttl > 0
|
||
}
|
||
|
||
func minRRsetTTL(m *dns.Msg) (uint32, bool) {
|
||
minTTL := uint32(0)
|
||
hasTTL := false
|
||
for _, sec := range [][]dns.RR{m.Answer, m.Ns, m.Extra} {
|
||
for _, rr := range sec {
|
||
if isPseudo(rr) {
|
||
continue
|
||
}
|
||
ttl := rr.Header().Ttl
|
||
if !hasTTL || ttl < minTTL {
|
||
minTTL = ttl
|
||
hasTTL = true
|
||
}
|
||
}
|
||
}
|
||
return minTTL, hasTTL
|
||
}
|
||
|
||
func stripPseudoExtras(m *dns.Msg) {
|
||
if len(m.Extra) == 0 {
|
||
return
|
||
}
|
||
out := m.Extra[:0]
|
||
for _, rr := range m.Extra {
|
||
if isPseudo(rr) {
|
||
continue
|
||
}
|
||
out = append(out, rr)
|
||
}
|
||
m.Extra = out
|
||
}
|
||
|
||
// 写缓存
|
||
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
|
||
if in == nil {
|
||
return
|
||
}
|
||
if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError {
|
||
return
|
||
}
|
||
var ttl uint32
|
||
var ok bool
|
||
if ttl, ok = negativeTTL(in, maxTTL); !ok {
|
||
minTTL, has := minRRsetTTL(in)
|
||
if has {
|
||
cfgTTL := uint32(maxTTL.Seconds())
|
||
if cfgTTL > 0 && minTTL > cfgTTL {
|
||
minTTL = cfgTTL
|
||
}
|
||
ttl = minTTL
|
||
} else {
|
||
ttl = uint32(maxTTL.Seconds())
|
||
}
|
||
}
|
||
if ttl == 0 {
|
||
return
|
||
}
|
||
expire := time.Now().Add(time.Duration(ttl) * time.Second)
|
||
cp := in.Copy()
|
||
stripPseudoExtras(cp)
|
||
cacheMutex.Lock()
|
||
cache.Add(key, &cacheEntry{msg: cp, expireAt: expire})
|
||
cacheMutex.Unlock()
|
||
}
|
||
|
||
/******************************************************************
|
||
* 上游查询
|
||
******************************************************************/
|
||
var udpClient *dns.Client
|
||
|
||
func shuffled(xs []string) []string {
|
||
out := make([]string, len(xs))
|
||
copy(out, xs)
|
||
rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] })
|
||
return out
|
||
}
|
||
|
||
// clampEDNSForUpstream 返回一个 msg 副本,把 EDNS UDP size 夹到给定大小
|
||
func clampEDNSForUpstream(in *dns.Msg, size uint16) *dns.Msg {
|
||
m := in.Copy()
|
||
o := m.IsEdns0()
|
||
if o == nil {
|
||
o = &dns.OPT{}
|
||
o.Hdr.Name = "."
|
||
o.Hdr.Rrtype = dns.TypeOPT
|
||
m.Extra = append(m.Extra, o)
|
||
}
|
||
if size > 0 {
|
||
o.SetUDPSize(size)
|
||
}
|
||
return m
|
||
}
|
||
|
||
func queryUpstreamsLimited(
|
||
ctx context.Context,
|
||
req *dns.Msg,
|
||
upstreams []string,
|
||
timeout time.Duration,
|
||
maxParallel int,
|
||
allowTCPFallback bool,
|
||
) *dns.Msg {
|
||
if maxParallel <= 0 {
|
||
maxParallel = 1
|
||
}
|
||
servers := shuffled(upstreams)
|
||
|
||
cctx, cancel := context.WithTimeout(ctx, timeout)
|
||
defer cancel()
|
||
|
||
type result struct{ msg *dns.Msg }
|
||
ch := make(chan result, len(servers))
|
||
sem := make(chan struct{}, maxParallel)
|
||
|
||
execOne := func(svr string) {
|
||
upReq := clampEDNSForUpstream(req, 1232) // 或做成 flag
|
||
resp, _, err := udpClient.ExchangeContext(cctx, upReq, svr)
|
||
if err == nil && resp != nil && resp.Truncated && allowTCPFallback {
|
||
log.Printf("[upstream] UDP truncated, retry TCP: %s", svr)
|
||
tcpClient := *udpClient
|
||
tcpClient.Net = "tcp"
|
||
|
||
resp, _, err = tcpClient.ExchangeContext(cctx, req.Copy(), svr)
|
||
}
|
||
if err != nil || resp == nil {
|
||
if err != nil && cctx.Err() == nil {
|
||
log.Printf("[upstream] %s error: %v", svr, err)
|
||
}
|
||
return
|
||
}
|
||
if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError {
|
||
return
|
||
}
|
||
select {
|
||
case ch <- result{msg: resp}:
|
||
case <-cctx.Done():
|
||
}
|
||
}
|
||
|
||
for _, s := range servers {
|
||
s := s
|
||
go func() {
|
||
select {
|
||
case sem <- struct{}{}:
|
||
defer func() { <-sem }()
|
||
case <-cctx.Done():
|
||
return
|
||
}
|
||
execOne(s)
|
||
}()
|
||
}
|
||
|
||
for i := 0; i < len(servers); i++ {
|
||
select {
|
||
case r := <-ch:
|
||
if r.msg != nil {
|
||
cancel()
|
||
return r.msg
|
||
}
|
||
case <-cctx.Done():
|
||
log.Printf("[upstream] timeout after %v", timeout)
|
||
return nil
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
/******************************************************************
|
||
* EDNS / 响应构造
|
||
******************************************************************/
|
||
func stripECS(m *dns.Msg) {
|
||
if o := m.IsEdns0(); o != nil {
|
||
var kept []dns.EDNS0
|
||
for _, e := range o.Option {
|
||
if _, isECS := e.(*dns.EDNS0_SUBNET); !isECS {
|
||
kept = append(kept, e)
|
||
}
|
||
}
|
||
o.Option = kept
|
||
}
|
||
}
|
||
|
||
func getDOFlag(m *dns.Msg) bool {
|
||
if o := m.IsEdns0(); o != nil {
|
||
return o.Do()
|
||
}
|
||
return false
|
||
}
|
||
|
||
func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) {
|
||
if upstream == nil {
|
||
dns.HandleFailed(w, req)
|
||
return
|
||
}
|
||
out := new(dns.Msg)
|
||
out.SetReply(req)
|
||
out.Authoritative = false
|
||
out.RecursionAvailable = upstream.RecursionAvailable
|
||
out.AuthenticatedData = upstream.AuthenticatedData
|
||
out.CheckingDisabled = req.CheckingDisabled
|
||
out.Rcode = upstream.Rcode
|
||
out.Answer = upstream.Answer
|
||
out.Ns = upstream.Ns
|
||
|
||
var extras []dns.RR
|
||
for _, rr := range upstream.Extra {
|
||
if !isPseudo(rr) {
|
||
extras = append(extras, rr)
|
||
}
|
||
}
|
||
|
||
if ro := req.IsEdns0(); ro != nil {
|
||
o := new(dns.OPT)
|
||
o.Hdr.Name = "."
|
||
o.Hdr.Rrtype = dns.TypeOPT
|
||
o.SetUDPSize(ro.UDPSize())
|
||
if ro.Do() {
|
||
o.SetDo()
|
||
}
|
||
if uo := upstream.IsEdns0(); uo != nil {
|
||
o.SetExtendedRcode(uint16(uo.ExtendedRcode()))
|
||
o.SetVersion(uint8(uo.Version()))
|
||
for _, opt := range uo.Option {
|
||
if ede, ok := opt.(*dns.EDNS0_EDE); ok {
|
||
o.Option = append(o.Option, ede)
|
||
}
|
||
}
|
||
}
|
||
extras = append(extras, o)
|
||
}
|
||
out.Extra = extras
|
||
out.Compress = true
|
||
|
||
if err := w.WriteMsg(out); err != nil {
|
||
log.Printf("[write] WriteMsg error: %v", err)
|
||
}
|
||
}
|
||
|
||
/******************************************************************
|
||
* 主处理器
|
||
******************************************************************/
|
||
func handleDNS(
|
||
upstreams []string,
|
||
cacheMaxTTL, timeout time.Duration,
|
||
maxParallel int,
|
||
stripECSBeforeForward bool,
|
||
allowTCPFallback bool,
|
||
) dns.HandlerFunc {
|
||
return func(w dns.ResponseWriter, r *dns.Msg) {
|
||
if len(r.Question) == 0 {
|
||
dns.HandleFailed(w, r)
|
||
return
|
||
}
|
||
q := r.Question[0]
|
||
log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)",
|
||
dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr())
|
||
|
||
if stripECSBeforeForward {
|
||
stripECS(r)
|
||
}
|
||
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled)
|
||
|
||
if cached, ok := tryCacheRead(key); ok {
|
||
log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
|
||
writeReply(w, r, cached)
|
||
return
|
||
}
|
||
log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
|
||
|
||
ctx := context.Background()
|
||
resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback)
|
||
if resp == nil {
|
||
log.Printf("[error] all upstreams failed for %s", q.Name)
|
||
dns.HandleFailed(w, r)
|
||
return
|
||
}
|
||
cacheWrite(key, resp, cacheMaxTTL)
|
||
for _, ans := range resp.Answer {
|
||
log.Printf("[answer] %s", ans.String())
|
||
}
|
||
writeReply(w, r, resp)
|
||
}
|
||
}
|
||
|
||
/******************************************************************
|
||
* 主函数
|
||
******************************************************************/
|
||
func main() {
|
||
rand.Seed(time.Now().UnixNano())
|
||
|
||
certFile := flag.String("cert", "server.crt", "TLS 证书文件路径")
|
||
keyFile := flag.String("key", "server.key", "TLS 私钥文件路径")
|
||
addr := flag.String("addr", ":853", "DoT 监听地址")
|
||
upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表")
|
||
cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL")
|
||
cacheSizeFlag := flag.Int("cache-size", defaultCacheSize, "LRU 缓存大小上限")
|
||
timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时")
|
||
maxParallel := flag.Int("max-parallel", 3, "并发上游数量")
|
||
stripECSFlag := flag.Bool("strip-ecs", true, "去除 ECS")
|
||
allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时 TCP 回退")
|
||
verbose := flag.Bool("v", false, "verbose 日志")
|
||
flag.Parse()
|
||
|
||
initLogger(*verbose)
|
||
|
||
var err error
|
||
cache, err = lru.New[string, *cacheEntry](*cacheSizeFlag)
|
||
if err != nil {
|
||
log.Fatalf("[fatal] failed to init LRU cache: %v", err)
|
||
}
|
||
|
||
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
|
||
if err != nil {
|
||
log.Fatalf("[fatal] failed to load cert: %v", err)
|
||
}
|
||
|
||
var upstreams []string
|
||
for _, s := range strings.Split(*upstreamStr, ",") {
|
||
if t := strings.TrimSpace(s); t != "" {
|
||
if !strings.Contains(t, ":") {
|
||
t = fmt.Sprintf("%s:53", t)
|
||
}
|
||
upstreams = append(upstreams, t)
|
||
}
|
||
}
|
||
if len(upstreams) == 0 {
|
||
log.Fatal("[fatal] no upstream DNS servers provided")
|
||
}
|
||
|
||
udpClient = &dns.Client{Net: "udp", UDPSize: 4096, SingleInflight: true}
|
||
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
defer cancel()
|
||
|
||
startCacheCleaner(ctx)
|
||
|
||
mux := dns.NewServeMux()
|
||
mux.HandleFunc(".", handleDNS(upstreams, *cacheTTLFlag, *timeoutFlag, *maxParallel, *stripECSFlag, *allowTCPFallback))
|
||
|
||
srv := &dns.Server{
|
||
Addr: *addr,
|
||
Net: "tcp-tls",
|
||
TLSConfig: &tls.Config{
|
||
Certificates: []tls.Certificate{cert},
|
||
MinVersion: tls.VersionTLS12,
|
||
NextProtos: []string{"dot"},
|
||
},
|
||
Handler: mux,
|
||
ReadTimeout: 10 * time.Second,
|
||
WriteTimeout: 10 * time.Second,
|
||
}
|
||
|
||
stop := make(chan os.Signal, 1)
|
||
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
|
||
|
||
errCh := make(chan error, 1)
|
||
go func() {
|
||
log.Printf("🚀 starting DNS-over-TLS on %s", *addr)
|
||
log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v",
|
||
upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback)
|
||
errCh <- srv.ListenAndServe()
|
||
}()
|
||
|
||
select {
|
||
case sig := <-stop:
|
||
log.Printf("[shutdown] caught signal: %s", sig)
|
||
cancel()
|
||
if err := srv.Shutdown(); err != nil {
|
||
log.Printf("[shutdown] server shutdown error: %v", err)
|
||
}
|
||
case err := <-errCh:
|
||
if err != nil {
|
||
log.Fatalf("[fatal] server error: %v", err)
|
||
}
|
||
}
|
||
|
||
log.Println("[bye] server stopped.")
|
||
}
|