summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_13/4386.rst11
-rw-r--r--lib/sqlalchemy/sql/compiler.py17
-rw-r--r--lib/sqlalchemy/sql/functions.py15
-rw-r--r--test/sql/test_functions.py33
4 files changed, 48 insertions, 28 deletions
diff --git a/doc/build/changelog/unreleased_13/4386.rst b/doc/build/changelog/unreleased_13/4386.rst
new file mode 100644
index 000000000..24e9f848b
--- /dev/null
+++ b/doc/build/changelog/unreleased_13/4386.rst
@@ -0,0 +1,11 @@
+.. change::
+ :tags: feature, sql
+ :tickets: 4386
+
+ Amended the :class:`.AnsiFunction` class, the base of common SQL
+ functions like ``CURRENT_TIMESTAMP``, to accept positional arguments
+ like a regular ad-hoc function. This to suit the case that many of
+ these functions on specific backends accept arguments such as
+ "fractional seconds" precision and such. If the function is created
+ with arguments, it renders the the parenthesis and the arguments. If
+ no arguents are present, the compiler generates the non-parenthesized form.
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index c2a23a758..80ed707ed 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -111,20 +111,20 @@ OPERATORS = {
}
FUNCTIONS = {
- functions.coalesce: 'coalesce%(expr)s',
+ functions.coalesce: 'coalesce',
functions.current_date: 'CURRENT_DATE',
functions.current_time: 'CURRENT_TIME',
functions.current_timestamp: 'CURRENT_TIMESTAMP',
functions.current_user: 'CURRENT_USER',
functions.localtime: 'LOCALTIME',
functions.localtimestamp: 'LOCALTIMESTAMP',
- functions.random: 'random%(expr)s',
+ functions.random: 'random',
functions.sysdate: 'sysdate',
functions.session_user: 'SESSION_USER',
functions.user: 'USER',
- functions.cube: 'CUBE%(expr)s',
- functions.rollup: 'ROLLUP%(expr)s',
- functions.grouping_sets: 'GROUPING SETS%(expr)s',
+ functions.cube: 'CUBE',
+ functions.rollup: 'ROLLUP',
+ functions.grouping_sets: 'GROUPING SETS',
}
EXTRACT_MAP = {
@@ -927,7 +927,12 @@ class SQLCompiler(Compiled):
if disp:
return disp(func, **kwargs)
else:
- name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s")
+ name = FUNCTIONS.get(func.__class__, None)
+ if name:
+ if func._has_args:
+ name += "%(expr)s"
+ else:
+ name = func.name + "%(expr)s"
return ".".join(list(func.packagenames) + [name]) % \
{'expr': self.function_argspec(func, **kwargs)}
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 5cea7750a..4b4d2d463 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -54,10 +54,13 @@ class FunctionElement(Executable, ColumnElement, FromClause):
packagenames = ()
+ _has_args = False
+
def __init__(self, *clauses, **kwargs):
"""Construct a :class:`.FunctionElement`.
"""
args = [_literal_as_binds(c, self.name) for c in clauses]
+ self._has_args = self._has_args or bool(args)
self.clause_expr = ClauseList(
operator=operators.comma_op,
group_contents=True, *args).\
@@ -635,6 +638,7 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)):
parsed_args = kwargs.pop('_parsed_args', None)
if parsed_args is None:
parsed_args = [_literal_as_binds(c, self.name) for c in args]
+ self._has_args = self._has_args or bool(parsed_args)
self.packagenames = []
self._bind = kwargs.get('bind', None)
self.clause_expr = ClauseList(
@@ -671,8 +675,8 @@ class next_value(GenericFunction):
class AnsiFunction(GenericFunction):
- def __init__(self, **kwargs):
- GenericFunction.__init__(self, **kwargs)
+ def __init__(self, *args, **kwargs):
+ GenericFunction.__init__(self, *args, **kwargs)
class ReturnTypeFromArgs(GenericFunction):
@@ -686,7 +690,7 @@ class ReturnTypeFromArgs(GenericFunction):
class coalesce(ReturnTypeFromArgs):
- pass
+ _has_args = True
class max(ReturnTypeFromArgs):
@@ -717,7 +721,7 @@ class char_length(GenericFunction):
class random(GenericFunction):
- pass
+ _has_args = True
class count(GenericFunction):
@@ -937,6 +941,7 @@ class cube(GenericFunction):
.. versionadded:: 1.2
"""
+ _has_args = True
class rollup(GenericFunction):
@@ -952,6 +957,7 @@ class rollup(GenericFunction):
.. versionadded:: 1.2
"""
+ _has_args = True
class grouping_sets(GenericFunction):
@@ -984,3 +990,4 @@ class grouping_sets(GenericFunction):
.. versionadded:: 1.2
"""
+ _has_args = True
diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py
index 48d5fc37f..ffc72e9ee 100644
--- a/test/sql/test_functions.py
+++ b/test/sql/test_functions.py
@@ -14,7 +14,7 @@ import decimal
from sqlalchemy import testing
from sqlalchemy.testing import fixtures, AssertsCompiledSQL, engines
from sqlalchemy.dialects import sqlite, postgresql, mysql, oracle
-from sqlalchemy.testing import assert_raises_message
+from sqlalchemy.testing import assert_raises_message, assert_raises
table1 = table('mytable',
column('myid', Integer),
@@ -133,12 +133,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
pass
assert isinstance(func.myfunc(), myfunc)
+ self.assert_compile(func.myfunc(), "myfunc()")
def test_custom_type(self):
class myfunc(GenericFunction):
type = DateTime
assert isinstance(func.myfunc().type, DateTime)
+ self.assert_compile(func.myfunc(), "myfunc()")
def test_custom_legacy_type(self):
# in case someone was using this system
@@ -228,24 +230,19 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
c = column('abc')
self.assert_compile(func.count(c), 'count(abc)')
- def test_constructor(self):
- try:
- func.current_timestamp('somearg')
- assert False
- except TypeError:
- assert True
-
- try:
- func.char_length('a', 'b')
- assert False
- except TypeError:
- assert True
+ def test_ansi_functions_with_args(self):
+ ct = func.current_timestamp('somearg')
+ self.assert_compile(ct, "CURRENT_TIMESTAMP(:current_timestamp_1)")
- try:
- func.char_length()
- assert False
- except TypeError:
- assert True
+ def test_char_length_fixed_args(self):
+ assert_raises(
+ TypeError,
+ func.char_length, 'a', 'b'
+ )
+ assert_raises(
+ TypeError,
+ func.char_length
+ )
def test_return_type_detection(self):