diff options
| -rwxr-xr-x | example.py | 1 | ||||
| -rw-r--r-- | pygerrit/client.py | 19 | ||||
| -rw-r--r-- | pygerrit/ssh.py | 31 |
3 files changed, 35 insertions, 16 deletions
@@ -56,6 +56,7 @@ def _main(): logging.basicConfig(format='%(message)s', level=logging.INFO) gerrit = GerritClient(host=options.hostname) + logging.info("Connected to Gerrit version [%s]", gerrit.gerrit_version()) gerrit.start_event_stream() try: while True: diff --git a/pygerrit/client.py b/pygerrit/client.py index a6bb77f..5ab3950 100644 --- a/pygerrit/client.py +++ b/pygerrit/client.py @@ -34,8 +34,6 @@ from pygerrit.models import Change from pygerrit.ssh import GerritSSHClient from pygerrit.stream import GerritStream -_GERRIT_VERSION_PREFIX = "gerrit version " - class GerritClient(object): @@ -46,21 +44,10 @@ class GerritClient(object): self._events = Queue() self._stream = None self._ssh_client = GerritSSHClient(host) - self._gerrit_version = self._get_gerrit_version() - - def _get_gerrit_version(self): - """ Run `gerrit version` to get the version of Gerrit connected to. - Return the version as a string. Empty if version was not returned. - - """ - _stdin, stdout, _stderr = self._ssh_client.run_gerrit_command("version") - version_string = stdout.read() - if version_string: - if version_string.startswith(_GERRIT_VERSION_PREFIX): - return version_string[len(_GERRIT_VERSION_PREFIX):].strip() - return version_string.strip() - return "" + def gerrit_version(self): + """ Get the version of Gerrit that is connected to. """ + return self._ssh_client.get_remote_version() def query(self, term): """ Run `gerrit query` with the given term. diff --git a/pygerrit/ssh.py b/pygerrit/ssh.py index 06154c4..52817c2 100644 --- a/pygerrit/ssh.py +++ b/pygerrit/ssh.py @@ -25,6 +25,7 @@ THE SOFTWARE. """ from os.path import abspath, expanduser, isfile +import re from threading import Lock from pygerrit.error import GerritError @@ -33,6 +34,15 @@ from paramiko import SSHClient, SSHConfig from paramiko.ssh_exception import SSHException +def _extract_version(version_string, pattern): + """ Extract the version from `version_string` using `pattern`. """ + if version_string: + match = pattern.match(version_string.strip()) + if match: + return match.group(1) + return "" + + class GerritSSHClient(SSHClient): """ Gerrit SSH Client, wrapping the paramiko SSH Client. """ @@ -42,6 +52,7 @@ class GerritSSHClient(SSHClient): super(GerritSSHClient, self).__init__() self.load_system_host_keys() self.lock = Lock() + self.remote_version = None configfile = expanduser("~/.ssh/config") if not isfile(configfile): @@ -70,6 +81,26 @@ class GerritSSHClient(SSHClient): username=data['user'], key_filename=key_filename) + def get_remote_version(self): + """ Return the version of the remote Gerrit server. """ + if self.remote_version is not None: + return self.remote_version + + try: + version_string = self._transport.remote_version + pattern = re.compile(r'^.*GerritCodeReview_([a-z0-9-\.]*) .*$') + self.remote_version = _extract_version(version_string, pattern) + except AttributeError: + try: + _stdin, stdout, _stderr = self.run_gerrit_command("version") + except GerritError: + self.remote_version = "" + else: + version_string = stdout.read() + pattern = re.compile(r'^gerrit version (.*)$') + self.remote_version = _extract_version(version_string, pattern) + return self.remote_version + def run_gerrit_command(self, command): """ Run the given command. |
