更换库
This commit is contained in:
BIN
src/db_tunnel
Normal file
BIN
src/db_tunnel
Normal file
Binary file not shown.
35
src/endpoint.go
Normal file
35
src/endpoint.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Endpoint struct {
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
}
|
||||
|
||||
func NewEndpoint(s string) *Endpoint {
|
||||
endpoint := &Endpoint{
|
||||
Host: s,
|
||||
}
|
||||
|
||||
if parts := strings.Split(endpoint.Host, "@"); len(parts) > 1 {
|
||||
endpoint.User = parts[0]
|
||||
endpoint.Host = parts[1]
|
||||
}
|
||||
|
||||
if parts := strings.Split(endpoint.Host, ":"); len(parts) > 1 {
|
||||
endpoint.Host = parts[0]
|
||||
endpoint.Port, _ = strconv.Atoi(parts[1])
|
||||
}
|
||||
|
||||
return endpoint
|
||||
}
|
||||
|
||||
func (endpoint *Endpoint) String() string {
|
||||
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
|
||||
}
|
||||
617
src/main.go
617
src/main.go
@@ -1,356 +1,261 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"gopkg.in/ini.v1"
|
||||
)
|
||||
|
||||
// 转发
|
||||
func sForward(serverAddr string, remoteAddr string, localConn net.Conn, config *ssh.ClientConfig) {
|
||||
// 设置sshClientConn
|
||||
sshClientConn, err := ssh.Dial("tcp", serverAddr, config)
|
||||
if err != nil {
|
||||
fmt.Printf("ssh.Dial failed: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// 设置Connection
|
||||
sshConn, _ := sshClientConn.Dial("tcp", remoteAddr)
|
||||
/*
|
||||
// 将localConn.Reader复制到sshConn.Writer
|
||||
go func() {
|
||||
_, err = io.Copy(sshConn, localConn)
|
||||
if err != nil {
|
||||
fmt.Printf("io.Copy failed: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
// 将sshConn.Reader复制到localConn.Writer
|
||||
go func() {
|
||||
_, err = io.Copy(localConn, sshConn)
|
||||
if err != nil {
|
||||
fmt.Printf("io.Copy failed: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
*/
|
||||
|
||||
copyConn := func(writer, reader net.Conn) {
|
||||
_, err := io.Copy(writer, reader)
|
||||
if err != nil {
|
||||
fmt.Printf(" io.Copy error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
go copyConn(localConn, sshConn)
|
||||
go copyConn(sshConn, localConn)
|
||||
|
||||
}
|
||||
|
||||
func Tunnel(username string, password string, serverAddr string, remoteAddr string, localAddr string, connect_max int) {
|
||||
// 设置SSH配置
|
||||
fmt.Printf("%s,服务器:%s; 用户/密码: %s; 远程:%s; 本地:%s\n", "设置SSH配置", serverAddr, username+"/"+password, remoteAddr, localAddr)
|
||||
config := &ssh.ClientConfig{
|
||||
User: username,
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.Password(password),
|
||||
},
|
||||
Timeout: 7 * time.Second,
|
||||
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// 设置本地监听器
|
||||
localListener, err := net.Listen("tcp", localAddr)
|
||||
if err != nil {
|
||||
fmt.Printf("net.Listen failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var connect int // 连接数
|
||||
connect = 0
|
||||
|
||||
for {
|
||||
|
||||
connect++
|
||||
fmt.Println("当前连接数:", connect)
|
||||
|
||||
// 设置本地
|
||||
localConn, err := localListener.Accept()
|
||||
if err != nil {
|
||||
fmt.Printf("localListener.Accept failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if connect >= connect_max { // 连接数达到1000时重启进程
|
||||
fmt.Printf("连接数达到上线,重启进程!%d\n", connect)
|
||||
syscall.Kill(os.Getpid(), syscall.SIGKILL)
|
||||
}
|
||||
|
||||
sForward(serverAddr, remoteAddr, localConn, config)
|
||||
}
|
||||
}
|
||||
|
||||
type user struct {
|
||||
id string
|
||||
host_id string
|
||||
host_ip string
|
||||
router_ip string
|
||||
router_port string
|
||||
protocol_type string
|
||||
protocol_port string
|
||||
state string
|
||||
acc_auth_id string
|
||||
auth_type string
|
||||
username string
|
||||
username_prompt string
|
||||
password_prompt string
|
||||
password string
|
||||
pri_key string
|
||||
creator_id string
|
||||
create_time string
|
||||
last_secret string
|
||||
}
|
||||
|
||||
func ExecCommand(strCommand string) string {
|
||||
cmd := exec.Command("/bin/bash", "-c", strCommand)
|
||||
|
||||
stdout, _ := cmd.StdoutPipe()
|
||||
if err := cmd.Start(); err != nil {
|
||||
fmt.Println("Execute failed when Start: " + err.Error())
|
||||
return ""
|
||||
}
|
||||
|
||||
out_bytes, _ := ioutil.ReadAll(stdout)
|
||||
stdout.Close()
|
||||
|
||||
if err := cmd.Wait(); err != nil {
|
||||
fmt.Println("Execute failed when Wait: " + err.Error())
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(out_bytes)
|
||||
}
|
||||
|
||||
func ssh_tunnel(host_ip string, host_port string, host_name string, remote_ip string, local_ip string, mysql_info string, connect_max int) {
|
||||
|
||||
fmt.Println(host_ip, host_name, host_port)
|
||||
|
||||
// "用户名:密码@[连接方式](主机名:端口号)/数据库名"
|
||||
var db *sql.DB
|
||||
db, _ = sql.Open("mysql", mysql_info) // 设置连接数据库的参数
|
||||
defer db.Close() // 关闭数据库
|
||||
err := db.Ping() // 连接数据库
|
||||
if err != nil {
|
||||
fmt.Println("数据库连接失败")
|
||||
// 退出
|
||||
syscall.Kill(os.Getppid(), syscall.SIGKILL)
|
||||
syscall.Kill(os.Getpid(), syscall.SIGKILL)
|
||||
} else {
|
||||
fmt.Println("数据库连接成功")
|
||||
|
||||
}
|
||||
|
||||
sqlStr := "select host_ip, username, protocol_port, password from tp_acc where host_ip = ? and username = ? and protocol_port = ?"
|
||||
rows, err := db.Query(sqlStr, host_ip, host_name, host_port)
|
||||
if err != nil {
|
||||
fmt.Printf("query failed, err:%v\n", err)
|
||||
return
|
||||
}
|
||||
// 非常重要关闭rows释放持有的数据库链接
|
||||
defer rows.Close()
|
||||
var u user
|
||||
// 循环读取结果集中的数据
|
||||
for rows.Next() {
|
||||
err := rows.Scan(&u.host_ip, &u.username, &u.protocol_port, &u.password)
|
||||
if err != nil {
|
||||
fmt.Printf("scan failed, err:%v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("%s %s %s %s\n", u.host_ip, u.protocol_port, u.username, u.password)
|
||||
}
|
||||
|
||||
// 解密密钥
|
||||
command := "./tp_decrypt "
|
||||
command = command + u.password
|
||||
strData := ExecCommand(command)
|
||||
strData = strings.Replace(strData, "\n", "", -1)
|
||||
fmt.Println(strData)
|
||||
if strData == "" {
|
||||
// 判断密码是否解密成功,不成功退出
|
||||
syscall.Kill(os.Getppid(), syscall.SIGKILL)
|
||||
syscall.Kill(os.Getpid(), syscall.SIGKILL)
|
||||
}
|
||||
|
||||
Tunnel(host_name, strData, u.host_ip+":"+u.protocol_port, "0.0.0.0:"+remote_ip, "0.0.0.0:"+local_ip, connect_max)
|
||||
|
||||
}
|
||||
|
||||
func StripSlice(slice []string, element string) []string {
|
||||
for i := 0; i < len(slice); {
|
||||
if slice[i] == element && i != len(slice)-1 {
|
||||
slice = append(slice[:i], slice[i+1:]...)
|
||||
} else if slice[i] == element && i == len(slice)-1 {
|
||||
slice = slice[:i]
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
}
|
||||
return slice
|
||||
}
|
||||
|
||||
func SubProcess(args []string) *exec.Cmd {
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[-] Error: %s\n", err)
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
// Check if a port is available
|
||||
func Check(port int) (status bool, err error) {
|
||||
|
||||
// Concatenate a colon and the port
|
||||
host := ":" + strconv.Itoa(port)
|
||||
|
||||
// Try to create a server with the port
|
||||
server, err := net.Listen("tcp", host)
|
||||
|
||||
// if it fails then the port is likely taken
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// close the server
|
||||
server.Close()
|
||||
|
||||
// we successfully used and closed the port
|
||||
// so it's now available to be used again
|
||||
return true, nil
|
||||
|
||||
}
|
||||
|
||||
func GetCurrentDirectory() string {
|
||||
dir, err := filepath.Abs(filepath.Dir(os.Args[0])) // 返回绝对路径 filepath.Dir(os.Args[0])去除最后一个元素的路径
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return strings.Replace(dir, "\\", "/", -1) // 将\替换成/
|
||||
}
|
||||
|
||||
func PathExists(path string) (bool, error) {
|
||||
_, err := os.Stat(path)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
const (
|
||||
DAEMON = "d"
|
||||
FOREVER = "f"
|
||||
HOST_IP = "h"
|
||||
HOST_PORT = "p"
|
||||
HOST_USER = "u"
|
||||
HOST_REMOTE = "r"
|
||||
HOST_LOCAL = "l"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
// 判断配置文件是否存在
|
||||
INIFILE := GetCurrentDirectory() + "/" + "tunnel.ini"
|
||||
b, _ := PathExists(INIFILE)
|
||||
|
||||
if !b {
|
||||
INIFILE = "/etc/tunnel.ini"
|
||||
}
|
||||
|
||||
// 读取配置文件
|
||||
cfg, inierr := ini.Load(INIFILE)
|
||||
if inierr != nil {
|
||||
fmt.Printf("Fail to read file: %v", inierr)
|
||||
os.Exit(1)
|
||||
}
|
||||
// 读取数据库连接信息
|
||||
MYSQL_INFO := cfg.Section("global").Key("MYSQL_INFO").String()
|
||||
if MYSQL_INFO == "" {
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
daemon := flag.Bool(DAEMON, false, "run in daemon")
|
||||
forever := flag.Bool(FOREVER, false, "run forever")
|
||||
host_ip := flag.String(HOST_IP, "", "DB服务器SSH IP地址")
|
||||
host_port := flag.String(HOST_PORT, "", "DB服务器SSH PORT")
|
||||
host_user := flag.String(HOST_USER, "", "DB服务器SSH USER用户")
|
||||
remote_port := flag.String(HOST_REMOTE, "", "DB服务器端口(如: 3306、1521 ...)")
|
||||
local_port := flag.String(HOST_LOCAL, "", "本地监听端口(或者堡垒机监听端口)")
|
||||
connect_max := flag.Int("c", 999, "建立隧道最大次数(与DB服务器断开)")
|
||||
flag.Parse()
|
||||
|
||||
if *daemon {
|
||||
SubProcess(StripSlice(os.Args, "-"+DAEMON))
|
||||
fmt.Printf("[*] Daemon running in PID: %d PPID: %d\n", os.Getpid(), os.Getppid())
|
||||
os.Exit(0)
|
||||
} else if *forever {
|
||||
for {
|
||||
cmd := SubProcess(StripSlice(os.Args, "-"+FOREVER))
|
||||
fmt.Printf("[*] Forever running in PID: %d PPID: %d\n", os.Getpid(), os.Getppid())
|
||||
time.Sleep(time.Second * 5)
|
||||
cmd.Wait()
|
||||
}
|
||||
//os.Exit(0)
|
||||
} else {
|
||||
fmt.Printf("[*] Service running in PID: %d PPID: %d\n", os.Getpid(), os.Getppid())
|
||||
}
|
||||
|
||||
fmt.Println(*host_ip, *host_port, *host_user, *remote_port, *local_port)
|
||||
|
||||
//
|
||||
local_port_, _ := strconv.Atoi(*local_port)
|
||||
r, _ := Check(local_port_)
|
||||
//fmt.Println(r)
|
||||
if r {
|
||||
fmt.Println(*connect_max)
|
||||
ssh_tunnel(*host_ip, *host_port, *host_user, *remote_port, *local_port, MYSQL_INFO, *connect_max)
|
||||
} else {
|
||||
fmt.Println(local_port_, "端口不可用!", r, "退出!")
|
||||
fp, _ := os.OpenFile("./pid.log", os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
|
||||
log.SetOutput(fp)
|
||||
|
||||
log.Printf("DoSomething running in PPID: %d\n", os.Getppid())
|
||||
|
||||
syscall.Kill(os.Getppid(), syscall.SIGKILL)
|
||||
time.Sleep(time.Second * 10)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 5)
|
||||
os.Exit(0)
|
||||
}
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"gopkg.in/ini.v1"
|
||||
)
|
||||
|
||||
const (
|
||||
DAEMON = "d"
|
||||
FOREVER = "f"
|
||||
HOST_IP = "h"
|
||||
HOST_PORT = "p"
|
||||
HOST_USER = "u"
|
||||
HOST_REMOTE = "r"
|
||||
HOST_LOCAL = "l"
|
||||
)
|
||||
|
||||
func StripSlice(slice []string, element string) []string {
|
||||
for i := 0; i < len(slice); {
|
||||
if slice[i] == element && i != len(slice)-1 {
|
||||
slice = append(slice[:i], slice[i+1:]...)
|
||||
} else if slice[i] == element && i == len(slice)-1 {
|
||||
slice = slice[:i]
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
}
|
||||
return slice
|
||||
}
|
||||
|
||||
func SubProcess(args []string) *exec.Cmd {
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[-] Error: %s\n", err)
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
// Check if a port is available
|
||||
func Check(port int) (status bool, err error) {
|
||||
|
||||
// Concatenate a colon and the port
|
||||
host := ":" + strconv.Itoa(port)
|
||||
|
||||
// Try to create a server with the port
|
||||
server, err := net.Listen("tcp", host)
|
||||
|
||||
// if it fails then the port is likely taken
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// close the server
|
||||
server.Close()
|
||||
|
||||
// we successfully used and closed the port
|
||||
// so it's now available to be used again
|
||||
return true, nil
|
||||
|
||||
}
|
||||
|
||||
func GetCurrentDirectory() string {
|
||||
dir, err := filepath.Abs(filepath.Dir(os.Args[0])) // 返回绝对路径 filepath.Dir(os.Args[0])去除最后一个元素的路径
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return strings.Replace(dir, "\\", "/", -1) // 将\替换成/
|
||||
}
|
||||
|
||||
func PathExists(path string) (bool, error) {
|
||||
_, err := os.Stat(path)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
type user struct {
|
||||
id string
|
||||
host_id string
|
||||
host_ip string
|
||||
router_ip string
|
||||
router_port string
|
||||
protocol_type string
|
||||
protocol_port string
|
||||
state string
|
||||
acc_auth_id string
|
||||
auth_type string
|
||||
username string
|
||||
username_prompt string
|
||||
password_prompt string
|
||||
password string
|
||||
password_decrypt string
|
||||
pri_key string
|
||||
creator_id string
|
||||
create_time string
|
||||
last_secret string
|
||||
}
|
||||
|
||||
func ExecCommand(strCommand string) string {
|
||||
cmd := exec.Command("/bin/bash", "-c", strCommand)
|
||||
|
||||
stdout, _ := cmd.StdoutPipe()
|
||||
if err := cmd.Start(); err != nil {
|
||||
fmt.Println("Execute failed when Start: " + err.Error())
|
||||
return ""
|
||||
}
|
||||
|
||||
out_bytes, _ := ioutil.ReadAll(stdout)
|
||||
stdout.Close()
|
||||
|
||||
if err := cmd.Wait(); err != nil {
|
||||
fmt.Println("Execute failed when Wait: " + err.Error())
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(out_bytes)
|
||||
}
|
||||
|
||||
// 读取mysql数据库内信息
|
||||
func mysql_show(connect_info string, host_ip string, host_user string, host_port string) user {
|
||||
|
||||
var db *sql.DB
|
||||
db, _ = sql.Open("mysql", connect_info) // 设置连接数据库的参数
|
||||
defer db.Close() // 关闭数据库
|
||||
err := db.Ping() // 连接数据库
|
||||
if err != nil {
|
||||
// 退出
|
||||
log.Printf("数据库连接失败")
|
||||
syscall.Kill(os.Getppid(), syscall.SIGKILL)
|
||||
syscall.Kill(os.Getpid(), syscall.SIGKILL)
|
||||
} else {
|
||||
log.Printf("数据库连接成功")
|
||||
|
||||
}
|
||||
sqlStr := "select host_ip, username, protocol_port, password from tp_acc where host_ip = ? and username = ? and protocol_port = ?"
|
||||
rows, err := db.Query(sqlStr, host_ip, host_user, host_port)
|
||||
if err != nil {
|
||||
log.Printf("query failed, err:%v\n", err)
|
||||
}
|
||||
// 非常重要关闭rows释放持有的数据库链接
|
||||
defer rows.Close()
|
||||
var u user
|
||||
// 循环读取结果集中的数据
|
||||
for rows.Next() {
|
||||
err := rows.Scan(&u.host_ip, &u.username, &u.protocol_port, &u.password)
|
||||
if err != nil {
|
||||
log.Printf("scan failed, err:%v\n", err)
|
||||
}
|
||||
log.Printf("%s %s %s %s\n", u.host_ip, u.protocol_port, u.username, u.password)
|
||||
}
|
||||
|
||||
// 解密密钥
|
||||
command := "./tp_decrypt "
|
||||
command = command + u.password
|
||||
strData := ExecCommand(command)
|
||||
strData = strings.Replace(strData, "\n", "", -1)
|
||||
fmt.Println(strData)
|
||||
if strData == "" {
|
||||
// 判断密码是否解密成功,不成功退出
|
||||
syscall.Kill(os.Getppid(), syscall.SIGKILL)
|
||||
syscall.Kill(os.Getpid(), syscall.SIGKILL)
|
||||
}
|
||||
u.password_decrypt = strData
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
func main() {
|
||||
// 参数处理
|
||||
daemon := flag.Bool(DAEMON, false, "run in daemon")
|
||||
forever := flag.Bool(FOREVER, false, "run forever")
|
||||
host_ip := flag.String(HOST_IP, "", "DB服务器SSH IP地址")
|
||||
host_port := flag.String(HOST_PORT, "", "DB服务器SSH PORT")
|
||||
host_user := flag.String(HOST_USER, "", "DB服务器SSH USER用户")
|
||||
remote_port := flag.String(HOST_REMOTE, "", "DB服务器端口(如: 3306、1521 ...)")
|
||||
local_port := flag.String(HOST_LOCAL, "", "本地监听端口(或者堡垒机监听端口)")
|
||||
connect_max := flag.Int("c", 999, "建立隧道最大次数(与DB服务器断开)")
|
||||
flag.Parse()
|
||||
|
||||
// 判断配置文件是否存在
|
||||
INIFILE := GetCurrentDirectory() + "/" + "tunnel.ini"
|
||||
b, _ := PathExists(INIFILE)
|
||||
if !b {
|
||||
INIFILE = "/etc/tunnel.ini"
|
||||
}
|
||||
// 读取配置文件
|
||||
cfg, inierr := ini.Load(INIFILE)
|
||||
if inierr != nil {
|
||||
fmt.Printf("Fail to read file: %v", inierr)
|
||||
os.Exit(1)
|
||||
}
|
||||
// 读取数据库连接信息
|
||||
MYSQL_INFO := cfg.Section("global").Key("MYSQL_INFO").String()
|
||||
if MYSQL_INFO == "" {
|
||||
log.Printf("读取Mysql配置出错!\n")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if *daemon {
|
||||
SubProcess(StripSlice(os.Args, "-"+DAEMON))
|
||||
fmt.Printf("[*] Daemon running in PID: %d PPID: %d\n", os.Getpid(), os.Getppid())
|
||||
os.Exit(0)
|
||||
} else if *forever {
|
||||
for {
|
||||
cmd := SubProcess(StripSlice(os.Args, "-"+FOREVER))
|
||||
fmt.Printf("[*] Forever running in PID: %d PPID: %d\n", os.Getpid(), os.Getppid())
|
||||
time.Sleep(time.Second * 5)
|
||||
cmd.Wait()
|
||||
}
|
||||
//os.Exit(0)
|
||||
} else {
|
||||
fmt.Printf("[*] Service running in PID: %d PPID: %d\n", os.Getpid(), os.Getppid())
|
||||
}
|
||||
|
||||
local_port_, _ := strconv.Atoi(*local_port)
|
||||
r, _ := Check(local_port_)
|
||||
if r {
|
||||
u := mysql_show(MYSQL_INFO, *host_ip, *host_user, *host_port)
|
||||
tunnel := NewSSHTunnel(u.username+"@"+u.host_ip+":"+u.protocol_port, ssh.Password(u.password_decrypt), "0.0.0.0:"+*remote_port, *local_port)
|
||||
tunnel.Log = log.New(os.Stdout, "", log.Ldate|log.Lmicroseconds)
|
||||
|
||||
tunnel.Start(*connect_max)
|
||||
tunnel.Close()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
} else {
|
||||
log.Printf("端口: %d 不可用!退出!\n", local_port_)
|
||||
log.Printf("DoSomething running in PPID: %d\n", os.Getppid())
|
||||
|
||||
syscall.Kill(os.Getppid(), syscall.SIGKILL)
|
||||
time.Sleep(time.Second * 5)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 5)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
21
src/private_key_file.go
Normal file
21
src/private_key_file.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func PrivateKeyFile(file string) ssh.AuthMethod {
|
||||
buffer, err := ioutil.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key, err := ssh.ParsePrivateKey(buffer)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ssh.PublicKeys(key)
|
||||
}
|
||||
16
src/ssh_agent.go
Normal file
16
src/ssh_agent.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
)
|
||||
|
||||
func SSHAgent() ssh.AuthMethod {
|
||||
if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
|
||||
return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
194
src/ssh_tunnel.go
Normal file
194
src/ssh_tunnel.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type logger interface {
|
||||
Printf(string, ...interface{})
|
||||
}
|
||||
|
||||
type SSHTunnel struct {
|
||||
Local *Endpoint
|
||||
Server *Endpoint
|
||||
Remote *Endpoint
|
||||
Config *ssh.ClientConfig
|
||||
Log logger
|
||||
Conns []net.Conn
|
||||
SvrConns []*ssh.Client
|
||||
MaxConnectionAttempts int
|
||||
isOpen bool
|
||||
close chan interface{}
|
||||
}
|
||||
|
||||
func (tunnel *SSHTunnel) logf(fmt string, args ...interface{}) {
|
||||
if tunnel.Log != nil {
|
||||
tunnel.Log.Printf(fmt, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func newConnectionWaiter(listener net.Listener, c chan net.Conn) {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c <- conn
|
||||
}
|
||||
|
||||
func (tunnel *SSHTunnel) Close() {
|
||||
tunnel.close <- struct{}{}
|
||||
return
|
||||
}
|
||||
|
||||
func (tunnel *SSHTunnel) Start(connect_max int) error {
|
||||
listener, err := net.Listen("tcp", tunnel.Local.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tunnel.isOpen = true
|
||||
tunnel.Local.Port = listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
// Ensure that MaxConnectionAttempts is at least 1. This check is done here
|
||||
// since the library user can set the value at any point before Start() is called,
|
||||
// and this check protects against the case where the programmer set MaxConnectionAttempts
|
||||
// to 0 for some reason.
|
||||
if tunnel.MaxConnectionAttempts <= 0 {
|
||||
tunnel.MaxConnectionAttempts = 9
|
||||
}
|
||||
|
||||
var connect = 0 // 连接数
|
||||
|
||||
for {
|
||||
|
||||
connect++
|
||||
tunnel.logf("PID: %d PPID: %d 本地端口: %d 当前连接数:%d\n", os.Getpid(), os.Getppid(), tunnel.Local.Port, connect)
|
||||
|
||||
if connect >= connect_max { // 连接数达到1000时重启进程
|
||||
tunnel.logf("连接数达到上线,重启进程!%d\n", connect)
|
||||
syscall.Kill(os.Getpid(), syscall.SIGKILL)
|
||||
|
||||
}
|
||||
|
||||
if !tunnel.isOpen {
|
||||
break
|
||||
}
|
||||
|
||||
c := make(chan net.Conn)
|
||||
go newConnectionWaiter(listener, c)
|
||||
tunnel.logf("正在侦听新连接 ...")
|
||||
|
||||
select {
|
||||
case <-tunnel.close:
|
||||
tunnel.logf("收到关闭信号,关闭 ...")
|
||||
tunnel.isOpen = false
|
||||
case conn := <-c:
|
||||
tunnel.Conns = append(tunnel.Conns, conn)
|
||||
tunnel.logf("accepted connection")
|
||||
|
||||
go tunnel.forward(conn)
|
||||
}
|
||||
}
|
||||
|
||||
var total int
|
||||
total = len(tunnel.Conns)
|
||||
for i, conn := range tunnel.Conns {
|
||||
tunnel.logf("closing the netConn (%d of %d)", i+1, total)
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
tunnel.logf(err.Error())
|
||||
}
|
||||
}
|
||||
total = len(tunnel.SvrConns)
|
||||
for i, conn := range tunnel.SvrConns {
|
||||
tunnel.logf("closing the serverConn (%d of %d)", i+1, total)
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
tunnel.logf(err.Error())
|
||||
}
|
||||
}
|
||||
err = listener.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tunnel.logf("tunnel closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tunnel *SSHTunnel) forward(localConn net.Conn) {
|
||||
var (
|
||||
serverConn *ssh.Client
|
||||
err error
|
||||
attemptsLeft int = tunnel.MaxConnectionAttempts
|
||||
)
|
||||
|
||||
for {
|
||||
serverConn, err = ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config)
|
||||
if err != nil {
|
||||
attemptsLeft--
|
||||
|
||||
if attemptsLeft <= 0 {
|
||||
tunnel.logf("服务器拨号错误: %v: exceeded %d attempts", err, tunnel.MaxConnectionAttempts)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
tunnel.logf("connected to %s (1 of 2)\n", tunnel.Server.String())
|
||||
tunnel.SvrConns = append(tunnel.SvrConns, serverConn)
|
||||
|
||||
remoteConn, err := serverConn.Dial("tcp", tunnel.Remote.String())
|
||||
if err != nil {
|
||||
tunnel.logf("remote dial error: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
tunnel.Conns = append(tunnel.Conns, remoteConn)
|
||||
tunnel.logf("connected to %s (2 of 2)\n", tunnel.Remote.String())
|
||||
|
||||
copyConn := func(writer, reader net.Conn) {
|
||||
_, err := io.Copy(writer, reader)
|
||||
if err != nil {
|
||||
tunnel.logf("io.Copy error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
go copyConn(localConn, remoteConn)
|
||||
go copyConn(remoteConn, localConn)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// NewSSHTunnel creates a new single-use tunnel. Supplying "0" for localport will use a random port.
|
||||
func NewSSHTunnel(tunnel string, auth ssh.AuthMethod, destination string, localport string) *SSHTunnel {
|
||||
|
||||
localEndpoint := NewEndpoint("0.0.0.0:" + localport)
|
||||
|
||||
server := NewEndpoint(tunnel)
|
||||
if server.Port == 0 {
|
||||
server.Port = 22
|
||||
}
|
||||
|
||||
sshTunnel := &SSHTunnel{
|
||||
Config: &ssh.ClientConfig{
|
||||
User: server.User,
|
||||
Auth: []ssh.AuthMethod{auth},
|
||||
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
// Always accept key.
|
||||
return nil
|
||||
},
|
||||
},
|
||||
Local: localEndpoint,
|
||||
Server: server,
|
||||
Remote: NewEndpoint(destination),
|
||||
close: make(chan interface{}),
|
||||
}
|
||||
|
||||
return sshTunnel
|
||||
}
|
||||
Reference in New Issue
Block a user