2025-01-24 09:59:28 +08:00

202 lines
4.8 KiB
Go

package main
import (
"flag"
"fmt"
"io"
"log"
"os"
"syscall"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
)
// 创建 SFTP 客户端
func connectSFTP(config SSHConfig) (*sftp.Client, *ssh.Client, error) {
var authMethods []ssh.AuthMethod
// 使用密码认证
if config.Password != "" {
authMethods = append(authMethods, ssh.Password(config.Password))
}
// 使用私钥认证
if config.PrivateKey != "" {
key, err := os.ReadFile(config.PrivateKey)
if err != nil {
return nil, nil, fmt.Errorf("failed to read private key file: %w", err)
}
var signer ssh.Signer
if config.Passphrase != "" {
signer, err = ssh.ParsePrivateKeyWithPassphrase(key, []byte(config.Passphrase))
} else {
signer, err = ssh.ParsePrivateKey(key)
}
if err != nil {
return nil, nil, fmt.Errorf("failed to parse private key: %w", err)
}
authMethods = append(authMethods, ssh.PublicKeys(signer))
}
if len(authMethods) == 0 {
return nil, nil, fmt.Errorf("no authentication method provided")
}
sshConfig := &ssh.ClientConfig{
User: config.User,
Auth: authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
// 建立 SSH 连接
address := fmt.Sprintf("%s:%d", config.Host, config.Port)
sshClient, err := ssh.Dial("tcp", address, sshConfig)
if err != nil {
return nil, nil, fmt.Errorf("failed to connect to SSH: %w", err)
}
// 创建 SFTP 客户端
sftpClient, err := sftp.NewClient(sshClient)
if err != nil {
sshClient.Close()
return nil, nil, fmt.Errorf("failed to create SFTP client: %w", err)
}
return sftpClient, sshClient, nil
}
// downloadFile 通过 SFTP 协议从远程服务器下载文件到本地
func downloadFile(sftpClient *sftp.Client, remotePath, localPath string) error {
// 打开远程文件
remoteFile, err := sftpClient.Open(remotePath)
if err != nil {
return fmt.Errorf("failed to open remote file: %w", err)
}
defer remoteFile.Close()
// 获取远程文件信息以确定文件大小
fileInfo, err := remoteFile.Stat()
if err != nil {
return fmt.Errorf("failed to get remote file info: %w", err)
}
// 创建本地文件(创建或覆盖)
localFile, err := os.Create(localPath)
if err != nil {
return fmt.Errorf("failed to create local file: %w", err)
}
defer localFile.Close()
// 分块复制文件内容
const bufferSize = 32 * 1024 // 32KB
buf := make([]byte, bufferSize)
var totalBytes int64 = 0
for {
n, err := remoteFile.Read(buf)
if err != nil && err != io.EOF {
return fmt.Errorf("failed to read from remote file: %w", err)
}
if n == 0 {
break
}
if _, err := localFile.Write(buf[:n]); err != nil {
return fmt.Errorf("failed to write to local file: %w", err)
}
totalBytes += int64(n)
fmt.Printf("\rDownloaded %d / %d bytes", totalBytes, fileInfo.Size())
}
fmt.Println() // 换行
return nil
}
// 上传文件
func uploadFile(sftpClient *sftp.Client, localPath, remotePath string) error {
// 打开本地文件
localFile, err := os.Open(localPath)
if err != nil {
return fmt.Errorf("failed to open local file: %w", err)
}
defer localFile.Close()
// 打开远程文件(创建或覆盖)
RemoteFile, err := sftpClient.Create(remotePath)
if err != nil {
return fmt.Errorf("failed to create remote file: %w", err)
}
defer RemoteFile.Close()
// 分块复制文件内容
const bufferSize = 32 * 1024 // 32KB
buf := make([]byte, bufferSize)
for {
n, err := localFile.Read(buf)
if err != nil && err.Error() != "EOF" {
return fmt.Errorf("failed to read local file: %w", err)
}
if n == 0 {
break
}
if _, err := RemoteFile.Write(buf[:n]); err != nil {
return fmt.Errorf("failed to write to remote file: %w", err)
}
}
return nil
}
func _sftp() error {
var err error
// 校验参数
if *h == "" || *u == "" || *l == "" || *r == "" || (*e == "" && *k == "") {
fmt.Println("Error: Missing required parameters")
flag.Usage()
os.Exit(1)
}
// 如果使用私钥且需要密码解锁,则提示用户输入密码
var passphrase string
if *k != "" {
fmt.Print("Enter passphrase for private key: ")
bytePassword, err := term.ReadPassword(int(syscall.Stdin))
if err != nil {
log.Printf("Failed to read passphrase: %v", err)
}
passphrase = string(bytePassword)
fmt.Println()
}
// 构造 SSH 配置
config := SSHConfig{
Host: *h,
Port: *p,
User: *u,
Password: *e,
PrivateKey: *k,
Passphrase: passphrase,
}
// 连接 SFTP
sftpClient, sshClient, err := connectSFTP(config)
if err != nil {
log.Printf("Failed to connect to SFTP: %v", err)
}
defer sftpClient.Close()
defer sshClient.Close()
// 上传文件
err = uploadFile(sftpClient, *l, *r)
if err != nil {
log.Printf("Failed to upload file: %v", err)
}
fmt.Println("File uploaded successfully!")
return err
}