diff options
Diffstat (limited to 'redis/commands/graph/__init__.py')
| -rw-r--r-- | redis/commands/graph/__init__.py | 129 |
1 files changed, 111 insertions, 18 deletions
diff --git a/redis/commands/graph/__init__.py b/redis/commands/graph/__init__.py index 3736195..a882dd5 100644 --- a/redis/commands/graph/__init__.py +++ b/redis/commands/graph/__init__.py @@ -1,9 +1,13 @@ from ..helpers import quote_string, random_string, stringify_param_value -from .commands import GraphCommands +from .commands import AsyncGraphCommands, GraphCommands from .edge import Edge # noqa from .node import Node # noqa from .path import Path # noqa +DB_LABELS = "DB.LABELS" +DB_RAELATIONSHIPTYPES = "DB.RELATIONSHIPTYPES" +DB_PROPERTYKEYS = "DB.PROPERTYKEYS" + class Graph(GraphCommands): """ @@ -44,25 +48,19 @@ class Graph(GraphCommands): lbls = self.labels() # Unpack data. - self._labels = [None] * len(lbls) - for i, l in enumerate(lbls): - self._labels[i] = l[0] + self._labels = [l[0] for _, l in enumerate(lbls)] def _refresh_relations(self): rels = self.relationship_types() # Unpack data. - self._relationship_types = [None] * len(rels) - for i, r in enumerate(rels): - self._relationship_types[i] = r[0] + self._relationship_types = [r[0] for _, r in enumerate(rels)] def _refresh_attributes(self): props = self.property_keys() # Unpack data. - self._properties = [None] * len(props) - for i, p in enumerate(props): - self._properties[i] = p[0] + self._properties = [p[0] for _, p in enumerate(props)] def get_label(self, idx): """ @@ -108,12 +106,12 @@ class Graph(GraphCommands): The index of the property """ try: - propertie = self._properties[idx] + p = self._properties[idx] except IndexError: # Refresh properties. self._refresh_attributes() - propertie = self._properties[idx] - return propertie + p = self._properties[idx] + return p def add_node(self, node): """ @@ -133,6 +131,8 @@ class Graph(GraphCommands): self.edges.append(edge) def _build_params_header(self, params): + if params is None: + return "" if not isinstance(params, dict): raise TypeError("'params' must be a dict") # Header starts with "CYPHER" @@ -147,16 +147,109 @@ class Graph(GraphCommands): q = f"CALL {procedure}({','.join(args)})" y = kwagrs.get("y", None) - if y: - q += f" YIELD {','.join(y)}" + if y is not None: + q += f"YIELD {','.join(y)}" return self.query(q, read_only=read_only) def labels(self): - return self.call_procedure("db.labels", read_only=True).result_set + return self.call_procedure(DB_LABELS, read_only=True).result_set def relationship_types(self): - return self.call_procedure("db.relationshipTypes", read_only=True).result_set + return self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True).result_set def property_keys(self): - return self.call_procedure("db.propertyKeys", read_only=True).result_set + return self.call_procedure(DB_PROPERTYKEYS, read_only=True).result_set + + +class AsyncGraph(Graph, AsyncGraphCommands): + """Async version for Graph""" + + async def _refresh_labels(self): + lbls = await self.labels() + + # Unpack data. + self._labels = [l[0] for _, l in enumerate(lbls)] + + async def _refresh_attributes(self): + props = await self.property_keys() + + # Unpack data. + self._properties = [p[0] for _, p in enumerate(props)] + + async def _refresh_relations(self): + rels = await self.relationship_types() + + # Unpack data. + self._relationship_types = [r[0] for _, r in enumerate(rels)] + + async def get_label(self, idx): + """ + Returns a label by it's index + + Args: + + idx: + The index of the label + """ + try: + label = self._labels[idx] + except IndexError: + # Refresh labels. + await self._refresh_labels() + label = self._labels[idx] + return label + + async def get_property(self, idx): + """ + Returns a property by it's index + + Args: + + idx: + The index of the property + """ + try: + p = self._properties[idx] + except IndexError: + # Refresh properties. + await self._refresh_attributes() + p = self._properties[idx] + return p + + async def get_relation(self, idx): + """ + Returns a relationship type by it's index + + Args: + + idx: + The index of the relation + """ + try: + relationship_type = self._relationship_types[idx] + except IndexError: + # Refresh relationship types. + await self._refresh_relations() + relationship_type = self._relationship_types[idx] + return relationship_type + + async def call_procedure(self, procedure, *args, read_only=False, **kwagrs): + args = [quote_string(arg) for arg in args] + q = f"CALL {procedure}({','.join(args)})" + + y = kwagrs.get("y", None) + if y is not None: + f"YIELD {','.join(y)}" + return await self.query(q, read_only=read_only) + + async def labels(self): + return ((await self.call_procedure(DB_LABELS, read_only=True))).result_set + + async def property_keys(self): + return (await self.call_procedure(DB_PROPERTYKEYS, read_only=True)).result_set + + async def relationship_types(self): + return ( + await self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True) + ).result_set |
