summaryrefslogtreecommitdiff
path: root/redis/commands/graph/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'redis/commands/graph/__init__.py')
-rw-r--r--redis/commands/graph/__init__.py129
1 files changed, 111 insertions, 18 deletions
diff --git a/redis/commands/graph/__init__.py b/redis/commands/graph/__init__.py
index 3736195..a882dd5 100644
--- a/redis/commands/graph/__init__.py
+++ b/redis/commands/graph/__init__.py
@@ -1,9 +1,13 @@
from ..helpers import quote_string, random_string, stringify_param_value
-from .commands import GraphCommands
+from .commands import AsyncGraphCommands, GraphCommands
from .edge import Edge # noqa
from .node import Node # noqa
from .path import Path # noqa
+DB_LABELS = "DB.LABELS"
+DB_RAELATIONSHIPTYPES = "DB.RELATIONSHIPTYPES"
+DB_PROPERTYKEYS = "DB.PROPERTYKEYS"
+
class Graph(GraphCommands):
"""
@@ -44,25 +48,19 @@ class Graph(GraphCommands):
lbls = self.labels()
# Unpack data.
- self._labels = [None] * len(lbls)
- for i, l in enumerate(lbls):
- self._labels[i] = l[0]
+ self._labels = [l[0] for _, l in enumerate(lbls)]
def _refresh_relations(self):
rels = self.relationship_types()
# Unpack data.
- self._relationship_types = [None] * len(rels)
- for i, r in enumerate(rels):
- self._relationship_types[i] = r[0]
+ self._relationship_types = [r[0] for _, r in enumerate(rels)]
def _refresh_attributes(self):
props = self.property_keys()
# Unpack data.
- self._properties = [None] * len(props)
- for i, p in enumerate(props):
- self._properties[i] = p[0]
+ self._properties = [p[0] for _, p in enumerate(props)]
def get_label(self, idx):
"""
@@ -108,12 +106,12 @@ class Graph(GraphCommands):
The index of the property
"""
try:
- propertie = self._properties[idx]
+ p = self._properties[idx]
except IndexError:
# Refresh properties.
self._refresh_attributes()
- propertie = self._properties[idx]
- return propertie
+ p = self._properties[idx]
+ return p
def add_node(self, node):
"""
@@ -133,6 +131,8 @@ class Graph(GraphCommands):
self.edges.append(edge)
def _build_params_header(self, params):
+ if params is None:
+ return ""
if not isinstance(params, dict):
raise TypeError("'params' must be a dict")
# Header starts with "CYPHER"
@@ -147,16 +147,109 @@ class Graph(GraphCommands):
q = f"CALL {procedure}({','.join(args)})"
y = kwagrs.get("y", None)
- if y:
- q += f" YIELD {','.join(y)}"
+ if y is not None:
+ q += f"YIELD {','.join(y)}"
return self.query(q, read_only=read_only)
def labels(self):
- return self.call_procedure("db.labels", read_only=True).result_set
+ return self.call_procedure(DB_LABELS, read_only=True).result_set
def relationship_types(self):
- return self.call_procedure("db.relationshipTypes", read_only=True).result_set
+ return self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True).result_set
def property_keys(self):
- return self.call_procedure("db.propertyKeys", read_only=True).result_set
+ return self.call_procedure(DB_PROPERTYKEYS, read_only=True).result_set
+
+
+class AsyncGraph(Graph, AsyncGraphCommands):
+ """Async version for Graph"""
+
+ async def _refresh_labels(self):
+ lbls = await self.labels()
+
+ # Unpack data.
+ self._labels = [l[0] for _, l in enumerate(lbls)]
+
+ async def _refresh_attributes(self):
+ props = await self.property_keys()
+
+ # Unpack data.
+ self._properties = [p[0] for _, p in enumerate(props)]
+
+ async def _refresh_relations(self):
+ rels = await self.relationship_types()
+
+ # Unpack data.
+ self._relationship_types = [r[0] for _, r in enumerate(rels)]
+
+ async def get_label(self, idx):
+ """
+ Returns a label by it's index
+
+ Args:
+
+ idx:
+ The index of the label
+ """
+ try:
+ label = self._labels[idx]
+ except IndexError:
+ # Refresh labels.
+ await self._refresh_labels()
+ label = self._labels[idx]
+ return label
+
+ async def get_property(self, idx):
+ """
+ Returns a property by it's index
+
+ Args:
+
+ idx:
+ The index of the property
+ """
+ try:
+ p = self._properties[idx]
+ except IndexError:
+ # Refresh properties.
+ await self._refresh_attributes()
+ p = self._properties[idx]
+ return p
+
+ async def get_relation(self, idx):
+ """
+ Returns a relationship type by it's index
+
+ Args:
+
+ idx:
+ The index of the relation
+ """
+ try:
+ relationship_type = self._relationship_types[idx]
+ except IndexError:
+ # Refresh relationship types.
+ await self._refresh_relations()
+ relationship_type = self._relationship_types[idx]
+ return relationship_type
+
+ async def call_procedure(self, procedure, *args, read_only=False, **kwagrs):
+ args = [quote_string(arg) for arg in args]
+ q = f"CALL {procedure}({','.join(args)})"
+
+ y = kwagrs.get("y", None)
+ if y is not None:
+ f"YIELD {','.join(y)}"
+ return await self.query(q, read_only=read_only)
+
+ async def labels(self):
+ return ((await self.call_procedure(DB_LABELS, read_only=True))).result_set
+
+ async def property_keys(self):
+ return (await self.call_procedure(DB_PROPERTYKEYS, read_only=True)).result_set
+
+ async def relationship_types(self):
+ return (
+ await self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True)
+ ).result_set