126 lines
2.8 KiB
Go
126 lines
2.8 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"syscall"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/term"
|
|
)
|
|
|
|
// 通过 SSH 执行命令
|
|
func runSSHCommand(config SSHConfig, command string) (string, error) {
|
|
// 创建 SSH 配置
|
|
var authMethods []ssh.AuthMethod
|
|
|
|
// 使用密码认证
|
|
if config.Password != "" {
|
|
authMethods = append(authMethods, ssh.Password(config.Password))
|
|
}
|
|
|
|
// 使用私钥认证
|
|
if config.PrivateKey != "" {
|
|
key, err := os.ReadFile(config.PrivateKey)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to read private key file: %w", err)
|
|
}
|
|
|
|
var signer ssh.Signer
|
|
if config.Passphrase != "" {
|
|
// 使用密码解锁私钥
|
|
signer, err = ssh.ParsePrivateKeyWithPassphrase(key, []byte(config.Passphrase))
|
|
} else {
|
|
// 无密码的私钥
|
|
signer, err = ssh.ParsePrivateKey(key)
|
|
}
|
|
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to parse private key: %w", err)
|
|
}
|
|
|
|
authMethods = append(authMethods, ssh.PublicKeys(signer))
|
|
}
|
|
|
|
if len(authMethods) == 0 {
|
|
return "", fmt.Errorf("no authentication method provided")
|
|
}
|
|
|
|
sshConfig := &ssh.ClientConfig{
|
|
User: config.User,
|
|
Auth: authMethods,
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // 跳过主机密钥验证,生产环境中建议替换为可信的回调函数
|
|
}
|
|
|
|
// 连接到远程服务器
|
|
address := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
|
client, err := ssh.Dial("tcp", address, sshConfig)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to dial: %w", err)
|
|
}
|
|
defer client.Close()
|
|
|
|
// 创建会话
|
|
session, err := client.NewSession()
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to create session: %w", err)
|
|
}
|
|
defer session.Close()
|
|
|
|
// 执行命令
|
|
var stdout, stderr bytes.Buffer
|
|
session.Stdout = &stdout
|
|
session.Stderr = &stderr
|
|
|
|
if err := session.Run(command); err != nil {
|
|
return "", fmt.Errorf("执行命令失败: %w. stderr: %s", err, stderr.String())
|
|
}
|
|
if stderr.Len() > 0 {
|
|
log.Printf("stderr: %s", stderr.String())
|
|
}
|
|
return stdout.String(), nil
|
|
|
|
}
|
|
|
|
func _ssh() (string, error) {
|
|
// 校验必需参数
|
|
if *h == "" || *u == "" || *c == "" || (*e == "" && *k == "") {
|
|
fmt.Println("Error: Missing required parameters")
|
|
flag.Usage()
|
|
os.Exit(1)
|
|
}
|
|
|
|
// 如果使用私钥且需要密码解锁,则提示用户输入密码
|
|
var passphrase string
|
|
if *k != "" {
|
|
fmt.Print("Enter passphrase for private key: ")
|
|
bytePassword, err := term.ReadPassword(int(syscall.Stdin))
|
|
if err != nil {
|
|
log.Printf("Failed to read passphrase: %v", err)
|
|
}
|
|
passphrase = string(bytePassword)
|
|
fmt.Println()
|
|
}
|
|
|
|
// 构造 SSH 配置
|
|
config := SSHConfig{
|
|
Host: *h,
|
|
Port: *p,
|
|
User: *u,
|
|
Password: *e,
|
|
PrivateKey: *k,
|
|
Passphrase: passphrase,
|
|
}
|
|
|
|
// 执行命令
|
|
output, err := runSSHCommand(config, *c)
|
|
if err != nil {
|
|
log.Printf("Error: %v", err)
|
|
}
|
|
|
|
return output, err
|
|
}
|