summaryrefslogtreecommitdiff
path: root/test/utils/graph.py
blob: 42f62a1893f931c5b02103db043a5e49283cecd4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import logging
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Optional, Tuple, Union

from rdflib.graph import Graph
from rdflib.util import guess_format

GraphSourceType = Union["GraphSource", Path]


@dataclass(frozen=True)
class GraphSource:
    path: Path
    format: str
    public_id: Optional[str] = None

    @classmethod
    def from_path(cls, path: Path, public_id: Optional[str] = None) -> "GraphSource":
        format = guess_format(f"{path}")
        if format is None:
            raise ValueError(f"could not guess format for source {path}")

        return cls(path, format, public_id)

    @classmethod
    def from_paths(cls, *paths: Path) -> Tuple["GraphSource", ...]:
        result = []
        for path in paths:
            result.append(cls.from_path(path))
        return tuple(result)

    @classmethod
    def from_source(
        cls, source: GraphSourceType, public_id: Optional[str] = None
    ) -> "GraphSource":
        logging.debug("source(%s) = %r", id(source), source)
        if isinstance(source, Path):
            source = GraphSource.from_path(source)
        return source

    def load(
        self, graph: Optional[Graph] = None, public_id: Optional[str] = None
    ) -> Graph:
        if graph is None:
            graph = Graph()
        graph.parse(
            source=self.path,
            format=self.format,
            publicID=self.public_id if public_id is None else public_id,
        )
        return graph


def load_sources(
    *sources: GraphSourceType,
    graph: Optional[Graph] = None,
    public_id: Optional[str] = None,
) -> Graph:
    if graph is None:
        graph = Graph()
    for source in sources:
        GraphSource.from_source(source).load(graph, public_id)
    return graph


@lru_cache(maxsize=None)
def cached_graph(
    sources: Tuple[Union[GraphSource, Path], ...], public_id: Optional[str] = None
) -> Graph:
    return load_sources(*sources, public_id=public_id)