package main import ( "flag" "fmt" "io" "log" "os" "syscall" "github.com/pkg/sftp" "golang.org/x/crypto/ssh" "golang.org/x/term" ) // 创建 SFTP 客户端 func connectSFTP(config SSHConfig) (*sftp.Client, *ssh.Client, error) { var authMethods []ssh.AuthMethod // 使用密码认证 if config.Password != "" { authMethods = append(authMethods, ssh.Password(config.Password)) } // 使用私钥认证 if config.PrivateKey != "" { key, err := os.ReadFile(config.PrivateKey) if err != nil { return nil, nil, fmt.Errorf("failed to read private key file: %w", err) } var signer ssh.Signer if config.Passphrase != "" { signer, err = ssh.ParsePrivateKeyWithPassphrase(key, []byte(config.Passphrase)) } else { signer, err = ssh.ParsePrivateKey(key) } if err != nil { return nil, nil, fmt.Errorf("failed to parse private key: %w", err) } authMethods = append(authMethods, ssh.PublicKeys(signer)) } if len(authMethods) == 0 { return nil, nil, fmt.Errorf("no authentication method provided") } sshConfig := &ssh.ClientConfig{ User: config.User, Auth: authMethods, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } // 建立 SSH 连接 address := fmt.Sprintf("%s:%d", config.Host, config.Port) sshClient, err := ssh.Dial("tcp", address, sshConfig) if err != nil { return nil, nil, fmt.Errorf("failed to connect to SSH: %w", err) } // 创建 SFTP 客户端 sftpClient, err := sftp.NewClient(sshClient) if err != nil { sshClient.Close() return nil, nil, fmt.Errorf("failed to create SFTP client: %w", err) } return sftpClient, sshClient, nil } // downloadFile 通过 SFTP 协议从远程服务器下载文件到本地 func downloadFile(sftpClient *sftp.Client, remotePath, localPath string) error { // 打开远程文件 remoteFile, err := sftpClient.Open(remotePath) if err != nil { return fmt.Errorf("failed to open remote file: %w", err) } defer remoteFile.Close() // 获取远程文件信息以确定文件大小 fileInfo, err := remoteFile.Stat() if err != nil { return fmt.Errorf("failed to get remote file info: %w", err) } // 创建本地文件(创建或覆盖) localFile, err := os.Create(localPath) if err != nil { return fmt.Errorf("failed to create local file: %w", err) } defer localFile.Close() // 分块复制文件内容 const bufferSize = 32 * 1024 // 32KB buf := make([]byte, bufferSize) var totalBytes int64 = 0 for { n, err := remoteFile.Read(buf) if err != nil && err != io.EOF { return fmt.Errorf("failed to read from remote file: %w", err) } if n == 0 { break } if _, err := localFile.Write(buf[:n]); err != nil { return fmt.Errorf("failed to write to local file: %w", err) } totalBytes += int64(n) fmt.Printf("\rDownloaded %d / %d bytes", totalBytes, fileInfo.Size()) } fmt.Println() // 换行 return nil } // 上传文件 func uploadFile(sftpClient *sftp.Client, localPath, remotePath string) error { // 打开本地文件 localFile, err := os.Open(localPath) if err != nil { return fmt.Errorf("failed to open local file: %w", err) } defer localFile.Close() // 打开远程文件(创建或覆盖) RemoteFile, err := sftpClient.Create(remotePath) if err != nil { return fmt.Errorf("failed to create remote file: %w", err) } defer RemoteFile.Close() // 分块复制文件内容 const bufferSize = 32 * 1024 // 32KB buf := make([]byte, bufferSize) for { n, err := localFile.Read(buf) if err != nil && err.Error() != "EOF" { return fmt.Errorf("failed to read local file: %w", err) } if n == 0 { break } if _, err := RemoteFile.Write(buf[:n]); err != nil { return fmt.Errorf("failed to write to remote file: %w", err) } } return nil } func _sftp() error { var err error // 校验参数 if *h == "" || *u == "" || *l == "" || *r == "" || (*e == "" && *k == "") { fmt.Println("Error: Missing required parameters") flag.Usage() os.Exit(1) } // 如果使用私钥且需要密码解锁,则提示用户输入密码 var passphrase string if *k != "" { fmt.Print("Enter passphrase for private key: ") bytePassword, err := term.ReadPassword(int(syscall.Stdin)) if err != nil { log.Printf("Failed to read passphrase: %v", err) } passphrase = string(bytePassword) fmt.Println() } // 构造 SSH 配置 config := SSHConfig{ Host: *h, Port: *p, User: *u, Password: *e, PrivateKey: *k, Passphrase: passphrase, } // 连接 SFTP sftpClient, sshClient, err := connectSFTP(config) if err != nil { log.Printf("Failed to connect to SFTP: %v", err) } defer sftpClient.Close() defer sshClient.Close() // 上传文件 err = uploadFile(sftpClient, *l, *r) if err != nil { log.Printf("Failed to upload file: %v", err) } fmt.Println("File uploaded successfully!") return err }