202 lines
4.8 KiB
Go
202 lines
4.8 KiB
Go
package main
|
|
|
|
import (
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"os"
|
|
"syscall"
|
|
|
|
"github.com/pkg/sftp"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/term"
|
|
)
|
|
|
|
// 创建 SFTP 客户端
|
|
func connectSFTP(config SSHConfig) (*sftp.Client, *ssh.Client, error) {
|
|
var authMethods []ssh.AuthMethod
|
|
|
|
// 使用密码认证
|
|
if config.Password != "" {
|
|
authMethods = append(authMethods, ssh.Password(config.Password))
|
|
}
|
|
|
|
// 使用私钥认证
|
|
if config.PrivateKey != "" {
|
|
key, err := os.ReadFile(config.PrivateKey)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to read private key file: %w", err)
|
|
}
|
|
|
|
var signer ssh.Signer
|
|
if config.Passphrase != "" {
|
|
signer, err = ssh.ParsePrivateKeyWithPassphrase(key, []byte(config.Passphrase))
|
|
} else {
|
|
signer, err = ssh.ParsePrivateKey(key)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to parse private key: %w", err)
|
|
}
|
|
authMethods = append(authMethods, ssh.PublicKeys(signer))
|
|
}
|
|
|
|
if len(authMethods) == 0 {
|
|
return nil, nil, fmt.Errorf("no authentication method provided")
|
|
}
|
|
|
|
sshConfig := &ssh.ClientConfig{
|
|
User: config.User,
|
|
Auth: authMethods,
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
}
|
|
|
|
// 建立 SSH 连接
|
|
address := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
|
sshClient, err := ssh.Dial("tcp", address, sshConfig)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to connect to SSH: %w", err)
|
|
}
|
|
|
|
// 创建 SFTP 客户端
|
|
sftpClient, err := sftp.NewClient(sshClient)
|
|
if err != nil {
|
|
sshClient.Close()
|
|
return nil, nil, fmt.Errorf("failed to create SFTP client: %w", err)
|
|
}
|
|
|
|
return sftpClient, sshClient, nil
|
|
}
|
|
|
|
// downloadFile 通过 SFTP 协议从远程服务器下载文件到本地
|
|
func downloadFile(sftpClient *sftp.Client, remotePath, localPath string) error {
|
|
// 打开远程文件
|
|
remoteFile, err := sftpClient.Open(remotePath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open remote file: %w", err)
|
|
}
|
|
defer remoteFile.Close()
|
|
|
|
// 获取远程文件信息以确定文件大小
|
|
fileInfo, err := remoteFile.Stat()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get remote file info: %w", err)
|
|
}
|
|
|
|
// 创建本地文件(创建或覆盖)
|
|
localFile, err := os.Create(localPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create local file: %w", err)
|
|
}
|
|
defer localFile.Close()
|
|
|
|
// 分块复制文件内容
|
|
const bufferSize = 32 * 1024 // 32KB
|
|
buf := make([]byte, bufferSize)
|
|
var totalBytes int64 = 0
|
|
for {
|
|
n, err := remoteFile.Read(buf)
|
|
if err != nil && err != io.EOF {
|
|
return fmt.Errorf("failed to read from remote file: %w", err)
|
|
}
|
|
if n == 0 {
|
|
break
|
|
}
|
|
if _, err := localFile.Write(buf[:n]); err != nil {
|
|
return fmt.Errorf("failed to write to local file: %w", err)
|
|
}
|
|
totalBytes += int64(n)
|
|
fmt.Printf("\rDownloaded %d / %d bytes", totalBytes, fileInfo.Size())
|
|
}
|
|
fmt.Println() // 换行
|
|
|
|
return nil
|
|
}
|
|
|
|
// 上传文件
|
|
func uploadFile(sftpClient *sftp.Client, localPath, remotePath string) error {
|
|
// 打开本地文件
|
|
localFile, err := os.Open(localPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open local file: %w", err)
|
|
}
|
|
defer localFile.Close()
|
|
|
|
// 打开远程文件(创建或覆盖)
|
|
RemoteFile, err := sftpClient.Create(remotePath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create remote file: %w", err)
|
|
}
|
|
defer RemoteFile.Close()
|
|
|
|
// 分块复制文件内容
|
|
const bufferSize = 32 * 1024 // 32KB
|
|
buf := make([]byte, bufferSize)
|
|
for {
|
|
n, err := localFile.Read(buf)
|
|
if err != nil && err.Error() != "EOF" {
|
|
return fmt.Errorf("failed to read local file: %w", err)
|
|
}
|
|
if n == 0 {
|
|
break
|
|
}
|
|
if _, err := RemoteFile.Write(buf[:n]); err != nil {
|
|
return fmt.Errorf("failed to write to remote file: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func _sftp() error {
|
|
|
|
var err error
|
|
|
|
// 校验参数
|
|
if *h == "" || *u == "" || *l == "" || *r == "" || (*e == "" && *k == "") {
|
|
fmt.Println("Error: Missing required parameters")
|
|
flag.Usage()
|
|
os.Exit(1)
|
|
}
|
|
|
|
// 如果使用私钥且需要密码解锁,则提示用户输入密码
|
|
var passphrase string
|
|
if *k != "" {
|
|
fmt.Print("Enter passphrase for private key: ")
|
|
bytePassword, err := term.ReadPassword(int(syscall.Stdin))
|
|
if err != nil {
|
|
log.Printf("Failed to read passphrase: %v", err)
|
|
}
|
|
passphrase = string(bytePassword)
|
|
fmt.Println()
|
|
}
|
|
|
|
// 构造 SSH 配置
|
|
config := SSHConfig{
|
|
Host: *h,
|
|
Port: *p,
|
|
User: *u,
|
|
Password: *e,
|
|
PrivateKey: *k,
|
|
Passphrase: passphrase,
|
|
}
|
|
|
|
// 连接 SFTP
|
|
sftpClient, sshClient, err := connectSFTP(config)
|
|
if err != nil {
|
|
log.Printf("Failed to connect to SFTP: %v", err)
|
|
}
|
|
defer sftpClient.Close()
|
|
defer sshClient.Close()
|
|
|
|
// 上传文件
|
|
err = uploadFile(sftpClient, *l, *r)
|
|
if err != nil {
|
|
log.Printf("Failed to upload file: %v", err)
|
|
}
|
|
|
|
fmt.Println("File uploaded successfully!")
|
|
return err
|
|
}
|