diff --git a/bin/db_tunnel b/bin/db_tunnel index bd6550e..c55a4dc 100644 Binary files a/bin/db_tunnel and b/bin/db_tunnel differ diff --git a/src/main.go b/src/main.go index 27ce799..e830030 100644 --- a/src/main.go +++ b/src/main.go @@ -64,7 +64,7 @@ func sForward(serverAddr string, remoteAddr string, localConn net.Conn, config * } -func Tunnel(username string, password string, serverAddr string, remoteAddr string, localAddr string) { +func Tunnel(username string, password string, serverAddr string, remoteAddr string, localAddr string, connect_max int) { // 设置SSH配置 fmt.Printf("%s,服务器:%s; 用户/密码: %s; 远程:%s; 本地:%s\n", "设置SSH配置", serverAddr, username+"/"+password, remoteAddr, localAddr) config := &ssh.ClientConfig{ @@ -97,7 +97,7 @@ func Tunnel(username string, password string, serverAddr string, remoteAddr stri fmt.Printf("localListener.Accept failed: %v\n", err) os.Exit(1) } - if connect == 3 { + if connect >= connect_max { // 连接数达到1000时重启进程 fmt.Printf("连接数达到上线,重启进程!%d\n", connect) syscall.Kill(os.Getpid(), syscall.SIGKILL) } @@ -147,7 +147,7 @@ func ExecCommand(strCommand string) string { return string(out_bytes) } -func ssh_tunnel(host_ip string, host_port string, host_name string, remote_ip string, local_ip string, mysql_info string) { +func ssh_tunnel(host_ip string, host_port string, host_name string, remote_ip string, local_ip string, mysql_info string, connect_max int) { fmt.Println(host_ip, host_name, host_port) @@ -197,7 +197,7 @@ func ssh_tunnel(host_ip string, host_port string, host_name string, remote_ip st syscall.Kill(os.Getpid(), syscall.SIGKILL) } - Tunnel(host_name, strData, u.host_ip+":"+u.protocol_port, "0.0.0.0:"+remote_ip, "0.0.0.0:"+local_ip) + Tunnel(host_name, strData, u.host_ip+":"+u.protocol_port, "0.0.0.0:"+remote_ip, "0.0.0.0:"+local_ip, connect_max) } @@ -309,6 +309,7 @@ func main() { host_user := flag.String(HOST_USER, "", "DB服务器SSH USER用户") remote_port := flag.String(HOST_REMOTE, "", "DB服务器端口(如: 3306、1521 ...)") local_port := flag.String(HOST_LOCAL, "", "本地监听端口(或者堡垒机监听端口)") + connect_max := flag.Int("c", 999, "建立隧道最大次数(与DB服务器断开)") flag.Parse() if *daemon { @@ -327,27 +328,25 @@ func main() { fmt.Printf("[*] Service running in PID: %d PPID: %d\n", os.Getpid(), os.Getppid()) } - if 0 == 0 { - fmt.Println(*host_ip, *host_port, *host_user, *remote_port, *local_port) + fmt.Println(*host_ip, *host_port, *host_user, *remote_port, *local_port) - // - local_port_, _ := strconv.Atoi(*local_port) - r, _ := Check(local_port_) - //fmt.Println(r) - if r { + // + local_port_, _ := strconv.Atoi(*local_port) + r, _ := Check(local_port_) + //fmt.Println(r) + if r { + fmt.Println(*connect_max) + ssh_tunnel(*host_ip, *host_port, *host_user, *remote_port, *local_port, MYSQL_INFO, *connect_max) + } else { + fmt.Println(local_port_, "端口不可用!", r, "退出!") + fp, _ := os.OpenFile("./pid.log", os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) + log.SetOutput(fp) - ssh_tunnel(*host_ip, *host_port, *host_user, *remote_port, *local_port, MYSQL_INFO) - } else { - fmt.Println(local_port_, "端口不可用!", r, "退出!") - fp, _ := os.OpenFile("./pid.log", os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) - log.SetOutput(fp) + log.Printf("DoSomething running in PPID: %d\n", os.Getppid()) - log.Printf("DoSomething running in PPID: %d\n", os.Getppid()) - - syscall.Kill(os.Getppid(), syscall.SIGKILL) - time.Sleep(time.Second * 10) - os.Exit(0) - } + syscall.Kill(os.Getppid(), syscall.SIGKILL) + time.Sleep(time.Second * 10) + os.Exit(0) } time.Sleep(time.Second * 5)