125 lines
3.1 KiB
Go
125 lines
3.1 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"net"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
var (
|
||
udp_wg sync.WaitGroup
|
||
clientMap sync.Map // 代替 sync.Mutex 保护的 map[string]*net.UDPConn
|
||
)
|
||
|
||
// 处理 UDP 数据转发
|
||
func handleUdpTraffic(listener *net.UDPConn, targetConn *net.UDPConn, clientAddr *net.UDPAddr, data []byte) {
|
||
// 目标服务器超时设置,防止阻塞
|
||
targetConn.SetWriteDeadline(time.Now().Add(2 * time.Second))
|
||
_, err := targetConn.Write(data)
|
||
if err != nil {
|
||
fmt.Printf("UDP 数据转发失败 %s -> %s: %v\n", clientAddr, targetConn.RemoteAddr(), err)
|
||
return
|
||
}
|
||
|
||
// 读取目标服务器响应
|
||
targetConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||
buffer := make([]byte, 4096)
|
||
n, _, err := targetConn.ReadFromUDP(buffer)
|
||
if err != nil {
|
||
fmt.Printf("UDP 读取目标服务器响应失败: %v\n", err)
|
||
return
|
||
}
|
||
|
||
// 发送响应回客户端
|
||
listener.SetWriteDeadline(time.Now().Add(2 * time.Second))
|
||
_, err = listener.WriteToUDP(buffer[:n], clientAddr)
|
||
if err != nil {
|
||
fmt.Printf("UDP 发送响应失败: %v\n", err)
|
||
}
|
||
}
|
||
|
||
// 启动 UDP 代理
|
||
func StartUdpProxy(ctx context.Context, listenAddr, targetAddr string) {
|
||
defer udp_wg.Done()
|
||
|
||
localAddr, err := net.ResolveUDPAddr("udp", listenAddr)
|
||
if err != nil {
|
||
fmt.Printf("解析本地 UDP 地址 %s 失败: %v\n", listenAddr, err)
|
||
return
|
||
}
|
||
|
||
listener, err := net.ListenUDP("udp", localAddr)
|
||
if err != nil {
|
||
fmt.Printf("监听 UDP %s 失败: %v\n", listenAddr, err)
|
||
return
|
||
}
|
||
defer listener.Close()
|
||
|
||
targetUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
|
||
if err != nil {
|
||
fmt.Printf("解析目标 UDP 地址 %s 失败: %v\n", targetAddr, err)
|
||
return
|
||
}
|
||
|
||
fmt.Printf("UDP 代理启动: %s -> %s\n", listenAddr, targetAddr)
|
||
|
||
// 设置多个 goroutine 处理 UDP 请求,提高吞吐量
|
||
numWorkers := 4 // 线程数,可根据 CPU 资源调整
|
||
for i := 0; i < numWorkers; i++ {
|
||
udp_wg.Add(1)
|
||
go func() {
|
||
defer udp_wg.Done()
|
||
buffer := make([]byte, 4096)
|
||
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
return
|
||
default:
|
||
listener.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||
n, clientAddr, err := listener.ReadFromUDP(buffer)
|
||
if err != nil {
|
||
if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
|
||
continue
|
||
}
|
||
if ctx.Err() != nil {
|
||
return
|
||
}
|
||
fmt.Printf("UDP 读取数据失败: %v\n", err)
|
||
continue
|
||
}
|
||
|
||
// 查找或创建目标连接
|
||
clientKey := clientAddr.String()
|
||
targetConn, exists := clientMap.Load(clientKey)
|
||
if !exists {
|
||
targetConn, err = net.DialUDP("udp", nil, targetUDPAddr)
|
||
if err != nil {
|
||
fmt.Printf("连接 UDP 目标 %s 失败: %v\n", targetAddr, err)
|
||
continue
|
||
}
|
||
clientMap.Store(clientKey, targetConn)
|
||
}
|
||
|
||
// 处理 UDP 数据
|
||
go handleUdpTraffic(listener, targetConn.(*net.UDPConn), clientAddr, buffer[:n])
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
|
||
// 等待退出信号
|
||
<-ctx.Done()
|
||
fmt.Println("\n收到退出信号,正在关闭 UDP 代理...")
|
||
|
||
listener.Close()
|
||
clientMap.Range(func(key, value interface{}) bool {
|
||
value.(*net.UDPConn).Close()
|
||
return true
|
||
})
|
||
|
||
fmt.Println("UDP 代理已关闭")
|
||
}
|