summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIlja Everilä <saarni@gmail.com>2014-09-10 11:34:33 +0300
committerIlja Everilä <saarni@gmail.com>2014-09-10 11:34:33 +0300
commitad82849bbe4ef329129204d02781f737c0c79fcb (patch)
tree58bb07abaada3c96277933520fefd973c365a103
parenta23264e1dc43b1250b9b5de541ff27bd49a2b2c1 (diff)
downloadsqlalchemy-ad82849bbe4ef329129204d02781f737c0c79fcb.tar.gz
implementation for <aggregate_fun> FILTER (WHERE ...)
-rw-r--r--lib/sqlalchemy/__init__.py1
-rw-r--r--lib/sqlalchemy/sql/__init__.py1
-rw-r--r--lib/sqlalchemy/sql/compiler.py6
-rw-r--r--lib/sqlalchemy/sql/elements.py65
-rw-r--r--lib/sqlalchemy/sql/expression.py4
-rw-r--r--lib/sqlalchemy/sql/functions.py24
6 files changed, 99 insertions, 2 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py
index 853566172..1af0de3ba 100644
--- a/lib/sqlalchemy/__init__.py
+++ b/lib/sqlalchemy/__init__.py
@@ -7,6 +7,7 @@
from .sql import (
+ aggregatefilter,
alias,
and_,
asc,
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
index 4d013859c..8fbf1b536 100644
--- a/lib/sqlalchemy/sql/__init__.py
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -19,6 +19,7 @@ from .expression import (
Selectable,
TableClause,
Update,
+ aggregatefilter,
alias,
and_,
asc,
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 5149fa4fe..6ebd61e9c 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -760,6 +760,12 @@ class SQLCompiler(Compiled):
)
)
+ def visit_aggregatefilter(self, aggregatefilter, **kwargs):
+ return "%s FILTER (WHERE %s)" % (
+ aggregatefilter.func._compiler_dispatch(self, **kwargs),
+ aggregatefilter.criterion._compiler_dispatch(self, **kwargs)
+ )
+
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
return "EXTRACT(%s FROM %s)" % (
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 8ec0aa700..5562e80d7 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -2888,6 +2888,71 @@ class Over(ColumnElement):
))
+class AggregateFilter(ColumnElement):
+ """Represent an aggregate FILTER clause.
+
+ This is a special operator against aggregate functions,
+ which controls which rows are passed to it.
+ It's supported only by certain database backends.
+
+ """
+ __visit_name__ = 'aggregatefilter'
+
+ criterion = None
+
+ def __init__(self, func, *criterion):
+ """Produce an :class:`.AggregateFilter` object against a function.
+
+ Used against aggregate functions,
+ for database backends that support aggregate "FILTER" clause.
+
+ E.g.::
+
+ from sqlalchemy import aggregatefilter
+ aggregatefilter(func.count(1), MyClass.name == 'some name')
+
+ Would produce "COUNT(1) FILTER (WHERE myclass.name = 'some name')".
+
+ This function is also available from the :data:`~.expression.func`
+ construct itself via the :meth:`.FunctionElement.filter` method.
+
+ """
+ self.func = func
+ self.filter(*criterion)
+
+ def filter(self, *criterion):
+ for criterion in list(criterion):
+ criterion = _expression_literal_as_text(criterion)
+
+ if self.criterion is not None:
+ self.criterion = self.criterion & criterion
+ else:
+ self.criterion = criterion
+
+ return self
+
+ @util.memoized_property
+ def type(self):
+ return self.func.type
+
+ def get_children(self, **kwargs):
+ return [c for c in
+ (self.func, self.criterion)
+ if c is not None]
+
+ def _copy_internals(self, clone=_clone, **kw):
+ self.func = clone(self.func, **kw)
+ if self.criterion is not None:
+ self.criterion = clone(self.criterion, **kw)
+
+ @property
+ def _from_objects(self):
+ return list(itertools.chain(
+ *[c._from_objects for c in (self.func, self.criterion)
+ if c is not None]
+ ))
+
+
class Label(ColumnElement):
"""Represents a column label (AS).
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index d96f048b9..7b22cab3e 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -36,7 +36,7 @@ from .elements import ClauseElement, ColumnElement,\
True_, False_, BinaryExpression, Tuple, TypeClause, Extract, \
Grouping, not_, \
collate, literal_column, between,\
- literal, outparam, type_coerce, ClauseList
+ literal, outparam, type_coerce, ClauseList, AggregateFilter
from .elements import SavepointClause, RollbackToSavepointClause, \
ReleaseSavepointClause
@@ -97,6 +97,8 @@ outerjoin = public_factory(Join._create_outerjoin, ".expression.outerjoin")
insert = public_factory(Insert, ".expression.insert")
update = public_factory(Update, ".expression.update")
delete = public_factory(Delete, ".expression.delete")
+aggregatefilter = public_factory(
+ AggregateFilter, ".expression.aggregatefilter")
# internal functions still being called from tests and the ORM,
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 7efb1e916..46f3e27dc 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -12,7 +12,7 @@ from . import sqltypes, schema
from .base import Executable, ColumnCollection
from .elements import ClauseList, Cast, Extract, _literal_as_binds, \
literal_column, _type_from_args, ColumnElement, _clone,\
- Over, BindParameter
+ Over, BindParameter, AggregateFilter
from .selectable import FromClause, Select, Alias
from . import operators
@@ -116,6 +116,28 @@ class FunctionElement(Executable, ColumnElement, FromClause):
"""
return Over(self, partition_by=partition_by, order_by=order_by)
+ def filter(self, *criterion):
+ """Produce a FILTER clause against this function.
+
+ Used against aggregate functions,
+ for database backends that support aggregate "FILTER" clause.
+
+ The expression::
+
+ func.count(1).filter(True)
+
+ is shorthand for::
+
+ from sqlalchemy import aggregatefilter
+ aggregatefilter(func.count(1), True)
+
+ See :func:`~.expression.aggregatefilter` for a full description.
+
+ """
+ if not criterion:
+ return self
+ return AggregateFilter(self, *criterion)
+
@property
def _from_objects(self):
return self.clauses._from_objects