4to6/tcp.go

138 lines
3.1 KiB
Go
Raw Normal View History

2025-02-07 10:32:47 +08:00
package main
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"os/signal"
"sync"
"syscall"
"time"
)
var (
tcp_wg sync.WaitGroup
tcp_listeners []net.Listener
tcp_listenMux sync.Mutex // 确保 tcp_listeners 操作是线程安全的
)
// 启用 TCP KeepAlive
func setKeepAlive(conn net.Conn) {
if tcpConn, ok := conn.(*net.TCPConn); ok {
tcpConn.SetKeepAlive(true)
tcpConn.SetKeepAlivePeriod(30 * time.Second)
}
}
// 关闭写方向
func closeWrite(conn net.Conn) {
if tcpConn, ok := conn.(*net.TCPConn); ok {
tcpConn.CloseWrite()
}
}
// 处理客户端连接,并将数据转发到目标服务器
func HandleTcpConnection(clientConn net.Conn, targetAddr string) {
defer clientConn.Close()
// 连接目标服务器,支持 IPv6
serverConn, err := net.Dial("tcp", targetAddr)
if err != nil {
fmt.Printf("无法连接到 %s: %v\n", targetAddr, err)
return
}
defer serverConn.Close()
// 开启 TCP KeepAlive防止连接过早断开
setKeepAlive(clientConn)
setKeepAlive(serverConn)
clientAddr := clientConn.RemoteAddr().String()
fmt.Printf("连接 %s -> %s\n", clientAddr, targetAddr)
var cpWG sync.WaitGroup
cpWG.Add(2)
// 客户端 -> 目标服务器
go func() {
defer cpWG.Done()
_, err := io.Copy(serverConn, clientConn)
if err != nil && !errors.Is(err, io.EOF) {
fmt.Printf("数据转发 %s -> %s 失败: %v\n", clientAddr, targetAddr, err)
}
closeWrite(serverConn)
}()
// 目标服务器 -> 客户端
go func() {
defer cpWG.Done()
_, err := io.Copy(clientConn, serverConn)
if err != nil && !errors.Is(err, io.EOF) {
fmt.Printf("数据转发 %s <- %s 失败: %v\n", clientAddr, targetAddr, err)
}
closeWrite(clientConn)
}()
cpWG.Wait()
fmt.Printf("连接 %s 结束\n", clientAddr)
}
// 启动代理服务器
func StartTcpProxy(listenAddr, targetAddr string) {
defer tcp_wg.Done()
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
fmt.Printf("监听 %s 失败: %v\n", listenAddr, err)
return
}
tcp_listenMux.Lock()
tcp_listeners = append(tcp_listeners, listener)
tcp_listenMux.Unlock()
fmt.Printf("代理服务启动: %s -> %s\n", listenAddr, targetAddr)
for {
clientConn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
fmt.Printf("监听器 %s 已关闭\n", listenAddr)
return
}
fmt.Printf("接受连接失败: %v\n", err)
continue
}
go HandleTcpConnection(clientConn, targetAddr)
}
}
// 监听系统信号,以便优雅地关闭 TCP 和 UDP 代理服务
func WaitForExit(cancel context.CancelFunc) {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
<-sigChan
fmt.Println("\n收到终止信号准备关闭代理服务器...")
// 关闭所有 TCP 监听器
for _, listener := range tcp_listeners {
if err := listener.Close(); err != nil {
fmt.Printf("关闭 TCP 监听器 %s 失败: %v\n", listener.Addr(), err)
}
}
// 触发 UDP 代理退出
cancel()
// 等待所有 TCP 和 UDP 代理退出
tcp_wg.Wait()
udp_wg.Wait()
fmt.Println("所有代理已安全退出")
os.Exit(0)
}