优化UDP

This commit is contained in:
2025-06-27 17:54:34 +08:00
parent ab488d3393
commit 3b07ee819c
5 changed files with 202 additions and 95 deletions

Binary file not shown.

Binary file not shown.

34
main.go
View File

@@ -5,6 +5,8 @@ import (
"flag"
"fmt"
"os"
"os/signal"
"syscall"
)
// 主循环,为每个端口映射启动一个代理
@@ -13,13 +15,13 @@ func Loop(config Config) {
// 启动 TCP 代理
for tcp_listenAddr, tcp_targetAddr := range config.Global.LT {
tcp_wg.Add(1)
tcpWg.Add(1)
go StartTcpProxy(tcp_listenAddr, tcp_targetAddr)
}
// 启动 UDP 代理
for udp_listenAddr, udp_targetAddr := range config.Global.LU {
udp_wg.Add(1)
udpWg.Add(1)
go StartUdpProxy(ctx, udp_listenAddr, udp_targetAddr)
}
@@ -27,6 +29,34 @@ func Loop(config Config) {
WaitForExit(cancel)
}
// 等待退出信号(如 Ctrl+C并优雅地关闭所有代理服务
func WaitForExit(cancel context.CancelFunc) {
sigChan := make(chan os.Signal, 1)
// 注册接收 SIGINT、SIGTERM 信号
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
<-sigChan // 阻塞直到收到退出信号
fmt.Println("\n收到终止信号准备关闭代理服务器...")
// 关闭所有 TCP 监听器
for _, listener := range tcp_listeners {
if err := listener.Close(); err != nil {
fmt.Printf("关闭 TCP 监听器 %s 失败: %v\n", listener.Addr(), err)
}
}
// 通知 UDP goroutine 退出
cancel()
fmt.Println("[INFO] 已通知所有 goroutine 退出,等待清理...")
// 等待所有 TCP 和 UDP goroutine 完成
tcpWg.Wait()
udpWg.Wait()
fmt.Println("所有代理已安全退出")
os.Exit(0)
}
func main() {
daemon := flag.Bool("d", false, "守护进程模式") // 解析命令行参数,是否以守护进程模式运行
config := flag.String("c", "4to6.conf", "指定配置文件") // 解析命令行参数,是否以守护进程模式运行

78
tcp.go
View File

@@ -1,52 +1,49 @@
package main
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"os/signal"
"sync"
"syscall"
"time"
)
// 定义全局变量
var (
tcp_wg sync.WaitGroup
tcp_listeners []net.Listener
tcp_listenMux sync.Mutex // 确保 tcp_listeners 操作是线程安全的
tcpWg sync.WaitGroup // 用于等待所有 TCP 代理任务完成
tcp_listeners []net.Listener // 存储所有监听器,便于统一关闭
tcp_listenMux sync.Mutex // 互斥锁,确保 tcp_listeners 操作是线程安全的
)
// 启用 TCP KeepAlive
// 启用 TCP KeepAlive,防止连接空闲时被中间设备关闭
func setKeepAlive(conn net.Conn) {
if tcpConn, ok := conn.(*net.TCPConn); ok {
tcpConn.SetKeepAlive(true)
tcpConn.SetKeepAlivePeriod(30 * time.Second)
tcpConn.SetKeepAlive(true) // 启用 keep-alive
tcpConn.SetKeepAlivePeriod(30 * time.Second) // 设置保活周期为 30 秒
}
}
// 关闭写方向
// 关闭连接的写方向,常用于半关闭通信
func closeWrite(conn net.Conn) {
if tcpConn, ok := conn.(*net.TCPConn); ok {
tcpConn.CloseWrite()
tcpConn.CloseWrite() // 关闭写通道,保持读通道开启
}
}
// 处理客户端连接,并将数据转发到目标服务器
// 处理一个 TCP 客户端连接,并将转发到目标服务器
func HandleTcpConnection(clientConn net.Conn, targetAddr string) {
defer clientConn.Close()
defer clientConn.Close() // 函数结束时关闭客户端连接
// 连接目标服务器支持 IPv6
// 尝试连接目标服务器地址(支持 IPv6
serverConn, err := net.Dial("tcp", targetAddr)
if err != nil {
fmt.Printf("无法连接到 %s: %v\n", targetAddr, err)
return
}
defer serverConn.Close()
defer serverConn.Close() // 函数结束时关闭服务器连接
// 开启 TCP KeepAlive,防止连接过早断开
// 对两个连接都开启 KeepAlive
setKeepAlive(clientConn)
setKeepAlive(serverConn)
@@ -54,42 +51,44 @@ func HandleTcpConnection(clientConn net.Conn, targetAddr string) {
fmt.Printf("连接 %s -> %s\n", clientAddr, targetAddr)
var cpWG sync.WaitGroup
cpWG.Add(2)
cpWG.Add(2) // 我们会启动两个数据复制 goroutine
// 客户端 -> 目标服务器
go func() {
defer cpWG.Done()
_, err := io.Copy(serverConn, clientConn)
_, err := io.Copy(serverConn, clientConn) // 将客户端数据复制到目标服务器
if err != nil && !errors.Is(err, io.EOF) {
fmt.Printf("数据转发 %s -> %s 失败: %v\n", clientAddr, targetAddr, err)
}
closeWrite(serverConn)
closeWrite(serverConn) // 通知目标服务器我们写完了
}()
// 目标服务器 -> 客户端
go func() {
defer cpWG.Done()
_, err := io.Copy(clientConn, serverConn)
_, err := io.Copy(clientConn, serverConn) // 将目标服务器数据复制到客户端
if err != nil && !errors.Is(err, io.EOF) {
fmt.Printf("数据转发 %s <- %s 失败: %v\n", clientAddr, targetAddr, err)
}
closeWrite(clientConn)
closeWrite(clientConn) // 通知客户端我们写完了
}()
cpWG.Wait()
cpWG.Wait() // 等待两个数据转发完成
fmt.Printf("连接 %s 结束\n", clientAddr)
}
// 启动代理服务器
// 启动一个 TCP 代理服务,将监听地址的数据转发到目标地址
func StartTcpProxy(listenAddr, targetAddr string) {
defer tcp_wg.Done()
defer tcpWg.Done() // 当前代理任务结束时计数器减 1
// 监听本地地址
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
fmt.Printf("监听 %s 失败: %v\n", listenAddr, err)
return
}
// 添加到监听器列表,供后续关闭使用
tcp_listenMux.Lock()
tcp_listeners = append(tcp_listeners, listener)
tcp_listenMux.Unlock()
@@ -97,8 +96,10 @@ func StartTcpProxy(listenAddr, targetAddr string) {
fmt.Printf("代理服务启动: %s -> %s\n", listenAddr, targetAddr)
for {
// 接受客户端连接
clientConn, err := listener.Accept()
if err != nil {
// 如果监听器已关闭,说明程序正在退出
if errors.Is(err, net.ErrClosed) {
fmt.Printf("监听器 %s 已关闭\n", listenAddr)
return
@@ -106,32 +107,7 @@ func StartTcpProxy(listenAddr, targetAddr string) {
fmt.Printf("接受连接失败: %v\n", err)
continue
}
// 启动一个 goroutine 处理该连接
go HandleTcpConnection(clientConn, targetAddr)
}
}
// 监听系统信号,以便优雅地关闭 TCP 和 UDP 代理服务
func WaitForExit(cancel context.CancelFunc) {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
<-sigChan
fmt.Println("\n收到终止信号准备关闭代理服务器...")
// 关闭所有 TCP 监听器
for _, listener := range tcp_listeners {
if err := listener.Close(); err != nil {
fmt.Printf("关闭 TCP 监听器 %s 失败: %v\n", listener.Addr(), err)
}
}
// 触发 UDP 代理退出
cancel()
// 等待所有 TCP 和 UDP 代理退出
tcp_wg.Wait()
udp_wg.Wait()
fmt.Println("所有代理已安全退出")
os.Exit(0)
}

185
udp.go
View File

@@ -4,121 +4,222 @@ import (
"context"
"fmt"
"net"
"runtime"
"sync"
"time"
)
var (
udp_wg sync.WaitGroup
clientMap sync.Map // 代替 sync.Mutex 保护的 map[string]*net.UDPConn
udpWg sync.WaitGroup
clientMap sync.Map // key: 客户端地址字符串value: *clientConn保存客户端对应的目标服务器连接
)
// 处理 UDP 数据转发
func handleUdpTraffic(listener *net.UDPConn, targetConn *net.UDPConn, clientAddr *net.UDPAddr, data []byte) {
// 目标服务器超时设置,防止阻塞
targetConn.SetWriteDeadline(time.Now().Add(2 * time.Second))
_, err := targetConn.Write(data)
if err != nil {
fmt.Printf("UDP 数据转发失败 %s -> %s: %v\n", clientAddr, targetConn.RemoteAddr(), err)
// clientConn 保存目标服务器的 UDP 连接及其状态
type clientConn struct {
conn *net.UDPConn
lastActive time.Time
mu sync.Mutex // 保护 conn 读写和 lastActive 字段的并发安全
}
// 写数据到目标 UDP 服务器,线程安全
func (c *clientConn) write(data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
// 设置写超时时间为2秒
c.conn.SetWriteDeadline(time.Now().Add(2 * time.Second))
_, err := c.conn.Write(data)
return err
}
// 读数据从目标 UDP 服务器,线程安全
func (c *clientConn) read(buffer []byte) (int, error) {
c.mu.Lock()
defer c.mu.Unlock()
// 设置读超时时间为2秒
c.conn.SetReadDeadline(time.Now().Add(2 * time.Second))
return c.conn.Read(buffer)
}
// 更新连接最后活跃时间,线程安全
func (c *clientConn) updateLastActive() {
c.mu.Lock()
defer c.mu.Unlock()
c.lastActive = time.Now()
}
// 处理单个 UDP 客户端的数据转发
func handleUdpTraffic(listener *net.UDPConn, c *clientConn, clientAddr *net.UDPAddr, data []byte) {
// 更新客户端最后活跃时间
c.updateLastActive()
// 将收到的数据写入到目标 UDP 服务器
if err := c.write(data); err != nil {
fmt.Printf("[WARN] UDP 转发失败 %s -> %s: %v\n", clientAddr, c.conn.RemoteAddr(), err)
return
}
// 读取目标服务器响应
targetConn.SetReadDeadline(time.Now().Add(2 * time.Second))
// 读取目标 UDP 服务器响应
buffer := make([]byte, 4096)
n, _, err := targetConn.ReadFromUDP(buffer)
n, err := c.read(buffer)
if err != nil {
fmt.Printf("UDP 读取目标服务器响应失败: %v\n", err)
// 如果是超时错误,忽略日志,减少日志噪音
if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
fmt.Printf("[WARN] UDP 读取目标响应失败: %v\n", err)
}
return
}
// 发送响应回客户端
// 将目标服务器响应写回给客户端
listener.SetWriteDeadline(time.Now().Add(2 * time.Second))
_, err = listener.WriteToUDP(buffer[:n], clientAddr)
if err != nil {
fmt.Printf("UDP 发送响应失败: %v\n", err)
fmt.Printf("[WARN] UDP 发送响应失败: %v\n", err)
}
}
// 启动 UDP 代理
func StartUdpProxy(ctx context.Context, listenAddr, targetAddr string) {
defer udp_wg.Done()
// 定期清理长时间未活跃的客户端连接
func cleanupInactiveClients(ctx context.Context, interval, maxIdle time.Duration) {
ticker := time.NewTicker(interval) // 定时器,按 interval 频率执行清理
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now()
// 遍历所有客户端连接,判断是否超过最大空闲时间
clientMap.Range(func(key, value interface{}) bool {
c := value.(*clientConn)
c.mu.Lock()
inactive := now.Sub(c.lastActive) > maxIdle
c.mu.Unlock()
if inactive {
// 从 map 删除连接,防止新请求复用已关闭连接
clientMap.Delete(key)
fmt.Printf("[INFO] 关闭长时间未活跃客户端连接 %s\n", key)
c.conn.Close() // 关闭 UDP 连接
}
return true
})
case <-ctx.Done():
// 上下文取消,退出清理协程
return
}
}
}
// 启动 UDP 代理服务
func StartUdpProxy(ctx context.Context, listenAddr, targetAddr string) {
defer udpWg.Done()
// 解析本地监听 UDP 地址
localAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
fmt.Printf("解析本地 UDP 地址 %s 失败: %v\n", listenAddr, err)
fmt.Printf("[ERROR] 解析本地地址失败: %v\n", err)
return
}
// 监听本地 UDP 端口
listener, err := net.ListenUDP("udp", localAddr)
if err != nil {
fmt.Printf("监听 UDP %s 失败: %v\n", listenAddr, err)
fmt.Printf("[ERROR] 监听 UDP 失败: %v\n", err)
return
}
defer listener.Close()
// 解析目标 UDP 服务器地址
targetUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
if err != nil {
fmt.Printf("解析目标 UDP 地址 %s 失败: %v\n", targetAddr, err)
fmt.Printf("[ERROR] 解析目标地址失败: %v\n", err)
return
}
fmt.Printf("UDP 代理启动: %s -> %s\n", listenAddr, targetAddr)
fmt.Printf("[INFO] UDP 代理启动 %s -> %s\n", listenAddr, targetAddr)
// 设置多个 goroutine 处理 UDP 请求,提高吞吐量
numWorkers := 4 // 线程数,可根据 CPU 资源调整
// 启动后台协程定期清理空闲连接
go cleanupInactiveClients(ctx, 1*time.Minute, 5*time.Minute)
// 根据 CPU 核心数启动多个工作协程,提高并发处理能力
var numWorkers = runtime.NumCPU()
for i := 0; i < numWorkers; i++ {
udp_wg.Add(1)
go func() {
defer udp_wg.Done()
buffer := make([]byte, 4096)
udpWg.Add(1)
go func(workerID int) {
defer udpWg.Done()
buffer := make([]byte, 4096) // 每个协程独立 buffer
for {
select {
case <-ctx.Done():
// 收到退出信号,结束协程
return
default:
// 设置读取超时,避免阻塞过久
listener.SetReadDeadline(time.Now().Add(1 * time.Second))
n, clientAddr, err := listener.ReadFromUDP(buffer)
if err != nil {
if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
// 超时错误正常,继续循环
if ne, ok := err.(net.Error); ok && ne.Timeout() {
continue
}
// 如果上下文已取消,退出协程
if ctx.Err() != nil {
return
}
fmt.Printf("UDP 读取数据失败: %v\n", err)
// 监听器关闭,退出协程
if err.Error() == "use of closed network connection" {
return
}
fmt.Printf("[WARN] UDP 读取数据失败: %v\n", err)
continue
}
// 查找或创建目标连接
clientKey := clientAddr.String()
targetConn, exists := clientMap.Load(clientKey)
val, exists := clientMap.Load(clientKey)
var c *clientConn
if !exists {
targetConn, err = net.DialUDP("udp", nil, targetUDPAddr)
// 新客户端,创建连接到目标服务器的 UDP 连接
conn, err := net.DialUDP("udp", nil, targetUDPAddr)
if err != nil {
fmt.Printf("连接 UDP 目标 %s 失败: %v\n", targetAddr, err)
fmt.Printf("[WARN] 连接目标服务器失败: %v\n", err)
continue
}
clientMap.Store(clientKey, targetConn)
c = &clientConn{
conn: conn,
lastActive: time.Now(),
}
clientMap.Store(clientKey, c)
} else {
// 已存在连接,复用
c = val.(*clientConn)
}
// 处理 UDP 数据
go handleUdpTraffic(listener, targetConn.(*net.UDPConn), clientAddr, buffer[:n])
// 同步调用,避免因并发过多导致压力大
handleUdpTraffic(listener, c, clientAddr, buffer[:n])
}
}
}()
}(i)
}
// 等待退出信号
<-ctx.Done()
fmt.Println("\n收到退出信号,正在关闭 UDP 代理...")
fmt.Println("[INFO] 收到退出信号,正在关闭 UDP 代理...")
listener.Close()
listener.Close() // 关闭监听
// 关闭所有客户端到目标服务器的连接
clientMap.Range(func(key, value interface{}) bool {
value.(*net.UDPConn).Close()
c := value.(*clientConn)
c.conn.Close()
return true
})
fmt.Println("UDP 代理已关闭")
fmt.Println("[INFO] UDP 代理已关闭")
}