package main import ( "context" "errors" "fmt" "io" "net" "os" "os/signal" "sync" "syscall" "time" ) var ( tcp_wg sync.WaitGroup tcp_listeners []net.Listener tcp_listenMux sync.Mutex // 确保 tcp_listeners 操作是线程安全的 ) // 启用 TCP KeepAlive func setKeepAlive(conn net.Conn) { if tcpConn, ok := conn.(*net.TCPConn); ok { tcpConn.SetKeepAlive(true) tcpConn.SetKeepAlivePeriod(30 * time.Second) } } // 关闭写方向 func closeWrite(conn net.Conn) { if tcpConn, ok := conn.(*net.TCPConn); ok { tcpConn.CloseWrite() } } // 处理客户端连接,并将数据转发到目标服务器 func HandleTcpConnection(clientConn net.Conn, targetAddr string) { defer clientConn.Close() // 连接目标服务器,支持 IPv6 serverConn, err := net.Dial("tcp", targetAddr) if err != nil { fmt.Printf("无法连接到 %s: %v\n", targetAddr, err) return } defer serverConn.Close() // 开启 TCP KeepAlive,防止连接过早断开 setKeepAlive(clientConn) setKeepAlive(serverConn) clientAddr := clientConn.RemoteAddr().String() fmt.Printf("连接 %s -> %s\n", clientAddr, targetAddr) var cpWG sync.WaitGroup cpWG.Add(2) // 客户端 -> 目标服务器 go func() { defer cpWG.Done() _, err := io.Copy(serverConn, clientConn) if err != nil && !errors.Is(err, io.EOF) { fmt.Printf("数据转发 %s -> %s 失败: %v\n", clientAddr, targetAddr, err) } closeWrite(serverConn) }() // 目标服务器 -> 客户端 go func() { defer cpWG.Done() _, err := io.Copy(clientConn, serverConn) if err != nil && !errors.Is(err, io.EOF) { fmt.Printf("数据转发 %s <- %s 失败: %v\n", clientAddr, targetAddr, err) } closeWrite(clientConn) }() cpWG.Wait() fmt.Printf("连接 %s 结束\n", clientAddr) } // 启动代理服务器 func StartTcpProxy(listenAddr, targetAddr string) { defer tcp_wg.Done() listener, err := net.Listen("tcp", listenAddr) if err != nil { fmt.Printf("监听 %s 失败: %v\n", listenAddr, err) return } tcp_listenMux.Lock() tcp_listeners = append(tcp_listeners, listener) tcp_listenMux.Unlock() fmt.Printf("代理服务启动: %s -> %s\n", listenAddr, targetAddr) for { clientConn, err := listener.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { fmt.Printf("监听器 %s 已关闭\n", listenAddr) return } fmt.Printf("接受连接失败: %v\n", err) continue } go HandleTcpConnection(clientConn, targetAddr) } } // 监听系统信号,以便优雅地关闭 TCP 和 UDP 代理服务 func WaitForExit(cancel context.CancelFunc) { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) <-sigChan fmt.Println("\n收到终止信号,准备关闭代理服务器...") // 关闭所有 TCP 监听器 for _, listener := range tcp_listeners { if err := listener.Close(); err != nil { fmt.Printf("关闭 TCP 监听器 %s 失败: %v\n", listener.Addr(), err) } } // 触发 UDP 代理退出 cancel() // 等待所有 TCP 和 UDP 代理退出 tcp_wg.Wait() udp_wg.Wait() fmt.Println("所有代理已安全退出") os.Exit(0) }