4to6/udp.go

125 lines
3.1 KiB
Go
Raw Normal View History

2025-02-07 10:32:47 +08:00
package main
import (
"context"
"fmt"
"net"
"sync"
"time"
)
var (
udp_wg sync.WaitGroup
clientMap sync.Map // 代替 sync.Mutex 保护的 map[string]*net.UDPConn
)
// 处理 UDP 数据转发
func handleUdpTraffic(listener *net.UDPConn, targetConn *net.UDPConn, clientAddr *net.UDPAddr, data []byte) {
// 目标服务器超时设置,防止阻塞
targetConn.SetWriteDeadline(time.Now().Add(2 * time.Second))
_, err := targetConn.Write(data)
if err != nil {
fmt.Printf("UDP 数据转发失败 %s -> %s: %v\n", clientAddr, targetConn.RemoteAddr(), err)
return
}
// 读取目标服务器响应
targetConn.SetReadDeadline(time.Now().Add(2 * time.Second))
buffer := make([]byte, 4096)
n, _, err := targetConn.ReadFromUDP(buffer)
if err != nil {
fmt.Printf("UDP 读取目标服务器响应失败: %v\n", err)
return
}
// 发送响应回客户端
listener.SetWriteDeadline(time.Now().Add(2 * time.Second))
_, err = listener.WriteToUDP(buffer[:n], clientAddr)
if err != nil {
fmt.Printf("UDP 发送响应失败: %v\n", err)
}
}
// 启动 UDP 代理
func StartUdpProxy(ctx context.Context, listenAddr, targetAddr string) {
defer udp_wg.Done()
localAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
fmt.Printf("解析本地 UDP 地址 %s 失败: %v\n", listenAddr, err)
return
}
listener, err := net.ListenUDP("udp", localAddr)
if err != nil {
fmt.Printf("监听 UDP %s 失败: %v\n", listenAddr, err)
return
}
defer listener.Close()
targetUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
if err != nil {
fmt.Printf("解析目标 UDP 地址 %s 失败: %v\n", targetAddr, err)
return
}
fmt.Printf("UDP 代理启动: %s -> %s\n", listenAddr, targetAddr)
// 设置多个 goroutine 处理 UDP 请求,提高吞吐量
numWorkers := 4 // 线程数,可根据 CPU 资源调整
for i := 0; i < numWorkers; i++ {
udp_wg.Add(1)
go func() {
defer udp_wg.Done()
buffer := make([]byte, 4096)
for {
select {
case <-ctx.Done():
return
default:
listener.SetReadDeadline(time.Now().Add(1 * time.Second))
n, clientAddr, err := listener.ReadFromUDP(buffer)
if err != nil {
if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
continue
}
if ctx.Err() != nil {
return
}
fmt.Printf("UDP 读取数据失败: %v\n", err)
continue
}
// 查找或创建目标连接
clientKey := clientAddr.String()
targetConn, exists := clientMap.Load(clientKey)
if !exists {
targetConn, err = net.DialUDP("udp", nil, targetUDPAddr)
if err != nil {
fmt.Printf("连接 UDP 目标 %s 失败: %v\n", targetAddr, err)
continue
}
clientMap.Store(clientKey, targetConn)
}
// 处理 UDP 数据
go handleUdpTraffic(listener, targetConn.(*net.UDPConn), clientAddr, buffer[:n])
}
}
}()
}
// 等待退出信号
<-ctx.Done()
fmt.Println("\n收到退出信号正在关闭 UDP 代理...")
listener.Close()
clientMap.Range(func(key, value interface{}) bool {
value.(*net.UDPConn).Close()
return true
})
fmt.Println("UDP 代理已关闭")
}