138 lines
3.1 KiB
Go
138 lines
3.1 KiB
Go
|
package main
|
|||
|
|
|||
|
import (
|
|||
|
"context"
|
|||
|
"errors"
|
|||
|
"fmt"
|
|||
|
"io"
|
|||
|
"net"
|
|||
|
"os"
|
|||
|
"os/signal"
|
|||
|
"sync"
|
|||
|
"syscall"
|
|||
|
"time"
|
|||
|
)
|
|||
|
|
|||
|
var (
|
|||
|
tcp_wg sync.WaitGroup
|
|||
|
tcp_listeners []net.Listener
|
|||
|
tcp_listenMux sync.Mutex // 确保 tcp_listeners 操作是线程安全的
|
|||
|
)
|
|||
|
|
|||
|
// 启用 TCP KeepAlive
|
|||
|
func setKeepAlive(conn net.Conn) {
|
|||
|
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
|||
|
tcpConn.SetKeepAlive(true)
|
|||
|
tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// 关闭写方向
|
|||
|
func closeWrite(conn net.Conn) {
|
|||
|
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
|||
|
tcpConn.CloseWrite()
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// 处理客户端连接,并将数据转发到目标服务器
|
|||
|
func HandleTcpConnection(clientConn net.Conn, targetAddr string) {
|
|||
|
defer clientConn.Close()
|
|||
|
|
|||
|
// 连接目标服务器,支持 IPv6
|
|||
|
serverConn, err := net.Dial("tcp", targetAddr)
|
|||
|
if err != nil {
|
|||
|
fmt.Printf("无法连接到 %s: %v\n", targetAddr, err)
|
|||
|
return
|
|||
|
}
|
|||
|
defer serverConn.Close()
|
|||
|
|
|||
|
// 开启 TCP KeepAlive,防止连接过早断开
|
|||
|
setKeepAlive(clientConn)
|
|||
|
setKeepAlive(serverConn)
|
|||
|
|
|||
|
clientAddr := clientConn.RemoteAddr().String()
|
|||
|
fmt.Printf("连接 %s -> %s\n", clientAddr, targetAddr)
|
|||
|
|
|||
|
var cpWG sync.WaitGroup
|
|||
|
cpWG.Add(2)
|
|||
|
|
|||
|
// 客户端 -> 目标服务器
|
|||
|
go func() {
|
|||
|
defer cpWG.Done()
|
|||
|
_, err := io.Copy(serverConn, clientConn)
|
|||
|
if err != nil && !errors.Is(err, io.EOF) {
|
|||
|
fmt.Printf("数据转发 %s -> %s 失败: %v\n", clientAddr, targetAddr, err)
|
|||
|
}
|
|||
|
closeWrite(serverConn)
|
|||
|
}()
|
|||
|
|
|||
|
// 目标服务器 -> 客户端
|
|||
|
go func() {
|
|||
|
defer cpWG.Done()
|
|||
|
_, err := io.Copy(clientConn, serverConn)
|
|||
|
if err != nil && !errors.Is(err, io.EOF) {
|
|||
|
fmt.Printf("数据转发 %s <- %s 失败: %v\n", clientAddr, targetAddr, err)
|
|||
|
}
|
|||
|
closeWrite(clientConn)
|
|||
|
}()
|
|||
|
|
|||
|
cpWG.Wait()
|
|||
|
fmt.Printf("连接 %s 结束\n", clientAddr)
|
|||
|
}
|
|||
|
|
|||
|
// 启动代理服务器
|
|||
|
func StartTcpProxy(listenAddr, targetAddr string) {
|
|||
|
defer tcp_wg.Done()
|
|||
|
|
|||
|
listener, err := net.Listen("tcp", listenAddr)
|
|||
|
if err != nil {
|
|||
|
fmt.Printf("监听 %s 失败: %v\n", listenAddr, err)
|
|||
|
return
|
|||
|
}
|
|||
|
|
|||
|
tcp_listenMux.Lock()
|
|||
|
tcp_listeners = append(tcp_listeners, listener)
|
|||
|
tcp_listenMux.Unlock()
|
|||
|
|
|||
|
fmt.Printf("代理服务启动: %s -> %s\n", listenAddr, targetAddr)
|
|||
|
|
|||
|
for {
|
|||
|
clientConn, err := listener.Accept()
|
|||
|
if err != nil {
|
|||
|
if errors.Is(err, net.ErrClosed) {
|
|||
|
fmt.Printf("监听器 %s 已关闭\n", listenAddr)
|
|||
|
return
|
|||
|
}
|
|||
|
fmt.Printf("接受连接失败: %v\n", err)
|
|||
|
continue
|
|||
|
}
|
|||
|
go HandleTcpConnection(clientConn, targetAddr)
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// 监听系统信号,以便优雅地关闭 TCP 和 UDP 代理服务
|
|||
|
func WaitForExit(cancel context.CancelFunc) {
|
|||
|
sigChan := make(chan os.Signal, 1)
|
|||
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
|||
|
|
|||
|
<-sigChan
|
|||
|
fmt.Println("\n收到终止信号,准备关闭代理服务器...")
|
|||
|
|
|||
|
// 关闭所有 TCP 监听器
|
|||
|
for _, listener := range tcp_listeners {
|
|||
|
if err := listener.Close(); err != nil {
|
|||
|
fmt.Printf("关闭 TCP 监听器 %s 失败: %v\n", listener.Addr(), err)
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
// 触发 UDP 代理退出
|
|||
|
cancel()
|
|||
|
|
|||
|
// 等待所有 TCP 和 UDP 代理退出
|
|||
|
tcp_wg.Wait()
|
|||
|
udp_wg.Wait()
|
|||
|
|
|||
|
fmt.Println("所有代理已安全退出")
|
|||
|
os.Exit(0)
|
|||
|
}
|