summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYurii Karabas <1998uriyyo@gmail.com>2023-04-14 13:37:40 -0400
committersqla-tester <sqla-tester@sqlalchemy.org>2023-04-14 13:37:40 -0400
commit609f432563954167b8f0148e43c70c08380e8ba4 (patch)
treeaa803817ae45a8ce68805cb7ffddb6c62a2c8aed
parent6f3c741e41e13624d63f4faa4bdcec8466bda7f4 (diff)
downloadsqlalchemy-609f432563954167b8f0148e43c70c08380e8ba4.tar.gz
Add intersection method to Range class
<!-- Provide a general summary of your proposed changes in the Title field above --> ### Description Fixes: #9509 <!-- Describe your changes in detail --> ### Checklist <!-- go over following points. check them with an `x` if they do apply, (they turn into clickable checkboxes once the PR is submitted, so no need to do everything at once) --> This pull request is: - [ ] A documentation / typographical error fix - Good to go, no issue or tests are needed - [ ] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #<issue number>` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [x] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #<issue number>` in the commit message - please include tests. **Have a nice day!** Closes: #9510 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9510 Pull-request-sha: 596648e7989327eef1807057519b2295b48f1adf Change-Id: I7b527edda09eb78dee6948edd4d49b00ea437011
-rw-r--r--doc/build/changelog/unreleased_20/9509.rst6
-rw-r--r--lib/sqlalchemy/dialects/postgresql/ranges.py46
-rw-r--r--test/dialect/postgresql/test_types.py44
3 files changed, 96 insertions, 0 deletions
diff --git a/doc/build/changelog/unreleased_20/9509.rst b/doc/build/changelog/unreleased_20/9509.rst
new file mode 100644
index 000000000..b50a4a028
--- /dev/null
+++ b/doc/build/changelog/unreleased_20/9509.rst
@@ -0,0 +1,6 @@
+.. change::
+ :tags: usecase, postgresql
+ :tickets: 9509
+
+ Add missing :meth:`_postgresql.Range.intersection` method.
+ Pull request courtesy Yurii Karabas.
diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py
index 3cf2ceb44..cefd280ea 100644
--- a/lib/sqlalchemy/dialects/postgresql/ranges.py
+++ b/lib/sqlalchemy/dialects/postgresql/ranges.py
@@ -641,6 +641,43 @@ class Range(Generic[_T]):
def __sub__(self, other: Range[_T]) -> Range[_T]:
return self.difference(other)
+ def intersection(self, other: Range[_T]) -> Range[_T]:
+ """Compute the intersection of this range with the `other`."""
+ if self.empty or other.empty or not self.overlaps(other):
+ return Range(None, None, empty=True)
+
+ slower = self.lower
+ slower_b = self.bounds[0]
+ supper = self.upper
+ supper_b = self.bounds[1]
+ olower = other.lower
+ olower_b = other.bounds[0]
+ oupper = other.upper
+ oupper_b = other.bounds[1]
+
+ if self._compare_edges(slower, slower_b, olower, olower_b) < 0:
+ rlower = olower
+ rlower_b = olower_b
+ else:
+ rlower = slower
+ rlower_b = slower_b
+
+ if self._compare_edges(supper, supper_b, oupper, oupper_b) > 0:
+ rupper = oupper
+ rupper_b = oupper_b
+ else:
+ rupper = supper
+ rupper_b = supper_b
+
+ return Range(
+ rlower,
+ rupper,
+ bounds=cast(_BoundsType, rlower_b + rupper_b),
+ )
+
+ def __mul__(self, other: Range[_T]) -> Range[_T]:
+ return self.intersection(other)
+
def __str__(self) -> str:
return self._stringify()
@@ -809,6 +846,15 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]):
__sub__ = difference
+ def intersection(self, other: Any) -> ColumnElement[Range[_T]]:
+ """Range expression. Returns the intersection of the two ranges.
+ Will raise an exception if the resulting range is not
+ contiguous.
+ """
+ return self.expr.op("*")(other) # type: ignore
+
+ __mul__ = intersection
+
class AbstractRangeImpl(AbstractRange[Range[_T]]):
"""Marker for AbstractRange that will apply a subclass-specific
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index 5f5be3c57..f322bf354 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -4499,6 +4499,50 @@ class _RangeComparisonFixtures(_RangeTests):
@testing.combinations(
*_common_ranges_to_test,
lambda r, e: Range(r.lower, r.lower, bounds="[]"),
+ lambda r, e: Range(r.lower - e, r.upper - e, bounds="[]"),
+ lambda r, e: Range(r.lower - e, r.upper + e, bounds="[)"),
+ lambda r, e: Range(r.lower - e, r.upper + e, bounds="[]"),
+ argnames="r1t",
+ )
+ @testing.combinations(
+ *_common_ranges_to_test,
+ lambda r, e: Range(r.lower, r.lower, bounds="[]"),
+ lambda r, e: Range(r.lower, r.upper - e, bounds="(]"),
+ lambda r, e: Range(r.lower, r.lower + e, bounds="[)"),
+ lambda r, e: Range(r.lower - e, r.lower, bounds="(]"),
+ lambda r, e: Range(r.lower - e, r.lower + e, bounds="()"),
+ lambda r, e: Range(r.lower, r.upper, bounds="[]"),
+ lambda r, e: Range(r.lower, r.upper, bounds="()"),
+ argnames="r2t",
+ )
+ def test_intersection(self, connection, r1t, r2t):
+ r1 = r1t(self._data_obj(), self._epsilon)
+ r2 = r2t(self._data_obj(), self._epsilon)
+
+ RANGE = self._col_type
+ range_typ = self._col_str
+
+ q = select(
+ cast(r1, RANGE).intersection(r2),
+ )
+ validate_q = select(
+ literal_column(f"'{r1}'::{range_typ}*'{r2}'::{range_typ}", RANGE),
+ )
+
+ pg_res = connection.execute(q).scalar()
+
+ validate_intersection = connection.execute(validate_q).scalar()
+ eq_(pg_res, validate_intersection)
+ py_res = r1.intersection(r2)
+ eq_(
+ py_res,
+ pg_res,
+ f"{r1}.intersection({r2}): got {py_res}, expected {pg_res}",
+ )
+
+ @testing.combinations(
+ *_common_ranges_to_test,
+ lambda r, e: Range(r.lower, r.lower, bounds="[]"),
argnames="r1t",
)
@testing.combinations(