114 lines
3.3 KiB
Go
114 lines
3.3 KiB
Go
package main
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
// 定义全局变量
|
||
var (
|
||
tcpWg sync.WaitGroup // 用于等待所有 TCP 代理任务完成
|
||
tcp_listeners []net.Listener // 存储所有监听器,便于统一关闭
|
||
tcp_listenMux sync.Mutex // 互斥锁,确保对 tcp_listeners 的操作是线程安全的
|
||
)
|
||
|
||
// 启用 TCP KeepAlive,防止连接空闲时被中间设备关闭
|
||
func setKeepAlive(conn net.Conn) {
|
||
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||
tcpConn.SetKeepAlive(true) // 启用 keep-alive
|
||
tcpConn.SetKeepAlivePeriod(30 * time.Second) // 设置保活周期为 30 秒
|
||
}
|
||
}
|
||
|
||
// 关闭连接的写方向,常用于半关闭通信
|
||
func closeWrite(conn net.Conn) {
|
||
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||
tcpConn.CloseWrite() // 关闭写通道,保持读通道开启
|
||
}
|
||
}
|
||
|
||
// 处理一个 TCP 客户端连接,并将其转发到目标服务器
|
||
func HandleTcpConnection(clientConn net.Conn, targetAddr string) {
|
||
defer clientConn.Close() // 函数结束时关闭客户端连接
|
||
|
||
// 尝试连接目标服务器地址(支持 IPv6)
|
||
serverConn, err := net.Dial("tcp", targetAddr)
|
||
if err != nil {
|
||
fmt.Printf("无法连接到 %s: %v\n", targetAddr, err)
|
||
return
|
||
}
|
||
defer serverConn.Close() // 函数结束时关闭服务器连接
|
||
|
||
// 对两个连接都开启 KeepAlive
|
||
setKeepAlive(clientConn)
|
||
setKeepAlive(serverConn)
|
||
|
||
clientAddr := clientConn.RemoteAddr().String()
|
||
fmt.Printf("连接 %s -> %s\n", clientAddr, targetAddr)
|
||
|
||
var cpWG sync.WaitGroup
|
||
cpWG.Add(2) // 我们会启动两个数据复制 goroutine
|
||
|
||
// 客户端 -> 目标服务器
|
||
go func() {
|
||
defer cpWG.Done()
|
||
_, err := io.Copy(serverConn, clientConn) // 将客户端数据复制到目标服务器
|
||
if err != nil && !errors.Is(err, io.EOF) {
|
||
fmt.Printf("数据转发 %s -> %s 失败: %v\n", clientAddr, targetAddr, err)
|
||
}
|
||
closeWrite(serverConn) // 通知目标服务器我们写完了
|
||
}()
|
||
|
||
// 目标服务器 -> 客户端
|
||
go func() {
|
||
defer cpWG.Done()
|
||
_, err := io.Copy(clientConn, serverConn) // 将目标服务器数据复制到客户端
|
||
if err != nil && !errors.Is(err, io.EOF) {
|
||
fmt.Printf("数据转发 %s <- %s 失败: %v\n", clientAddr, targetAddr, err)
|
||
}
|
||
closeWrite(clientConn) // 通知客户端我们写完了
|
||
}()
|
||
|
||
cpWG.Wait() // 等待两个数据转发完成
|
||
fmt.Printf("连接 %s 结束\n", clientAddr)
|
||
}
|
||
|
||
// 启动一个 TCP 代理服务,将监听地址的数据转发到目标地址
|
||
func StartTcpProxy(listenAddr, targetAddr string) {
|
||
defer tcpWg.Done() // 当前代理任务结束时计数器减 1
|
||
|
||
// 监听本地地址
|
||
listener, err := net.Listen("tcp", listenAddr)
|
||
if err != nil {
|
||
fmt.Printf("监听 %s 失败: %v\n", listenAddr, err)
|
||
return
|
||
}
|
||
|
||
// 添加到监听器列表,供后续关闭使用
|
||
tcp_listenMux.Lock()
|
||
tcp_listeners = append(tcp_listeners, listener)
|
||
tcp_listenMux.Unlock()
|
||
|
||
fmt.Printf("代理服务启动: %s -> %s\n", listenAddr, targetAddr)
|
||
|
||
for {
|
||
// 接受客户端连接
|
||
clientConn, err := listener.Accept()
|
||
if err != nil {
|
||
// 如果监听器已关闭,说明程序正在退出
|
||
if errors.Is(err, net.ErrClosed) {
|
||
fmt.Printf("监听器 %s 已关闭\n", listenAddr)
|
||
return
|
||
}
|
||
fmt.Printf("接受连接失败: %v\n", err)
|
||
continue
|
||
}
|
||
// 启动一个 goroutine 处理该连接
|
||
go HandleTcpConnection(clientConn, targetAddr)
|
||
}
|
||
}
|