```
docs(readme): 重构 README 文档结构与内容以提升可读性 - 更新项目标题图标并优化描述语句 - 重新组织特性列表为表格形式,增加 ECS 剥离、黑名单过滤等功能说明 - 补充快速开始章节,细化源码构建与 Docker 使用方式 - 调整参数说明表,新增黑名单相关配置项及缓存条目限制 - 增加缓存机制详解、黑名单功能使用示例与架构图 - 更新开发依赖信息与推荐编译参数 - 修正作者信息展示格式并添加仓库链接 feat(cache): 改进缓存键生成逻辑与 EDNS 元数据处理 - 使用 dns.CanonicalName 规范化域名避免重复缓存键 - 缓存条目中保存 EDNS 扩展信息(version, rcode, EDE) - 修复缓存读取函数返回值,传递完整缓存元数据 - 调整 TTL 计算优先级,仅在必要时检查 Extra 区域 - 黑名单匹配提前拦截请求,跳过上游查询 - 启动日志中显示黑名单规则数量与返回码设置 ```
This commit is contained in:
120
main.go
120
main.go
@@ -37,6 +37,11 @@ func initLogger(verbose bool) {
|
||||
type cacheEntry struct {
|
||||
msg *dns.Msg
|
||||
expireAt time.Time
|
||||
// EDNS metadata (to reproduce Extended RCODE / EDE on cache hits)
|
||||
ednsPresent bool
|
||||
ednsVersion uint8
|
||||
ednsExtRcode uint16
|
||||
ednsEDE []*dns.EDNS0_EDE
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -92,7 +97,8 @@ func startCacheCleaner(ctx context.Context) {
|
||||
func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
|
||||
var b strings.Builder
|
||||
b.Grow(len(q.Name) + 32)
|
||||
b.WriteString(strings.ToLower(q.Name))
|
||||
// 采用规范化域名,避免尾随点/IDNA/大小写造成的重复键
|
||||
b.WriteString(dns.CanonicalName(q.Name))
|
||||
b.WriteString("|T=")
|
||||
b.WriteString(dns.TypeToString[q.Qtype])
|
||||
b.WriteString("|C=")
|
||||
@@ -115,20 +121,43 @@ func isPseudo(rr dns.RR) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// 读取缓存(修复:Get 在写锁下;在锁外调整 TTL)
|
||||
func tryCacheRead(key string) (*dns.Msg, bool) {
|
||||
// clone and extract EDNS metadata (present, version, ext-rcode, all EDEs)
|
||||
func cloneEDE(in *dns.EDNS0_EDE) *dns.EDNS0_EDE {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
cp := *in
|
||||
return &cp
|
||||
}
|
||||
|
||||
func extractEDNSMeta(m *dns.Msg) (present bool, version uint8, ext uint16, ede []*dns.EDNS0_EDE) {
|
||||
if o := m.IsEdns0(); o != nil {
|
||||
present = true
|
||||
version = o.Version()
|
||||
ext = uint16(o.ExtendedRcode())
|
||||
for _, opt := range o.Option {
|
||||
if e, ok := opt.(*dns.EDNS0_EDE); ok {
|
||||
ede = append(ede, cloneEDE(e))
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 读取缓存(修复:Get 在写锁下;在锁外调整 TTL;返回 EDNS 元数据)
|
||||
func tryCacheRead(key string) (*dns.Msg, *cacheEntry, bool) {
|
||||
now := time.Now()
|
||||
|
||||
cacheMutex.Lock()
|
||||
e, ok := cache.Get(key) // Get 会更新 LRU,必须在写锁下
|
||||
if !ok {
|
||||
cacheMutex.Unlock()
|
||||
return nil, false
|
||||
return nil, nil, false
|
||||
}
|
||||
if now.After(e.expireAt) {
|
||||
cache.Remove(key)
|
||||
cacheMutex.Unlock()
|
||||
return nil, false
|
||||
return nil, nil, false
|
||||
}
|
||||
// 拷贝副本,在锁外改 TTL,减少临界区时间
|
||||
out := e.msg.Copy()
|
||||
@@ -140,7 +169,7 @@ func tryCacheRead(key string) (*dns.Msg, bool) {
|
||||
cacheMutex.Lock()
|
||||
cache.Remove(key)
|
||||
cacheMutex.Unlock()
|
||||
return nil, false
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
for _, sec := range [][]dns.RR{out.Answer, out.Ns, out.Extra} {
|
||||
@@ -153,7 +182,7 @@ func tryCacheRead(key string) (*dns.Msg, bool) {
|
||||
}
|
||||
}
|
||||
}
|
||||
return out, true
|
||||
return out, e, true
|
||||
}
|
||||
|
||||
// 计算负面 TTL
|
||||
@@ -197,7 +226,6 @@ func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
|
||||
}
|
||||
if soa == nil {
|
||||
// 建议:无 SOA 时不做负面缓存(返回 0,false)
|
||||
// 如你更希望兜底,可改成:return uint32(maxTTL.Seconds()), true
|
||||
return 0, false
|
||||
}
|
||||
|
||||
@@ -216,7 +244,8 @@ func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
|
||||
func minRRsetTTL(m *dns.Msg) (uint32, bool) {
|
||||
minTTL := uint32(0)
|
||||
hasTTL := false
|
||||
for _, sec := range [][]dns.RR{m.Answer, m.Ns, m.Extra} {
|
||||
// 优先 Answer -> Ns;若都为空,再考虑 Extra(排除伪记录)
|
||||
for _, sec := range [][]dns.RR{m.Answer, m.Ns} {
|
||||
for _, rr := range sec {
|
||||
if isPseudo(rr) {
|
||||
continue
|
||||
@@ -228,6 +257,18 @@ func minRRsetTTL(m *dns.Msg) (uint32, bool) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasTTL {
|
||||
for _, rr := range m.Extra {
|
||||
if isPseudo(rr) {
|
||||
continue
|
||||
}
|
||||
ttl := rr.Header().Ttl
|
||||
if !hasTTL || ttl < minTTL {
|
||||
minTTL = ttl
|
||||
hasTTL = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return minTTL, hasTTL
|
||||
}
|
||||
|
||||
@@ -245,7 +286,7 @@ func stripPseudoExtras(m *dns.Msg) {
|
||||
m.Extra = out
|
||||
}
|
||||
|
||||
// 写缓存
|
||||
// 写缓存(保存 EDNS 元数据,命中时可重建扩展 RCODE/EDE)
|
||||
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
|
||||
if in == nil {
|
||||
return
|
||||
@@ -272,9 +313,18 @@ func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
|
||||
}
|
||||
expire := time.Now().Add(time.Duration(ttl) * time.Second)
|
||||
cp := in.Copy()
|
||||
// 提取 EDNS 元数据后再剥离伪记录
|
||||
present, ver, ext, ede := extractEDNSMeta(cp)
|
||||
stripPseudoExtras(cp)
|
||||
cacheMutex.Lock()
|
||||
cache.Add(key, &cacheEntry{msg: cp, expireAt: expire})
|
||||
cache.Add(key, &cacheEntry{
|
||||
msg: cp,
|
||||
expireAt: expire,
|
||||
ednsPresent: present,
|
||||
ednsVersion: ver,
|
||||
ednsExtRcode: ext,
|
||||
ednsEDE: ede,
|
||||
})
|
||||
cacheMutex.Unlock()
|
||||
}
|
||||
|
||||
@@ -401,7 +451,7 @@ func getDOFlag(m *dns.Msg) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) {
|
||||
func writeReply(w dns.ResponseWriter, req *dns.Msg, upstream *dns.Msg, meta *cacheEntry) {
|
||||
if upstream == nil {
|
||||
dns.HandleFailed(w, req)
|
||||
return
|
||||
@@ -439,6 +489,13 @@ func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) {
|
||||
o.Option = append(o.Option, ede)
|
||||
}
|
||||
}
|
||||
} else if meta != nil && meta.ednsPresent {
|
||||
// Upstream/cached msg has no OPT(例如缓存时被剥离),用缓存元数据重建
|
||||
o.SetExtendedRcode(meta.ednsExtRcode)
|
||||
o.SetVersion(meta.ednsVersion)
|
||||
for _, e := range meta.ednsEDE {
|
||||
o.Option = append(o.Option, e)
|
||||
}
|
||||
}
|
||||
extras = append(extras, o)
|
||||
}
|
||||
@@ -459,6 +516,8 @@ func handleDNS(
|
||||
maxParallel int,
|
||||
stripECSBeforeForward bool,
|
||||
allowTCPFallback bool,
|
||||
bl *suffixMatcher,
|
||||
blRcode int,
|
||||
) dns.HandlerFunc {
|
||||
return func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
@@ -472,11 +531,21 @@ func handleDNS(
|
||||
if stripECSBeforeForward {
|
||||
stripECS(r)
|
||||
}
|
||||
|
||||
// 黑名单拦截:命中则不查上游,直接返回
|
||||
if rule, ok := bl.match(q.Name); ok {
|
||||
nameCanon := dns.CanonicalName(q.Name)
|
||||
log.Printf("[blacklist] HIT %s rule=%s (no upstream query)", nameCanon, rule)
|
||||
up := makeBlockedUpstream(blRcode, rule)
|
||||
writeReply(w, r, up, nil)
|
||||
return
|
||||
}
|
||||
|
||||
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled)
|
||||
|
||||
if cached, ok := tryCacheRead(key); ok {
|
||||
if cachedMsg, cachedMeta, ok := tryCacheRead(key); ok {
|
||||
log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
|
||||
writeReply(w, r, cached)
|
||||
writeReply(w, r, cachedMsg, cachedMeta)
|
||||
return
|
||||
}
|
||||
log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
|
||||
@@ -492,7 +561,7 @@ func handleDNS(
|
||||
for _, ans := range resp.Answer {
|
||||
log.Printf("[answer] %s", ans.String())
|
||||
}
|
||||
writeReply(w, r, resp)
|
||||
writeReply(w, r, resp, nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -512,6 +581,9 @@ func main() {
|
||||
maxParallel := flag.Int("max-parallel", 3, "并发上游数量")
|
||||
stripECSFlag := flag.Bool("strip-ecs", true, "去除 ECS")
|
||||
allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时 TCP 回退")
|
||||
blacklistStr := flag.String("blacklist", "", "逗号分隔的黑名单域名(后缀匹配;支持如 *.example.com)")
|
||||
blacklistFile := flag.String("blacklist-file", "", "黑名单文件路径(每行一个域名;支持 # 或 ; 注释;后缀匹配)")
|
||||
blacklistRcodeFlag := flag.String("blacklist-rcode", "REFUSED", "命中黑名单返回的 RCODE:REFUSED|NXDOMAIN|SERVFAIL")
|
||||
verbose := flag.Bool("v", false, "verbose 日志")
|
||||
flag.Parse()
|
||||
|
||||
@@ -548,8 +620,20 @@ func main() {
|
||||
|
||||
startCacheCleaner(ctx)
|
||||
|
||||
// 加载黑名单规则
|
||||
bl, blRcode := initBlacklist(ctx, *blacklistStr, *blacklistFile, *blacklistRcodeFlag)
|
||||
|
||||
mux := dns.NewServeMux()
|
||||
mux.HandleFunc(".", handleDNS(upstreams, *cacheTTLFlag, *timeoutFlag, *maxParallel, *stripECSFlag, *allowTCPFallback))
|
||||
mux.HandleFunc(".", handleDNS(
|
||||
upstreams,
|
||||
*cacheTTLFlag,
|
||||
*timeoutFlag,
|
||||
*maxParallel,
|
||||
*stripECSFlag,
|
||||
*allowTCPFallback,
|
||||
bl,
|
||||
blRcode,
|
||||
))
|
||||
|
||||
srv := &dns.Server{
|
||||
Addr: *addr,
|
||||
@@ -570,8 +654,8 @@ func main() {
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
log.Printf("🚀 starting DNS-over-TLS on %s", *addr)
|
||||
log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v",
|
||||
upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback)
|
||||
log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v | blacklist_rules=%d | blacklist_rcode=%s",
|
||||
upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback, len(bl.rules), strings.ToUpper(*blacklistRcodeFlag))
|
||||
errCh <- srv.ListenAndServe()
|
||||
}()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user