1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
| class SSH: def __init__(self, hostname, port=SSH_PORT, username='root', pkey=None, password=None, connect_timeout=10, sock=None, proxy_hosts=None): if pkey is None and password is None: raise Exception('public key and password must have one is not None') self.client = None self.arguments = { 'hostname': hostname, 'port': port, 'username': username, 'password': password, 'pkey': RSAKey.from_private_key(StringIO(pkey), password) if isinstance(pkey, str) else pkey, 'timeout': connect_timeout, 'sock': sock, 'proxy_hosts': proxy_hosts }
def get_ssh_client(self): if self.client: return self.client if self.arguments['proxy_hosts'] and len(self.arguments['proxy_hosts']) > 0: self.arguments['port'], self.arguments['sock'] = self.connect_host_with_proxy_jump() self.arguments['hostname'] = '127.0.0.1' self.arguments.pop('proxy_hosts') self.client = SSHClient() self.client.set_missing_host_key_policy(AutoAddPolicy) self.client.connect(**self.arguments) return self.client
def list_dir_attr(self, path): with self as cli: sftp = cli.open_sftp() return sftp.listdir_attr(path)
@staticmethod def get_channel(proxy_obj, dest_ip, dest_port): """ :param proxy_obj: 跳板机的SSH对象 :param dest_ip: 目标主机的IP :param dest_port: 目标主机的端口 :return: free_port 本地端口, channel 已建立的频道. """ dest_addr = (dest_ip, dest_port) client = SSHClient() client.set_missing_host_key_policy(AutoAddPolicy) proxy_obj.arguments.pop('proxy_hosts') client.connect(**proxy_obj.arguments) transport = client.get_transport() with socketserver.TCPServer(("localhost", 0), None) as s: free_port = s.server_address[1] local_addr = ('127.0.0.1', free_port) channel = transport.open_channel("direct-tcpip", dest_addr, local_addr) return free_port, channel
def connect_host_with_proxy_jump(self): global channel, free_port end_index = len(self.arguments['proxy_hosts']) for index, host in enumerate(self.arguments['proxy_hosts'], 1): if index == 1: free_port, channel = self.get_channel(host, self.arguments['proxy_hosts'][1].arguments['hostname'], self.arguments['proxy_hosts'][1].arguments['port']) elif end_index == index: host.arguments['hostname'] = '127.0.0.1' host.arguments['sock'] = channel host.arguments['port'] = free_port free_port, channel = self.get_channel(host, self.arguments['hostname'], self.arguments['port']) else: host.arguments['hostname'] = '127.0.0.1' host.arguments['sock'] = channel host.arguments['port'] = free_port free_port, channel = self.get_channel(host, self.arguments['proxy_hosts'][index].arguments['hostname'], self.arguments['proxy_hosts'][index].arguments['port']) return free_port, channel
def __enter__(self): if self.client is not None: raise RuntimeError('Already connected') return self.get_ssh_client()
def __exit__(self, exc_type, exc_val, exc_tb): self.client.close() self.client = None
|