diff options
| -rw-r--r-- | pygerrit/ssh.py | 48 |
1 files changed, 31 insertions, 17 deletions
diff --git a/pygerrit/ssh.py b/pygerrit/ssh.py index 0436165..d508d56 100644 --- a/pygerrit/ssh.py +++ b/pygerrit/ssh.py @@ -65,9 +65,14 @@ class GerritSSHClient(SSHClient): def __init__(self, hostname): """ Initialise and connect to SSH. """ super(GerritSSHClient, self).__init__() - self.load_system_host_keys() self.remote_version = None + self.hostname = hostname + self.connected = False + def _connect(self): + if self.connected: + return + self.load_system_host_keys() configfile = expanduser("~/.ssh/config") if not isfile(configfile): raise GerritError("ssh config file '%s' does not exist" % @@ -75,9 +80,9 @@ class GerritSSHClient(SSHClient): config = SSHConfig() config.parse(open(configfile)) - data = config.lookup(hostname) + data = config.lookup(self.hostname) if not data: - raise GerritError("No ssh config for host %s" % hostname) + raise GerritError("No ssh config for host %s" % self.hostname) if not 'hostname' in data or not 'port' in data or not 'user' in data: raise GerritError("Missing configuration data in %s" % configfile) key_filename = None @@ -95,27 +100,35 @@ class GerritSSHClient(SSHClient): port=port, username=data['user'], key_filename=key_filename) + self.connected = True except socket.error as e: raise GerritError("Failed to connect to server: %s" % e) - def get_remote_version(self): - """ Return the version of the remote Gerrit server. """ - if self.remote_version is not None: - return self.remote_version - try: version_string = self._transport.remote_version pattern = re.compile(r'^.*GerritCodeReview_([a-z0-9-\.]*) .*$') self.remote_version = _extract_version(version_string, pattern) except AttributeError: - try: - result = self.run_gerrit_command("version") - except GerritError: - self.remote_version = "" - else: - version_string = result.stdout.read() - pattern = re.compile(r'^gerrit version (.*)$') - self.remote_version = _extract_version(version_string, pattern) + self.remote_version = None + + def exec_command(self, command, bufsize=1): + """ Execute the command. + + Make sure we're connected and then execute the command. + + Return a tuple of stdin, stdout, stderr. + + """ + self._connect() + return super(GerritSSHClient, self).exec_command(command, bufsize) + + def get_remote_version(self): + """ Return the version of the remote Gerrit server. """ + if self.remote_version is None: + result = self.run_gerrit_command("version") + version_string = result.stdout.read() + pattern = re.compile(r'^gerrit version (.*)$') + self.remote_version = _extract_version(version_string, pattern) return self.remote_version def run_gerrit_command(self, command): @@ -130,7 +143,8 @@ class GerritSSHClient(SSHClient): else: gerrit_command.append(command) try: - stdin, stdout, stderr = self.exec_command(" ".join(gerrit_command)) + c = " ".join(gerrit_command) + stdin, stdout, stderr = self.exec_command(c) except SSHException as err: raise GerritError("Command execution error: %s" % err) return GerritSSHCommandResult(stdin, stdout, stderr) |
