- 引入 `golang.org/x/net/idna` 实现 Unicode 域名转 ASCII(Punycode) - 黑名单加载支持通配符格式如 `*.example.com` - 支持解析 hosts 风格的文件(每行首列为 IP 地址时,其余列为域名) - 扩展 Scanner 缓冲区至 2MB 以适应大型 hosts 文件 - 注释处理优化,兼容 `#` 和 `;` 分隔符 - 加载后对规则排序并去重,提升匹配效率与一致性 fix(cache): 调整负面响应缓存逻辑与上游查询并发控制 - 明确区分 NXDOMAIN 与 NODATA 并正确处理 SOA 缺失情况 - 查询上游时引入更可靠的并发限制与超时机制 - UDP 截断时自动回退 TCP 查询 - 过滤无效 RCODE(如 SERVFAIL、REFUSED 等),防止污染结果 - 区分“全部失败”与“部分完成但无有效响应”,增强调试日志信息
791 lines
20 KiB
Go
791 lines
20 KiB
Go
// main.go
|
||
package main
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"flag"
|
||
"fmt"
|
||
"log"
|
||
"math/rand"
|
||
"net"
|
||
"os"
|
||
"os/signal"
|
||
"strings"
|
||
"sync"
|
||
"sync/atomic"
|
||
"syscall"
|
||
"time"
|
||
|
||
lru "github.com/hashicorp/golang-lru/v2"
|
||
"github.com/miekg/dns"
|
||
"golang.org/x/sync/singleflight"
|
||
)
|
||
|
||
var BuildDate = "unknown" // 由编译时注入
|
||
|
||
/******************************************************************
|
||
* 日志初始化
|
||
******************************************************************/
|
||
func initLogger(verbose bool) {
|
||
flags := log.Ldate | log.Ltime | log.Lmicroseconds
|
||
if verbose {
|
||
flags |= log.Lshortfile
|
||
}
|
||
log.SetFlags(flags)
|
||
}
|
||
|
||
/******************************************************************
|
||
* 缓存结构(支持 TTL + LRU)
|
||
******************************************************************/
|
||
|
||
type cacheEntry struct {
|
||
msg *dns.Msg
|
||
expireAt time.Time
|
||
// EDNS metadata (to reproduce Extended RCODE / EDE on cache hits)
|
||
ednsPresent bool
|
||
ednsVersion uint8
|
||
ednsExtRcode uint16
|
||
ednsEDE []*dns.EDNS0_EDE
|
||
}
|
||
|
||
var (
|
||
cache *lru.Cache[string, *cacheEntry]
|
||
cacheMutex sync.RWMutex
|
||
inflight singleflight.Group
|
||
)
|
||
|
||
const (
|
||
cacheCleanupInterval = 5 * time.Minute
|
||
defaultCacheSize = 10000 // 默认最大缓存条目数
|
||
)
|
||
|
||
// startCacheCleaner 定期清理过期缓存(在删除前二次校验)
|
||
func startCacheCleaner(ctx context.Context) {
|
||
go func() {
|
||
ticker := time.NewTicker(cacheCleanupInterval)
|
||
defer ticker.Stop()
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
return
|
||
case <-ticker.C:
|
||
now := time.Now()
|
||
var toDelete []string
|
||
|
||
cacheMutex.RLock()
|
||
for _, k := range cache.Keys() {
|
||
if v, ok := cache.Peek(k); ok && now.After(v.expireAt) {
|
||
toDelete = append(toDelete, k)
|
||
}
|
||
}
|
||
cacheMutex.RUnlock()
|
||
|
||
if len(toDelete) > 0 {
|
||
pruned := 0
|
||
cacheMutex.Lock()
|
||
for _, k := range toDelete {
|
||
if v, ok := cache.Peek(k); ok && now.After(v.expireAt) {
|
||
cache.Remove(k)
|
||
pruned++
|
||
}
|
||
}
|
||
cacheMutex.Unlock()
|
||
if pruned > 0 {
|
||
log.Printf("[cache] cleaned %d expired entries", pruned)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
|
||
func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
|
||
var b strings.Builder
|
||
b.Grow(len(q.Name) + 32)
|
||
// 采用规范化域名,避免尾随点/IDNA/大小写造成的重复键
|
||
b.WriteString(dns.CanonicalName(q.Name))
|
||
b.WriteString("|T=")
|
||
b.WriteString(dns.TypeToString[q.Qtype])
|
||
b.WriteString("|C=")
|
||
b.WriteString(dns.ClassToString[q.Qclass])
|
||
if do {
|
||
b.WriteString("|DO")
|
||
}
|
||
if cd {
|
||
b.WriteString("|CD")
|
||
}
|
||
return b.String()
|
||
}
|
||
|
||
// ecsKeyPart 在未 strip ECS 时,把 ECS 归一化后的“网络”信息并入缓存 key
|
||
// 为最小改动:当 strip=true(即启用去 ECS)时直接返回空字符串
|
||
func ecsKeyPart(m *dns.Msg, strip bool) string {
|
||
if strip {
|
||
return ""
|
||
}
|
||
o := m.IsEdns0()
|
||
if o == nil {
|
||
return ""
|
||
}
|
||
for _, opt := range o.Option {
|
||
s, ok := opt.(*dns.EDNS0_SUBNET)
|
||
if !ok {
|
||
continue
|
||
}
|
||
fam := s.Family
|
||
pfx := int(s.SourceNetmask)
|
||
addr := append(net.IP(nil), s.Address...) // 拷贝以免原切片被改
|
||
switch fam {
|
||
case 1: // IPv4
|
||
ip := addr.To4()
|
||
if ip != nil {
|
||
mask := net.CIDRMask(pfx, 32)
|
||
ip = ip.Mask(mask)
|
||
return fmt.Sprintf("|ECS=%d/%d/%x", fam, s.SourceNetmask, ip)
|
||
}
|
||
case 2: // IPv6
|
||
ip := addr.To16()
|
||
if ip != nil {
|
||
mask := net.CIDRMask(pfx, 128)
|
||
for i := 0; i < 16; i++ {
|
||
ip[i] &= mask[i]
|
||
}
|
||
return fmt.Sprintf("|ECS=%d/%d/%x", fam, s.SourceNetmask, ip)
|
||
}
|
||
}
|
||
// 回退:不做掩码
|
||
return fmt.Sprintf("|ECS=%d/%d/%x", fam, s.SourceNetmask, addr)
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func isPseudo(rr dns.RR) bool {
|
||
switch rr.(type) {
|
||
case *dns.OPT, *dns.TSIG:
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
// clone and extract EDNS metadata (present, version, ext-rcode, all EDEs)
|
||
func cloneEDE(in *dns.EDNS0_EDE) *dns.EDNS0_EDE {
|
||
if in == nil {
|
||
return nil
|
||
}
|
||
cp := *in
|
||
return &cp
|
||
}
|
||
|
||
func extractEDNSMeta(m *dns.Msg) (present bool, version uint8, ext uint16, ede []*dns.EDNS0_EDE) {
|
||
if o := m.IsEdns0(); o != nil {
|
||
present = true
|
||
version = o.Version()
|
||
ext = uint16(o.ExtendedRcode())
|
||
for _, opt := range o.Option {
|
||
if e, ok := opt.(*dns.EDNS0_EDE); ok {
|
||
ede = append(ede, cloneEDE(e))
|
||
}
|
||
}
|
||
}
|
||
return
|
||
}
|
||
|
||
// 读取缓存(Get 在写锁下;在锁外调整 TTL;返回 EDNS 元数据)
|
||
func tryCacheRead(key string) (*dns.Msg, *cacheEntry, bool) {
|
||
now := time.Now()
|
||
|
||
cacheMutex.Lock()
|
||
e, ok := cache.Get(key) // Get 会更新 LRU,必须在写锁下
|
||
if !ok {
|
||
cacheMutex.Unlock()
|
||
return nil, nil, false
|
||
}
|
||
if now.After(e.expireAt) {
|
||
cache.Remove(key)
|
||
cacheMutex.Unlock()
|
||
return nil, nil, false
|
||
}
|
||
// 拷贝副本,在锁外改 TTL,减少临界区时间
|
||
out := e.msg.Copy()
|
||
expireAt := e.expireAt
|
||
cacheMutex.Unlock()
|
||
|
||
remaining := uint32(expireAt.Sub(now).Seconds())
|
||
if remaining == 0 {
|
||
cacheMutex.Lock()
|
||
cache.Remove(key)
|
||
cacheMutex.Unlock()
|
||
return nil, nil, false
|
||
}
|
||
|
||
for _, sec := range [][]dns.RR{out.Answer, out.Ns, out.Extra} {
|
||
for _, rr := range sec {
|
||
if isPseudo(rr) {
|
||
continue
|
||
}
|
||
if rr.Header().Ttl > remaining {
|
||
rr.Header().Ttl = remaining
|
||
}
|
||
}
|
||
}
|
||
return out, e, true
|
||
}
|
||
|
||
// hasAnswerForType 判断报文中是否存在回答“请求类型”的 RRset
|
||
func hasAnswerForType(m *dns.Msg, q dns.Question) bool {
|
||
for _, rr := range m.Answer {
|
||
h := rr.Header()
|
||
if h.Rrtype == q.Qtype && strings.EqualFold(h.Name, q.Name) {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// 计算负面 TTL(正确识别 NODATA,包括 CNAME 等场景)
|
||
func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
|
||
// NXDOMAIN:肯定是负面
|
||
if m.Rcode != dns.RcodeNameError {
|
||
// 不是 NXDOMAIN,则仅当 NOERROR 但没有“匹配 QTYPE 的答案”时才是 NODATA
|
||
if m.Rcode != dns.RcodeSuccess || len(m.Question) == 0 || hasAnswerForType(m, m.Question[0]) {
|
||
return 0, false
|
||
}
|
||
}
|
||
|
||
// 按 RFC 2308,从 Authority(Ns)优先取 SOA(多数实现都只放在 Authority)
|
||
var soa *dns.SOA
|
||
for _, rr := range m.Ns {
|
||
if s, ok := rr.(*dns.SOA); ok {
|
||
soa = s
|
||
break
|
||
}
|
||
}
|
||
// 兼容性:偶尔也有人把 SOA 放 Extra(不规范,但为了兼容可以兜底看看)
|
||
if soa == nil {
|
||
for _, rr := range m.Extra {
|
||
if s, ok := rr.(*dns.SOA); ok {
|
||
soa = s
|
||
break
|
||
}
|
||
}
|
||
}
|
||
if soa == nil {
|
||
// 建议:无 SOA 时不做负面缓存(返回 0,false)
|
||
return 0, false
|
||
}
|
||
|
||
// 负面 TTL 取 min(SOA.MINIMUM, SOA 自身 TTL),再与配置上限比较
|
||
ttl := soa.Hdr.Ttl
|
||
if soa.Minttl < ttl {
|
||
ttl = soa.Minttl
|
||
}
|
||
capTTL := uint32(maxTTL.Seconds())
|
||
if capTTL > 0 && ttl > capTTL {
|
||
ttl = capTTL
|
||
}
|
||
return ttl, ttl > 0
|
||
}
|
||
|
||
func minRRsetTTL(m *dns.Msg) (uint32, bool) {
|
||
minTTL := uint32(0)
|
||
hasTTL := false
|
||
// 优先 Answer -> Ns;若都为空,再考虑 Extra(排除伪记录)
|
||
for _, sec := range [][]dns.RR{m.Answer, m.Ns} {
|
||
for _, rr := range sec {
|
||
if isPseudo(rr) {
|
||
continue
|
||
}
|
||
ttl := rr.Header().Ttl
|
||
if !hasTTL || ttl < minTTL {
|
||
minTTL = ttl
|
||
hasTTL = true
|
||
}
|
||
}
|
||
}
|
||
if !hasTTL {
|
||
for _, rr := range m.Extra {
|
||
if isPseudo(rr) {
|
||
continue
|
||
}
|
||
ttl := rr.Header().Ttl
|
||
if !hasTTL || ttl < minTTL {
|
||
minTTL = ttl
|
||
hasTTL = true
|
||
}
|
||
}
|
||
}
|
||
return minTTL, hasTTL
|
||
}
|
||
|
||
func stripPseudoExtras(m *dns.Msg) {
|
||
if len(m.Extra) == 0 {
|
||
return
|
||
}
|
||
out := m.Extra[:0]
|
||
for _, rr := range m.Extra {
|
||
if isPseudo(rr) {
|
||
continue
|
||
}
|
||
out = append(out, rr)
|
||
}
|
||
m.Extra = out
|
||
}
|
||
|
||
// 写缓存(保存 EDNS 元数据,命中时可重建扩展 RCODE/EDE)
|
||
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
|
||
if in == nil {
|
||
return
|
||
}
|
||
if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError {
|
||
return
|
||
}
|
||
var ttl uint32
|
||
var ok bool
|
||
|
||
// 判断负面响应(NXDOMAIN 或 NODATA)
|
||
neg, isNodata := in.Rcode == dns.RcodeNameError, false
|
||
if in.Rcode == dns.RcodeSuccess && len(in.Question) > 0 && !hasAnswerForType(in, in.Question[0]) {
|
||
isNodata = true
|
||
}
|
||
|
||
if ttl, ok = negativeTTL(in, maxTTL); !ok {
|
||
if neg || isNodata {
|
||
return // 负面但无 SOA → 不缓存
|
||
}
|
||
|
||
minTTL, has := minRRsetTTL(in)
|
||
if has {
|
||
cfgTTL := uint32(maxTTL.Seconds())
|
||
if cfgTTL > 0 && minTTL > cfgTTL {
|
||
minTTL = cfgTTL
|
||
}
|
||
ttl = minTTL
|
||
} else {
|
||
ttl = uint32(maxTTL.Seconds())
|
||
}
|
||
}
|
||
if ttl == 0 {
|
||
return
|
||
}
|
||
expire := time.Now().Add(time.Duration(ttl) * time.Second)
|
||
cp := in.Copy()
|
||
// 提取 EDNS 元数据后再剥离伪记录
|
||
present, ver, ext, ede := extractEDNSMeta(cp)
|
||
stripPseudoExtras(cp)
|
||
cacheMutex.Lock()
|
||
cache.Add(key, &cacheEntry{
|
||
msg: cp,
|
||
expireAt: expire,
|
||
ednsPresent: present,
|
||
ednsVersion: ver,
|
||
ednsExtRcode: ext,
|
||
ednsEDE: ede,
|
||
})
|
||
cacheMutex.Unlock()
|
||
}
|
||
|
||
/******************************************************************
|
||
* 上游查询
|
||
******************************************************************/
|
||
var udpClient *dns.Client
|
||
|
||
func shuffled(xs []string) []string {
|
||
out := make([]string, len(xs))
|
||
copy(out, xs)
|
||
rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] })
|
||
return out
|
||
}
|
||
|
||
// clampEDNSForUpstream 返回一个 msg 副本,把 EDNS UDP size 夹到给定大小
|
||
func clampEDNSForUpstream(in *dns.Msg, size uint16) *dns.Msg {
|
||
m := in.Copy()
|
||
o := m.IsEdns0()
|
||
if o == nil {
|
||
o = &dns.OPT{}
|
||
o.Hdr.Name = "."
|
||
o.Hdr.Rrtype = dns.TypeOPT
|
||
m.Extra = append(m.Extra, o)
|
||
}
|
||
if size > 0 {
|
||
o.SetUDPSize(size)
|
||
}
|
||
return m
|
||
}
|
||
|
||
func queryUpstreamsLimited(
|
||
ctx context.Context,
|
||
req *dns.Msg,
|
||
upstreams []string,
|
||
timeout time.Duration,
|
||
maxParallel int,
|
||
allowTCPFallback bool,
|
||
) *dns.Msg {
|
||
if maxParallel <= 0 {
|
||
maxParallel = 1
|
||
}
|
||
servers := shuffled(upstreams)
|
||
|
||
cctx, cancel := context.WithTimeout(ctx, timeout)
|
||
defer cancel()
|
||
|
||
type result struct {
|
||
msg *dns.Msg
|
||
}
|
||
ch := make(chan result, len(servers))
|
||
done := make(chan struct{}, len(servers))
|
||
sem := make(chan struct{}, maxParallel)
|
||
|
||
// 单个上游执行
|
||
execOne := func(svr string) {
|
||
// 并发限流(可被超时取消)
|
||
select {
|
||
case sem <- struct{}{}:
|
||
defer func() { <-sem }()
|
||
case <-cctx.Done():
|
||
// 超时/取消,直接放弃
|
||
return
|
||
}
|
||
defer func() { done <- struct{}{} }()
|
||
|
||
// 为 UDP 上游把 EDNS UDP size 夹到 1232,降低分片风险
|
||
upReq := clampEDNSForUpstream(req, 1232)
|
||
|
||
// 先走 UDP
|
||
resp, _, err := udpClient.ExchangeContext(cctx, upReq, svr)
|
||
// 截断且允许回退则走 TCP
|
||
if err == nil && resp != nil && resp.Truncated && allowTCPFallback {
|
||
log.Printf("[upstream] UDP truncated, retry TCP: %s", svr)
|
||
tcpClient := *udpClient
|
||
tcpClient.Net = "tcp"
|
||
resp, _, err = tcpClient.ExchangeContext(cctx, req.Copy(), svr)
|
||
}
|
||
// 失败直接返回(但不写入 ch);只在未超时情况下打印错误
|
||
if err != nil || resp == nil {
|
||
if err != nil && cctx.Err() == nil {
|
||
log.Printf("[upstream] %s: %v", svr, err)
|
||
}
|
||
return
|
||
}
|
||
// 过滤不可用的错误 RCODE(避免造成“假性超时”的错觉)
|
||
if resp.Rcode == dns.RcodeServerFailure ||
|
||
resp.Rcode == dns.RcodeRefused ||
|
||
resp.Rcode == dns.RcodeFormatError {
|
||
return
|
||
}
|
||
|
||
// 投递可用结果(若已经超时则丢弃)
|
||
select {
|
||
case ch <- result{msg: resp}:
|
||
case <-cctx.Done():
|
||
}
|
||
}
|
||
|
||
// 并发发起
|
||
for _, s := range servers {
|
||
s := s
|
||
go execOne(s)
|
||
}
|
||
|
||
finished := 0
|
||
total := len(servers)
|
||
|
||
// 聚合:首个可用响应直接返回;区分“真超时”与“无可用结果”
|
||
for finished < total {
|
||
select {
|
||
case r := <-ch:
|
||
if r.msg != nil {
|
||
cancel()
|
||
return r.msg
|
||
}
|
||
case <-done:
|
||
finished++
|
||
case <-cctx.Done():
|
||
log.Printf("[upstream] timeout after %v (finished=%d/%d)", timeout, finished, total)
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// 所有上游都结束,但没有一个可用
|
||
log.Printf("[upstream] no acceptable upstream response (finished=%d/%d)", finished, total)
|
||
return nil
|
||
}
|
||
|
||
/******************************************************************
|
||
* EDNS / 响应构造
|
||
******************************************************************/
|
||
func stripECS(m *dns.Msg) {
|
||
if o := m.IsEdns0(); o != nil {
|
||
var kept []dns.EDNS0
|
||
for _, e := range o.Option {
|
||
if _, isECS := e.(*dns.EDNS0_SUBNET); !isECS {
|
||
kept = append(kept, e)
|
||
}
|
||
}
|
||
o.Option = kept
|
||
}
|
||
}
|
||
|
||
func getDOFlag(m *dns.Msg) bool {
|
||
if o := m.IsEdns0(); o != nil {
|
||
return o.Do()
|
||
}
|
||
return false
|
||
}
|
||
|
||
func writeReply(w dns.ResponseWriter, req *dns.Msg, upstream *dns.Msg, meta *cacheEntry) {
|
||
if upstream == nil {
|
||
dns.HandleFailed(w, req)
|
||
return
|
||
}
|
||
out := new(dns.Msg)
|
||
out.SetReply(req)
|
||
out.Authoritative = false
|
||
// RA 语义修正:RA 表示“服务器是否支持递归”,与客户端 RD 无关
|
||
out.RecursionAvailable = upstream.RecursionAvailable
|
||
out.AuthenticatedData = upstream.AuthenticatedData
|
||
out.CheckingDisabled = req.CheckingDisabled
|
||
out.Rcode = upstream.Rcode
|
||
out.Answer = upstream.Answer
|
||
out.Ns = upstream.Ns
|
||
|
||
var extras []dns.RR
|
||
for _, rr := range upstream.Extra {
|
||
if !isPseudo(rr) {
|
||
extras = append(extras, rr)
|
||
}
|
||
}
|
||
|
||
if ro := req.IsEdns0(); ro != nil {
|
||
o := new(dns.OPT)
|
||
o.Hdr.Name = "."
|
||
o.Hdr.Rrtype = dns.TypeOPT
|
||
o.SetUDPSize(ro.UDPSize())
|
||
if ro.Do() {
|
||
o.SetDo()
|
||
}
|
||
// 优先使用“请求”的 EDNS 版本;扩展 RCODE/EDE 来自上游/缓存
|
||
o.SetVersion(ro.Version())
|
||
if uo := upstream.IsEdns0(); uo != nil {
|
||
o.SetExtendedRcode(uint16(uo.ExtendedRcode()))
|
||
for _, opt := range uo.Option {
|
||
if ede, ok := opt.(*dns.EDNS0_EDE); ok {
|
||
o.Option = append(o.Option, ede)
|
||
}
|
||
}
|
||
} else if meta != nil && meta.ednsPresent {
|
||
// Upstream/cached msg has no OPT(例如缓存时被剥离),用缓存元数据重建
|
||
o.SetExtendedRcode(uint16(meta.ednsExtRcode))
|
||
for _, e := range meta.ednsEDE {
|
||
o.Option = append(o.Option, e)
|
||
}
|
||
}
|
||
extras = append(extras, o)
|
||
}
|
||
out.Extra = extras
|
||
out.Compress = true
|
||
|
||
if err := w.WriteMsg(out); err != nil {
|
||
log.Printf("[write] WriteMsg error: %v", err)
|
||
}
|
||
}
|
||
|
||
/******************************************************************
|
||
* 主处理器
|
||
******************************************************************/
|
||
func handleDNS(
|
||
upstreams []string,
|
||
cacheMaxTTL, timeout time.Duration,
|
||
maxParallel int,
|
||
stripECSBeforeForward bool,
|
||
allowTCPFallback bool,
|
||
blPtr *atomic.Pointer[suffixMatcher],
|
||
blRcode int,
|
||
) dns.HandlerFunc {
|
||
return func(w dns.ResponseWriter, r *dns.Msg) {
|
||
if len(r.Question) == 0 {
|
||
dns.HandleFailed(w, r)
|
||
return
|
||
}
|
||
q := r.Question[0]
|
||
log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)",
|
||
dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr())
|
||
|
||
if stripECSBeforeForward {
|
||
stripECS(r)
|
||
}
|
||
|
||
// 黑名单拦截:命中则不查上游,直接返回
|
||
if rule, ok := blPtr.Load().match(q.Name); ok {
|
||
nameCanon := dns.CanonicalName(q.Name)
|
||
log.Printf("[blacklist] HIT %s rule=%s (no upstream query)", nameCanon, rule)
|
||
up := makeBlockedUpstream(blRcode, rule)
|
||
writeReply(w, r, up, nil)
|
||
return
|
||
}
|
||
|
||
// 缓存 key:基础 + (未 strip 时的)ECS 片段
|
||
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled) + ecsKeyPart(r, stripECSBeforeForward)
|
||
|
||
if cachedMsg, cachedMeta, ok := tryCacheRead(key); ok {
|
||
log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
|
||
writeReply(w, r, cachedMsg, cachedMeta)
|
||
return
|
||
}
|
||
log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
|
||
|
||
// 使用 singleflight 合并相同 key 的并发查询,避免上游雪崩
|
||
v, _, _ := inflight.Do(key, func() (any, error) {
|
||
ctx := context.Background()
|
||
resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback)
|
||
if resp != nil {
|
||
cacheWrite(key, resp, cacheMaxTTL)
|
||
}
|
||
return resp, nil
|
||
})
|
||
resp, _ := v.(*dns.Msg)
|
||
if resp == nil {
|
||
log.Printf("[error] all upstreams failed for %s", q.Name)
|
||
dns.HandleFailed(w, r)
|
||
return
|
||
}
|
||
for _, ans := range resp.Answer {
|
||
log.Printf("[answer] %s", ans.String())
|
||
}
|
||
writeReply(w, r, resp, nil)
|
||
}
|
||
}
|
||
|
||
/******************************************************************
|
||
* 主函数
|
||
******************************************************************/
|
||
func main() {
|
||
rand.Seed(time.Now().UnixNano())
|
||
|
||
var help bool
|
||
certFile := flag.String("cert", "server.crt", "TLS 证书文件路径")
|
||
keyFile := flag.String("key", "server.key", "TLS 私钥文件路径")
|
||
addr := flag.String("addr", ":853", "DoT 监听地址")
|
||
upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表")
|
||
cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL")
|
||
cacheSizeFlag := flag.Int("cache-size", defaultCacheSize, "LRU 缓存大小上限")
|
||
timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时")
|
||
readTimeoutFlag := flag.Duration("read-timeout", 0, "DoT 连接读超时(0=不限制)")
|
||
writeTimeoutFlag := flag.Duration("write-timeout", 0, "DoT 连接写超时(0=不限制)")
|
||
maxParallel := flag.Int("max-parallel", 3, "并发上游数量")
|
||
stripECSFlag := flag.Bool("strip-ecs", true, "去除 ECS")
|
||
allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时 TCP 回退")
|
||
blacklistStr := flag.String("blacklist", "", "逗号分隔的黑名单域名(后缀匹配;支持如 *.example.com)")
|
||
blacklistFile := flag.String("blacklist-file", "", "黑名单文件路径(每行一个域名;支持 # 或 ; 注释;可接受 hosts 风格)")
|
||
blacklistRcodeFlag := flag.String("blacklist-rcode", "REFUSED", "命中黑名单返回的 RCODE:REFUSED|NXDOMAIN|SERVFAIL")
|
||
verbose := flag.Bool("v", false, "verbose 日志")
|
||
|
||
flag.BoolVar(&help, "h", false, "")
|
||
flag.BoolVar(&help, "help", false, "帮助信息")
|
||
|
||
flag.Parse()
|
||
|
||
if help {
|
||
fmt.Printf(
|
||
"\t\tDNS-over-TLS (DoT)\n"+
|
||
"\tVersion 0.1\n"+
|
||
"\tE-mail: aixiao@aixiao.me\n"+
|
||
"\tBuild Date: %s\n", BuildDate)
|
||
|
||
flag.Usage()
|
||
fmt.Printf("\n")
|
||
os.Exit(0)
|
||
}
|
||
|
||
initLogger(*verbose)
|
||
|
||
var err error
|
||
cache, err = lru.New[string, *cacheEntry](*cacheSizeFlag)
|
||
if err != nil {
|
||
log.Fatalf("[fatal] failed to init LRU cache: %v", err)
|
||
}
|
||
|
||
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
|
||
if err != nil {
|
||
log.Fatalf("[fatal] failed to load cert: %v", err)
|
||
}
|
||
|
||
var upstreams []string
|
||
for _, s := range strings.Split(*upstreamStr, ",") {
|
||
if t := strings.TrimSpace(s); t != "" {
|
||
if !strings.Contains(t, ":") {
|
||
t = fmt.Sprintf("%s:53", t)
|
||
}
|
||
upstreams = append(upstreams, t)
|
||
}
|
||
}
|
||
if len(upstreams) == 0 {
|
||
log.Fatal("[fatal] no upstream DNS servers provided")
|
||
}
|
||
|
||
// 如需更保守的 UDP 尺寸以减少分片,可将 UDPSize 改为 1232
|
||
udpClient = &dns.Client{Net: "udp", UDPSize: 4096, SingleInflight: true}
|
||
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
defer cancel()
|
||
|
||
startCacheCleaner(ctx)
|
||
|
||
// 加载黑名单规则
|
||
blPtr, blRcode := initBlacklist(ctx, *blacklistStr, *blacklistFile, *blacklistRcodeFlag)
|
||
|
||
mux := dns.NewServeMux()
|
||
mux.HandleFunc(".", handleDNS(
|
||
upstreams,
|
||
*cacheTTLFlag,
|
||
*timeoutFlag,
|
||
*maxParallel,
|
||
*stripECSFlag,
|
||
*allowTCPFallback,
|
||
blPtr,
|
||
blRcode,
|
||
))
|
||
|
||
srv := &dns.Server{
|
||
Addr: *addr,
|
||
Net: "tcp-tls",
|
||
TLSConfig: &tls.Config{
|
||
Certificates: []tls.Certificate{cert},
|
||
MinVersion: tls.VersionTLS13, // TLS 1.3
|
||
NextProtos: []string{"dot"},
|
||
},
|
||
Handler: mux,
|
||
ReadTimeout: *readTimeoutFlag,
|
||
WriteTimeout: *writeTimeoutFlag,
|
||
}
|
||
|
||
stop := make(chan os.Signal, 1)
|
||
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
|
||
|
||
errCh := make(chan error, 1)
|
||
go func() {
|
||
log.Printf("🚀 starting DNS-over-TLS on %s", *addr)
|
||
log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v | read_timeout=%s | write_timeout=%s | blacklist_rules=%d | blacklist_rcode=%s",
|
||
upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback,
|
||
readTimeoutFlag.String(), writeTimeoutFlag.String(),
|
||
len(blPtr.Load().rules), strings.ToUpper(*blacklistRcodeFlag))
|
||
errCh <- srv.ListenAndServe()
|
||
}()
|
||
|
||
select {
|
||
case sig := <-stop:
|
||
log.Printf("[shutdown] caught signal: %s", sig)
|
||
cancel()
|
||
if err := srv.Shutdown(); err != nil {
|
||
log.Printf("[shutdown] server shutdown error: %v", err)
|
||
}
|
||
case err := <-errCh:
|
||
if err != nil {
|
||
log.Fatalf("[fatal] server error: %v", err)
|
||
}
|
||
}
|
||
|
||
log.Println("[bye] server stopped.")
|
||
}
|