初始提交
This commit is contained in:
342
src/main.go
Normal file
342
src/main.go
Normal file
@@ -0,0 +1,342 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"gopkg.in/ini.v1"
|
||||
)
|
||||
|
||||
// 转发
|
||||
func sForward(serverAddr string, remoteAddr string, localConn net.Conn, config *ssh.ClientConfig) {
|
||||
// 设置sshClientConn
|
||||
sshClientConn, err := ssh.Dial("tcp", serverAddr, config)
|
||||
if err != nil {
|
||||
fmt.Printf("ssh.Dial failed: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// 设置Connection
|
||||
sshConn, _ := sshClientConn.Dial("tcp", remoteAddr)
|
||||
/*
|
||||
// 将localConn.Reader复制到sshConn.Writer
|
||||
go func() {
|
||||
_, err = io.Copy(sshConn, localConn)
|
||||
if err != nil {
|
||||
fmt.Printf("io.Copy failed: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
// 将sshConn.Reader复制到localConn.Writer
|
||||
go func() {
|
||||
_, err = io.Copy(localConn, sshConn)
|
||||
if err != nil {
|
||||
fmt.Printf("io.Copy failed: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
*/
|
||||
|
||||
copyConn := func(writer, reader net.Conn) {
|
||||
_, err := io.Copy(writer, reader)
|
||||
if err != nil {
|
||||
fmt.Printf(" io.Copy error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
go copyConn(localConn, sshConn)
|
||||
go copyConn(sshConn, localConn)
|
||||
|
||||
}
|
||||
|
||||
func Tunnel(username string, password string, serverAddr string, remoteAddr string, localAddr string) {
|
||||
// 设置SSH配置
|
||||
fmt.Printf("%s,服务器:%s; 用户/密码: %s; 远程:%s; 本地:%s\n", "设置SSH配置", serverAddr, username+"/"+password, remoteAddr, localAddr)
|
||||
config := &ssh.ClientConfig{
|
||||
User: username,
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.Password(password),
|
||||
},
|
||||
Timeout: 7 * time.Second,
|
||||
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// 设置本地监听器
|
||||
localListener, err := net.Listen("tcp", localAddr)
|
||||
if err != nil {
|
||||
fmt.Printf("net.Listen failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
for {
|
||||
// 设置本地
|
||||
localConn, err := localListener.Accept()
|
||||
if err != nil {
|
||||
fmt.Printf("localListener.Accept failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
go sForward(serverAddr, remoteAddr, localConn, config)
|
||||
}
|
||||
}
|
||||
|
||||
type user struct {
|
||||
id string
|
||||
host_id string
|
||||
host_ip string
|
||||
router_ip string
|
||||
router_port string
|
||||
protocol_type string
|
||||
protocol_port string
|
||||
state string
|
||||
acc_auth_id string
|
||||
auth_type string
|
||||
username string
|
||||
username_prompt string
|
||||
password_prompt string
|
||||
password string
|
||||
pri_key string
|
||||
creator_id string
|
||||
create_time string
|
||||
last_secret string
|
||||
}
|
||||
|
||||
func ExecCommand(strCommand string) string {
|
||||
cmd := exec.Command("/bin/bash", "-c", strCommand)
|
||||
|
||||
stdout, _ := cmd.StdoutPipe()
|
||||
if err := cmd.Start(); err != nil {
|
||||
fmt.Println("Execute failed when Start: " + err.Error())
|
||||
return ""
|
||||
}
|
||||
|
||||
out_bytes, _ := ioutil.ReadAll(stdout)
|
||||
stdout.Close()
|
||||
|
||||
if err := cmd.Wait(); err != nil {
|
||||
fmt.Println("Execute failed when Wait: " + err.Error())
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(out_bytes)
|
||||
}
|
||||
|
||||
func ssh_tunnel(host_ip string, host_port string, host_name string, remote_ip string, local_ip string, mysql_info string) {
|
||||
|
||||
fmt.Println(host_ip, host_name, host_port)
|
||||
|
||||
//"用户名:密码@[连接方式](主机名:端口号)/数据库名"
|
||||
var db *sql.DB
|
||||
db, _ = sql.Open("mysql", mysql_info) // 设置连接数据库的参数
|
||||
defer db.Close() //关闭数据库
|
||||
err := db.Ping() //连接数据库
|
||||
if err != nil {
|
||||
fmt.Println("数据库连接失败")
|
||||
// 退出
|
||||
syscall.Kill(os.Getppid(), syscall.SIGKILL)
|
||||
syscall.Kill(os.Getpid(), syscall.SIGKILL)
|
||||
} else {
|
||||
fmt.Println("数据库连接成功")
|
||||
|
||||
}
|
||||
|
||||
sqlStr := "select host_ip, username, protocol_port, password from tp_acc where host_ip = ? and username = ? and protocol_port = ?"
|
||||
rows, err := db.Query(sqlStr, host_ip, host_name, host_port)
|
||||
if err != nil {
|
||||
fmt.Printf("query failed, err:%v\n", err)
|
||||
return
|
||||
}
|
||||
// 非常重要关闭rows释放持有的数据库链接
|
||||
defer rows.Close()
|
||||
var u user
|
||||
// 循环读取结果集中的数据
|
||||
for rows.Next() {
|
||||
err := rows.Scan(&u.host_ip, &u.username, &u.protocol_port, &u.password)
|
||||
if err != nil {
|
||||
fmt.Printf("scan failed, err:%v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("%s %s %s %s\n", u.host_ip, u.protocol_port, u.username, u.password)
|
||||
}
|
||||
|
||||
// 解密密钥
|
||||
command := "./tp_decrypt "
|
||||
command = command + u.password
|
||||
strData := ExecCommand(command)
|
||||
strData = strings.Replace(strData, "\n", "", -1)
|
||||
fmt.Println(strData)
|
||||
if strData == "" {
|
||||
// 判断密码是否解密成功,不成功退出
|
||||
syscall.Kill(os.Getppid(), syscall.SIGKILL)
|
||||
syscall.Kill(os.Getpid(), syscall.SIGKILL)
|
||||
}
|
||||
|
||||
Tunnel(host_name, strData, u.host_ip+":"+u.protocol_port, "0.0.0.0:"+remote_ip, "0.0.0.0:"+local_ip)
|
||||
|
||||
}
|
||||
|
||||
func StripSlice(slice []string, element string) []string {
|
||||
for i := 0; i < len(slice); {
|
||||
if slice[i] == element && i != len(slice)-1 {
|
||||
slice = append(slice[:i], slice[i+1:]...)
|
||||
} else if slice[i] == element && i == len(slice)-1 {
|
||||
slice = slice[:i]
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
}
|
||||
return slice
|
||||
}
|
||||
|
||||
func SubProcess(args []string) *exec.Cmd {
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[-] Error: %s\n", err)
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
const (
|
||||
DAEMON = "d"
|
||||
FOREVER = "f"
|
||||
HOST_IP = "h"
|
||||
HOST_PORT = "p"
|
||||
HOST_USER = "u"
|
||||
HOST_REMOTE = "r"
|
||||
HOST_LOCAL = "l"
|
||||
)
|
||||
|
||||
// Check if a port is available
|
||||
func Check(port int) (status bool, err error) {
|
||||
|
||||
// Concatenate a colon and the port
|
||||
host := ":" + strconv.Itoa(port)
|
||||
|
||||
// Try to create a server with the port
|
||||
server, err := net.Listen("tcp", host)
|
||||
|
||||
// if it fails then the port is likely taken
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// close the server
|
||||
server.Close()
|
||||
|
||||
// we successfully used and closed the port
|
||||
// so it's now available to be used again
|
||||
return true, nil
|
||||
|
||||
}
|
||||
|
||||
func GetCurrentDirectory() string {
|
||||
dir, err := filepath.Abs(filepath.Dir(os.Args[0])) //返回绝对路径 filepath.Dir(os.Args[0])去除最后一个元素的路径
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return strings.Replace(dir, "\\", "/", -1) //将\替换成/
|
||||
}
|
||||
|
||||
func PathExists(path string) (bool, error) {
|
||||
_, err := os.Stat(path)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
// 判断配置文件是否存在
|
||||
INIFILE := GetCurrentDirectory() + "/" + "tunnel.ini"
|
||||
b, _ := PathExists(INIFILE)
|
||||
|
||||
if !b {
|
||||
INIFILE = "/etc/tunnel.ini"
|
||||
}
|
||||
|
||||
// 读取配置文件
|
||||
cfg, inierr := ini.Load(INIFILE)
|
||||
if inierr != nil {
|
||||
fmt.Printf("Fail to read file: %v", inierr)
|
||||
os.Exit(1)
|
||||
}
|
||||
// 读取数据库连接信息
|
||||
MYSQL_INFO := cfg.Section("global").Key("MYSQL_INFO").String()
|
||||
|
||||
daemon := flag.Bool(DAEMON, false, "run in daemon")
|
||||
forever := flag.Bool(FOREVER, false, "run forever")
|
||||
host_ip := flag.String(HOST_IP, "", "DB服务器SSH IP地址")
|
||||
host_port := flag.String(HOST_PORT, "", "DB服务器SSH PORT")
|
||||
host_user := flag.String(HOST_USER, "", "DB服务器SSH USER用户")
|
||||
remote_port := flag.String(HOST_REMOTE, "", "DB服务器端口(如: 3306、1521 ...)")
|
||||
local_port := flag.String(HOST_LOCAL, "", "本地监听端口(或者堡垒机监听端口)")
|
||||
flag.Parse()
|
||||
|
||||
if *daemon {
|
||||
SubProcess(StripSlice(os.Args, "-"+DAEMON))
|
||||
fmt.Printf("[*] Daemon running in PID: %d PPID: %d\n", os.Getpid(), os.Getppid())
|
||||
os.Exit(0)
|
||||
} else if *forever {
|
||||
for {
|
||||
cmd := SubProcess(StripSlice(os.Args, "-"+FOREVER))
|
||||
fmt.Printf("[*] Forever running in PID: %d PPID: %d\n", os.Getpid(), os.Getppid())
|
||||
time.Sleep(time.Second * 5)
|
||||
cmd.Wait()
|
||||
}
|
||||
os.Exit(0)
|
||||
} else {
|
||||
fmt.Printf("[*] Service running in PID: %d PPID: %d\n", os.Getpid(), os.Getppid())
|
||||
}
|
||||
|
||||
if 0 == 0 {
|
||||
fmt.Println(*host_ip, *host_port, *host_user, *remote_port, *local_port)
|
||||
|
||||
//
|
||||
local_port_, _ := strconv.Atoi(*local_port)
|
||||
r, _ := Check(local_port_)
|
||||
//fmt.Println(r)
|
||||
if r {
|
||||
ssh_tunnel(*host_ip, *host_port, *host_user, *remote_port, *local_port, MYSQL_INFO)
|
||||
} else {
|
||||
fmt.Println(local_port_, "端口不可用!", r, "退出!")
|
||||
fp, _ := os.OpenFile("./pid.log", os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
|
||||
log.SetOutput(fp)
|
||||
|
||||
log.Printf("DoSomething running in PPID: %d\n", os.Getppid())
|
||||
|
||||
syscall.Kill(os.Getppid(), syscall.SIGKILL)
|
||||
time.Sleep(time.Second * 10)
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 5)
|
||||
os.Exit(0)
|
||||
}
|
||||
Reference in New Issue
Block a user