```
feat(cache): 优化缓存键生成与缓存写入逻辑 - 引入更精确的缓存键计算方式,包含 QTYPE、QCLASS、DO 和 CD 标志 - 实现负面缓存(NXDOMAIN/NODATA)支持,遵循 RFC 2308 规范 - 改进缓存清理机制,在 TTL 为 0 时主动删除过期条目 - 添加日志初始化函数,支持 verbose 模式显示源码位置 - 重构上游查询逻辑,支持 context 控制超时和 TCP 回退 - 增加 ECS(EDNS Client Subnet)剥离选项以增强隐私保护 - 调整命令行参数默认值及日志输出格式,提升可读性与调试体验 ```
This commit is contained in:
474
main.go
474
main.go
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"flag"
|
"flag"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"os"
|
"os"
|
||||||
@@ -16,22 +17,31 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
/*********************************
|
/******************************************************************
|
||||||
* 缓存结构与全局对象
|
* 日志初始化
|
||||||
*********************************/
|
******************************************************************/
|
||||||
|
|
||||||
type cacheEntry struct {
|
func initLogger(verbose bool) {
|
||||||
msg *dns.Msg // 缓存的完整响应(深拷贝)
|
flags := log.Ldate | log.Ltime | log.Lmicroseconds
|
||||||
expireAt time.Time // 过期时间(由动态 TTL 决定)
|
if verbose {
|
||||||
|
flags |= log.Lshortfile
|
||||||
|
}
|
||||||
|
log.SetFlags(flags)
|
||||||
|
}
|
||||||
|
|
||||||
|
/******************************************************************
|
||||||
|
* 缓存结构
|
||||||
|
******************************************************************/
|
||||||
|
|
||||||
|
type cacheEntry struct {
|
||||||
|
msg *dns.Msg // 上游完整响应(拷贝存储)
|
||||||
|
expireAt time.Time // 过期时间
|
||||||
}
|
}
|
||||||
|
|
||||||
// 并发安全缓存
|
|
||||||
var cache sync.Map
|
var cache sync.Map
|
||||||
|
|
||||||
// 后台清理:每隔 N 分钟清一次
|
|
||||||
const cacheCleanupInterval = 5 * time.Minute
|
const cacheCleanupInterval = 5 * time.Minute
|
||||||
|
|
||||||
// 启动带 context 的缓存清理器(优雅退出)
|
|
||||||
func startCacheCleaner(ctx context.Context) {
|
func startCacheCleaner(ctx context.Context) {
|
||||||
go func() {
|
go func() {
|
||||||
ticker := time.NewTicker(cacheCleanupInterval)
|
ticker := time.NewTicker(cacheCleanupInterval)
|
||||||
@@ -52,19 +62,32 @@ func startCacheCleaner(ctx context.Context) {
|
|||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
log.Printf("[Cache] Cleaned %d expired entries", n)
|
log.Printf("[cache] cleaned %d expired entries", n)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 生成缓存键
|
// 计算缓存键:name + type + class + DO + CD
|
||||||
func cacheKey(name string, qtype uint16) string {
|
func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
|
||||||
return strings.ToLower(name) + ":" + dns.TypeToString[qtype]
|
var b strings.Builder
|
||||||
|
b.Grow(len(q.Name) + 32)
|
||||||
|
b.WriteString(strings.ToLower(q.Name))
|
||||||
|
b.WriteString("|T=")
|
||||||
|
b.WriteString(dns.TypeToString[q.Qtype])
|
||||||
|
b.WriteString("|C=")
|
||||||
|
b.WriteString(dns.ClassToString[q.Qclass])
|
||||||
|
if do {
|
||||||
|
b.WriteString("|DO")
|
||||||
|
}
|
||||||
|
if cd {
|
||||||
|
b.WriteString("|CD")
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 读取缓存;命中则回填剩余 TTL 并返回拷贝
|
// 命中缓存:回填剩余 TTL(不超过记录自身 TTL)
|
||||||
func tryCacheRead(key string) (*dns.Msg, bool) {
|
func tryCacheRead(key string) (*dns.Msg, bool) {
|
||||||
v, ok := cache.Load(key)
|
v, ok := cache.Load(key)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -78,7 +101,10 @@ func tryCacheRead(key string) (*dns.Msg, bool) {
|
|||||||
}
|
}
|
||||||
out := e.msg.Copy()
|
out := e.msg.Copy()
|
||||||
remaining := uint32(e.expireAt.Sub(now).Seconds())
|
remaining := uint32(e.expireAt.Sub(now).Seconds())
|
||||||
// 回填剩余 TTL,避免客户端收到过期 TTL
|
if remaining == 0 {
|
||||||
|
cache.Delete(key)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
for i := range out.Answer {
|
for i := range out.Answer {
|
||||||
if out.Answer[i].Header().Ttl > remaining {
|
if out.Answer[i].Header().Ttl > remaining {
|
||||||
out.Answer[i].Header().Ttl = remaining
|
out.Answer[i].Header().Ttl = remaining
|
||||||
@@ -97,19 +123,38 @@ func tryCacheRead(key string) (*dns.Msg, bool) {
|
|||||||
return out, true
|
return out, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// 写缓存:以「上游最小 TTL」与「配置上限 TTL」取较小值
|
// 负面缓存 TTL(RFC 2308):NXDOMAIN 或 NODATA 使用 SOA.MINIMUM 与 SOA TTL 的较小者,再与配置上限取 min
|
||||||
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
|
func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
|
||||||
if in == nil {
|
// NXDOMAIN,或 NOERROR 但 Answer 为空(NODATA)
|
||||||
return
|
if m.Rcode != dns.RcodeNameError && !(m.Rcode == dns.RcodeSuccess && len(m.Answer) == 0) {
|
||||||
|
return 0, false
|
||||||
}
|
}
|
||||||
// 可按需缓存 NXDOMAIN;这里允许缓存 NOERROR 与 NXDOMAIN
|
var soa *dns.SOA
|
||||||
if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError {
|
for _, rr := range append(m.Ns, m.Extra...) {
|
||||||
return
|
if s, ok := rr.(*dns.SOA); ok {
|
||||||
|
soa = s
|
||||||
|
break
|
||||||
}
|
}
|
||||||
// 计算报文中的最小 TTL(Answer/Ns/Extra)
|
}
|
||||||
|
if soa == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
ttl := soa.Hdr.Ttl
|
||||||
|
if soa.Minttl < ttl {
|
||||||
|
ttl = soa.Minttl
|
||||||
|
}
|
||||||
|
capTTL := uint32(maxTTL.Seconds())
|
||||||
|
if capTTL > 0 && ttl > capTTL {
|
||||||
|
ttl = capTTL
|
||||||
|
}
|
||||||
|
return ttl, ttl > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// 普通(正向)响应的最小 TTL(Answer/Ns/Extra)
|
||||||
|
func minRRsetTTL(m *dns.Msg) (uint32, bool) {
|
||||||
minTTL := uint32(0)
|
minTTL := uint32(0)
|
||||||
hasTTL := false
|
hasTTL := false
|
||||||
for _, sec := range [][]dns.RR{in.Answer, in.Ns, in.Extra} {
|
for _, sec := range [][]dns.RR{m.Answer, m.Ns, m.Extra} {
|
||||||
for _, rr := range sec {
|
for _, rr := range sec {
|
||||||
ttl := rr.Header().Ttl
|
ttl := rr.Header().Ttl
|
||||||
if !hasTTL || ttl < minTTL {
|
if !hasTTL || ttl < minTTL {
|
||||||
@@ -118,19 +163,39 @@ func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 若无 TTL,可用配置上限作为兜底(也可选择不缓存)
|
return minTTL, hasTTL
|
||||||
cfgTTL := uint32(maxTTL.Seconds())
|
}
|
||||||
var finalTTL uint32
|
|
||||||
switch {
|
// 写缓存:先处理负面缓存,再处理正面缓存
|
||||||
case !hasTTL && cfgTTL > 0:
|
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
|
||||||
finalTTL = cfgTTL
|
if in == nil {
|
||||||
case hasTTL && cfgTTL > 0 && minTTL > cfgTTL:
|
|
||||||
finalTTL = cfgTTL
|
|
||||||
case hasTTL:
|
|
||||||
finalTTL = minTTL
|
|
||||||
default:
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// 仅缓存 NOERROR / NXDOMAIN,其余不缓存
|
||||||
|
if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 负面缓存
|
||||||
|
if ttl, ok := negativeTTL(in, maxTTL); ok && ttl > 0 {
|
||||||
|
expire := time.Now().Add(time.Duration(ttl) * time.Second)
|
||||||
|
cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 正向缓存:minTTL 与 maxTTL 取较小
|
||||||
|
minTTL, ok := minRRsetTTL(in)
|
||||||
|
if !ok {
|
||||||
|
// 没有 TTL 时可用上限兜底(也可选择不缓存)
|
||||||
|
if maxTTL > 0 {
|
||||||
|
expire := time.Now().Add(maxTTL)
|
||||||
|
cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfgTTL := uint32(maxTTL.Seconds())
|
||||||
|
finalTTL := minTTL
|
||||||
|
if cfgTTL > 0 && finalTTL > cfgTTL {
|
||||||
|
finalTTL = cfgTTL
|
||||||
|
}
|
||||||
if finalTTL == 0 {
|
if finalTTL == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -138,219 +203,306 @@ func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
|
|||||||
cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire})
|
cache.Store(key, &cacheEntry{msg: in.Copy(), expireAt: expire})
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************
|
/******************************************************************
|
||||||
* 上游查询与并发控制
|
* 上游查询(带 context 取消、并发上限、UDP→TCP 回退)
|
||||||
*********************************/
|
******************************************************************/
|
||||||
|
|
||||||
// 全局可复用的 DNS 客户端(默认 UDP)
|
// 全局可复用 UDP 客户端
|
||||||
var dnsClient *dns.Client
|
var udpClient *dns.Client
|
||||||
|
|
||||||
// 并发上限通过信号量限制
|
func shuffled(xs []string) []string {
|
||||||
func queryUpstreamsLimited(r *dns.Msg, upstreams []string, timeout time.Duration, maxParallel int) *dns.Msg {
|
out := make([]string, len(xs))
|
||||||
if maxParallel <= 0 {
|
copy(out, xs)
|
||||||
maxParallel = 1
|
|
||||||
}
|
|
||||||
ch := make(chan *dns.Msg, len(upstreams))
|
|
||||||
sem := make(chan struct{}, maxParallel)
|
|
||||||
|
|
||||||
// 在 UDP 上查询,遇到截断再 TCP fallback
|
|
||||||
queryOnce := func(server string) *dns.Msg {
|
|
||||||
resp, _, err := dnsClient.Exchange(r, server)
|
|
||||||
if err == nil && resp != nil && resp.Truncated {
|
|
||||||
// UDP 截断,尝试 TCP
|
|
||||||
log.Printf("[Info] UDP truncated, retry TCP: %s", server)
|
|
||||||
tcpClient := *dnsClient
|
|
||||||
tcpClient.Net = "tcp"
|
|
||||||
resp, _, err = tcpClient.Exchange(r, server)
|
|
||||||
}
|
|
||||||
if err != nil || resp == nil {
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[Warn] Upstream %s failed: %v", server, err)
|
|
||||||
} else {
|
|
||||||
log.Printf("[Warn] Upstream %s failed: nil response", server)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// 可选:丢弃 SERVFAIL
|
|
||||||
if resp.Rcode == dns.RcodeServerFailure {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return resp
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, server := range upstreams {
|
|
||||||
sem <- struct{}{}
|
|
||||||
go func(svr string) {
|
|
||||||
defer func() { <-sem }()
|
|
||||||
resp := queryOnce(svr)
|
|
||||||
// 非阻塞/限时写入,防止消费者意外退出导致阻塞
|
|
||||||
select {
|
|
||||||
case ch <- resp:
|
|
||||||
case <-time.After(1 * time.Second):
|
|
||||||
}
|
|
||||||
}(server)
|
|
||||||
}
|
|
||||||
|
|
||||||
timer := time.NewTimer(timeout)
|
|
||||||
defer timer.Stop()
|
|
||||||
|
|
||||||
// 在超时内返回第一个非空的结果
|
|
||||||
for i := 0; i < len(upstreams); i++ {
|
|
||||||
select {
|
|
||||||
case resp := <-ch:
|
|
||||||
if resp != nil {
|
|
||||||
return resp
|
|
||||||
}
|
|
||||||
case <-timer.C:
|
|
||||||
log.Printf("[Error] Upstream query timeout after %v", timeout)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
/*********************************
|
|
||||||
* DNS 处理逻辑
|
|
||||||
*********************************/
|
|
||||||
|
|
||||||
func shuffled(slice []string) []string {
|
|
||||||
out := make([]string, len(slice))
|
|
||||||
copy(out, slice)
|
|
||||||
rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] })
|
rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] })
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleDNS(upstreams []string, cacheMaxTTL, timeout time.Duration, maxParallel int) dns.HandlerFunc {
|
func queryUpstreamsLimited(
|
||||||
|
ctx context.Context,
|
||||||
|
req *dns.Msg,
|
||||||
|
upstreams []string,
|
||||||
|
timeout time.Duration,
|
||||||
|
maxParallel int,
|
||||||
|
allowTCPFallback bool,
|
||||||
|
) *dns.Msg {
|
||||||
|
if maxParallel <= 0 {
|
||||||
|
maxParallel = 1
|
||||||
|
}
|
||||||
|
servers := shuffled(upstreams)
|
||||||
|
|
||||||
|
// 每次查询一个带超时的子 context;拿到首个有效结果后取消。
|
||||||
|
cctx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
type result struct {
|
||||||
|
msg *dns.Msg
|
||||||
|
}
|
||||||
|
ch := make(chan result, len(servers))
|
||||||
|
sem := make(chan struct{}, maxParallel)
|
||||||
|
|
||||||
|
execOne := func(svr string) {
|
||||||
|
defer func() { <-sem }()
|
||||||
|
// UDP 查询(带 context)
|
||||||
|
resp, _, err := udpClient.ExchangeContext(cctx, req, svr)
|
||||||
|
if err == nil && resp != nil && resp.Truncated && allowTCPFallback {
|
||||||
|
// TCP 回退
|
||||||
|
log.Printf("[upstream] UDP truncated, retry TCP: %s", svr)
|
||||||
|
tcpClient := *udpClient
|
||||||
|
tcpClient.Net = "tcp"
|
||||||
|
resp, _, err = tcpClient.ExchangeContext(cctx, req, svr)
|
||||||
|
}
|
||||||
|
if err != nil || resp == nil {
|
||||||
|
if err != nil {
|
||||||
|
if cctx.Err() == nil {
|
||||||
|
log.Printf("[upstream] %s error: %v", svr, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Printf("[upstream] %s nil response", svr)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 丢弃 SERVFAIL
|
||||||
|
if resp.Rcode == dns.RcodeServerFailure {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case ch <- result{msg: resp}:
|
||||||
|
case <-cctx.Done():
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, s := range servers {
|
||||||
|
sem <- struct{}{}
|
||||||
|
go execOne(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 返回第一个非空结果,并 cancel 其他 goroutine
|
||||||
|
for i := 0; i < len(servers); i++ {
|
||||||
|
select {
|
||||||
|
case r := <-ch:
|
||||||
|
if r.msg != nil {
|
||||||
|
cancel()
|
||||||
|
return r.msg
|
||||||
|
}
|
||||||
|
case <-cctx.Done():
|
||||||
|
log.Printf("[upstream] timeout after %v", timeout)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
/******************************************************************
|
||||||
|
* EDNS(0) / ECS 处理
|
||||||
|
******************************************************************/
|
||||||
|
|
||||||
|
// 去除 EDNS Client Subnet(避免缓存污染与隐私泄露)
|
||||||
|
func stripECS(m *dns.Msg) {
|
||||||
|
if o := m.IsEdns0(); o != nil {
|
||||||
|
var kept []dns.EDNS0
|
||||||
|
for _, e := range o.Option {
|
||||||
|
if _, isECS := e.(*dns.EDNS0_SUBNET); !isECS {
|
||||||
|
kept = append(kept, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
o.Option = kept
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取 DO/EDNS
|
||||||
|
func getDOFlag(m *dns.Msg) bool {
|
||||||
|
if o := m.IsEdns0(); o != nil {
|
||||||
|
return o.Do()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
/******************************************************************
|
||||||
|
* 响应构造:使用客户端请求头构造 reply,复制上游内容
|
||||||
|
******************************************************************/
|
||||||
|
|
||||||
|
func writeReply(w dns.ResponseWriter, req, upstream *dns.Msg) {
|
||||||
|
if upstream == nil {
|
||||||
|
dns.HandleFailed(w, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out := new(dns.Msg)
|
||||||
|
out.SetReply(req)
|
||||||
|
out.Authoritative = false
|
||||||
|
out.RecursionAvailable = upstream.RecursionAvailable
|
||||||
|
out.AuthenticatedData = upstream.AuthenticatedData
|
||||||
|
out.CheckingDisabled = req.CheckingDisabled // 反映客户端 CD 位
|
||||||
|
out.Rcode = upstream.Rcode
|
||||||
|
out.Answer = upstream.Answer
|
||||||
|
out.Ns = upstream.Ns
|
||||||
|
out.Extra = upstream.Extra
|
||||||
|
out.Compress = true
|
||||||
|
if err := w.WriteMsg(out); err != nil {
|
||||||
|
log.Printf("[write] WriteMsg error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/******************************************************************
|
||||||
|
* 处理器
|
||||||
|
******************************************************************/
|
||||||
|
|
||||||
|
func handleDNS(
|
||||||
|
upstreams []string,
|
||||||
|
cacheMaxTTL, timeout time.Duration,
|
||||||
|
maxParallel int,
|
||||||
|
stripECSBeforeForward bool,
|
||||||
|
allowTCPFallback bool,
|
||||||
|
) dns.HandlerFunc {
|
||||||
return func(w dns.ResponseWriter, r *dns.Msg) {
|
return func(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
dns.HandleFailed(w, r)
|
dns.HandleFailed(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
q := r.Question[0]
|
q := r.Question[0]
|
||||||
key := cacheKey(q.Name, q.Qtype)
|
|
||||||
|
// 记录请求
|
||||||
|
log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)",
|
||||||
|
dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr())
|
||||||
|
|
||||||
|
// 可选:去除 ECS(推荐)
|
||||||
|
if stripECSBeforeForward {
|
||||||
|
stripECS(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 缓存键
|
||||||
|
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled)
|
||||||
|
|
||||||
// 1) 缓存命中
|
// 1) 缓存命中
|
||||||
if cached, ok := tryCacheRead(key); ok {
|
if cached, ok := tryCacheRead(key); ok {
|
||||||
log.Printf("[Cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
|
log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
|
||||||
_ = w.WriteMsg(cached)
|
writeReply(w, r, cached)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Printf("[Cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
|
log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
|
||||||
|
|
||||||
// 2) 随机化上游并并发查询(带 fallback)
|
// 2) 上游查询(带 context 取消 & TCP 可选回退)
|
||||||
servers := shuffled(upstreams)
|
ctx := context.Background()
|
||||||
resp := queryUpstreamsLimited(r, servers, timeout, maxParallel)
|
resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback)
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
log.Printf("[Error] All upstreams failed for %s", q.Name)
|
log.Printf("[error] all upstreams failed for %s", q.Name)
|
||||||
dns.HandleFailed(w, r)
|
dns.HandleFailed(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3) 写入缓存(动态 TTL)
|
// 3) 写缓存
|
||||||
cacheWrite(key, resp, cacheMaxTTL)
|
cacheWrite(key, resp, cacheMaxTTL)
|
||||||
|
|
||||||
// 4) 返回结果
|
// 4) 回写
|
||||||
for _, ans := range resp.Answer {
|
for _, ans := range resp.Answer {
|
||||||
log.Printf("[Answer] %s", ans.String())
|
log.Printf("[answer] %s", ans.String())
|
||||||
}
|
}
|
||||||
_ = w.WriteMsg(resp)
|
writeReply(w, r, resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************
|
/******************************************************************
|
||||||
* 主程序(优雅退出)
|
* 主程序
|
||||||
*********************************/
|
******************************************************************/
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
rand.Seed(time.Now().UnixNano())
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
|
||||||
// 参数
|
// 参数
|
||||||
certFile := flag.String("cert", "aixiao.me.cer", "TLS 证书文件路径 (.cer/.crt)")
|
certFile := flag.String("cert", "server.crt", "TLS 证书文件路径 (.cer/.crt)")
|
||||||
keyFile := flag.String("key", "aixiao.me.key", "TLS 私钥文件路径 (.key)")
|
keyFile := flag.String("key", "server.key", "TLS 私钥文件路径 (.key)")
|
||||||
addr := flag.String("addr", ":853", "DoT 服务监听地址(默认 :853)")
|
addr := flag.String("addr", ":853", "DoT 监听地址(默认 :853)")
|
||||||
upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表(逗号分隔)")
|
upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表(逗号分隔)")
|
||||||
cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL(默认 60s;实际取 min(上游最小TTL, 本值))")
|
cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL(默认 60s;实际取 min(上游最小TTL, 本值))")
|
||||||
timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时(默认 3s)")
|
timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时(默认 3s)")
|
||||||
maxParallel := flag.Int("max-parallel", 3, "并发查询的上游数量上限")
|
maxParallel := flag.Int("max-parallel", 3, "并发查询的上游数量上限")
|
||||||
|
stripECSFlag := flag.Bool("strip-ecs", true, "转发上游前去除 EDNS Client Subnet")
|
||||||
|
allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时允许 TCP 回退")
|
||||||
|
verbose := flag.Bool("v", false, "verbose 日志(包含源码位置)")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
initLogger(*verbose)
|
||||||
|
|
||||||
// 证书
|
// 证书
|
||||||
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
|
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to load cert: %v", err)
|
log.Fatalf("[fatal] failed to load cert: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 上游列表
|
// 上游
|
||||||
raw := strings.Split(*upstreamStr, ",")
|
var upstreams []string
|
||||||
upstreams := make([]string, 0, len(raw))
|
for _, s := range strings.Split(*upstreamStr, ",") {
|
||||||
for _, s := range raw {
|
|
||||||
if t := strings.TrimSpace(s); t != "" {
|
if t := strings.TrimSpace(s); t != "" {
|
||||||
|
if !strings.Contains(t, ":") {
|
||||||
|
t = fmt.Sprintf("%s:53", t)
|
||||||
|
}
|
||||||
upstreams = append(upstreams, t)
|
upstreams = append(upstreams, t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(upstreams) == 0 {
|
if len(upstreams) == 0 {
|
||||||
log.Fatal("no upstream DNS servers provided")
|
log.Fatal("[fatal] no upstream DNS servers provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 全局 DNS 客户端(UDP,扩大 UDPSize;fallback 在查询函数中完成)
|
// 全局 UDP 客户端(UDPSize 放大;不设置 Timeout,使用 ExchangeContext 控制)
|
||||||
dnsClient = &dns.Client{
|
udpClient = &dns.Client{
|
||||||
Net: "udp",
|
Net: "udp",
|
||||||
UDPSize: 4096, // 防截断
|
UDPSize: 4096,
|
||||||
Timeout: *timeoutFlag,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// context 用于优雅退出与清理协程
|
// context 用于优雅退出
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// 启动缓存清理器
|
// 启动缓存清理器
|
||||||
startCacheCleaner(ctx)
|
startCacheCleaner(ctx)
|
||||||
|
|
||||||
// DNS 处理器
|
// mux/handler
|
||||||
mux := dns.NewServeMux()
|
mux := dns.NewServeMux()
|
||||||
mux.HandleFunc(".", handleDNS(upstreams, *cacheTTLFlag, *timeoutFlag, *maxParallel))
|
mux.HandleFunc(".", handleDNS(
|
||||||
|
upstreams,
|
||||||
|
*cacheTTLFlag,
|
||||||
|
*timeoutFlag,
|
||||||
|
*maxParallel,
|
||||||
|
*stripECSFlag,
|
||||||
|
*allowTCPFallback,
|
||||||
|
))
|
||||||
|
|
||||||
// DoT 服务器(TLS 会话缓存 + 安全套件 + TLS1.2+)
|
// DoT 服务(启用 TLS1.3;不手动设 CipherSuites 以使用 Go 默认安全套件)
|
||||||
srv := &dns.Server{
|
srv := &dns.Server{
|
||||||
Addr: *addr,
|
Addr: *addr,
|
||||||
Net: "tcp-tls",
|
Net: "tcp-tls",
|
||||||
TLSConfig: &tls.Config{
|
TLSConfig: &tls.Config{
|
||||||
Certificates: []tls.Certificate{cert},
|
Certificates: []tls.Certificate{cert},
|
||||||
ClientSessionCache: tls.NewLRUClientSessionCache(256),
|
MinVersion: tls.VersionTLS12, // 允许 1.2/1.3(默认启用 1.3)
|
||||||
MinVersion: tls.VersionTLS12,
|
// 不设置 CipherSuites,交由 Go 自动选择(TLS1.3 有自身套件)
|
||||||
CipherSuites: []uint16{
|
|
||||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
|
||||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 捕获退出信号,优雅关闭
|
// 捕获信号优雅退出
|
||||||
stop := make(chan os.Signal, 1)
|
stop := make(chan os.Signal, 1)
|
||||||
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
errCh := make(chan error, 1)
|
errCh := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
log.Printf("🚀 Starting DNS-over-TLS server on %s", *addr)
|
log.Printf("🚀 starting DNS-over-TLS on %s", *addr)
|
||||||
log.Printf("Upstreams=%v | MaxTTL=%s | Timeout=%s | MaxParallel=%d",
|
log.Printf(" upstreams=%v | cache_max_ttl=%s | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v",
|
||||||
upstreams, cacheTTLFlag.String(), timeoutFlag.String(), *maxParallel)
|
upstreams, cacheTTLFlag.String(), timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback)
|
||||||
errCh <- srv.ListenAndServe()
|
errCh <- srv.ListenAndServe()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case sig := <-stop:
|
case sig := <-stop:
|
||||||
log.Printf("[Shutdown] Caught signal: %s", sig)
|
log.Printf("[shutdown] caught signal: %s", sig)
|
||||||
cancel() // 结束清理器
|
cancel()
|
||||||
// 优雅关闭服务器
|
// miekg/dns 提供 Shutdown();部分版本没有 ShutdownContext,这里用 Shutdown()
|
||||||
// miekg/dns 提供 Shutdown();若你的版本支持 ShutdownContext,可改用带 ctx 的版本
|
|
||||||
if err := srv.Shutdown(); err != nil {
|
if err := srv.Shutdown(); err != nil {
|
||||||
log.Printf("[Shutdown] server shutdown error: %v", err)
|
log.Printf("[shutdown] server shutdown error: %v", err)
|
||||||
}
|
}
|
||||||
case err := <-errCh:
|
case err := <-errCh:
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("server error: %v", err)
|
log.Fatalf("[fatal] server error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Println("[Shutdown] Bye.")
|
|
||||||
|
log.Println("[bye] server stopped.")
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user