```
feat(cache): 实现并发安全的 DNS 响应缓存机制 新增基于 sync.Map 的缓存结构,支持 TTL 管理与定时清理。 完善缓存键生成逻辑,包含 DO/CD 标志避免上下文污染。 增强缓存读写处理: - 自动跳过伪 RR(OPT/TSIG)防止干扰 TTL 计算 - 写入前剥离传输层细节提升通用性 - 支持负面缓存(NXDOMAIN/NODATA)并遵循 RFC 2308 - 回填剩余 TTL 并确保对外不可变 优化上游查询模块: - 并发向多个上游竞速,支持 UDP 截断后 TCP 回退 - 过滤不良 RCODE(SERVFAIL/REFUSED/FORMERR)提升稳定性 - 使用信号量控制最大并发数,改善资源利用率 - 快速失败机制减少无效等待 其他改进: - 完善日志记录,区分缓存命中/未命中及响应内容 - 显式构建客户端兼容的 EDNS0 选项 - 增加注释说明关键设计决策和行为边界 ```
This commit is contained in:
198
main.go
198
main.go
@@ -21,6 +21,8 @@ import (
|
||||
* 日志初始化
|
||||
******************************************************************/
|
||||
|
||||
// initLogger 根据 -v 开关设置日志格式;verbose=true 时附加源码文件与行号。
|
||||
// 说明:Lmicroseconds 便于排查毫秒级时序问题。
|
||||
func initLogger(verbose bool) {
|
||||
flags := log.Ldate | log.Ltime | log.Lmicroseconds
|
||||
if verbose {
|
||||
@@ -33,15 +35,21 @@ func initLogger(verbose bool) {
|
||||
* 缓存结构
|
||||
******************************************************************/
|
||||
|
||||
// cacheEntry 表示一条缓存项:保存上游响应的副本,以及过期时间。
|
||||
// 注:将完整 *dns.Msg 存入缓存,便于原样复用 Answer/Ns/Extra。
|
||||
type cacheEntry struct {
|
||||
msg *dns.Msg // 上游完整响应(拷贝存储)
|
||||
expireAt time.Time // 过期时间
|
||||
}
|
||||
|
||||
// cache 使用 sync.Map 作为并发安全的键值存储。
|
||||
// 注:此实现为 TTL 驱动的简单缓存;如需上限/淘汰策略,可叠加 LRU。
|
||||
var cache sync.Map
|
||||
|
||||
const cacheCleanupInterval = 5 * time.Minute
|
||||
|
||||
// startCacheCleaner 定时清理过期项,避免缓存无限增长。
|
||||
// 使用 ctx 控制生命周期,与主服务一同退出。
|
||||
func startCacheCleaner(ctx context.Context) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(cacheCleanupInterval)
|
||||
@@ -69,7 +77,8 @@ func startCacheCleaner(ctx context.Context) {
|
||||
}()
|
||||
}
|
||||
|
||||
// 计算缓存键:name + type + class + DO + CD
|
||||
// 计算缓存键:name + type + class + DO + CD。
|
||||
// 说明:包含 DO/CD 可避免不同验证上下文污染同一键。
|
||||
func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
|
||||
var b strings.Builder
|
||||
b.Grow(len(q.Name) + 32)
|
||||
@@ -87,7 +96,19 @@ func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// 命中缓存:回填剩余 TTL(不超过记录自身 TTL)
|
||||
// 识别伪 RR(OPT/TSIG);这些记录的“TTL 字段”并非真实 TTL,不可参与 TTL 计算或改写。
|
||||
// OPT:其 TTL 字段承载扩展 RCODE 与 DO 位等标志;TSIG:签名,不应缓存或改写。
|
||||
func isPseudo(rr dns.RR) bool {
|
||||
switch rr.(type) {
|
||||
case *dns.OPT, *dns.TSIG:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// tryCacheRead 尝试读取缓存并回填“剩余 TTL”;对 Answer/Ns/Extra 普通 RR 截断 TTL,跳过伪 RR。
|
||||
// 返回响应副本,保证对外不可变(防止调用方修改缓存内部对象)。
|
||||
func tryCacheRead(key string) (*dns.Msg, bool) {
|
||||
v, ok := cache.Load(key)
|
||||
if !ok {
|
||||
@@ -105,6 +126,7 @@ func tryCacheRead(key string) (*dns.Msg, bool) {
|
||||
cache.Delete(key)
|
||||
return nil, false
|
||||
}
|
||||
// 回填剩余 TTL(不增加,只做上限截断)
|
||||
for i := range out.Answer {
|
||||
if out.Answer[i].Header().Ttl > remaining {
|
||||
out.Answer[i].Header().Ttl = remaining
|
||||
@@ -116,6 +138,9 @@ func tryCacheRead(key string) (*dns.Msg, bool) {
|
||||
}
|
||||
}
|
||||
for i := range out.Extra {
|
||||
if isPseudo(out.Extra[i]) {
|
||||
continue
|
||||
}
|
||||
if out.Extra[i].Header().Ttl > remaining {
|
||||
out.Extra[i].Header().Ttl = remaining
|
||||
}
|
||||
@@ -123,13 +148,15 @@ func tryCacheRead(key string) (*dns.Msg, bool) {
|
||||
return out, true
|
||||
}
|
||||
|
||||
// 负面缓存 TTL(RFC 2308):NXDOMAIN 或 NODATA 使用 SOA.MINIMUM 与 SOA TTL 的较小者,再与配置上限取 min
|
||||
// negativeTTL 依据 RFC 2308 计算负面缓存 TTL。
|
||||
// 对 NXDOMAIN 或 NODATA(NOERROR + Answer 为空)取 min(SOA.TTL, SOA.MINIMUM),再与配置上限取 min。
|
||||
func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
|
||||
// NXDOMAIN,或 NOERROR 但 Answer 为空(NODATA)
|
||||
if m.Rcode != dns.RcodeNameError && !(m.Rcode == dns.RcodeSuccess && len(m.Answer) == 0) {
|
||||
return 0, false
|
||||
}
|
||||
var soa *dns.SOA
|
||||
// SOA 可能出现在 Authority(Ns) 或 Additional(Extra)
|
||||
for _, rr := range append(m.Ns, m.Extra...) {
|
||||
if s, ok := rr.(*dns.SOA); ok {
|
||||
soa = s
|
||||
@@ -137,6 +164,7 @@ func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
|
||||
}
|
||||
}
|
||||
if soa == nil {
|
||||
// 无 SOA 无法可靠计算负面 TTL;此时不缓存或由上限兜底(本实现选择不缓存)
|
||||
return 0, false
|
||||
}
|
||||
ttl := soa.Hdr.Ttl
|
||||
@@ -150,12 +178,16 @@ func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
|
||||
return ttl, ttl > 0
|
||||
}
|
||||
|
||||
// 普通(正向)响应的最小 TTL(Answer/Ns/Extra)
|
||||
// minRRsetTTL 获取普通(正向)响应的最小 TTL(Answer/Ns/Extra 中的普通 RR),跳过伪 RR。
|
||||
// 用于决定缓存过期时间的上限(与配置上限再取 min)。
|
||||
func minRRsetTTL(m *dns.Msg) (uint32, bool) {
|
||||
minTTL := uint32(0)
|
||||
hasTTL := false
|
||||
for _, sec := range [][]dns.RR{m.Answer, m.Ns, m.Extra} {
|
||||
for _, rr := range sec {
|
||||
if isPseudo(rr) {
|
||||
continue
|
||||
}
|
||||
ttl := rr.Header().Ttl
|
||||
if !hasTTL || ttl < minTTL {
|
||||
minTTL = ttl
|
||||
@@ -166,28 +198,49 @@ func minRRsetTTL(m *dns.Msg) (uint32, bool) {
|
||||
return minTTL, hasTTL
|
||||
}
|
||||
|
||||
// 写缓存:先处理负面缓存,再处理正面缓存
|
||||
// stripPseudoExtras 从消息中剥离伪 RR(OPT/TSIG)。
|
||||
// 用途:缓存前剥离,避免将传输层细节或签名内容写入缓存。
|
||||
func stripPseudoExtras(m *dns.Msg) {
|
||||
if len(m.Extra) == 0 {
|
||||
return
|
||||
}
|
||||
out := m.Extra[:0]
|
||||
for _, rr := range m.Extra {
|
||||
if isPseudo(rr) {
|
||||
continue
|
||||
}
|
||||
out = append(out, rr)
|
||||
}
|
||||
m.Extra = out
|
||||
}
|
||||
|
||||
// cacheWrite 写缓存:优先处理负面缓存;正面缓存取 min(应答中最小 TTL, 配置上限)。
|
||||
// 写入前统一剥离伪 RR,保证缓存与传输解耦。
|
||||
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
|
||||
if in == nil {
|
||||
return
|
||||
}
|
||||
// 仅缓存 NOERROR / NXDOMAIN,其余不缓存
|
||||
// 仅缓存 NOERROR / NXDOMAIN,其余不缓存(如 SERVFAIL/REFUSED 等)
|
||||
if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError {
|
||||
return
|
||||
}
|
||||
// 负面缓存
|
||||
if ttl, ok := negativeTTL(in, maxTTL); ok && ttl > 0 {
|
||||
expire := time.Now().Add(time.Duration(ttl) * time.Second)
|
||||
cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire})
|
||||
cp := in.Copy()
|
||||
stripPseudoExtras(cp)
|
||||
cache.Store(key, &cacheEntry{msg: cp, expireAt: expire})
|
||||
return
|
||||
}
|
||||
// 正向缓存:minTTL 与 maxTTL 取较小
|
||||
minTTL, ok := minRRsetTTL(in)
|
||||
if !ok {
|
||||
// 没有 TTL 时可用上限兜底(也可选择不缓存)
|
||||
// 没有 TTL 时可用上限兜底(也可选择不缓存,这里选择兜底)
|
||||
if maxTTL > 0 {
|
||||
expire := time.Now().Add(maxTTL)
|
||||
cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire})
|
||||
cp := in.Copy()
|
||||
stripPseudoExtras(cp)
|
||||
cache.Store(key, &cacheEntry{msg: cp, expireAt: expire})
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -200,16 +253,19 @@ func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
|
||||
return
|
||||
}
|
||||
expire := time.Now().Add(time.Duration(finalTTL) * time.Second)
|
||||
cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire})
|
||||
cp := in.Copy()
|
||||
stripPseudoExtras(cp)
|
||||
cache.Store(key, &cacheEntry{msg: cp, expireAt: expire})
|
||||
}
|
||||
|
||||
/******************************************************************
|
||||
* 上游查询(带 context 取消、并发上限、UDP→TCP 回退)
|
||||
******************************************************************/
|
||||
|
||||
// 全局可复用 UDP 客户端
|
||||
// 全局可复用 UDP 客户端(Net=udp,UDPSize 放大以承载更大的响应)
|
||||
var udpClient *dns.Client
|
||||
|
||||
// shuffled 打乱上游列表,避免固定顺序导致单点拥塞或偏置。
|
||||
func shuffled(xs []string) []string {
|
||||
out := make([]string, len(xs))
|
||||
copy(out, xs)
|
||||
@@ -217,6 +273,10 @@ func shuffled(xs []string) []string {
|
||||
return out
|
||||
}
|
||||
|
||||
// queryUpstreamsLimited 并发向多个上游发起查询,返回首个有效结果。
|
||||
// - timeout:整个查询窗口的上限(基于子 context)
|
||||
// - maxParallel:同时在飞请求上限
|
||||
// - allowTCPFallback:若 UDP 截断(TC 位)则回退 TCP 重试
|
||||
func queryUpstreamsLimited(
|
||||
ctx context.Context,
|
||||
req *dns.Msg,
|
||||
@@ -230,28 +290,29 @@ func queryUpstreamsLimited(
|
||||
}
|
||||
servers := shuffled(upstreams)
|
||||
|
||||
// 每次查询一个带超时的子 context;拿到首个有效结果后取消。
|
||||
// 每次查询一个带超时的子 context;拿到首个有效结果后 cancel,取消其他请求。
|
||||
cctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
type result struct {
|
||||
msg *dns.Msg
|
||||
}
|
||||
ch := make(chan result, len(servers))
|
||||
sem := make(chan struct{}, maxParallel)
|
||||
ch := make(chan result, len(servers)) // 缓冲至多保存所有返回,避免阻塞
|
||||
sem := make(chan struct{}, maxParallel) // 并发信号量,限制同时在飞请求数
|
||||
|
||||
// execOne 对单个上游发起查询;对 req 使用 Copy() 防止在多 goroutine 下共享同一 *dns.Msg。
|
||||
execOne := func(svr string) {
|
||||
defer func() { <-sem }()
|
||||
// UDP 查询(带 context)
|
||||
resp, _, err := udpClient.ExchangeContext(cctx, req, svr)
|
||||
// UDP 查询(带 context);使用 req.Copy() 防止并发读写
|
||||
resp, _, err := udpClient.ExchangeContext(cctx, req.Copy(), svr)
|
||||
if err == nil && resp != nil && resp.Truncated && allowTCPFallback {
|
||||
// TCP 回退
|
||||
// TCP 回退(响应被截断)
|
||||
log.Printf("[upstream] UDP truncated, retry TCP: %s", svr)
|
||||
tcpClient := *udpClient
|
||||
tcpClient.Net = "tcp"
|
||||
resp, _, err = tcpClient.ExchangeContext(cctx, req, svr)
|
||||
resp, _, err = tcpClient.ExchangeContext(cctx, req.Copy(), svr)
|
||||
}
|
||||
if err != nil || resp == nil {
|
||||
// 错误仅在未被上层取消时打印,避免超时后噪声
|
||||
if err != nil {
|
||||
if cctx.Err() == nil {
|
||||
log.Printf("[upstream] %s error: %v", svr, err)
|
||||
@@ -261,8 +322,8 @@ func queryUpstreamsLimited(
|
||||
}
|
||||
return
|
||||
}
|
||||
// 丢弃 SERVFAIL
|
||||
if resp.Rcode == dns.RcodeServerFailure {
|
||||
// 过滤差 RCODE(不参与竞速):SERVFAIL / REFUSED / FORMERR
|
||||
if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError {
|
||||
return
|
||||
}
|
||||
select {
|
||||
@@ -271,12 +332,21 @@ func queryUpstreamsLimited(
|
||||
}
|
||||
}
|
||||
|
||||
// “快乐眼球”式启动:生产者不阻塞,获取配额在 goroutine 内部完成;这样可以更快进入接收循环并在首个成功后取消其余。
|
||||
for _, s := range servers {
|
||||
sem <- struct{}{}
|
||||
go execOne(s)
|
||||
s := s
|
||||
go func() {
|
||||
select {
|
||||
case sem <- struct{}{}:
|
||||
defer func() { <-sem }()
|
||||
case <-cctx.Done():
|
||||
return
|
||||
}
|
||||
execOne(s)
|
||||
}()
|
||||
}
|
||||
|
||||
// 返回第一个非空结果,并 cancel 其他 goroutine
|
||||
// 返回第一个非空结果,并 cancel 其他 goroutine。
|
||||
for i := 0; i < len(servers); i++ {
|
||||
select {
|
||||
case r := <-ch:
|
||||
@@ -296,7 +366,8 @@ func queryUpstreamsLimited(
|
||||
* EDNS(0) / ECS 处理
|
||||
******************************************************************/
|
||||
|
||||
// 去除 EDNS Client Subnet(避免缓存污染与隐私泄露)
|
||||
// stripECS 从请求中去除 EDNS Client Subnet(ECS),减少缓存污染并保护隐私。
|
||||
// 注:ECS 会导致上游按地理/网络分区返回不同答案,不适合集中缓存。
|
||||
func stripECS(m *dns.Msg) {
|
||||
if o := m.IsEdns0(); o != nil {
|
||||
var kept []dns.EDNS0
|
||||
@@ -309,7 +380,7 @@ func stripECS(m *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
// 获取 DO/EDNS
|
||||
// getDOFlag 读取 DO(DNSSEC OK)位,构成缓存键的一部分。
|
||||
func getDOFlag(m *dns.Msg) bool {
|
||||
if o := m.IsEdns0(); o != nil {
|
||||
return o.Do()
|
||||
@@ -321,6 +392,8 @@ func getDOFlag(m *dns.Msg) bool {
|
||||
* 响应构造:使用客户端请求头构造 reply,复制上游内容
|
||||
******************************************************************/
|
||||
|
||||
// writeReply 根据客户端请求构造响应骨架:复制上游的 Answer/Ns/(非伪)Extra,
|
||||
// 并按照**客户端请求**重建 OPT(UDPSize/DO),避免直接采纳上游的 OPT/TSIG。
|
||||
func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) {
|
||||
if upstream == nil {
|
||||
dns.HandleFailed(w, req)
|
||||
@@ -335,7 +408,28 @@ func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) {
|
||||
out.Rcode = upstream.Rcode
|
||||
out.Answer = upstream.Answer
|
||||
out.Ns = upstream.Ns
|
||||
out.Extra = upstream.Extra
|
||||
|
||||
// 复制上游的非伪 RR 额外记录;OPT/TSIG 不透传
|
||||
extras := make([]dns.RR, 0, len(upstream.Extra))
|
||||
for _, rr := range upstream.Extra {
|
||||
if isPseudo(rr) {
|
||||
continue
|
||||
}
|
||||
extras = append(extras, rr)
|
||||
}
|
||||
// 基于客户端请求镜像 EDNS(UDPSize + DO),保持传输一致性
|
||||
if ro := req.IsEdns0(); ro != nil {
|
||||
o := new(dns.OPT)
|
||||
o.Hdr.Name = "."
|
||||
o.Hdr.Rrtype = dns.TypeOPT
|
||||
o.SetUDPSize(ro.UDPSize())
|
||||
if ro.Do() {
|
||||
o.SetDo()
|
||||
}
|
||||
// 如需转发 NSID/COOKIE 等可在此附加;为减少复杂性与缓存污染,建议保守最小集合
|
||||
extras = append(extras, o)
|
||||
}
|
||||
out.Extra = extras
|
||||
out.Compress = true
|
||||
if err := w.WriteMsg(out); err != nil {
|
||||
log.Printf("[write] WriteMsg error: %v", err)
|
||||
@@ -346,6 +440,8 @@ func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) {
|
||||
* 处理器
|
||||
******************************************************************/
|
||||
|
||||
// handleDNS 为每个请求执行:日志 → (可选)剥离 ECS → 缓存命中 → 上游查询 → 写缓存 → 回写。
|
||||
// 注意:缓存键包含 DO/CD;同时通过 tryCacheRead 回填剩余 TTL。
|
||||
func handleDNS(
|
||||
upstreams []string,
|
||||
cacheMaxTTL, timeout time.Duration,
|
||||
@@ -360,7 +456,7 @@ func handleDNS(
|
||||
}
|
||||
q := r.Question[0]
|
||||
|
||||
// 记录请求
|
||||
// 基本访问日志:类型/域名/ID/CD/DO/来源 IP:Port
|
||||
log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)",
|
||||
dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr())
|
||||
|
||||
@@ -369,10 +465,10 @@ func handleDNS(
|
||||
stripECS(r)
|
||||
}
|
||||
|
||||
// 缓存键
|
||||
// 缓存键(域名小写 + QTYPE/QCLASS + DO/CD)
|
||||
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled)
|
||||
|
||||
// 1) 缓存命中
|
||||
// 1) 缓存命中:命中即快速返回
|
||||
if cached, ok := tryCacheRead(key); ok {
|
||||
log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
|
||||
writeReply(w, r, cached)
|
||||
@@ -380,7 +476,7 @@ func handleDNS(
|
||||
}
|
||||
log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
|
||||
|
||||
// 2) 上游查询(带 context 取消 & TCP 可选回退)
|
||||
// 2) 上游查询(带 context 取消 & TCP 可选回退);并发向多个上游竞速
|
||||
ctx := context.Background()
|
||||
resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback)
|
||||
if resp == nil {
|
||||
@@ -389,10 +485,10 @@ func handleDNS(
|
||||
return
|
||||
}
|
||||
|
||||
// 3) 写缓存
|
||||
// 3) 写缓存(负面/正面均处理;剥离伪 RR)
|
||||
cacheWrite(key, resp, cacheMaxTTL)
|
||||
|
||||
// 4) 回写
|
||||
// 4) 回写给客户端,并打印 Answer 方便调试
|
||||
for _, ans := range resp.Answer {
|
||||
log.Printf("[answer] %s", ans.String())
|
||||
}
|
||||
@@ -400,14 +496,10 @@ func handleDNS(
|
||||
}
|
||||
}
|
||||
|
||||
/******************************************************************
|
||||
* 主程序
|
||||
******************************************************************/
|
||||
|
||||
func main() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
rand.Seed(time.Now().UnixNano()) // 随机源:用于上游列表打乱
|
||||
|
||||
// 参数
|
||||
// 命令行参数:可根据部署场景调整默认值
|
||||
certFile := flag.String("cert", "server.crt", "TLS 证书文件路径 (.cer/.crt)")
|
||||
keyFile := flag.String("key", "server.key", "TLS 私钥文件路径 (.key)")
|
||||
addr := flag.String("addr", ":853", "DoT 监听地址(默认 :853)")
|
||||
@@ -422,13 +514,13 @@ func main() {
|
||||
|
||||
initLogger(*verbose)
|
||||
|
||||
// 证书
|
||||
// 加载 TLS 证书/私钥;用于 DoT(RFC 7858)监听
|
||||
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
|
||||
if err != nil {
|
||||
log.Fatalf("[fatal] failed to load cert: %v", err)
|
||||
}
|
||||
|
||||
// 上游
|
||||
// 解析上游地址:支持“host”或“host:port”,缺省端口补 53。
|
||||
var upstreams []string
|
||||
for _, s := range strings.Split(*upstreamStr, ",") {
|
||||
if t := strings.TrimSpace(s); t != "" {
|
||||
@@ -442,20 +534,20 @@ func main() {
|
||||
log.Fatal("[fatal] no upstream DNS servers provided")
|
||||
}
|
||||
|
||||
// 全局 UDP 客户端(UDPSize 放大;不设置 Timeout,使用 ExchangeContext 控制)
|
||||
// 全局 UDP 客户端(不设置 Client.Timeout,改用 ExchangeContext 控制超时)
|
||||
udpClient = &dns.Client{
|
||||
Net: "udp",
|
||||
UDPSize: 4096,
|
||||
UDPSize: 4096, // 放大到 4K,减小 UDP 截断概率
|
||||
}
|
||||
|
||||
// context 用于优雅退出
|
||||
// context 用于优雅退出(SIGINT/SIGTERM 收到后取消)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// 启动缓存清理器
|
||||
// 启动缓存清理器(后台 goroutine)
|
||||
startCacheCleaner(ctx)
|
||||
|
||||
// mux/handler
|
||||
// 注册处理器:使用 ServeMux 将所有查询交给 handleDNS
|
||||
mux := dns.NewServeMux()
|
||||
mux.HandleFunc(".", handleDNS(
|
||||
upstreams,
|
||||
@@ -466,7 +558,8 @@ func main() {
|
||||
*allowTCPFallback,
|
||||
))
|
||||
|
||||
// DoT 服务(启用 TLS1.3;不手动设 CipherSuites 以使用 Go 默认安全套件)
|
||||
// 构造 DoT(tcp-tls)服务器:显式开启 TLS1.2/1.3,增加读写超时防止慢连接。
|
||||
// NextProtos: "dot"(可选,部分客户端用于 ALPN 检测)
|
||||
srv := &dns.Server{
|
||||
Addr: *addr,
|
||||
Net: "tcp-tls",
|
||||
@@ -474,14 +567,18 @@ func main() {
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: tls.VersionTLS12, // 允许 1.2/1.3(默认启用 1.3)
|
||||
// 不设置 CipherSuites,交由 Go 自动选择(TLS1.3 有自身套件)
|
||||
NextProtos: []string{"dot"}, // 可选:显式 ALPN
|
||||
},
|
||||
Handler: mux,
|
||||
Handler: mux,
|
||||
ReadTimeout: 10 * time.Second, // 防止慢连接
|
||||
WriteTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
// 捕获信号优雅退出
|
||||
// 捕获信号以便优雅退出(关闭监听、结束后台协程)
|
||||
stop := make(chan os.Signal, 1)
|
||||
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// 异步启动服务,错误通过 errCh 返回
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
log.Printf("🚀 starting DNS-over-TLS on %s", *addr)
|
||||
@@ -490,11 +587,12 @@ func main() {
|
||||
errCh <- srv.ListenAndServe()
|
||||
}()
|
||||
|
||||
// 等待退出信号或服务器错误
|
||||
select {
|
||||
case sig := <-stop:
|
||||
log.Printf("[shutdown] caught signal: %s", sig)
|
||||
cancel()
|
||||
// miekg/dns 提供 Shutdown();部分版本没有 ShutdownContext,这里用 Shutdown()
|
||||
// miekg/dns 提供 Shutdown();部分版本无 ShutdownContext,这里用 Shutdown()
|
||||
if err := srv.Shutdown(); err != nil {
|
||||
log.Printf("[shutdown] server shutdown error: %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user