diff --git a/README.md b/README.md index 79e039c..f25f610 100644 --- a/README.md +++ b/README.md @@ -1,39 +1,44 @@ # db_tunnel + 结合TELEPORT堡垒机使用的DB SSH数据库隧道工具 数据库配置文件请放到/etc或者与bin文件同一路径 - + # build + git clone https://git.aixiao.me/aixiao/db_tunnel.git cd db_tunnel bash bash build.sh(可能会提示安装包) # test - root@NIUYULING:/mnt/c/Users/niuyuling/Desktop/db_tunnel# ./db_tunnel -h - flag needs an argument: -h - Usage of ./db_tunnel: - -d run in daemon - -f run forever - -h string - DB服务器SSH IP地址 - -l string - 本地监听端口(或者堡垒机监听端口) - -p string - DB服务器SSH PORT - -r string - DB服务器端口(如: 3306、1521 ...) - -u string - DB服务器SSH USER用户 - root@NIUYULING:/mnt/c/Users/niuyuling/Desktop/db_tunnel# + + root@NIUYULING:/mnt/c/Users/niuyuling/Desktop/tunnel# ./bin/tunnel -h + flag needs an argument: -h + Usage of ./bin/tunnel: + -c int + 建立隧道最大次数(与DB服务器断开) (default 999) + -d run in daemon + -f run forever + -h string + DB服务器SSH IP地址 + -l string + 本地监听端口(或者堡垒机监听端口) + -p string + DB服务器SSH PORT + -r string + DB服务器端口(如: 3306、1521 ...) + -u string + DB服务器SSH USER用户 + root@NIUYULING:/mnt/c/Users/niuyuling/Desktop/tunnel# - root@NIUYULING:/mnt/c/Users/niuyuling/Desktop/db_tunnel/bin# ./db_tunnel -d -f -h 39.104.27.21 -l 3308 -r 3306 -p 22 -u app - [*] Daemon running in PID: 12918 PPID: 67 - root@NIUYULING:/mnt/c/Users/niuyuling/Desktop/db_tunnel/bin# [*] Forever running in PID: 12922 PPID: 1 - [*] Service running in PID: 12927 PPID: 12922 - 39.104.27.21 22 app 3306 3308 - 39.104.27.21 app 22 - 数据库连接成功 - 39.104.27.21 22 app 6+sNDSN4QL7VCLSr+Vt/fNSAX1XsTUwf6fYRTf3pGS8rYBn8Ik - I9EKbb - 设置SSH配置,服务器:39.104.27.21:22; 用户/密码: app/I9EKbbH; 远程:0.0.0.0:3306; 本地:0.0.0.0:3308 - root@NIUYULING:/mnt/c/Users/niuyuling/Desktop/db_tunnel/bin# \ No newline at end of file + root@NIUYULING:/mnt/c/Users/niuyuling/Desktop/tunnel/bin# ./tunnel -d -f -c 500 -h 47.108.253.59 -p 22 -u app -r 3306 -l 3306 + [*] Daemon running in PID: 28640 PPID: 24 + root@NIUYULING:/mnt/c/Users/niuyuling/Desktop/tunnel/bin# [*] Forever running in PID: 28645 PPID: 1 + [*] Service running in PID: 28649 PPID: 28645 + 数据库连接成功 + 47.108.253.59 22 app hyj4Sxq1UQUJ0RF95MQGf1oNUSdZ/rlIYNXu9DPSCLt3sNsXCpKmRtm85tFCRPN + 0Hv5m&uRwYKFZdu + 2023/03/15 12:15:55.161984 PID: 28649 PPID: 28645 本地端口: 3306 当前连接数:1 + 2023/03/15 12:15:55.162630 listening for new connections... + + root@NIUYULING:/mnt/c/Users/niuyuling/Desktop/tunnel/bin# diff --git a/bin/cmdline b/bin/cmdline new file mode 100644 index 0000000..fa4dcf0 Binary files /dev/null and b/bin/cmdline differ diff --git a/bin/db_tunnel b/bin/db_tunnel deleted file mode 100644 index d427f0a..0000000 Binary files a/bin/db_tunnel and /dev/null differ diff --git a/bin/kill.sh b/bin/kill.sh new file mode 100644 index 0000000..d58806a --- /dev/null +++ b/bin/kill.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# +# +# + + +init() +{ + DB_PORT=$1 +} + + +main() +{ + PROCESS=`ps -axww | grep tunnel | grep ${DB_PORT} | grep "\-f" | awk '{print $1}'` + if test "${PROCESS}" = ""; then + echo "进程不存在!!!" + exit -1 + fi + + echo "重启参数:" `./cmdline $PROCESS | sed "s|-f|-d -f|g"` + + + echo "关闭进程: $PROCESS, 关闭端口: ${DB_PORT}" + kill `ps -ax | grep tunnel | grep ${DB_PORT} | awk '{print $1}' | xargs ` +} + + +if test "$1" = ""; then + echo "参数错误!" + exit -1 +fi + +init $1 +main + diff --git a/bin/tp_decrypt b/bin/tp_decrypt index ff9ac02..5c4288c 100644 Binary files a/bin/tp_decrypt and b/bin/tp_decrypt differ diff --git a/bin/tunnel b/bin/tunnel new file mode 100644 index 0000000..6a9ae8c Binary files /dev/null and b/bin/tunnel differ diff --git a/build.sh b/build.sh index 736289d..39275e2 100644 --- a/build.sh +++ b/build.sh @@ -1,2 +1,16 @@ +#!/bin/bash +# +# Build Project +# + + +# 构建主程序 +GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -ldflags '-w -s' -o bin/tunnel src/*.go && upx -9 bin/tunnel + +# 处理进程参数 +gcc -Wall -Os -g cmdline/cmdline.c -o bin/cmdline -static + +# 解密程序 +gcc -Wall tp_decrypt/tp_decrypt.py.c -o bin/tp_decrypt -static + -CGO_ENABLED=0 go build -ldflags '-w -s' -o bin/db_tunnel src/main.go && upx -9 bin/db_tunnel \ No newline at end of file diff --git a/cmdline/cmdline.c b/cmdline/cmdline.c new file mode 100644 index 0000000..0d9290d --- /dev/null +++ b/cmdline/cmdline.c @@ -0,0 +1,47 @@ +#include +#include + +#define BUFFER 270 + +char *cmdline(char *pid) +{ + FILE *fp; + char path[BUFFER]; + char temp[BUFFER * 100]; + char *p; + unsigned char ch; + int i = 0; + + memset(path, 0, BUFFER); + memset(temp, 0, BUFFER * 100); + snprintf(path, BUFFER, "/proc/%s/cmdline", pid); + + if ((fp = fopen(path, "r")) == NULL) { + perror("fopen"); + return 0; + } + + while (!feof(fp)) { + ch = fgetc(fp); + if (ch == 0) { + temp[i] = ' '; + } else { + temp[i] = ch; + } + i++; + } + + p = strrchr(temp, ' '); + temp[strlen(temp) - strlen(p)] = '\0'; + printf("%s\n", temp); + + fclose(fp); + return 0; +} + +int main(int argc, char *argv[], char **envlp) +{ + cmdline(argv[1]); + + return 0; +} diff --git a/go.mod b/go.mod index 8d02b89..a5c7e63 100644 --- a/go.mod +++ b/go.mod @@ -1,14 +1,10 @@ -module db_tunnel +module tunnel go 1.19 -require github.com/docker/docker v23.0.0+incompatible - require ( - gitee.com/dtapps/go-ssh-tunnel v1.0.4 // indirect - github.com/CodyGuo/godaemon v0.0.0-20200413142854-c36b39fdd071 // indirect github.com/go-sql-driver/mysql v1.7.0 // indirect - golang.org/x/crypto v0.3.0 // indirect - golang.org/x/sys v0.2.0 // indirect + golang.org/x/crypto v0.6.0 // indirect + golang.org/x/sys v0.5.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect ) diff --git a/go.sum b/go.sum index 440f695..c42b937 100644 --- a/go.sum +++ b/go.sum @@ -1,23 +1,20 @@ -gitee.com/dtapps/go-ssh-tunnel v1.0.4 h1:jQfS6WeoSWmMz+hSBxopUcgb0xi3fssMjHUYpF1BnYY= -gitee.com/dtapps/go-ssh-tunnel v1.0.4/go.mod h1:VzlCBrBerVEL1fgikHzApbRU5/Ru+KrzktRjnDfxu1M= -github.com/CodyGuo/godaemon v0.0.0-20200413142854-c36b39fdd071 h1:GFI7Rs86D4qip+tBvMcv0ux5kHbngC0rNWfgpTeVoAQ= -github.com/CodyGuo/godaemon v0.0.0-20200413142854-c36b39fdd071/go.mod h1:VBC/JvjvRkcgE7wMjDJs7Y94Ta6KSpCWDquUKW+WbJo= -github.com/docker/docker v23.0.0+incompatible h1:L6c28tNyqZ4/ub9AZC9d5QUuunoHHfEH4/Ue+h/E5nE= -github.com/docker/docker v23.0.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= -golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.3.0 h1:a06MkbcxBrEFc0w0QIZWXrH/9cCX6KJyWbBOIwAn+7A= -golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= -golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191122220453-ac88ee75c92c h1:/nJuwDLoL/zrqY6gf57vxC+Pi+pZ8bfhpPkicO5H7W4= +golang.org/x/crypto v0.0.0-20191122220453-ac88ee75c92c/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= +golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= +golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= +golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/src/db_tunnel b/src/db_tunnel new file mode 100644 index 0000000..5a1908a Binary files /dev/null and b/src/db_tunnel differ diff --git a/src/endpoint.go b/src/endpoint.go new file mode 100644 index 0000000..f873b3d --- /dev/null +++ b/src/endpoint.go @@ -0,0 +1,35 @@ +package main + +import ( + "fmt" + "strconv" + "strings" +) + +type Endpoint struct { + Host string + Port int + User string +} + +func NewEndpoint(s string) *Endpoint { + endpoint := &Endpoint{ + Host: s, + } + + if parts := strings.Split(endpoint.Host, "@"); len(parts) > 1 { + endpoint.User = parts[0] + endpoint.Host = parts[1] + } + + if parts := strings.Split(endpoint.Host, ":"); len(parts) > 1 { + endpoint.Host = parts[0] + endpoint.Port, _ = strconv.Atoi(parts[1]) + } + + return endpoint +} + +func (endpoint *Endpoint) String() string { + return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port) +} diff --git a/src/main.go b/src/main.go index bee52ae..9495274 100644 --- a/src/main.go +++ b/src/main.go @@ -1,356 +1,261 @@ -package main - -import ( - "database/sql" - "flag" - "fmt" - "io" - "io/ioutil" - "log" - "net" - "os" - "os/exec" - "path/filepath" - "strconv" - "strings" - "syscall" - "time" - - _ "github.com/go-sql-driver/mysql" - "golang.org/x/crypto/ssh" - "gopkg.in/ini.v1" -) - -// 转发 -func sForward(serverAddr string, remoteAddr string, localConn net.Conn, config *ssh.ClientConfig) { - // 设置sshClientConn - sshClientConn, err := ssh.Dial("tcp", serverAddr, config) - if err != nil { - fmt.Printf("ssh.Dial failed: %s", err) - os.Exit(1) - } - - // 设置Connection - sshConn, _ := sshClientConn.Dial("tcp", remoteAddr) - /* - // 将localConn.Reader复制到sshConn.Writer - go func() { - _, err = io.Copy(sshConn, localConn) - if err != nil { - fmt.Printf("io.Copy failed: %v", err) - os.Exit(1) - } - }() - - // 将sshConn.Reader复制到localConn.Writer - go func() { - _, err = io.Copy(localConn, sshConn) - if err != nil { - fmt.Printf("io.Copy failed: %v", err) - os.Exit(1) - } - }() - */ - - copyConn := func(writer, reader net.Conn) { - _, err := io.Copy(writer, reader) - if err != nil { - fmt.Printf(" io.Copy error: %s", err) - } - } - - go copyConn(localConn, sshConn) - go copyConn(sshConn, localConn) - -} - -func Tunnel(username string, password string, serverAddr string, remoteAddr string, localAddr string, connect_max int) { - // 设置SSH配置 - fmt.Printf("%s,服务器:%s; 用户/密码: %s; 远程:%s; 本地:%s\n", "设置SSH配置", serverAddr, username+"/"+password, remoteAddr, localAddr) - config := &ssh.ClientConfig{ - User: username, - Auth: []ssh.AuthMethod{ - ssh.Password(password), - }, - Timeout: 7 * time.Second, - HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { - return nil - }, - } - - // 设置本地监听器 - localListener, err := net.Listen("tcp", localAddr) - if err != nil { - fmt.Printf("net.Listen failed: %v\n", err) - os.Exit(1) - } - - var connect int // 连接数 - connect = 0 - - for { - - connect++ - fmt.Println("当前连接数:", connect) - - // 设置本地 - localConn, err := localListener.Accept() - if err != nil { - fmt.Printf("localListener.Accept failed: %v\n", err) - os.Exit(1) - } - if connect >= connect_max { // 连接数达到1000时重启进程 - fmt.Printf("连接数达到上线,重启进程!%d\n", connect) - syscall.Kill(os.Getpid(), syscall.SIGKILL) - } - - sForward(serverAddr, remoteAddr, localConn, config) - } -} - -type user struct { - id string - host_id string - host_ip string - router_ip string - router_port string - protocol_type string - protocol_port string - state string - acc_auth_id string - auth_type string - username string - username_prompt string - password_prompt string - password string - pri_key string - creator_id string - create_time string - last_secret string -} - -func ExecCommand(strCommand string) string { - cmd := exec.Command("/bin/bash", "-c", strCommand) - - stdout, _ := cmd.StdoutPipe() - if err := cmd.Start(); err != nil { - fmt.Println("Execute failed when Start: " + err.Error()) - return "" - } - - out_bytes, _ := ioutil.ReadAll(stdout) - stdout.Close() - - if err := cmd.Wait(); err != nil { - fmt.Println("Execute failed when Wait: " + err.Error()) - return "" - } - - return string(out_bytes) -} - -func ssh_tunnel(host_ip string, host_port string, host_name string, remote_ip string, local_ip string, mysql_info string, connect_max int) { - - fmt.Println(host_ip, host_name, host_port) - - // "用户名:密码@[连接方式](主机名:端口号)/数据库名" - var db *sql.DB - db, _ = sql.Open("mysql", mysql_info) // 设置连接数据库的参数 - defer db.Close() // 关闭数据库 - err := db.Ping() // 连接数据库 - if err != nil { - fmt.Println("数据库连接失败") - // 退出 - syscall.Kill(os.Getppid(), syscall.SIGKILL) - syscall.Kill(os.Getpid(), syscall.SIGKILL) - } else { - fmt.Println("数据库连接成功") - - } - - sqlStr := "select host_ip, username, protocol_port, password from tp_acc where host_ip = ? and username = ? and protocol_port = ?" - rows, err := db.Query(sqlStr, host_ip, host_name, host_port) - if err != nil { - fmt.Printf("query failed, err:%v\n", err) - return - } - // 非常重要关闭rows释放持有的数据库链接 - defer rows.Close() - var u user - // 循环读取结果集中的数据 - for rows.Next() { - err := rows.Scan(&u.host_ip, &u.username, &u.protocol_port, &u.password) - if err != nil { - fmt.Printf("scan failed, err:%v\n", err) - return - } - fmt.Printf("%s %s %s %s\n", u.host_ip, u.protocol_port, u.username, u.password) - } - - // 解密密钥 - command := "./tp_decrypt " - command = command + u.password - strData := ExecCommand(command) - strData = strings.Replace(strData, "\n", "", -1) - fmt.Println(strData) - if strData == "" { - // 判断密码是否解密成功,不成功退出 - syscall.Kill(os.Getppid(), syscall.SIGKILL) - syscall.Kill(os.Getpid(), syscall.SIGKILL) - } - - Tunnel(host_name, strData, u.host_ip+":"+u.protocol_port, "0.0.0.0:"+remote_ip, "0.0.0.0:"+local_ip, connect_max) - -} - -func StripSlice(slice []string, element string) []string { - for i := 0; i < len(slice); { - if slice[i] == element && i != len(slice)-1 { - slice = append(slice[:i], slice[i+1:]...) - } else if slice[i] == element && i == len(slice)-1 { - slice = slice[:i] - } else { - i++ - } - } - return slice -} - -func SubProcess(args []string) *exec.Cmd { - cmd := exec.Command(args[0], args[1:]...) - cmd.Stdin = os.Stdin - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - err := cmd.Start() - if err != nil { - fmt.Fprintf(os.Stderr, "[-] Error: %s\n", err) - } - return cmd -} - -// Check if a port is available -func Check(port int) (status bool, err error) { - - // Concatenate a colon and the port - host := ":" + strconv.Itoa(port) - - // Try to create a server with the port - server, err := net.Listen("tcp", host) - - // if it fails then the port is likely taken - if err != nil { - return false, err - } - - // close the server - server.Close() - - // we successfully used and closed the port - // so it's now available to be used again - return true, nil - -} - -func GetCurrentDirectory() string { - dir, err := filepath.Abs(filepath.Dir(os.Args[0])) // 返回绝对路径 filepath.Dir(os.Args[0])去除最后一个元素的路径 - if err != nil { - log.Fatal(err) - } - return strings.Replace(dir, "\\", "/", -1) // 将\替换成/ -} - -func PathExists(path string) (bool, error) { - _, err := os.Stat(path) - if err == nil { - return true, nil - } - - if os.IsNotExist(err) { - return false, nil - } - - return false, err -} - -const ( - DAEMON = "d" - FOREVER = "f" - HOST_IP = "h" - HOST_PORT = "p" - HOST_USER = "u" - HOST_REMOTE = "r" - HOST_LOCAL = "l" -) - -func main() { - - // 判断配置文件是否存在 - INIFILE := GetCurrentDirectory() + "/" + "tunnel.ini" - b, _ := PathExists(INIFILE) - - if !b { - INIFILE = "/etc/tunnel.ini" - } - - // 读取配置文件 - cfg, inierr := ini.Load(INIFILE) - if inierr != nil { - fmt.Printf("Fail to read file: %v", inierr) - os.Exit(1) - } - // 读取数据库连接信息 - MYSQL_INFO := cfg.Section("global").Key("MYSQL_INFO").String() - if MYSQL_INFO == "" { - os.Exit(0) - } - - daemon := flag.Bool(DAEMON, false, "run in daemon") - forever := flag.Bool(FOREVER, false, "run forever") - host_ip := flag.String(HOST_IP, "", "DB服务器SSH IP地址") - host_port := flag.String(HOST_PORT, "", "DB服务器SSH PORT") - host_user := flag.String(HOST_USER, "", "DB服务器SSH USER用户") - remote_port := flag.String(HOST_REMOTE, "", "DB服务器端口(如: 3306、1521 ...)") - local_port := flag.String(HOST_LOCAL, "", "本地监听端口(或者堡垒机监听端口)") - connect_max := flag.Int("c", 999, "建立隧道最大次数(与DB服务器断开)") - flag.Parse() - - if *daemon { - SubProcess(StripSlice(os.Args, "-"+DAEMON)) - fmt.Printf("[*] Daemon running in PID: %d PPID: %d\n", os.Getpid(), os.Getppid()) - os.Exit(0) - } else if *forever { - for { - cmd := SubProcess(StripSlice(os.Args, "-"+FOREVER)) - fmt.Printf("[*] Forever running in PID: %d PPID: %d\n", os.Getpid(), os.Getppid()) - time.Sleep(time.Second * 5) - cmd.Wait() - } - //os.Exit(0) - } else { - fmt.Printf("[*] Service running in PID: %d PPID: %d\n", os.Getpid(), os.Getppid()) - } - - fmt.Println(*host_ip, *host_port, *host_user, *remote_port, *local_port) - - // - local_port_, _ := strconv.Atoi(*local_port) - r, _ := Check(local_port_) - //fmt.Println(r) - if r { - fmt.Println(*connect_max) - ssh_tunnel(*host_ip, *host_port, *host_user, *remote_port, *local_port, MYSQL_INFO, *connect_max) - } else { - fmt.Println(local_port_, "端口不可用!", r, "退出!") - fp, _ := os.OpenFile("./pid.log", os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) - log.SetOutput(fp) - - log.Printf("DoSomething running in PPID: %d\n", os.Getppid()) - - syscall.Kill(os.Getppid(), syscall.SIGKILL) - time.Sleep(time.Second * 10) - os.Exit(0) - } - - time.Sleep(time.Second * 5) - os.Exit(0) -} +package main + +import ( + "database/sql" + "flag" + "fmt" + "io/ioutil" + "log" + "net" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "syscall" + "time" + + _ "github.com/go-sql-driver/mysql" + "golang.org/x/crypto/ssh" + "gopkg.in/ini.v1" +) + +const ( + DAEMON = "d" + FOREVER = "f" + HOST_IP = "h" + HOST_PORT = "p" + HOST_USER = "u" + HOST_REMOTE = "r" + HOST_LOCAL = "l" +) + +func StripSlice(slice []string, element string) []string { + for i := 0; i < len(slice); { + if slice[i] == element && i != len(slice)-1 { + slice = append(slice[:i], slice[i+1:]...) + } else if slice[i] == element && i == len(slice)-1 { + slice = slice[:i] + } else { + i++ + } + } + return slice +} + +func SubProcess(args []string) *exec.Cmd { + cmd := exec.Command(args[0], args[1:]...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Start() + if err != nil { + fmt.Fprintf(os.Stderr, "[-] Error: %s\n", err) + } + return cmd +} + +// Check if a port is available +func Check(port int) (status bool, err error) { + + // Concatenate a colon and the port + host := ":" + strconv.Itoa(port) + + // Try to create a server with the port + server, err := net.Listen("tcp", host) + + // if it fails then the port is likely taken + if err != nil { + return false, err + } + + // close the server + server.Close() + + // we successfully used and closed the port + // so it's now available to be used again + return true, nil + +} + +func GetCurrentDirectory() string { + dir, err := filepath.Abs(filepath.Dir(os.Args[0])) // 返回绝对路径 filepath.Dir(os.Args[0])去除最后一个元素的路径 + if err != nil { + log.Fatal(err) + } + return strings.Replace(dir, "\\", "/", -1) // 将\替换成/ +} + +func PathExists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + + if os.IsNotExist(err) { + return false, nil + } + + return false, err +} + +type user struct { + id string + host_id string + host_ip string + router_ip string + router_port string + protocol_type string + protocol_port string + state string + acc_auth_id string + auth_type string + username string + username_prompt string + password_prompt string + password string + password_decrypt string + pri_key string + creator_id string + create_time string + last_secret string +} + +func ExecCommand(strCommand string) string { + cmd := exec.Command("/bin/bash", "-c", strCommand) + + stdout, _ := cmd.StdoutPipe() + if err := cmd.Start(); err != nil { + fmt.Println("Execute failed when Start: " + err.Error()) + return "" + } + + out_bytes, _ := ioutil.ReadAll(stdout) + stdout.Close() + + if err := cmd.Wait(); err != nil { + fmt.Println("Execute failed when Wait: " + err.Error()) + return "" + } + + return string(out_bytes) +} + +// 读取mysql数据库内信息 +func mysql_show(connect_info string, host_ip string, host_user string, host_port string) user { + + var db *sql.DB + db, _ = sql.Open("mysql", connect_info) // 设置连接数据库的参数 + defer db.Close() // 关闭数据库 + err := db.Ping() // 连接数据库 + if err != nil { + // 退出 + log.Printf("数据库连接失败") + syscall.Kill(os.Getppid(), syscall.SIGKILL) + syscall.Kill(os.Getpid(), syscall.SIGKILL) + } else { + log.Printf("数据库连接成功") + + } + sqlStr := "select host_ip, username, protocol_port, password from tp_acc where host_ip = ? and username = ? and protocol_port = ?" + rows, err := db.Query(sqlStr, host_ip, host_user, host_port) + if err != nil { + log.Printf("query failed, err:%v\n", err) + } + // 非常重要关闭rows释放持有的数据库链接 + defer rows.Close() + var u user + // 循环读取结果集中的数据 + for rows.Next() { + err := rows.Scan(&u.host_ip, &u.username, &u.protocol_port, &u.password) + if err != nil { + log.Printf("scan failed, err:%v\n", err) + } + log.Printf("%s %s %s %s\n", u.host_ip, u.protocol_port, u.username, u.password) + } + + // 解密密钥 + command := "./tp_decrypt " + command = command + u.password + strData := ExecCommand(command) + strData = strings.Replace(strData, "\n", "", -1) + fmt.Println(strData) + if strData == "" { + // 判断密码是否解密成功,不成功退出 + syscall.Kill(os.Getppid(), syscall.SIGKILL) + syscall.Kill(os.Getpid(), syscall.SIGKILL) + } + u.password_decrypt = strData + + return u +} + +func main() { + // 参数处理 + daemon := flag.Bool(DAEMON, false, "run in daemon") + forever := flag.Bool(FOREVER, false, "run forever") + host_ip := flag.String(HOST_IP, "", "DB服务器SSH IP地址") + host_port := flag.String(HOST_PORT, "", "DB服务器SSH PORT") + host_user := flag.String(HOST_USER, "", "DB服务器SSH USER用户") + remote_port := flag.String(HOST_REMOTE, "", "DB服务器端口(如: 3306、1521 ...)") + local_port := flag.String(HOST_LOCAL, "", "本地监听端口(或者堡垒机监听端口)") + connect_max := flag.Int("c", 999, "建立隧道最大次数(与DB服务器断开)") + flag.Parse() + + // 判断配置文件是否存在 + INIFILE := GetCurrentDirectory() + "/" + "tunnel.ini" + b, _ := PathExists(INIFILE) + if !b { + INIFILE = "/etc/tunnel.ini" + } + // 读取配置文件 + cfg, inierr := ini.Load(INIFILE) + if inierr != nil { + fmt.Printf("Fail to read file: %v", inierr) + os.Exit(1) + } + // 读取数据库连接信息 + MYSQL_INFO := cfg.Section("global").Key("MYSQL_INFO").String() + if MYSQL_INFO == "" { + log.Printf("读取Mysql配置出错!\n") + os.Exit(0) + } + + if *daemon { + SubProcess(StripSlice(os.Args, "-"+DAEMON)) + fmt.Printf("[*] Daemon running in PID: %d PPID: %d\n", os.Getpid(), os.Getppid()) + os.Exit(0) + } else if *forever { + for { + cmd := SubProcess(StripSlice(os.Args, "-"+FOREVER)) + fmt.Printf("[*] Forever running in PID: %d PPID: %d\n", os.Getpid(), os.Getppid()) + time.Sleep(time.Second * 5) + cmd.Wait() + } + //os.Exit(0) + } else { + fmt.Printf("[*] Service running in PID: %d PPID: %d\n", os.Getpid(), os.Getppid()) + } + + local_port_, _ := strconv.Atoi(*local_port) + r, _ := Check(local_port_) + if r { + u := mysql_show(MYSQL_INFO, *host_ip, *host_user, *host_port) + tunnel := NewSSHTunnel(u.username+"@"+u.host_ip+":"+u.protocol_port, ssh.Password(u.password_decrypt), "0.0.0.0:"+*remote_port, *local_port) + tunnel.Log = log.New(os.Stdout, "", log.Ldate|log.Lmicroseconds) + + tunnel.Start(*connect_max) + tunnel.Close() + time.Sleep(100 * time.Millisecond) + } else { + log.Printf("端口: %d 不可用!退出!\n", local_port_) + log.Printf("DoSomething running in PPID: %d\n", os.Getppid()) + + syscall.Kill(os.Getppid(), syscall.SIGKILL) + time.Sleep(time.Second * 5) + os.Exit(0) + } + + time.Sleep(time.Second * 5) + os.Exit(0) +} diff --git a/src/private_key_file.go b/src/private_key_file.go new file mode 100644 index 0000000..25bd53c --- /dev/null +++ b/src/private_key_file.go @@ -0,0 +1,21 @@ +package main + +import ( + "io/ioutil" + + "golang.org/x/crypto/ssh" +) + +func PrivateKeyFile(file string) ssh.AuthMethod { + buffer, err := ioutil.ReadFile(file) + if err != nil { + return nil + } + + key, err := ssh.ParsePrivateKey(buffer) + if err != nil { + return nil + } + + return ssh.PublicKeys(key) +} diff --git a/src/ssh_agent.go b/src/ssh_agent.go new file mode 100644 index 0000000..2ee7f06 --- /dev/null +++ b/src/ssh_agent.go @@ -0,0 +1,16 @@ +package main + +import ( + "net" + "os" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +func SSHAgent() ssh.AuthMethod { + if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { + return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers) + } + return nil +} diff --git a/src/ssh_tunnel.go b/src/ssh_tunnel.go new file mode 100644 index 0000000..2b81208 --- /dev/null +++ b/src/ssh_tunnel.go @@ -0,0 +1,194 @@ +package main + +import ( + "io" + "net" + "os" + "syscall" + + "golang.org/x/crypto/ssh" +) + +type logger interface { + Printf(string, ...interface{}) +} + +type SSHTunnel struct { + Local *Endpoint + Server *Endpoint + Remote *Endpoint + Config *ssh.ClientConfig + Log logger + Conns []net.Conn + SvrConns []*ssh.Client + MaxConnectionAttempts int + isOpen bool + close chan interface{} +} + +func (tunnel *SSHTunnel) logf(fmt string, args ...interface{}) { + if tunnel.Log != nil { + tunnel.Log.Printf(fmt, args...) + } +} + +func newConnectionWaiter(listener net.Listener, c chan net.Conn) { + conn, err := listener.Accept() + if err != nil { + return + } + c <- conn +} + +func (tunnel *SSHTunnel) Close() { + tunnel.close <- struct{}{} + return +} + +func (tunnel *SSHTunnel) Start(connect_max int) error { + listener, err := net.Listen("tcp", tunnel.Local.String()) + if err != nil { + return err + } + tunnel.isOpen = true + tunnel.Local.Port = listener.Addr().(*net.TCPAddr).Port + + // Ensure that MaxConnectionAttempts is at least 1. This check is done here + // since the library user can set the value at any point before Start() is called, + // and this check protects against the case where the programmer set MaxConnectionAttempts + // to 0 for some reason. + if tunnel.MaxConnectionAttempts <= 0 { + tunnel.MaxConnectionAttempts = 9 + } + + var connect = 0 // 连接数 + + for { + + connect++ + tunnel.logf("PID: %d PPID: %d 本地端口: %d 当前连接数:%d\n", os.Getpid(), os.Getppid(), tunnel.Local.Port, connect) + + if connect >= connect_max { // 连接数达到1000时重启进程 + tunnel.logf("连接数达到上线,重启进程!%d\n", connect) + syscall.Kill(os.Getpid(), syscall.SIGKILL) + + } + + if !tunnel.isOpen { + break + } + + c := make(chan net.Conn) + go newConnectionWaiter(listener, c) + tunnel.logf("正在侦听新连接 ...") + + select { + case <-tunnel.close: + tunnel.logf("收到关闭信号,关闭 ...") + tunnel.isOpen = false + case conn := <-c: + tunnel.Conns = append(tunnel.Conns, conn) + tunnel.logf("accepted connection") + + go tunnel.forward(conn) + } + } + + var total int + total = len(tunnel.Conns) + for i, conn := range tunnel.Conns { + tunnel.logf("closing the netConn (%d of %d)", i+1, total) + err := conn.Close() + if err != nil { + tunnel.logf(err.Error()) + } + } + total = len(tunnel.SvrConns) + for i, conn := range tunnel.SvrConns { + tunnel.logf("closing the serverConn (%d of %d)", i+1, total) + err := conn.Close() + if err != nil { + tunnel.logf(err.Error()) + } + } + err = listener.Close() + if err != nil { + return err + } + tunnel.logf("tunnel closed") + return nil +} + +func (tunnel *SSHTunnel) forward(localConn net.Conn) { + var ( + serverConn *ssh.Client + err error + attemptsLeft int = tunnel.MaxConnectionAttempts + ) + + for { + serverConn, err = ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config) + if err != nil { + attemptsLeft-- + + if attemptsLeft <= 0 { + tunnel.logf("服务器拨号错误: %v: exceeded %d attempts", err, tunnel.MaxConnectionAttempts) + return + } + } else { + break + } + } + + tunnel.logf("connected to %s (1 of 2)\n", tunnel.Server.String()) + tunnel.SvrConns = append(tunnel.SvrConns, serverConn) + + remoteConn, err := serverConn.Dial("tcp", tunnel.Remote.String()) + if err != nil { + tunnel.logf("remote dial error: %s", err) + return + } + + tunnel.Conns = append(tunnel.Conns, remoteConn) + tunnel.logf("connected to %s (2 of 2)\n", tunnel.Remote.String()) + + copyConn := func(writer, reader net.Conn) { + _, err := io.Copy(writer, reader) + if err != nil { + tunnel.logf("io.Copy error: %s", err) + } + } + + go copyConn(localConn, remoteConn) + go copyConn(remoteConn, localConn) + + return +} + +// NewSSHTunnel creates a new single-use tunnel. Supplying "0" for localport will use a random port. +func NewSSHTunnel(tunnel string, auth ssh.AuthMethod, destination string, localport string) *SSHTunnel { + + localEndpoint := NewEndpoint("0.0.0.0:" + localport) + + server := NewEndpoint(tunnel) + if server.Port == 0 { + server.Port = 22 + } + + sshTunnel := &SSHTunnel{ + Config: &ssh.ClientConfig{ + User: server.User, + Auth: []ssh.AuthMethod{auth}, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + // Always accept key. + return nil + }, + }, + Local: localEndpoint, + Server: server, + Remote: NewEndpoint(destination), + close: make(chan interface{}), + } + + return sshTunnel +}