138 lines
3.1 KiB
Go
138 lines
3.1 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net"
|
||
"os"
|
||
"os/signal"
|
||
"sync"
|
||
"syscall"
|
||
"time"
|
||
)
|
||
|
||
var (
|
||
tcp_wg sync.WaitGroup
|
||
tcp_listeners []net.Listener
|
||
tcp_listenMux sync.Mutex // 确保 tcp_listeners 操作是线程安全的
|
||
)
|
||
|
||
// 启用 TCP KeepAlive
|
||
func setKeepAlive(conn net.Conn) {
|
||
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||
tcpConn.SetKeepAlive(true)
|
||
tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||
}
|
||
}
|
||
|
||
// 关闭写方向
|
||
func closeWrite(conn net.Conn) {
|
||
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||
tcpConn.CloseWrite()
|
||
}
|
||
}
|
||
|
||
// 处理客户端连接,并将数据转发到目标服务器
|
||
func HandleTcpConnection(clientConn net.Conn, targetAddr string) {
|
||
defer clientConn.Close()
|
||
|
||
// 连接目标服务器,支持 IPv6
|
||
serverConn, err := net.Dial("tcp", targetAddr)
|
||
if err != nil {
|
||
fmt.Printf("无法连接到 %s: %v\n", targetAddr, err)
|
||
return
|
||
}
|
||
defer serverConn.Close()
|
||
|
||
// 开启 TCP KeepAlive,防止连接过早断开
|
||
setKeepAlive(clientConn)
|
||
setKeepAlive(serverConn)
|
||
|
||
clientAddr := clientConn.RemoteAddr().String()
|
||
fmt.Printf("连接 %s -> %s\n", clientAddr, targetAddr)
|
||
|
||
var cpWG sync.WaitGroup
|
||
cpWG.Add(2)
|
||
|
||
// 客户端 -> 目标服务器
|
||
go func() {
|
||
defer cpWG.Done()
|
||
_, err := io.Copy(serverConn, clientConn)
|
||
if err != nil && !errors.Is(err, io.EOF) {
|
||
fmt.Printf("数据转发 %s -> %s 失败: %v\n", clientAddr, targetAddr, err)
|
||
}
|
||
closeWrite(serverConn)
|
||
}()
|
||
|
||
// 目标服务器 -> 客户端
|
||
go func() {
|
||
defer cpWG.Done()
|
||
_, err := io.Copy(clientConn, serverConn)
|
||
if err != nil && !errors.Is(err, io.EOF) {
|
||
fmt.Printf("数据转发 %s <- %s 失败: %v\n", clientAddr, targetAddr, err)
|
||
}
|
||
closeWrite(clientConn)
|
||
}()
|
||
|
||
cpWG.Wait()
|
||
fmt.Printf("连接 %s 结束\n", clientAddr)
|
||
}
|
||
|
||
// 启动代理服务器
|
||
func StartTcpProxy(listenAddr, targetAddr string) {
|
||
defer tcp_wg.Done()
|
||
|
||
listener, err := net.Listen("tcp", listenAddr)
|
||
if err != nil {
|
||
fmt.Printf("监听 %s 失败: %v\n", listenAddr, err)
|
||
return
|
||
}
|
||
|
||
tcp_listenMux.Lock()
|
||
tcp_listeners = append(tcp_listeners, listener)
|
||
tcp_listenMux.Unlock()
|
||
|
||
fmt.Printf("代理服务启动: %s -> %s\n", listenAddr, targetAddr)
|
||
|
||
for {
|
||
clientConn, err := listener.Accept()
|
||
if err != nil {
|
||
if errors.Is(err, net.ErrClosed) {
|
||
fmt.Printf("监听器 %s 已关闭\n", listenAddr)
|
||
return
|
||
}
|
||
fmt.Printf("接受连接失败: %v\n", err)
|
||
continue
|
||
}
|
||
go HandleTcpConnection(clientConn, targetAddr)
|
||
}
|
||
}
|
||
|
||
// 监听系统信号,以便优雅地关闭 TCP 和 UDP 代理服务
|
||
func WaitForExit(cancel context.CancelFunc) {
|
||
sigChan := make(chan os.Signal, 1)
|
||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||
|
||
<-sigChan
|
||
fmt.Println("\n收到终止信号,准备关闭代理服务器...")
|
||
|
||
// 关闭所有 TCP 监听器
|
||
for _, listener := range tcp_listeners {
|
||
if err := listener.Close(); err != nil {
|
||
fmt.Printf("关闭 TCP 监听器 %s 失败: %v\n", listener.Addr(), err)
|
||
}
|
||
}
|
||
|
||
// 触发 UDP 代理退出
|
||
cancel()
|
||
|
||
// 等待所有 TCP 和 UDP 代理退出
|
||
tcp_wg.Wait()
|
||
udp_wg.Wait()
|
||
|
||
fmt.Println("所有代理已安全退出")
|
||
os.Exit(0)
|
||
}
|