初始提交

This commit is contained in:
2023-02-13 18:09:13 +08:00
commit 764a8a3911
10 changed files with 1131 additions and 0 deletions

342
src/main.go Normal file
View File

@@ -0,0 +1,342 @@
package main
import (
"database/sql"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"syscall"
"time"
_ "github.com/go-sql-driver/mysql"
"golang.org/x/crypto/ssh"
"gopkg.in/ini.v1"
)
// 转发
func sForward(serverAddr string, remoteAddr string, localConn net.Conn, config *ssh.ClientConfig) {
// 设置sshClientConn
sshClientConn, err := ssh.Dial("tcp", serverAddr, config)
if err != nil {
fmt.Printf("ssh.Dial failed: %s", err)
os.Exit(1)
}
// 设置Connection
sshConn, _ := sshClientConn.Dial("tcp", remoteAddr)
/*
// 将localConn.Reader复制到sshConn.Writer
go func() {
_, err = io.Copy(sshConn, localConn)
if err != nil {
fmt.Printf("io.Copy failed: %v", err)
os.Exit(1)
}
}()
// 将sshConn.Reader复制到localConn.Writer
go func() {
_, err = io.Copy(localConn, sshConn)
if err != nil {
fmt.Printf("io.Copy failed: %v", err)
os.Exit(1)
}
}()
*/
copyConn := func(writer, reader net.Conn) {
_, err := io.Copy(writer, reader)
if err != nil {
fmt.Printf(" io.Copy error: %s", err)
}
}
go copyConn(localConn, sshConn)
go copyConn(sshConn, localConn)
}
func Tunnel(username string, password string, serverAddr string, remoteAddr string, localAddr string) {
// 设置SSH配置
fmt.Printf("%s服务器:%s; 用户/密码: %s; 远程:%s; 本地:%s\n", "设置SSH配置", serverAddr, username+"/"+password, remoteAddr, localAddr)
config := &ssh.ClientConfig{
User: username,
Auth: []ssh.AuthMethod{
ssh.Password(password),
},
Timeout: 7 * time.Second,
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
}
// 设置本地监听器
localListener, err := net.Listen("tcp", localAddr)
if err != nil {
fmt.Printf("net.Listen failed: %v\n", err)
os.Exit(1)
}
for {
// 设置本地
localConn, err := localListener.Accept()
if err != nil {
fmt.Printf("localListener.Accept failed: %v\n", err)
os.Exit(1)
}
go sForward(serverAddr, remoteAddr, localConn, config)
}
}
type user struct {
id string
host_id string
host_ip string
router_ip string
router_port string
protocol_type string
protocol_port string
state string
acc_auth_id string
auth_type string
username string
username_prompt string
password_prompt string
password string
pri_key string
creator_id string
create_time string
last_secret string
}
func ExecCommand(strCommand string) string {
cmd := exec.Command("/bin/bash", "-c", strCommand)
stdout, _ := cmd.StdoutPipe()
if err := cmd.Start(); err != nil {
fmt.Println("Execute failed when Start: " + err.Error())
return ""
}
out_bytes, _ := ioutil.ReadAll(stdout)
stdout.Close()
if err := cmd.Wait(); err != nil {
fmt.Println("Execute failed when Wait: " + err.Error())
return ""
}
return string(out_bytes)
}
func ssh_tunnel(host_ip string, host_port string, host_name string, remote_ip string, local_ip string, mysql_info string) {
fmt.Println(host_ip, host_name, host_port)
//"用户名:密码@[连接方式](主机名:端口号)/数据库名"
var db *sql.DB
db, _ = sql.Open("mysql", mysql_info) // 设置连接数据库的参数
defer db.Close() //关闭数据库
err := db.Ping() //连接数据库
if err != nil {
fmt.Println("数据库连接失败")
// 退出
syscall.Kill(os.Getppid(), syscall.SIGKILL)
syscall.Kill(os.Getpid(), syscall.SIGKILL)
} else {
fmt.Println("数据库连接成功")
}
sqlStr := "select host_ip, username, protocol_port, password from tp_acc where host_ip = ? and username = ? and protocol_port = ?"
rows, err := db.Query(sqlStr, host_ip, host_name, host_port)
if err != nil {
fmt.Printf("query failed, err:%v\n", err)
return
}
// 非常重要关闭rows释放持有的数据库链接
defer rows.Close()
var u user
// 循环读取结果集中的数据
for rows.Next() {
err := rows.Scan(&u.host_ip, &u.username, &u.protocol_port, &u.password)
if err != nil {
fmt.Printf("scan failed, err:%v\n", err)
return
}
fmt.Printf("%s %s %s %s\n", u.host_ip, u.protocol_port, u.username, u.password)
}
// 解密密钥
command := "./tp_decrypt "
command = command + u.password
strData := ExecCommand(command)
strData = strings.Replace(strData, "\n", "", -1)
fmt.Println(strData)
if strData == "" {
// 判断密码是否解密成功,不成功退出
syscall.Kill(os.Getppid(), syscall.SIGKILL)
syscall.Kill(os.Getpid(), syscall.SIGKILL)
}
Tunnel(host_name, strData, u.host_ip+":"+u.protocol_port, "0.0.0.0:"+remote_ip, "0.0.0.0:"+local_ip)
}
func StripSlice(slice []string, element string) []string {
for i := 0; i < len(slice); {
if slice[i] == element && i != len(slice)-1 {
slice = append(slice[:i], slice[i+1:]...)
} else if slice[i] == element && i == len(slice)-1 {
slice = slice[:i]
} else {
i++
}
}
return slice
}
func SubProcess(args []string) *exec.Cmd {
cmd := exec.Command(args[0], args[1:]...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Start()
if err != nil {
fmt.Fprintf(os.Stderr, "[-] Error: %s\n", err)
}
return cmd
}
const (
DAEMON = "d"
FOREVER = "f"
HOST_IP = "h"
HOST_PORT = "p"
HOST_USER = "u"
HOST_REMOTE = "r"
HOST_LOCAL = "l"
)
// Check if a port is available
func Check(port int) (status bool, err error) {
// Concatenate a colon and the port
host := ":" + strconv.Itoa(port)
// Try to create a server with the port
server, err := net.Listen("tcp", host)
// if it fails then the port is likely taken
if err != nil {
return false, err
}
// close the server
server.Close()
// we successfully used and closed the port
// so it's now available to be used again
return true, nil
}
func GetCurrentDirectory() string {
dir, err := filepath.Abs(filepath.Dir(os.Args[0])) //返回绝对路径 filepath.Dir(os.Args[0])去除最后一个元素的路径
if err != nil {
log.Fatal(err)
}
return strings.Replace(dir, "\\", "/", -1) //将\替换成/
}
func PathExists(path string) (bool, error) {
_, err := os.Stat(path)
if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
func main() {
// 判断配置文件是否存在
INIFILE := GetCurrentDirectory() + "/" + "tunnel.ini"
b, _ := PathExists(INIFILE)
if !b {
INIFILE = "/etc/tunnel.ini"
}
// 读取配置文件
cfg, inierr := ini.Load(INIFILE)
if inierr != nil {
fmt.Printf("Fail to read file: %v", inierr)
os.Exit(1)
}
// 读取数据库连接信息
MYSQL_INFO := cfg.Section("global").Key("MYSQL_INFO").String()
daemon := flag.Bool(DAEMON, false, "run in daemon")
forever := flag.Bool(FOREVER, false, "run forever")
host_ip := flag.String(HOST_IP, "", "DB服务器SSH IP地址")
host_port := flag.String(HOST_PORT, "", "DB服务器SSH PORT")
host_user := flag.String(HOST_USER, "", "DB服务器SSH USER用户")
remote_port := flag.String(HOST_REMOTE, "", "DB服务器端口如: 3306、1521 ...")
local_port := flag.String(HOST_LOCAL, "", "本地监听端口(或者堡垒机监听端口)")
flag.Parse()
if *daemon {
SubProcess(StripSlice(os.Args, "-"+DAEMON))
fmt.Printf("[*] Daemon running in PID: %d PPID: %d\n", os.Getpid(), os.Getppid())
os.Exit(0)
} else if *forever {
for {
cmd := SubProcess(StripSlice(os.Args, "-"+FOREVER))
fmt.Printf("[*] Forever running in PID: %d PPID: %d\n", os.Getpid(), os.Getppid())
time.Sleep(time.Second * 5)
cmd.Wait()
}
os.Exit(0)
} else {
fmt.Printf("[*] Service running in PID: %d PPID: %d\n", os.Getpid(), os.Getppid())
}
if 0 == 0 {
fmt.Println(*host_ip, *host_port, *host_user, *remote_port, *local_port)
//
local_port_, _ := strconv.Atoi(*local_port)
r, _ := Check(local_port_)
//fmt.Println(r)
if r {
ssh_tunnel(*host_ip, *host_port, *host_user, *remote_port, *local_port, MYSQL_INFO)
} else {
fmt.Println(local_port_, "端口不可用!", r, "退出!")
fp, _ := os.OpenFile("./pid.log", os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
log.SetOutput(fp)
log.Printf("DoSomething running in PPID: %d\n", os.Getppid())
syscall.Kill(os.Getppid(), syscall.SIGKILL)
time.Sleep(time.Second * 10)
os.Exit(0)
}
}
time.Sleep(time.Second * 5)
os.Exit(0)
}