docs(readme): 重构 README 文档结构与内容以提升可读性

- 更新项目标题图标并优化描述语句
- 重新组织特性列表为表格形式,增加 ECS 剥离、黑名单过滤等功能说明
- 补充快速开始章节,细化源码构建与 Docker 使用方式
- 调整参数说明表,新增黑名单相关配置项及缓存条目限制
- 增加缓存机制详解、黑名单功能使用示例与架构图
- 更新开发依赖信息与推荐编译参数
- 修正作者信息展示格式并添加仓库链接

feat(cache): 改进缓存键生成逻辑与 EDNS 元数据处理

- 使用 dns.CanonicalName 规范化域名避免重复缓存键
- 缓存条目中保存 EDNS 扩展信息(version, rcode, EDE)
- 修复缓存读取函数返回值,传递完整缓存元数据
- 调整 TTL 计算优先级,仅在必要时检查 Extra 区域
- 黑名单匹配提前拦截请求,跳过上游查询
- 启动日志中显示黑名单规则数量与返回码设置
```
This commit is contained in:
2025-10-15 14:19:55 +08:00
parent 916a7c8127
commit 4060e83686
5 changed files with 404 additions and 90 deletions

120
main.go
View File

@@ -37,6 +37,11 @@ func initLogger(verbose bool) {
type cacheEntry struct {
msg *dns.Msg
expireAt time.Time
// EDNS metadata (to reproduce Extended RCODE / EDE on cache hits)
ednsPresent bool
ednsVersion uint8
ednsExtRcode uint16
ednsEDE []*dns.EDNS0_EDE
}
var (
@@ -92,7 +97,8 @@ func startCacheCleaner(ctx context.Context) {
func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
var b strings.Builder
b.Grow(len(q.Name) + 32)
b.WriteString(strings.ToLower(q.Name))
// 采用规范化域名,避免尾随点/IDNA/大小写造成的重复键
b.WriteString(dns.CanonicalName(q.Name))
b.WriteString("|T=")
b.WriteString(dns.TypeToString[q.Qtype])
b.WriteString("|C=")
@@ -115,20 +121,43 @@ func isPseudo(rr dns.RR) bool {
}
}
// 读取缓存修复Get 在写锁下;在锁外调整 TTL
func tryCacheRead(key string) (*dns.Msg, bool) {
// clone and extract EDNS metadata (present, version, ext-rcode, all EDEs)
func cloneEDE(in *dns.EDNS0_EDE) *dns.EDNS0_EDE {
if in == nil {
return nil
}
cp := *in
return &cp
}
func extractEDNSMeta(m *dns.Msg) (present bool, version uint8, ext uint16, ede []*dns.EDNS0_EDE) {
if o := m.IsEdns0(); o != nil {
present = true
version = o.Version()
ext = uint16(o.ExtendedRcode())
for _, opt := range o.Option {
if e, ok := opt.(*dns.EDNS0_EDE); ok {
ede = append(ede, cloneEDE(e))
}
}
}
return
}
// 读取缓存修复Get 在写锁下;在锁外调整 TTL返回 EDNS 元数据)
func tryCacheRead(key string) (*dns.Msg, *cacheEntry, bool) {
now := time.Now()
cacheMutex.Lock()
e, ok := cache.Get(key) // Get 会更新 LRU必须在写锁下
if !ok {
cacheMutex.Unlock()
return nil, false
return nil, nil, false
}
if now.After(e.expireAt) {
cache.Remove(key)
cacheMutex.Unlock()
return nil, false
return nil, nil, false
}
// 拷贝副本,在锁外改 TTL减少临界区时间
out := e.msg.Copy()
@@ -140,7 +169,7 @@ func tryCacheRead(key string) (*dns.Msg, bool) {
cacheMutex.Lock()
cache.Remove(key)
cacheMutex.Unlock()
return nil, false
return nil, nil, false
}
for _, sec := range [][]dns.RR{out.Answer, out.Ns, out.Extra} {
@@ -153,7 +182,7 @@ func tryCacheRead(key string) (*dns.Msg, bool) {
}
}
}
return out, true
return out, e, true
}
// 计算负面 TTL
@@ -197,7 +226,6 @@ func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
}
if soa == nil {
// 建议:无 SOA 时不做负面缓存(返回 0,false
// 如你更希望兜底可改成return uint32(maxTTL.Seconds()), true
return 0, false
}
@@ -216,7 +244,8 @@ func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
func minRRsetTTL(m *dns.Msg) (uint32, bool) {
minTTL := uint32(0)
hasTTL := false
for _, sec := range [][]dns.RR{m.Answer, m.Ns, m.Extra} {
// 优先 Answer -> Ns若都为空再考虑 Extra排除伪记录
for _, sec := range [][]dns.RR{m.Answer, m.Ns} {
for _, rr := range sec {
if isPseudo(rr) {
continue
@@ -228,6 +257,18 @@ func minRRsetTTL(m *dns.Msg) (uint32, bool) {
}
}
}
if !hasTTL {
for _, rr := range m.Extra {
if isPseudo(rr) {
continue
}
ttl := rr.Header().Ttl
if !hasTTL || ttl < minTTL {
minTTL = ttl
hasTTL = true
}
}
}
return minTTL, hasTTL
}
@@ -245,7 +286,7 @@ func stripPseudoExtras(m *dns.Msg) {
m.Extra = out
}
// 写缓存
// 写缓存(保存 EDNS 元数据,命中时可重建扩展 RCODE/EDE
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
if in == nil {
return
@@ -272,9 +313,18 @@ func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
}
expire := time.Now().Add(time.Duration(ttl) * time.Second)
cp := in.Copy()
// 提取 EDNS 元数据后再剥离伪记录
present, ver, ext, ede := extractEDNSMeta(cp)
stripPseudoExtras(cp)
cacheMutex.Lock()
cache.Add(key, &cacheEntry{msg: cp, expireAt: expire})
cache.Add(key, &cacheEntry{
msg: cp,
expireAt: expire,
ednsPresent: present,
ednsVersion: ver,
ednsExtRcode: ext,
ednsEDE: ede,
})
cacheMutex.Unlock()
}
@@ -401,7 +451,7 @@ func getDOFlag(m *dns.Msg) bool {
return false
}
func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) {
func writeReply(w dns.ResponseWriter, req *dns.Msg, upstream *dns.Msg, meta *cacheEntry) {
if upstream == nil {
dns.HandleFailed(w, req)
return
@@ -439,6 +489,13 @@ func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) {
o.Option = append(o.Option, ede)
}
}
} else if meta != nil && meta.ednsPresent {
// Upstream/cached msg has no OPT例如缓存时被剥离用缓存元数据重建
o.SetExtendedRcode(meta.ednsExtRcode)
o.SetVersion(meta.ednsVersion)
for _, e := range meta.ednsEDE {
o.Option = append(o.Option, e)
}
}
extras = append(extras, o)
}
@@ -459,6 +516,8 @@ func handleDNS(
maxParallel int,
stripECSBeforeForward bool,
allowTCPFallback bool,
bl *suffixMatcher,
blRcode int,
) dns.HandlerFunc {
return func(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
@@ -472,11 +531,21 @@ func handleDNS(
if stripECSBeforeForward {
stripECS(r)
}
// 黑名单拦截:命中则不查上游,直接返回
if rule, ok := bl.match(q.Name); ok {
nameCanon := dns.CanonicalName(q.Name)
log.Printf("[blacklist] HIT %s rule=%s (no upstream query)", nameCanon, rule)
up := makeBlockedUpstream(blRcode, rule)
writeReply(w, r, up, nil)
return
}
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled)
if cached, ok := tryCacheRead(key); ok {
if cachedMsg, cachedMeta, ok := tryCacheRead(key); ok {
log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
writeReply(w, r, cached)
writeReply(w, r, cachedMsg, cachedMeta)
return
}
log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
@@ -492,7 +561,7 @@ func handleDNS(
for _, ans := range resp.Answer {
log.Printf("[answer] %s", ans.String())
}
writeReply(w, r, resp)
writeReply(w, r, resp, nil)
}
}
@@ -512,6 +581,9 @@ func main() {
maxParallel := flag.Int("max-parallel", 3, "并发上游数量")
stripECSFlag := flag.Bool("strip-ecs", true, "去除 ECS")
allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时 TCP 回退")
blacklistStr := flag.String("blacklist", "", "逗号分隔的黑名单域名(后缀匹配;支持如 *.example.com")
blacklistFile := flag.String("blacklist-file", "", "黑名单文件路径(每行一个域名;支持 # 或 ; 注释;后缀匹配)")
blacklistRcodeFlag := flag.String("blacklist-rcode", "REFUSED", "命中黑名单返回的 RCODEREFUSED|NXDOMAIN|SERVFAIL")
verbose := flag.Bool("v", false, "verbose 日志")
flag.Parse()
@@ -548,8 +620,20 @@ func main() {
startCacheCleaner(ctx)
// 加载黑名单规则
bl, blRcode := initBlacklist(ctx, *blacklistStr, *blacklistFile, *blacklistRcodeFlag)
mux := dns.NewServeMux()
mux.HandleFunc(".", handleDNS(upstreams, *cacheTTLFlag, *timeoutFlag, *maxParallel, *stripECSFlag, *allowTCPFallback))
mux.HandleFunc(".", handleDNS(
upstreams,
*cacheTTLFlag,
*timeoutFlag,
*maxParallel,
*stripECSFlag,
*allowTCPFallback,
bl,
blRcode,
))
srv := &dns.Server{
Addr: *addr,
@@ -570,8 +654,8 @@ func main() {
errCh := make(chan error, 1)
go func() {
log.Printf("🚀 starting DNS-over-TLS on %s", *addr)
log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v",
upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback)
log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v | blacklist_rules=%d | blacklist_rcode=%s",
upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback, len(bl.rules), strings.ToUpper(*blacklistRcodeFlag))
errCh <- srv.ListenAndServe()
}()