4to6/tcp.go
2025-02-07 10:32:47 +08:00

138 lines
3.1 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"os/signal"
"sync"
"syscall"
"time"
)
var (
tcp_wg sync.WaitGroup
tcp_listeners []net.Listener
tcp_listenMux sync.Mutex // 确保 tcp_listeners 操作是线程安全的
)
// 启用 TCP KeepAlive
func setKeepAlive(conn net.Conn) {
if tcpConn, ok := conn.(*net.TCPConn); ok {
tcpConn.SetKeepAlive(true)
tcpConn.SetKeepAlivePeriod(30 * time.Second)
}
}
// 关闭写方向
func closeWrite(conn net.Conn) {
if tcpConn, ok := conn.(*net.TCPConn); ok {
tcpConn.CloseWrite()
}
}
// 处理客户端连接,并将数据转发到目标服务器
func HandleTcpConnection(clientConn net.Conn, targetAddr string) {
defer clientConn.Close()
// 连接目标服务器,支持 IPv6
serverConn, err := net.Dial("tcp", targetAddr)
if err != nil {
fmt.Printf("无法连接到 %s: %v\n", targetAddr, err)
return
}
defer serverConn.Close()
// 开启 TCP KeepAlive防止连接过早断开
setKeepAlive(clientConn)
setKeepAlive(serverConn)
clientAddr := clientConn.RemoteAddr().String()
fmt.Printf("连接 %s -> %s\n", clientAddr, targetAddr)
var cpWG sync.WaitGroup
cpWG.Add(2)
// 客户端 -> 目标服务器
go func() {
defer cpWG.Done()
_, err := io.Copy(serverConn, clientConn)
if err != nil && !errors.Is(err, io.EOF) {
fmt.Printf("数据转发 %s -> %s 失败: %v\n", clientAddr, targetAddr, err)
}
closeWrite(serverConn)
}()
// 目标服务器 -> 客户端
go func() {
defer cpWG.Done()
_, err := io.Copy(clientConn, serverConn)
if err != nil && !errors.Is(err, io.EOF) {
fmt.Printf("数据转发 %s <- %s 失败: %v\n", clientAddr, targetAddr, err)
}
closeWrite(clientConn)
}()
cpWG.Wait()
fmt.Printf("连接 %s 结束\n", clientAddr)
}
// 启动代理服务器
func StartTcpProxy(listenAddr, targetAddr string) {
defer tcp_wg.Done()
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
fmt.Printf("监听 %s 失败: %v\n", listenAddr, err)
return
}
tcp_listenMux.Lock()
tcp_listeners = append(tcp_listeners, listener)
tcp_listenMux.Unlock()
fmt.Printf("代理服务启动: %s -> %s\n", listenAddr, targetAddr)
for {
clientConn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
fmt.Printf("监听器 %s 已关闭\n", listenAddr)
return
}
fmt.Printf("接受连接失败: %v\n", err)
continue
}
go HandleTcpConnection(clientConn, targetAddr)
}
}
// 监听系统信号,以便优雅地关闭 TCP 和 UDP 代理服务
func WaitForExit(cancel context.CancelFunc) {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
<-sigChan
fmt.Println("\n收到终止信号准备关闭代理服务器...")
// 关闭所有 TCP 监听器
for _, listener := range tcp_listeners {
if err := listener.Close(); err != nil {
fmt.Printf("关闭 TCP 监听器 %s 失败: %v\n", listener.Addr(), err)
}
}
// 触发 UDP 代理退出
cancel()
// 等待所有 TCP 和 UDP 代理退出
tcp_wg.Wait()
udp_wg.Wait()
fmt.Println("所有代理已安全退出")
os.Exit(0)
}