diff options
Diffstat (limited to 'docker/context/context.py')
-rw-r--r-- | docker/context/context.py | 106 |
1 files changed, 69 insertions, 37 deletions
diff --git a/docker/context/context.py b/docker/context/context.py index b2af20c..dbaa01c 100644 --- a/docker/context/context.py +++ b/docker/context/context.py @@ -11,35 +11,48 @@ from docker.context.config import get_context_host class Context: """A context.""" + def __init__(self, name, orchestrator=None, host=None, endpoints=None, tls=False): if not name: raise Exception("Name not provided") self.name = name + self.context_type = None self.orchestrator = orchestrator + self.endpoints = {} + self.tls_cfg = {} + self.meta_path = "IN MEMORY" + self.tls_path = "IN MEMORY" + if not endpoints: + # set default docker endpoint if no endpoint is set default_endpoint = "docker" if ( not orchestrator or orchestrator == "swarm" ) else orchestrator + self.endpoints = { default_endpoint: { "Host": get_context_host(host, tls), "SkipTLSVerify": not tls } } - else: - for k, v in endpoints.items(): - ekeys = v.keys() - for param in ["Host", "SkipTLSVerify"]: - if param not in ekeys: - raise ContextException( - "Missing parameter {} from endpoint {}".format( - param, k)) - self.endpoints = endpoints + return - self.tls_cfg = {} - self.meta_path = "IN MEMORY" - self.tls_path = "IN MEMORY" + # check docker endpoints + for k, v in endpoints.items(): + if not isinstance(v, dict): + # unknown format + raise ContextException("""Unknown endpoint format for + context {}: {}""".format(name, v)) + + self.endpoints[k] = v + if k != "docker": + continue + + self.endpoints[k]["Host"] = v.get("Host", get_context_host( + host, tls)) + self.endpoints[k]["SkipTLSVerify"] = bool(v.get( + "SkipTLSVerify", not tls)) def set_endpoint( self, name="docker", host=None, tls_cfg=None, @@ -59,9 +72,13 @@ class Context: @classmethod def load_context(cls, name): - name, orchestrator, endpoints = Context._load_meta(name) - if name: - instance = cls(name, orchestrator, endpoints=endpoints) + meta = Context._load_meta(name) + if meta: + instance = cls( + meta["Name"], + orchestrator=meta["Metadata"].get("StackOrchestrator", None), + endpoints=meta.get("Endpoints", None)) + instance.context_type = meta["Metadata"].get("Type", None) instance._load_certs() instance.meta_path = get_meta_dir(name) return instance @@ -69,26 +86,30 @@ class Context: @classmethod def _load_meta(cls, name): - metadata = {} meta_file = get_meta_file(name) - if os.path.isfile(meta_file): + if not os.path.isfile(meta_file): + return None + + metadata = {} + try: with open(meta_file) as f: - try: - with open(meta_file) as f: - metadata = json.load(f) - for k, v in metadata["Endpoints"].items(): - metadata["Endpoints"][k]["SkipTLSVerify"] = bool( - v["SkipTLSVerify"]) - except (IOError, KeyError, ValueError) as e: - # unknown format - raise Exception("""Detected corrupted meta file for - context {} : {}""".format(name, e)) - - return ( - metadata["Name"], - metadata["Metadata"].get("StackOrchestrator", None), - metadata["Endpoints"]) - return None, None, None + metadata = json.load(f) + except (OSError, KeyError, ValueError) as e: + # unknown format + raise Exception("""Detected corrupted meta file for + context {} : {}""".format(name, e)) + + # for docker endpoints, set defaults for + # Host and SkipTLSVerify fields + for k, v in metadata["Endpoints"].items(): + if k != "docker": + continue + metadata["Endpoints"][k]["Host"] = v.get( + "Host", get_context_host(None, False)) + metadata["Endpoints"][k]["SkipTLSVerify"] = bool( + v.get("SkipTLSVerify", True)) + + return metadata def _load_certs(self): certs = {} @@ -107,8 +128,12 @@ class Context: elif filename.startswith("key"): key = os.path.join(tls_dir, endpoint, filename) if all([ca_cert, cert, key]): + verify = None + if endpoint == "docker" and not self.endpoints["docker"].get( + "SkipTLSVerify", False): + verify = True certs[endpoint] = TLSConfig( - client_cert=(cert, key), ca_cert=ca_cert) + client_cert=(cert, key), ca_cert=ca_cert, verify=verify) self.tls_cfg = certs self.tls_path = tls_dir @@ -146,7 +171,7 @@ class Context: rmtree(self.tls_path) def __repr__(self): - return "<%s: '%s'>" % (self.__class__.__name__, self.name) + return f"<{self.__class__.__name__}: '{self.name}'>" def __str__(self): return json.dumps(self.__call__(), indent=2) @@ -157,6 +182,9 @@ class Context: result.update(self.Storage) return result + def is_docker_host(self): + return self.context_type is None + @property def Name(self): return self.name @@ -164,8 +192,12 @@ class Context: @property def Host(self): if not self.orchestrator or self.orchestrator == "swarm": - return self.endpoints["docker"]["Host"] - return self.endpoints[self.orchestrator]["Host"] + endpoint = self.endpoints.get("docker", None) + if endpoint: + return endpoint.get("Host", None) + return None + + return self.endpoints[self.orchestrator].get("Host", None) @property def Orchestrator(self): |