2025-01-24 09:59:28 +08:00

126 lines
2.8 KiB
Go

package main
import (
"bytes"
"flag"
"fmt"
"log"
"os"
"syscall"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
)
// 通过 SSH 执行命令
func runSSHCommand(config SSHConfig, command string) (string, error) {
// 创建 SSH 配置
var authMethods []ssh.AuthMethod
// 使用密码认证
if config.Password != "" {
authMethods = append(authMethods, ssh.Password(config.Password))
}
// 使用私钥认证
if config.PrivateKey != "" {
key, err := os.ReadFile(config.PrivateKey)
if err != nil {
return "", fmt.Errorf("failed to read private key file: %w", err)
}
var signer ssh.Signer
if config.Passphrase != "" {
// 使用密码解锁私钥
signer, err = ssh.ParsePrivateKeyWithPassphrase(key, []byte(config.Passphrase))
} else {
// 无密码的私钥
signer, err = ssh.ParsePrivateKey(key)
}
if err != nil {
return "", fmt.Errorf("failed to parse private key: %w", err)
}
authMethods = append(authMethods, ssh.PublicKeys(signer))
}
if len(authMethods) == 0 {
return "", fmt.Errorf("no authentication method provided")
}
sshConfig := &ssh.ClientConfig{
User: config.User,
Auth: authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // 跳过主机密钥验证,生产环境中建议替换为可信的回调函数
}
// 连接到远程服务器
address := fmt.Sprintf("%s:%d", config.Host, config.Port)
client, err := ssh.Dial("tcp", address, sshConfig)
if err != nil {
return "", fmt.Errorf("failed to dial: %w", err)
}
defer client.Close()
// 创建会话
session, err := client.NewSession()
if err != nil {
return "", fmt.Errorf("failed to create session: %w", err)
}
defer session.Close()
// 执行命令
var stdout, stderr bytes.Buffer
session.Stdout = &stdout
session.Stderr = &stderr
if err := session.Run(command); err != nil {
return "", fmt.Errorf("执行命令失败: %w. stderr: %s", err, stderr.String())
}
if stderr.Len() > 0 {
log.Printf("stderr: %s", stderr.String())
}
return stdout.String(), nil
}
func _ssh() (string, error) {
// 校验必需参数
if *h == "" || *u == "" || *c == "" || (*e == "" && *k == "") {
fmt.Println("Error: Missing required parameters")
flag.Usage()
os.Exit(1)
}
// 如果使用私钥且需要密码解锁,则提示用户输入密码
var passphrase string
if *k != "" {
fmt.Print("Enter passphrase for private key: ")
bytePassword, err := term.ReadPassword(int(syscall.Stdin))
if err != nil {
log.Printf("Failed to read passphrase: %v", err)
}
passphrase = string(bytePassword)
fmt.Println()
}
// 构造 SSH 配置
config := SSHConfig{
Host: *h,
Port: *p,
User: *u,
Password: *e,
PrivateKey: *k,
Passphrase: passphrase,
}
// 执行命令
output, err := runSSHCommand(config, *c)
if err != nil {
log.Printf("Error: %v", err)
}
return output, err
}