Files
dot/main.go
aixiao 3cf657bb62 feat(docker): 支持通过构建参数注入编译时间
在 Dockerfile 中添加 BUILD_DATE 构建参数,并通过 ldflags 将编译时间注入到二进制文件中。
同时更新 build.sh 脚本,在构建镜像时传入当前时间作为 BUILD_DATE 参数。

refactor(build): 优化 build.sh 脚本结构与可读性

对 build.sh 脚本中的函数进行了缩进统一和结构调整,提高代码可读性和维护性。
新增 bin 命令用于直接编译并压缩二进制文件。

feat(main): 添加版本与构建信息显示功能

在 main.go 中增加 BuildDate 变量用于存储构建时间,并支持通过 -h 或 --help 参数查看帮助信息,
包括版本号、联系邮箱以及构建日期等元数据。

chore(blacklist): 移除默认黑名单条目

从 blacklist.txt 文件中移除默认的 *.baidu.com 黑名单规则。

docs(readme): 清理 README.md 文件末尾空行

删除 README.md 文件最后多余的空白行,保持文档整洁。
2025-10-15 15:16:09 +08:00

696 lines
17 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"crypto/tls"
"flag"
"fmt"
"log"
"math/rand"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"github.com/miekg/dns"
lru "github.com/hashicorp/golang-lru/v2"
)
var BuildDate = "unknown" // 由编译时注入
/******************************************************************
* 日志初始化
******************************************************************/
func initLogger(verbose bool) {
flags := log.Ldate | log.Ltime | log.Lmicroseconds
if verbose {
flags |= log.Lshortfile
}
log.SetFlags(flags)
}
/******************************************************************
* 缓存结构(支持 TTL + LRU
******************************************************************/
type cacheEntry struct {
msg *dns.Msg
expireAt time.Time
// EDNS metadata (to reproduce Extended RCODE / EDE on cache hits)
ednsPresent bool
ednsVersion uint8
ednsExtRcode uint16
ednsEDE []*dns.EDNS0_EDE
}
var (
cache *lru.Cache[string, *cacheEntry]
cacheMutex sync.RWMutex
)
const (
cacheCleanupInterval = 5 * time.Minute
defaultCacheSize = 10000 // 默认最大缓存条目数
)
// startCacheCleaner 定期清理过期缓存(修复:在删除前二次校验)
func startCacheCleaner(ctx context.Context) {
go func() {
ticker := time.NewTicker(cacheCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
now := time.Now()
var toDelete []string
cacheMutex.RLock()
for _, k := range cache.Keys() {
if v, ok := cache.Peek(k); ok && now.After(v.expireAt) {
toDelete = append(toDelete, k)
}
}
cacheMutex.RUnlock()
if len(toDelete) > 0 {
pruned := 0
cacheMutex.Lock()
for _, k := range toDelete {
if v, ok := cache.Peek(k); ok && now.After(v.expireAt) {
cache.Remove(k)
pruned++
}
}
cacheMutex.Unlock()
if pruned > 0 {
log.Printf("[cache] cleaned %d expired entries", pruned)
}
}
}
}
}()
}
func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
var b strings.Builder
b.Grow(len(q.Name) + 32)
// 采用规范化域名,避免尾随点/IDNA/大小写造成的重复键
b.WriteString(dns.CanonicalName(q.Name))
b.WriteString("|T=")
b.WriteString(dns.TypeToString[q.Qtype])
b.WriteString("|C=")
b.WriteString(dns.ClassToString[q.Qclass])
if do {
b.WriteString("|DO")
}
if cd {
b.WriteString("|CD")
}
return b.String()
}
func isPseudo(rr dns.RR) bool {
switch rr.(type) {
case *dns.OPT, *dns.TSIG:
return true
default:
return false
}
}
// clone and extract EDNS metadata (present, version, ext-rcode, all EDEs)
func cloneEDE(in *dns.EDNS0_EDE) *dns.EDNS0_EDE {
if in == nil {
return nil
}
cp := *in
return &cp
}
func extractEDNSMeta(m *dns.Msg) (present bool, version uint8, ext uint16, ede []*dns.EDNS0_EDE) {
if o := m.IsEdns0(); o != nil {
present = true
version = o.Version()
ext = uint16(o.ExtendedRcode())
for _, opt := range o.Option {
if e, ok := opt.(*dns.EDNS0_EDE); ok {
ede = append(ede, cloneEDE(e))
}
}
}
return
}
// 读取缓存修复Get 在写锁下;在锁外调整 TTL返回 EDNS 元数据)
func tryCacheRead(key string) (*dns.Msg, *cacheEntry, bool) {
now := time.Now()
cacheMutex.Lock()
e, ok := cache.Get(key) // Get 会更新 LRU必须在写锁下
if !ok {
cacheMutex.Unlock()
return nil, nil, false
}
if now.After(e.expireAt) {
cache.Remove(key)
cacheMutex.Unlock()
return nil, nil, false
}
// 拷贝副本,在锁外改 TTL减少临界区时间
out := e.msg.Copy()
expireAt := e.expireAt
cacheMutex.Unlock()
remaining := uint32(expireAt.Sub(now).Seconds())
if remaining == 0 {
cacheMutex.Lock()
cache.Remove(key)
cacheMutex.Unlock()
return nil, nil, false
}
for _, sec := range [][]dns.RR{out.Answer, out.Ns, out.Extra} {
for _, rr := range sec {
if isPseudo(rr) {
continue
}
if rr.Header().Ttl > remaining {
rr.Header().Ttl = remaining
}
}
}
return out, e, true
}
// 计算负面 TTL
// hasAnswerForType 判断报文中是否存在回答“请求类型”的 RRset
func hasAnswerForType(m *dns.Msg, q dns.Question) bool {
for _, rr := range m.Answer {
h := rr.Header()
if h.Rrtype == q.Qtype && strings.EqualFold(h.Name, q.Name) {
return true
}
}
return false
}
// 计算负面 TTL修复正确识别 NODATA包括 CNAME 等场景)
func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
// NXDOMAIN肯定是负面
if m.Rcode != dns.RcodeNameError {
// 不是 NXDOMAIN则仅当 NOERROR 但没有“匹配 QTYPE 的答案”时才是 NODATA
if m.Rcode != dns.RcodeSuccess || len(m.Question) == 0 || hasAnswerForType(m, m.Question[0]) {
return 0, false
}
}
// 按 RFC 2308从 AuthorityNs优先取 SOA多数实现都只放在 Authority
var soa *dns.SOA
for _, rr := range m.Ns {
if s, ok := rr.(*dns.SOA); ok {
soa = s
break
}
}
// 兼容性:偶尔也有人把 SOA 放 Extra不规范但为了兼容可以兜底看看
if soa == nil {
for _, rr := range m.Extra {
if s, ok := rr.(*dns.SOA); ok {
soa = s
break
}
}
}
if soa == nil {
// 建议:无 SOA 时不做负面缓存(返回 0,false
return 0, false
}
// 负面 TTL 取 min(SOA.MINIMUM, SOA 自身 TTL),再与配置上限比较
ttl := soa.Hdr.Ttl
if soa.Minttl < ttl {
ttl = soa.Minttl
}
capTTL := uint32(maxTTL.Seconds())
if capTTL > 0 && ttl > capTTL {
ttl = capTTL
}
return ttl, ttl > 0
}
func minRRsetTTL(m *dns.Msg) (uint32, bool) {
minTTL := uint32(0)
hasTTL := false
// 优先 Answer -> Ns若都为空再考虑 Extra排除伪记录
for _, sec := range [][]dns.RR{m.Answer, m.Ns} {
for _, rr := range sec {
if isPseudo(rr) {
continue
}
ttl := rr.Header().Ttl
if !hasTTL || ttl < minTTL {
minTTL = ttl
hasTTL = true
}
}
}
if !hasTTL {
for _, rr := range m.Extra {
if isPseudo(rr) {
continue
}
ttl := rr.Header().Ttl
if !hasTTL || ttl < minTTL {
minTTL = ttl
hasTTL = true
}
}
}
return minTTL, hasTTL
}
func stripPseudoExtras(m *dns.Msg) {
if len(m.Extra) == 0 {
return
}
out := m.Extra[:0]
for _, rr := range m.Extra {
if isPseudo(rr) {
continue
}
out = append(out, rr)
}
m.Extra = out
}
// 写缓存(保存 EDNS 元数据,命中时可重建扩展 RCODE/EDE
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
if in == nil {
return
}
if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError {
return
}
var ttl uint32
var ok bool
if ttl, ok = negativeTTL(in, maxTTL); !ok {
minTTL, has := minRRsetTTL(in)
if has {
cfgTTL := uint32(maxTTL.Seconds())
if cfgTTL > 0 && minTTL > cfgTTL {
minTTL = cfgTTL
}
ttl = minTTL
} else {
ttl = uint32(maxTTL.Seconds())
}
}
if ttl == 0 {
return
}
expire := time.Now().Add(time.Duration(ttl) * time.Second)
cp := in.Copy()
// 提取 EDNS 元数据后再剥离伪记录
present, ver, ext, ede := extractEDNSMeta(cp)
stripPseudoExtras(cp)
cacheMutex.Lock()
cache.Add(key, &cacheEntry{
msg: cp,
expireAt: expire,
ednsPresent: present,
ednsVersion: ver,
ednsExtRcode: ext,
ednsEDE: ede,
})
cacheMutex.Unlock()
}
/******************************************************************
* 上游查询
******************************************************************/
var udpClient *dns.Client
func shuffled(xs []string) []string {
out := make([]string, len(xs))
copy(out, xs)
rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] })
return out
}
// clampEDNSForUpstream 返回一个 msg 副本,把 EDNS UDP size 夹到给定大小
func clampEDNSForUpstream(in *dns.Msg, size uint16) *dns.Msg {
m := in.Copy()
o := m.IsEdns0()
if o == nil {
o = &dns.OPT{}
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
m.Extra = append(m.Extra, o)
}
if size > 0 {
o.SetUDPSize(size)
}
return m
}
func queryUpstreamsLimited(
ctx context.Context,
req *dns.Msg,
upstreams []string,
timeout time.Duration,
maxParallel int,
allowTCPFallback bool,
) *dns.Msg {
if maxParallel <= 0 {
maxParallel = 1
}
servers := shuffled(upstreams)
cctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
type result struct{ msg *dns.Msg }
ch := make(chan result, len(servers))
sem := make(chan struct{}, maxParallel)
execOne := func(svr string) {
upReq := clampEDNSForUpstream(req, 1232) // 或做成 flag
resp, _, err := udpClient.ExchangeContext(cctx, upReq, svr)
if err == nil && resp != nil && resp.Truncated && allowTCPFallback {
log.Printf("[upstream] UDP truncated, retry TCP: %s", svr)
tcpClient := *udpClient
tcpClient.Net = "tcp"
resp, _, err = tcpClient.ExchangeContext(cctx, req.Copy(), svr)
}
if err != nil || resp == nil {
if err != nil && cctx.Err() == nil {
log.Printf("[upstream] %s error: %v", svr, err)
}
return
}
if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError {
return
}
select {
case ch <- result{msg: resp}:
case <-cctx.Done():
}
}
for _, s := range servers {
s := s
go func() {
select {
case sem <- struct{}{}:
defer func() { <-sem }()
case <-cctx.Done():
return
}
execOne(s)
}()
}
for i := 0; i < len(servers); i++ {
select {
case r := <-ch:
if r.msg != nil {
cancel()
return r.msg
}
case <-cctx.Done():
log.Printf("[upstream] timeout after %v", timeout)
return nil
}
}
return nil
}
/******************************************************************
* EDNS / 响应构造
******************************************************************/
func stripECS(m *dns.Msg) {
if o := m.IsEdns0(); o != nil {
var kept []dns.EDNS0
for _, e := range o.Option {
if _, isECS := e.(*dns.EDNS0_SUBNET); !isECS {
kept = append(kept, e)
}
}
o.Option = kept
}
}
func getDOFlag(m *dns.Msg) bool {
if o := m.IsEdns0(); o != nil {
return o.Do()
}
return false
}
func writeReply(w dns.ResponseWriter, req *dns.Msg, upstream *dns.Msg, meta *cacheEntry) {
if upstream == nil {
dns.HandleFailed(w, req)
return
}
out := new(dns.Msg)
out.SetReply(req)
out.Authoritative = false
out.RecursionAvailable = upstream.RecursionAvailable
out.AuthenticatedData = upstream.AuthenticatedData
out.CheckingDisabled = req.CheckingDisabled
out.Rcode = upstream.Rcode
out.Answer = upstream.Answer
out.Ns = upstream.Ns
var extras []dns.RR
for _, rr := range upstream.Extra {
if !isPseudo(rr) {
extras = append(extras, rr)
}
}
if ro := req.IsEdns0(); ro != nil {
o := new(dns.OPT)
o.Hdr.Name = "."
o.Hdr.Rrtype = dns.TypeOPT
o.SetUDPSize(ro.UDPSize())
if ro.Do() {
o.SetDo()
}
if uo := upstream.IsEdns0(); uo != nil {
o.SetExtendedRcode(uint16(uo.ExtendedRcode()))
o.SetVersion(uint8(uo.Version()))
for _, opt := range uo.Option {
if ede, ok := opt.(*dns.EDNS0_EDE); ok {
o.Option = append(o.Option, ede)
}
}
} else if meta != nil && meta.ednsPresent {
// Upstream/cached msg has no OPT例如缓存时被剥离用缓存元数据重建
o.SetExtendedRcode(meta.ednsExtRcode)
o.SetVersion(meta.ednsVersion)
for _, e := range meta.ednsEDE {
o.Option = append(o.Option, e)
}
}
extras = append(extras, o)
}
out.Extra = extras
out.Compress = true
if err := w.WriteMsg(out); err != nil {
log.Printf("[write] WriteMsg error: %v", err)
}
}
/******************************************************************
* 主处理器
******************************************************************/
func handleDNS(
upstreams []string,
cacheMaxTTL, timeout time.Duration,
maxParallel int,
stripECSBeforeForward bool,
allowTCPFallback bool,
bl *suffixMatcher,
blRcode int,
) dns.HandlerFunc {
return func(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
dns.HandleFailed(w, r)
return
}
q := r.Question[0]
log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)",
dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr())
if stripECSBeforeForward {
stripECS(r)
}
// 黑名单拦截:命中则不查上游,直接返回
if rule, ok := bl.match(q.Name); ok {
nameCanon := dns.CanonicalName(q.Name)
log.Printf("[blacklist] HIT %s rule=%s (no upstream query)", nameCanon, rule)
up := makeBlockedUpstream(blRcode, rule)
writeReply(w, r, up, nil)
return
}
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled)
if cachedMsg, cachedMeta, ok := tryCacheRead(key); ok {
log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
writeReply(w, r, cachedMsg, cachedMeta)
return
}
log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
ctx := context.Background()
resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback)
if resp == nil {
log.Printf("[error] all upstreams failed for %s", q.Name)
dns.HandleFailed(w, r)
return
}
cacheWrite(key, resp, cacheMaxTTL)
for _, ans := range resp.Answer {
log.Printf("[answer] %s", ans.String())
}
writeReply(w, r, resp, nil)
}
}
/******************************************************************
* 主函数
******************************************************************/
func main() {
rand.Seed(time.Now().UnixNano())
var help bool
certFile := flag.String("cert", "server.crt", "TLS 证书文件路径")
keyFile := flag.String("key", "server.key", "TLS 私钥文件路径")
addr := flag.String("addr", ":853", "DoT 监听地址")
upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表")
cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL")
cacheSizeFlag := flag.Int("cache-size", defaultCacheSize, "LRU 缓存大小上限")
timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时")
maxParallel := flag.Int("max-parallel", 3, "并发上游数量")
stripECSFlag := flag.Bool("strip-ecs", true, "去除 ECS")
allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时 TCP 回退")
blacklistStr := flag.String("blacklist", "", "逗号分隔的黑名单域名(后缀匹配;支持如 *.example.com")
blacklistFile := flag.String("blacklist-file", "", "黑名单文件路径(每行一个域名;支持 # 或 ; 注释;后缀匹配)")
blacklistRcodeFlag := flag.String("blacklist-rcode", "REFUSED", "命中黑名单返回的 RCODEREFUSED|NXDOMAIN|SERVFAIL")
verbose := flag.Bool("v", false, "verbose 日志")
flag.BoolVar(&help, "h", false, "")
flag.BoolVar(&help, "help", false, "帮助信息")
flag.Parse()
if help {
fmt.Printf(
"\t\tDNS-over-TLS (DoT)\n"+
"\tVersion 0.1\n"+
"\tE-mail: aixiao@aixiao.me\n"+
"\tBuild Date: %s\n", BuildDate)
flag.Usage()
fmt.Printf("\n")
os.Exit(0)
}
initLogger(*verbose)
var err error
cache, err = lru.New[string, *cacheEntry](*cacheSizeFlag)
if err != nil {
log.Fatalf("[fatal] failed to init LRU cache: %v", err)
}
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
if err != nil {
log.Fatalf("[fatal] failed to load cert: %v", err)
}
var upstreams []string
for _, s := range strings.Split(*upstreamStr, ",") {
if t := strings.TrimSpace(s); t != "" {
if !strings.Contains(t, ":") {
t = fmt.Sprintf("%s:53", t)
}
upstreams = append(upstreams, t)
}
}
if len(upstreams) == 0 {
log.Fatal("[fatal] no upstream DNS servers provided")
}
udpClient = &dns.Client{Net: "udp", UDPSize: 4096, SingleInflight: true}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
startCacheCleaner(ctx)
// 加载黑名单规则
bl, blRcode := initBlacklist(ctx, *blacklistStr, *blacklistFile, *blacklistRcodeFlag)
mux := dns.NewServeMux()
mux.HandleFunc(".", handleDNS(
upstreams,
*cacheTTLFlag,
*timeoutFlag,
*maxParallel,
*stripECSFlag,
*allowTCPFallback,
bl,
blRcode,
))
srv := &dns.Server{
Addr: *addr,
Net: "tcp-tls",
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
NextProtos: []string{"dot"},
},
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
stop := make(chan os.Signal, 1)
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
errCh := make(chan error, 1)
go func() {
log.Printf("🚀 starting DNS-over-TLS on %s", *addr)
log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v | blacklist_rules=%d | blacklist_rcode=%s",
upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback, len(bl.rules), strings.ToUpper(*blacklistRcodeFlag))
errCh <- srv.ListenAndServe()
}()
select {
case sig := <-stop:
log.Printf("[shutdown] caught signal: %s", sig)
cancel()
if err := srv.Shutdown(); err != nil {
log.Printf("[shutdown] server shutdown error: %v", err)
}
case err := <-errCh:
if err != nil {
log.Fatalf("[fatal] server error: %v", err)
}
}
log.Println("[bye] server stopped.")
}