在 Dockerfile 中添加 BUILD_DATE 构建参数,并通过 ldflags 将编译时间注入到二进制文件中。 同时更新 build.sh 脚本,在构建镜像时传入当前时间作为 BUILD_DATE 参数。 refactor(build): 优化 build.sh 脚本结构与可读性 对 build.sh 脚本中的函数进行了缩进统一和结构调整,提高代码可读性和维护性。 新增 bin 命令用于直接编译并压缩二进制文件。 feat(main): 添加版本与构建信息显示功能 在 main.go 中增加 BuildDate 变量用于存储构建时间,并支持通过 -h 或 --help 参数查看帮助信息, 包括版本号、联系邮箱以及构建日期等元数据。 chore(blacklist): 移除默认黑名单条目 从 blacklist.txt 文件中移除默认的 *.baidu.com 黑名单规则。 docs(readme): 清理 README.md 文件末尾空行 删除 README.md 文件最后多余的空白行,保持文档整洁。
696 lines
17 KiB
Go
696 lines
17 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"flag"
|
||
"fmt"
|
||
"log"
|
||
"math/rand"
|
||
"os"
|
||
"os/signal"
|
||
"strings"
|
||
"sync"
|
||
"syscall"
|
||
"time"
|
||
|
||
"github.com/miekg/dns"
|
||
|
||
lru "github.com/hashicorp/golang-lru/v2"
|
||
)
|
||
|
||
var BuildDate = "unknown" // 由编译时注入
|
||
|
||
/******************************************************************
|
||
* 日志初始化
|
||
******************************************************************/
|
||
func initLogger(verbose bool) {
|
||
flags := log.Ldate | log.Ltime | log.Lmicroseconds
|
||
if verbose {
|
||
flags |= log.Lshortfile
|
||
}
|
||
log.SetFlags(flags)
|
||
}
|
||
|
||
/******************************************************************
|
||
* 缓存结构(支持 TTL + LRU)
|
||
******************************************************************/
|
||
|
||
type cacheEntry struct {
|
||
msg *dns.Msg
|
||
expireAt time.Time
|
||
// EDNS metadata (to reproduce Extended RCODE / EDE on cache hits)
|
||
ednsPresent bool
|
||
ednsVersion uint8
|
||
ednsExtRcode uint16
|
||
ednsEDE []*dns.EDNS0_EDE
|
||
}
|
||
|
||
var (
|
||
cache *lru.Cache[string, *cacheEntry]
|
||
cacheMutex sync.RWMutex
|
||
)
|
||
|
||
const (
|
||
cacheCleanupInterval = 5 * time.Minute
|
||
defaultCacheSize = 10000 // 默认最大缓存条目数
|
||
)
|
||
|
||
// startCacheCleaner 定期清理过期缓存(修复:在删除前二次校验)
|
||
func startCacheCleaner(ctx context.Context) {
|
||
go func() {
|
||
ticker := time.NewTicker(cacheCleanupInterval)
|
||
defer ticker.Stop()
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
return
|
||
case <-ticker.C:
|
||
now := time.Now()
|
||
var toDelete []string
|
||
|
||
cacheMutex.RLock()
|
||
for _, k := range cache.Keys() {
|
||
if v, ok := cache.Peek(k); ok && now.After(v.expireAt) {
|
||
toDelete = append(toDelete, k)
|
||
}
|
||
}
|
||
cacheMutex.RUnlock()
|
||
|
||
if len(toDelete) > 0 {
|
||
pruned := 0
|
||
cacheMutex.Lock()
|
||
for _, k := range toDelete {
|
||
if v, ok := cache.Peek(k); ok && now.After(v.expireAt) {
|
||
cache.Remove(k)
|
||
pruned++
|
||
}
|
||
}
|
||
cacheMutex.Unlock()
|
||
if pruned > 0 {
|
||
log.Printf("[cache] cleaned %d expired entries", pruned)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
|
||
func cacheKeyFromMsg(q dns.Question, do, cd bool) string {
|
||
var b strings.Builder
|
||
b.Grow(len(q.Name) + 32)
|
||
// 采用规范化域名,避免尾随点/IDNA/大小写造成的重复键
|
||
b.WriteString(dns.CanonicalName(q.Name))
|
||
b.WriteString("|T=")
|
||
b.WriteString(dns.TypeToString[q.Qtype])
|
||
b.WriteString("|C=")
|
||
b.WriteString(dns.ClassToString[q.Qclass])
|
||
if do {
|
||
b.WriteString("|DO")
|
||
}
|
||
if cd {
|
||
b.WriteString("|CD")
|
||
}
|
||
return b.String()
|
||
}
|
||
|
||
func isPseudo(rr dns.RR) bool {
|
||
switch rr.(type) {
|
||
case *dns.OPT, *dns.TSIG:
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
// clone and extract EDNS metadata (present, version, ext-rcode, all EDEs)
|
||
func cloneEDE(in *dns.EDNS0_EDE) *dns.EDNS0_EDE {
|
||
if in == nil {
|
||
return nil
|
||
}
|
||
cp := *in
|
||
return &cp
|
||
}
|
||
|
||
func extractEDNSMeta(m *dns.Msg) (present bool, version uint8, ext uint16, ede []*dns.EDNS0_EDE) {
|
||
if o := m.IsEdns0(); o != nil {
|
||
present = true
|
||
version = o.Version()
|
||
ext = uint16(o.ExtendedRcode())
|
||
for _, opt := range o.Option {
|
||
if e, ok := opt.(*dns.EDNS0_EDE); ok {
|
||
ede = append(ede, cloneEDE(e))
|
||
}
|
||
}
|
||
}
|
||
return
|
||
}
|
||
|
||
// 读取缓存(修复:Get 在写锁下;在锁外调整 TTL;返回 EDNS 元数据)
|
||
func tryCacheRead(key string) (*dns.Msg, *cacheEntry, bool) {
|
||
now := time.Now()
|
||
|
||
cacheMutex.Lock()
|
||
e, ok := cache.Get(key) // Get 会更新 LRU,必须在写锁下
|
||
if !ok {
|
||
cacheMutex.Unlock()
|
||
return nil, nil, false
|
||
}
|
||
if now.After(e.expireAt) {
|
||
cache.Remove(key)
|
||
cacheMutex.Unlock()
|
||
return nil, nil, false
|
||
}
|
||
// 拷贝副本,在锁外改 TTL,减少临界区时间
|
||
out := e.msg.Copy()
|
||
expireAt := e.expireAt
|
||
cacheMutex.Unlock()
|
||
|
||
remaining := uint32(expireAt.Sub(now).Seconds())
|
||
if remaining == 0 {
|
||
cacheMutex.Lock()
|
||
cache.Remove(key)
|
||
cacheMutex.Unlock()
|
||
return nil, nil, false
|
||
}
|
||
|
||
for _, sec := range [][]dns.RR{out.Answer, out.Ns, out.Extra} {
|
||
for _, rr := range sec {
|
||
if isPseudo(rr) {
|
||
continue
|
||
}
|
||
if rr.Header().Ttl > remaining {
|
||
rr.Header().Ttl = remaining
|
||
}
|
||
}
|
||
}
|
||
return out, e, true
|
||
}
|
||
|
||
// 计算负面 TTL
|
||
// hasAnswerForType 判断报文中是否存在回答“请求类型”的 RRset
|
||
func hasAnswerForType(m *dns.Msg, q dns.Question) bool {
|
||
for _, rr := range m.Answer {
|
||
h := rr.Header()
|
||
if h.Rrtype == q.Qtype && strings.EqualFold(h.Name, q.Name) {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// 计算负面 TTL(修复:正确识别 NODATA,包括 CNAME 等场景)
|
||
func negativeTTL(m *dns.Msg, maxTTL time.Duration) (uint32, bool) {
|
||
// NXDOMAIN:肯定是负面
|
||
if m.Rcode != dns.RcodeNameError {
|
||
// 不是 NXDOMAIN,则仅当 NOERROR 但没有“匹配 QTYPE 的答案”时才是 NODATA
|
||
if m.Rcode != dns.RcodeSuccess || len(m.Question) == 0 || hasAnswerForType(m, m.Question[0]) {
|
||
return 0, false
|
||
}
|
||
}
|
||
|
||
// 按 RFC 2308,从 Authority(Ns)优先取 SOA(多数实现都只放在 Authority)
|
||
var soa *dns.SOA
|
||
for _, rr := range m.Ns {
|
||
if s, ok := rr.(*dns.SOA); ok {
|
||
soa = s
|
||
break
|
||
}
|
||
}
|
||
// 兼容性:偶尔也有人把 SOA 放 Extra(不规范,但为了兼容可以兜底看看)
|
||
if soa == nil {
|
||
for _, rr := range m.Extra {
|
||
if s, ok := rr.(*dns.SOA); ok {
|
||
soa = s
|
||
break
|
||
}
|
||
}
|
||
}
|
||
if soa == nil {
|
||
// 建议:无 SOA 时不做负面缓存(返回 0,false)
|
||
return 0, false
|
||
}
|
||
|
||
// 负面 TTL 取 min(SOA.MINIMUM, SOA 自身 TTL),再与配置上限比较
|
||
ttl := soa.Hdr.Ttl
|
||
if soa.Minttl < ttl {
|
||
ttl = soa.Minttl
|
||
}
|
||
capTTL := uint32(maxTTL.Seconds())
|
||
if capTTL > 0 && ttl > capTTL {
|
||
ttl = capTTL
|
||
}
|
||
return ttl, ttl > 0
|
||
}
|
||
|
||
func minRRsetTTL(m *dns.Msg) (uint32, bool) {
|
||
minTTL := uint32(0)
|
||
hasTTL := false
|
||
// 优先 Answer -> Ns;若都为空,再考虑 Extra(排除伪记录)
|
||
for _, sec := range [][]dns.RR{m.Answer, m.Ns} {
|
||
for _, rr := range sec {
|
||
if isPseudo(rr) {
|
||
continue
|
||
}
|
||
ttl := rr.Header().Ttl
|
||
if !hasTTL || ttl < minTTL {
|
||
minTTL = ttl
|
||
hasTTL = true
|
||
}
|
||
}
|
||
}
|
||
if !hasTTL {
|
||
for _, rr := range m.Extra {
|
||
if isPseudo(rr) {
|
||
continue
|
||
}
|
||
ttl := rr.Header().Ttl
|
||
if !hasTTL || ttl < minTTL {
|
||
minTTL = ttl
|
||
hasTTL = true
|
||
}
|
||
}
|
||
}
|
||
return minTTL, hasTTL
|
||
}
|
||
|
||
func stripPseudoExtras(m *dns.Msg) {
|
||
if len(m.Extra) == 0 {
|
||
return
|
||
}
|
||
out := m.Extra[:0]
|
||
for _, rr := range m.Extra {
|
||
if isPseudo(rr) {
|
||
continue
|
||
}
|
||
out = append(out, rr)
|
||
}
|
||
m.Extra = out
|
||
}
|
||
|
||
// 写缓存(保存 EDNS 元数据,命中时可重建扩展 RCODE/EDE)
|
||
func cacheWrite(key string, in *dns.Msg, maxTTL time.Duration) {
|
||
if in == nil {
|
||
return
|
||
}
|
||
if in.Rcode != dns.RcodeSuccess && in.Rcode != dns.RcodeNameError {
|
||
return
|
||
}
|
||
var ttl uint32
|
||
var ok bool
|
||
if ttl, ok = negativeTTL(in, maxTTL); !ok {
|
||
minTTL, has := minRRsetTTL(in)
|
||
if has {
|
||
cfgTTL := uint32(maxTTL.Seconds())
|
||
if cfgTTL > 0 && minTTL > cfgTTL {
|
||
minTTL = cfgTTL
|
||
}
|
||
ttl = minTTL
|
||
} else {
|
||
ttl = uint32(maxTTL.Seconds())
|
||
}
|
||
}
|
||
if ttl == 0 {
|
||
return
|
||
}
|
||
expire := time.Now().Add(time.Duration(ttl) * time.Second)
|
||
cp := in.Copy()
|
||
// 提取 EDNS 元数据后再剥离伪记录
|
||
present, ver, ext, ede := extractEDNSMeta(cp)
|
||
stripPseudoExtras(cp)
|
||
cacheMutex.Lock()
|
||
cache.Add(key, &cacheEntry{
|
||
msg: cp,
|
||
expireAt: expire,
|
||
ednsPresent: present,
|
||
ednsVersion: ver,
|
||
ednsExtRcode: ext,
|
||
ednsEDE: ede,
|
||
})
|
||
cacheMutex.Unlock()
|
||
}
|
||
|
||
/******************************************************************
|
||
* 上游查询
|
||
******************************************************************/
|
||
var udpClient *dns.Client
|
||
|
||
func shuffled(xs []string) []string {
|
||
out := make([]string, len(xs))
|
||
copy(out, xs)
|
||
rand.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] })
|
||
return out
|
||
}
|
||
|
||
// clampEDNSForUpstream 返回一个 msg 副本,把 EDNS UDP size 夹到给定大小
|
||
func clampEDNSForUpstream(in *dns.Msg, size uint16) *dns.Msg {
|
||
m := in.Copy()
|
||
o := m.IsEdns0()
|
||
if o == nil {
|
||
o = &dns.OPT{}
|
||
o.Hdr.Name = "."
|
||
o.Hdr.Rrtype = dns.TypeOPT
|
||
m.Extra = append(m.Extra, o)
|
||
}
|
||
if size > 0 {
|
||
o.SetUDPSize(size)
|
||
}
|
||
return m
|
||
}
|
||
|
||
func queryUpstreamsLimited(
|
||
ctx context.Context,
|
||
req *dns.Msg,
|
||
upstreams []string,
|
||
timeout time.Duration,
|
||
maxParallel int,
|
||
allowTCPFallback bool,
|
||
) *dns.Msg {
|
||
if maxParallel <= 0 {
|
||
maxParallel = 1
|
||
}
|
||
servers := shuffled(upstreams)
|
||
|
||
cctx, cancel := context.WithTimeout(ctx, timeout)
|
||
defer cancel()
|
||
|
||
type result struct{ msg *dns.Msg }
|
||
ch := make(chan result, len(servers))
|
||
sem := make(chan struct{}, maxParallel)
|
||
|
||
execOne := func(svr string) {
|
||
upReq := clampEDNSForUpstream(req, 1232) // 或做成 flag
|
||
resp, _, err := udpClient.ExchangeContext(cctx, upReq, svr)
|
||
if err == nil && resp != nil && resp.Truncated && allowTCPFallback {
|
||
log.Printf("[upstream] UDP truncated, retry TCP: %s", svr)
|
||
tcpClient := *udpClient
|
||
tcpClient.Net = "tcp"
|
||
|
||
resp, _, err = tcpClient.ExchangeContext(cctx, req.Copy(), svr)
|
||
}
|
||
if err != nil || resp == nil {
|
||
if err != nil && cctx.Err() == nil {
|
||
log.Printf("[upstream] %s error: %v", svr, err)
|
||
}
|
||
return
|
||
}
|
||
if resp.Rcode == dns.RcodeServerFailure || resp.Rcode == dns.RcodeRefused || resp.Rcode == dns.RcodeFormatError {
|
||
return
|
||
}
|
||
select {
|
||
case ch <- result{msg: resp}:
|
||
case <-cctx.Done():
|
||
}
|
||
}
|
||
|
||
for _, s := range servers {
|
||
s := s
|
||
go func() {
|
||
select {
|
||
case sem <- struct{}{}:
|
||
defer func() { <-sem }()
|
||
case <-cctx.Done():
|
||
return
|
||
}
|
||
execOne(s)
|
||
}()
|
||
}
|
||
|
||
for i := 0; i < len(servers); i++ {
|
||
select {
|
||
case r := <-ch:
|
||
if r.msg != nil {
|
||
cancel()
|
||
return r.msg
|
||
}
|
||
case <-cctx.Done():
|
||
log.Printf("[upstream] timeout after %v", timeout)
|
||
return nil
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
/******************************************************************
|
||
* EDNS / 响应构造
|
||
******************************************************************/
|
||
func stripECS(m *dns.Msg) {
|
||
if o := m.IsEdns0(); o != nil {
|
||
var kept []dns.EDNS0
|
||
for _, e := range o.Option {
|
||
if _, isECS := e.(*dns.EDNS0_SUBNET); !isECS {
|
||
kept = append(kept, e)
|
||
}
|
||
}
|
||
o.Option = kept
|
||
}
|
||
}
|
||
|
||
func getDOFlag(m *dns.Msg) bool {
|
||
if o := m.IsEdns0(); o != nil {
|
||
return o.Do()
|
||
}
|
||
return false
|
||
}
|
||
|
||
func writeReply(w dns.ResponseWriter, req *dns.Msg, upstream *dns.Msg, meta *cacheEntry) {
|
||
if upstream == nil {
|
||
dns.HandleFailed(w, req)
|
||
return
|
||
}
|
||
out := new(dns.Msg)
|
||
out.SetReply(req)
|
||
out.Authoritative = false
|
||
out.RecursionAvailable = upstream.RecursionAvailable
|
||
out.AuthenticatedData = upstream.AuthenticatedData
|
||
out.CheckingDisabled = req.CheckingDisabled
|
||
out.Rcode = upstream.Rcode
|
||
out.Answer = upstream.Answer
|
||
out.Ns = upstream.Ns
|
||
|
||
var extras []dns.RR
|
||
for _, rr := range upstream.Extra {
|
||
if !isPseudo(rr) {
|
||
extras = append(extras, rr)
|
||
}
|
||
}
|
||
|
||
if ro := req.IsEdns0(); ro != nil {
|
||
o := new(dns.OPT)
|
||
o.Hdr.Name = "."
|
||
o.Hdr.Rrtype = dns.TypeOPT
|
||
o.SetUDPSize(ro.UDPSize())
|
||
if ro.Do() {
|
||
o.SetDo()
|
||
}
|
||
if uo := upstream.IsEdns0(); uo != nil {
|
||
o.SetExtendedRcode(uint16(uo.ExtendedRcode()))
|
||
o.SetVersion(uint8(uo.Version()))
|
||
for _, opt := range uo.Option {
|
||
if ede, ok := opt.(*dns.EDNS0_EDE); ok {
|
||
o.Option = append(o.Option, ede)
|
||
}
|
||
}
|
||
} else if meta != nil && meta.ednsPresent {
|
||
// Upstream/cached msg has no OPT(例如缓存时被剥离),用缓存元数据重建
|
||
o.SetExtendedRcode(meta.ednsExtRcode)
|
||
o.SetVersion(meta.ednsVersion)
|
||
for _, e := range meta.ednsEDE {
|
||
o.Option = append(o.Option, e)
|
||
}
|
||
}
|
||
extras = append(extras, o)
|
||
}
|
||
out.Extra = extras
|
||
out.Compress = true
|
||
|
||
if err := w.WriteMsg(out); err != nil {
|
||
log.Printf("[write] WriteMsg error: %v", err)
|
||
}
|
||
}
|
||
|
||
/******************************************************************
|
||
* 主处理器
|
||
******************************************************************/
|
||
func handleDNS(
|
||
upstreams []string,
|
||
cacheMaxTTL, timeout time.Duration,
|
||
maxParallel int,
|
||
stripECSBeforeForward bool,
|
||
allowTCPFallback bool,
|
||
bl *suffixMatcher,
|
||
blRcode int,
|
||
) dns.HandlerFunc {
|
||
return func(w dns.ResponseWriter, r *dns.Msg) {
|
||
if len(r.Question) == 0 {
|
||
dns.HandleFailed(w, r)
|
||
return
|
||
}
|
||
q := r.Question[0]
|
||
log.Printf("[req] %s %s (id=%d cd=%v do=%v from=%s)",
|
||
dns.TypeToString[q.Qtype], q.Name, r.Id, r.CheckingDisabled, getDOFlag(r), w.RemoteAddr())
|
||
|
||
if stripECSBeforeForward {
|
||
stripECS(r)
|
||
}
|
||
|
||
// 黑名单拦截:命中则不查上游,直接返回
|
||
if rule, ok := bl.match(q.Name); ok {
|
||
nameCanon := dns.CanonicalName(q.Name)
|
||
log.Printf("[blacklist] HIT %s rule=%s (no upstream query)", nameCanon, rule)
|
||
up := makeBlockedUpstream(blRcode, rule)
|
||
writeReply(w, r, up, nil)
|
||
return
|
||
}
|
||
|
||
key := cacheKeyFromMsg(q, getDOFlag(r), r.CheckingDisabled)
|
||
|
||
if cachedMsg, cachedMeta, ok := tryCacheRead(key); ok {
|
||
log.Printf("[cache] HIT %s %s", dns.TypeToString[q.Qtype], q.Name)
|
||
writeReply(w, r, cachedMsg, cachedMeta)
|
||
return
|
||
}
|
||
log.Printf("[cache] MISS %s %s", dns.TypeToString[q.Qtype], q.Name)
|
||
|
||
ctx := context.Background()
|
||
resp := queryUpstreamsLimited(ctx, r, upstreams, timeout, maxParallel, allowTCPFallback)
|
||
if resp == nil {
|
||
log.Printf("[error] all upstreams failed for %s", q.Name)
|
||
dns.HandleFailed(w, r)
|
||
return
|
||
}
|
||
cacheWrite(key, resp, cacheMaxTTL)
|
||
for _, ans := range resp.Answer {
|
||
log.Printf("[answer] %s", ans.String())
|
||
}
|
||
writeReply(w, r, resp, nil)
|
||
}
|
||
}
|
||
|
||
/******************************************************************
|
||
* 主函数
|
||
******************************************************************/
|
||
func main() {
|
||
rand.Seed(time.Now().UnixNano())
|
||
|
||
var help bool
|
||
certFile := flag.String("cert", "server.crt", "TLS 证书文件路径")
|
||
keyFile := flag.String("key", "server.key", "TLS 私钥文件路径")
|
||
addr := flag.String("addr", ":853", "DoT 监听地址")
|
||
upstreamStr := flag.String("upstream", "8.8.8.8:53,1.1.1.1:53", "上游 DNS 列表")
|
||
cacheTTLFlag := flag.Duration("cache-ttl", 60*time.Second, "最大缓存 TTL")
|
||
cacheSizeFlag := flag.Int("cache-size", defaultCacheSize, "LRU 缓存大小上限")
|
||
timeoutFlag := flag.Duration("timeout", 3*time.Second, "上游查询超时")
|
||
maxParallel := flag.Int("max-parallel", 3, "并发上游数量")
|
||
stripECSFlag := flag.Bool("strip-ecs", true, "去除 ECS")
|
||
allowTCPFallback := flag.Bool("tcp-fallback", true, "UDP 截断时 TCP 回退")
|
||
blacklistStr := flag.String("blacklist", "", "逗号分隔的黑名单域名(后缀匹配;支持如 *.example.com)")
|
||
blacklistFile := flag.String("blacklist-file", "", "黑名单文件路径(每行一个域名;支持 # 或 ; 注释;后缀匹配)")
|
||
blacklistRcodeFlag := flag.String("blacklist-rcode", "REFUSED", "命中黑名单返回的 RCODE:REFUSED|NXDOMAIN|SERVFAIL")
|
||
verbose := flag.Bool("v", false, "verbose 日志")
|
||
|
||
flag.BoolVar(&help, "h", false, "")
|
||
flag.BoolVar(&help, "help", false, "帮助信息")
|
||
|
||
flag.Parse()
|
||
|
||
if help {
|
||
fmt.Printf(
|
||
"\t\tDNS-over-TLS (DoT)\n"+
|
||
"\tVersion 0.1\n"+
|
||
"\tE-mail: aixiao@aixiao.me\n"+
|
||
"\tBuild Date: %s\n", BuildDate)
|
||
|
||
flag.Usage()
|
||
fmt.Printf("\n")
|
||
os.Exit(0)
|
||
}
|
||
|
||
initLogger(*verbose)
|
||
|
||
var err error
|
||
cache, err = lru.New[string, *cacheEntry](*cacheSizeFlag)
|
||
if err != nil {
|
||
log.Fatalf("[fatal] failed to init LRU cache: %v", err)
|
||
}
|
||
|
||
cert, err := tls.LoadX509KeyPair(*certFile, *keyFile)
|
||
if err != nil {
|
||
log.Fatalf("[fatal] failed to load cert: %v", err)
|
||
}
|
||
|
||
var upstreams []string
|
||
for _, s := range strings.Split(*upstreamStr, ",") {
|
||
if t := strings.TrimSpace(s); t != "" {
|
||
if !strings.Contains(t, ":") {
|
||
t = fmt.Sprintf("%s:53", t)
|
||
}
|
||
upstreams = append(upstreams, t)
|
||
}
|
||
}
|
||
if len(upstreams) == 0 {
|
||
log.Fatal("[fatal] no upstream DNS servers provided")
|
||
}
|
||
|
||
udpClient = &dns.Client{Net: "udp", UDPSize: 4096, SingleInflight: true}
|
||
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
defer cancel()
|
||
|
||
startCacheCleaner(ctx)
|
||
|
||
// 加载黑名单规则
|
||
bl, blRcode := initBlacklist(ctx, *blacklistStr, *blacklistFile, *blacklistRcodeFlag)
|
||
|
||
mux := dns.NewServeMux()
|
||
mux.HandleFunc(".", handleDNS(
|
||
upstreams,
|
||
*cacheTTLFlag,
|
||
*timeoutFlag,
|
||
*maxParallel,
|
||
*stripECSFlag,
|
||
*allowTCPFallback,
|
||
bl,
|
||
blRcode,
|
||
))
|
||
|
||
srv := &dns.Server{
|
||
Addr: *addr,
|
||
Net: "tcp-tls",
|
||
TLSConfig: &tls.Config{
|
||
Certificates: []tls.Certificate{cert},
|
||
MinVersion: tls.VersionTLS12,
|
||
NextProtos: []string{"dot"},
|
||
},
|
||
Handler: mux,
|
||
ReadTimeout: 10 * time.Second,
|
||
WriteTimeout: 10 * time.Second,
|
||
}
|
||
|
||
stop := make(chan os.Signal, 1)
|
||
signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
|
||
|
||
errCh := make(chan error, 1)
|
||
go func() {
|
||
log.Printf("🚀 starting DNS-over-TLS on %s", *addr)
|
||
log.Printf(" upstreams=%v | cache_max_ttl=%s | cache_size=%d | timeout=%s | max_parallel=%d | strip_ecs=%v | tcp_fallback=%v | blacklist_rules=%d | blacklist_rcode=%s",
|
||
upstreams, cacheTTLFlag.String(), *cacheSizeFlag, timeoutFlag.String(), *maxParallel, *stripECSFlag, *allowTCPFallback, len(bl.rules), strings.ToUpper(*blacklistRcodeFlag))
|
||
errCh <- srv.ListenAndServe()
|
||
}()
|
||
|
||
select {
|
||
case sig := <-stop:
|
||
log.Printf("[shutdown] caught signal: %s", sig)
|
||
cancel()
|
||
if err := srv.Shutdown(); err != nil {
|
||
log.Printf("[shutdown] server shutdown error: %v", err)
|
||
}
|
||
case err := <-errCh:
|
||
if err != nil {
|
||
log.Fatalf("[fatal] server error: %v", err)
|
||
}
|
||
}
|
||
|
||
log.Println("[bye] server stopped.")
|
||
}
|