Files
dot/main.go
aixiao 4060e83686 ```
docs(readme): 重构 README 文档结构与内容以提升可读性

- 更新项目标题图标并优化描述语句
- 重新组织特性列表为表格形式,增加 ECS 剥离、黑名单过滤等功能说明
- 补充快速开始章节,细化源码构建与 Docker 使用方式
- 调整参数说明表,新增黑名单相关配置项及缓存条目限制
- 增加缓存机制详解、黑名单功能使用示例与架构图
- 更新开发依赖信息与推荐编译参数
- 修正作者信息展示格式并添加仓库链接

feat(cache): 改进缓存键生成逻辑与 EDNS 元数据处理

- 使用 dns.CanonicalName 规范化域名避免重复缓存键
- 缓存条目中保存 EDNS 扩展信息(version, rcode, EDE)
- 修复缓存读取函数返回值,传递完整缓存元数据
- 调整 TTL 计算优先级,仅在必要时检查 Extra 区域
- 黑名单匹配提前拦截请求,跳过上游查询
- 启动日志中显示黑名单规则数量与返回码设置
```
2025-10-15 14:19:55 +08:00

677 lines
17 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"crypto/tls"
"flag"
"fmt"
"log"
"math/rand"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"github.com/miekg/dns"
lru "github.com/hashicorp/golang-lru/v2"
)
/******************************************************************
* 日志初始化
******************************************************************/
func initLogger(verbose bool) {
flags := log.Ldate | log.Ltime | log.Lmicroseconds
if verbose {
flags |= log.Lshortfile
}
log.SetFlags(flags)
}
/******************************************************************
* 缓存结构(支持 TTL + LRU
******************************************************************/
type cacheEntry struct {
msg *dns.Msg
expireAt time.Time
// EDNS metadata (to reproduce Extended RCODE / EDE on cache hits)
ednsPresent bool
ednsVersion uint8
ednsExtRcode uint16
ednsEDE []*dns.EDNS0_EDE
}
var (
cache *lru.Cache[string, *cacheEntry]
cacheMutex sync.RWMutex
)
const (
cacheCleanupInterval = 5 * time.Minute
defaultCacheSize = 10000 // 默认最大缓存条目数
)
// startCacheCleaner 定期清理过期缓存(修复:在删除前二次校验)
func startCacheCleaner(ctx context.Context) {
go func() {
ticker := time.NewTicker(cacheCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
var toDelete []string
cacheMutex.RLock()
for _, k := range cache.Keys() {
if v, ok := cache.Peek(k); ok && now.After(v.expireAt) {
toDelete = append(toDelete, k)
}
}
cacheMutex.RUnlock()
if len(toDelete) > 0 {
pruned := 0
cacheMutex.Lock()
for _, k := range toDelete {
if v, ok := cache.Peek(k); ok && now.After(v.expireAt) {
cache.Remove(k)
pruned++
}
}
cacheMutex.Unlock()
if pruned > 0 {
log.Printf("[cache] cleaned %d expired entries", pruned)
}
}
}
}
}()
}
func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
var b strings.Builder
b.Grow(len(q.Name) + 32)
// 采用规范化域名,避免尾随点/IDNA/大小写造成的重复键
b.WriteString(dns.CanonicalName(q.Name))
b.WriteString("|T=")
b.WriteString(dns.TypeToString[q.Qtype])
b.WriteString("|C=")
b.WriteString(dns.ClassToString[q.Qclass])
if do {
b.WriteString("|DO")
}
if cd {
b.WriteString("|CD")
}
return b.String()
}
func isPseudo(rr dns.RR) bool {
switch rr.(type) {
case *dns.OPT, *dns.TSIG:
return true
default:
return false
}
}
// clone and extract EDNS metadata (present, version, ext-rcode, all EDEs)
func cloneEDE(in *dns.EDNS0_EDE) *dns.EDNS0_EDE {
if in == nil {
return nil
}
cp := *in
return &cp
}
func extractEDNSMeta(m *dns.Msg) (present bool, version uint8, ext uint16, ede []*dns.EDNS0_EDE) {
if o := m.IsEdns0(); o != nil {
present = true
version = o.Version()
ext = uint16(o.ExtendedRcode())
for _, opt := range o.Option {
if e, ok := opt.(*dns.EDNS0_EDE); ok {
ede = append(ede, cloneEDE(e))
}
}
}
return
}
// 读取缓存修复Get 在写锁下;在锁外调整 TTL返回 EDNS 元数据)
func tryCacheRead(key string) (*dns.Msg, *cacheEntry, bool) {
now := time.Now()
cacheMutex.Lock()
e, ok := cache.Get(key) // Get 会更新 LRU必须在写锁下
if !ok {
cacheMutex.Unlock()
return nil, nil, false
}
if now.After(e.expireAt) {
cache.Remove(key)
cacheMutex.Unlock()
return nil, nil, false
}
// 拷贝副本,在锁外改 TTL减少临界区时间
out := e.msg.Copy()
expireAt := e.expireAt
cacheMutex.Unlock()
remaining := uint32(expireAt.Sub(now).Seconds())
if remaining == 0 {
cacheMutex.Lock()
cache.Remove(key)
cacheMutex.Unlock()
return nil, nil, false
}
for _, sec := range [][]dns.RR{out.Answer, out.Ns, out.Extra} {
for _, rr := range sec {
if isPseudo(rr) {
continue
}
if rr.Header().Ttl > remaining {
rr.Header().Ttl = remaining
}
}
}
return out, e, true
}
// 计算负面 TTL
// hasAnswerForType 判断报文中是否存在回答“请求类型”的 RRset
func hasAnswerForType(m *dns.Msg, q dns.Question) bool {
for _, rr := range m.Answer {
h := rr.Header()
if h.Rrtype == q.Qtype && strings.EqualFold(h.Name, q.Name) {
return true
}
}
return false
}
// 计算负面 TTL修复正确识别 NODATA包括 CNAME 等场景)
func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
// NXDOMAIN肯定是负面
if m.Rcode != dns.RcodeNameError {
// 不是 NXDOMAIN则仅当 NOERROR 但没有“匹配 QTYPE 的答案”时才是 NODATA
if m.Rcode != dns.RcodeSuccess || len(m.Question) == 0 || hasAnswerForType(m, m.Question[0]) {
return 0, false
}
}
// 按 RFC 2308从 AuthorityNs优先取 SOA多数实现都只放在 Authority
var soa *dns.SOA
for _, rr := range m.Ns {
if s, ok := rr.(*dns.SOA); ok {
soa = s
break
}
}
// 兼容性:偶尔也有人把 SOA 放 Extra不规范但为了兼容可以兜底看看
if soa == nil {
for _, rr := range m.Extra {
if s, ok := rr.(*dns.SOA); ok {
soa = s
break
}
}
}
if soa == nil {
// 建议:无 SOA 时不做负面缓存(返回 0,false
return 0, false
}
// 负面 TTL 取 min(SOA.MINIMUM, SOA 自身 TTL),再与配置上限比较
ttl := soa.Hdr.Ttl
if soa.Minttl < ttl {
ttl = soa.Minttl
}
capTTL := uint32(maxTTL.Seconds())
if capTTL > 0 && ttl > capTTL {
ttl = capTTL
}
return ttl, ttl > 0
}
func minRRsetTTL(m *dns.Msg) (uint32, bool) {
minTTL := uint32(0)
hasTTL := false
// 优先 Answer -> Ns若都为空再考虑 Extra排除伪记录
for _, sec := range [][]dns.RR{m.Answer, m.Ns} {
for _, rr := range sec {
if isPseudo(rr) {
continue
}
ttl := rr.Header().Ttl
if !hasTTL || ttl < minTTL {
minTTL = ttl
hasTTL = true
}
}
}
if !hasTTL {
for _, rr := range m.Extra {
if isPseudo(rr) {
continue
}
ttl := rr.Header().Ttl
if !hasTTL || ttl < minTTL {
minTTL = ttl
hasTTL = true
}
}
}
return minTTL, hasTTL
}
func stripPseudoExtras(m *dns.Msg) {
if len(m.Extra) == 0 {
return
}
out := m.Extra[:0]
for _, rr := range m.Extra {
if isPseudo(rr) {
continue
}
out = append(out, rr)
}
m.Extra = out
}
// 写缓存(保存 EDNS 元数据,命中时可重建扩展 RCODE/EDE
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
if in == nil {
return
}
if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError {
return
}
var ttl uint32
var ok bool
if ttl, ok = negativeTTL(in, maxTTL); !ok {
minTTL, has := minRRsetTTL(in)
if has {
cfgTTL := uint32(maxTTL.Seconds())
if cfgTTL > 0 && minTTL > cfgTTL {
minTTL = cfgTTL
}
ttl = minTTL
} else {
ttl = uint32(maxTTL.Seconds())
}
}
if ttl == 0 {
return
}
expire := time.Now().Add(time.Duration(ttl) * time.Second)
cp := in.Copy()
// 提取 EDNS 元数据后再剥离伪记录
present, ver, ext, ede := extractEDNSMeta(cp)
stripPseudoExtras(cp)
cacheMutex.Lock()
cache.Add(key, &cacheEntry{
msg: cp,
expireAt: expire,
ednsPresent: present,
ednsVersion: ver,
ednsExtRcode: ext,
ednsEDE: ede,
})
cacheMutex.Unlock()
}
/******************************************************************
* 上游查询
******************************************************************/
var udpClient *dns.Client
func shuffled(xs []string) []string {
out := make([]string, len(xs))
copy(out, xs)
rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] })
return out
}
// clampEDNSForUpstream 返回一个 msg 副本,把 EDNS UDP size 夹到给定大小
func clampEDNSForUpstream(in *dns.Msg, size uint16) *dns.Msg {
m := in.Copy()
o := m.IsEdns0()
if o == nil {
o = &dns.OPT{}
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
m.Extra = append(m.Extra, o)
}
if size > 0 {
o.SetUDPSize(size)
}
return m
}
func queryUpstreamsLimited(
ctx context.Context,
req *dns.Msg,
upstreams []string,
timeout time.Duration,
maxParallel int,
allowTCPFallback bool,
) *dns.Msg {
if maxParallel <= 0 {
maxParallel = 1
}
servers := shuffled(upstreams)
cctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
type result struct{ msg *dns.Msg }
ch := make(chan result, len(servers))
sem := make(chan struct{}, maxParallel)
execOne := func(svr string) {
upReq := clampEDNSForUpstream(req, 1232) // 或做成 flag
resp, _, err := udpClient.ExchangeContext(cctx, upReq, svr)
if err == nil && resp != nil && resp.Truncated && allowTCPFallback {
log.Printf("[upstream] UDP truncated, retry TCP: %s", svr)
tcpClient := *udpClient
tcpClient.Net = "tcp"
resp, _, err = tcpClient.ExchangeContext(cctx, req.Copy(), svr)
}
if err != nil || resp == nil {
if err != nil && cctx.Err() == nil {
log.Printf("[upstream] %s error: %v", svr, err)
}
return
}
if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError {
return
}
select {
case ch <- result{msg: resp}:
case <-cctx.Done():
}
}
for _, s := range servers {
s := s
go func() {
select {
case sem <- struct{}{}:
defer func() { <-sem }()
case <-cctx.Done():
return
}
execOne(s)
}()
}
for i := 0; i < len(servers); i++ {
select {
case r := <-ch:
if r.msg != nil {
cancel()
return r.msg
}
case <-cctx.Done():
log.Printf("[upstream] timeout after %v", timeout)
return nil
}
}
return nil
}
/******************************************************************
* EDNS / 响应构造
******************************************************************/
func stripECS(m *dns.Msg) {
if o := m.IsEdns0(); o != nil {
var kept []dns.EDNS0
for _, e := range o.Option {
if _, isECS := e.(*dns.EDNS0_SUBNET); !isECS {
kept = append(kept, e)
}
}
o.Option = kept
}
}
func getDOFlag(m *dns.Msg) bool {
if o := m.IsEdns0(); o != nil {
return o.Do()
}
return false
}
func writeReply(w dns.ResponseWriter, req *dns.Msg, upstream *dns.Msg, meta *cacheEntry) {
if upstream == nil {
dns.HandleFailed(w, req)
return
}
out := new(dns.Msg)
out.SetReply(req)
out.Authoritative = false
out.RecursionAvailable = upstream.RecursionAvailable
out.AuthenticatedData = upstream.AuthenticatedData
out.CheckingDisabled = req.CheckingDisabled
out.Rcode = upstream.Rcode
out.Answer = upstream.Answer
out.Ns = upstream.Ns
var extras []dns.RR
for _, rr := range upstream.Extra {
if !isPseudo(rr) {
extras = append(extras, rr)
}
}
if ro := req.IsEdns0(); ro != nil {
o := new(dns.OPT)
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
o.SetUDPSize(ro.UDPSize())
if ro.Do() {
o.SetDo()
}
if uo := upstream.IsEdns0(); uo != nil {
o.SetExtendedRcode(uint16(uo.ExtendedRcode()))
o.SetVersion(uint8(uo.Version()))
for _, opt := range uo.Option {
if ede, ok := opt.(*dns.EDNS0_EDE); ok {
o.Option = append(o.Option, ede)
}
}
} else if meta != nil && meta.ednsPresent {
// Upstream/cached msg has no OPT例如缓存时被剥离用缓存元数据重建
o.SetExtendedRcode(meta.ednsExtRcode)
o.SetVersion(meta.ednsVersion)
for _, e := range meta.ednsEDE {
o.Option = append(o.Option, e)
}
}
extras = append(extras, o)
}
out.Extra = extras
out.Compress = true
if err := w.WriteMsg(out); err != nil {
log.Printf("[write] WriteMsg error: %v", err)
}
}
/******************************************************************
* 主处理器
******************************************************************/
func handleDNS(
upstreams []string,
cacheMaxTTL, timeout time.Duration,
maxParallel int,
stripECSBeforeForward bool,
allowTCPFallback bool,
bl *suffixMatcher,
blRcode int,
) dns.HandlerFunc {
return func(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
dns.HandleFailed(w, r)
return
}
q := r.Question[0]
log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)",
dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr())
if stripECSBeforeForward {
stripECS(r)
}
// 黑名单拦截:命中则不查上游,直接返回
if rule, ok := bl.match(q.Name); ok {
nameCanon := dns.CanonicalName(q.Name)
log.Printf("[blacklist] HIT %s rule=%s (no upstream query)", nameCanon, rule)
up := makeBlockedUpstream(blRcode, rule)
writeReply(w, r, up, nil)
return
}
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled)
if cachedMsg, cachedMeta, ok := tryCacheRead(key); ok {
log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
writeReply(w, r, cachedMsg, cachedMeta)
return
}
log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
ctx := context.Background()
resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback)
if resp == nil {
log.Printf("[error] all upstreams failed for %s", q.Name)
dns.HandleFailed(w, r)
return
}
cacheWrite(key, resp, cacheMaxTTL)
for _, ans := range resp.Answer {
log.Printf("[answer] %s", ans.String())
}
writeReply(w, r, resp, nil)
}
}
/******************************************************************
* 主函数
******************************************************************/
func main() {
rand.Seed(time.Now().UnixNano())
certFile := flag.String("cert", "server.crt", "TLS 证书文件路径")
keyFile := flag.String("key", "server.key", "TLS 私钥文件路径")
addr := flag.String("addr", ":853", "DoT 监听地址")
upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表")
cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL")
cacheSizeFlag := flag.Int("cache-size", defaultCacheSize, "LRU 缓存大小上限")
timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时")
maxParallel := flag.Int("max-parallel", 3, "并发上游数量")
stripECSFlag := flag.Bool("strip-ecs", true, "去除 ECS")
allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时 TCP 回退")
blacklistStr := flag.String("blacklist", "", "逗号分隔的黑名单域名(后缀匹配;支持如 *.example.com")
blacklistFile := flag.String("blacklist-file", "", "黑名单文件路径(每行一个域名;支持 # 或 ; 注释;后缀匹配)")
blacklistRcodeFlag := flag.String("blacklist-rcode", "REFUSED", "命中黑名单返回的 RCODEREFUSED|NXDOMAIN|SERVFAIL")
verbose := flag.Bool("v", false, "verbose 日志")
flag.Parse()
initLogger(*verbose)
var err error
cache, err = lru.New[string, *cacheEntry](*cacheSizeFlag)
if err != nil {
log.Fatalf("[fatal] failed to init LRU cache: %v", err)
}
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
if err != nil {
log.Fatalf("[fatal] failed to load cert: %v", err)
}
var upstreams []string
for _, s := range strings.Split(*upstreamStr, ",") {
if t := strings.TrimSpace(s); t != "" {
if !strings.Contains(t, ":") {
t = fmt.Sprintf("%s:53", t)
}
upstreams = append(upstreams, t)
}
}
if len(upstreams) == 0 {
log.Fatal("[fatal] no upstream DNS servers provided")
}
udpClient = &dns.Client{Net: "udp", UDPSize: 4096, SingleInflight: true}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
startCacheCleaner(ctx)
// 加载黑名单规则
bl, blRcode := initBlacklist(ctx, *blacklistStr, *blacklistFile, *blacklistRcodeFlag)
mux := dns.NewServeMux()
mux.HandleFunc(".", handleDNS(
upstreams,
*cacheTTLFlag,
*timeoutFlag,
*maxParallel,
*stripECSFlag,
*allowTCPFallback,
bl,
blRcode,
))
srv := &dns.Server{
Addr: *addr,
Net: "tcp-tls",
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
NextProtos: []string{"dot"},
},
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
stop := make(chan os.Signal, 1)
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
errCh := make(chan error, 1)
go func() {
log.Printf("🚀 starting DNS-over-TLS on %s", *addr)
log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v | blacklist_rules=%d | blacklist_rcode=%s",
upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback, len(bl.rules), strings.ToUpper(*blacklistRcodeFlag))
errCh <- srv.ListenAndServe()
}()
select {
case sig := <-stop:
log.Printf("[shutdown] caught signal: %s", sig)
cancel()
if err := srv.Shutdown(); err != nil {
log.Printf("[shutdown] server shutdown error: %v", err)
}
case err := <-errCh:
if err != nil {
log.Fatalf("[fatal] server error: %v", err)
}
}
log.Println("[bye] server stopped.")
}