diff options
| -rw-r--r-- | doc/build/changelog/unreleased_13/4386.rst | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 17 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/functions.py | 15 | ||||
| -rw-r--r-- | test/sql/test_functions.py | 33 |
4 files changed, 48 insertions, 28 deletions
diff --git a/doc/build/changelog/unreleased_13/4386.rst b/doc/build/changelog/unreleased_13/4386.rst new file mode 100644 index 000000000..24e9f848b --- /dev/null +++ b/doc/build/changelog/unreleased_13/4386.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: feature, sql + :tickets: 4386 + + Amended the :class:`.AnsiFunction` class, the base of common SQL + functions like ``CURRENT_TIMESTAMP``, to accept positional arguments + like a regular ad-hoc function. This to suit the case that many of + these functions on specific backends accept arguments such as + "fractional seconds" precision and such. If the function is created + with arguments, it renders the the parenthesis and the arguments. If + no arguents are present, the compiler generates the non-parenthesized form. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index c2a23a758..80ed707ed 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -111,20 +111,20 @@ OPERATORS = { } FUNCTIONS = { - functions.coalesce: 'coalesce%(expr)s', + functions.coalesce: 'coalesce', functions.current_date: 'CURRENT_DATE', functions.current_time: 'CURRENT_TIME', functions.current_timestamp: 'CURRENT_TIMESTAMP', functions.current_user: 'CURRENT_USER', functions.localtime: 'LOCALTIME', functions.localtimestamp: 'LOCALTIMESTAMP', - functions.random: 'random%(expr)s', + functions.random: 'random', functions.sysdate: 'sysdate', functions.session_user: 'SESSION_USER', functions.user: 'USER', - functions.cube: 'CUBE%(expr)s', - functions.rollup: 'ROLLUP%(expr)s', - functions.grouping_sets: 'GROUPING SETS%(expr)s', + functions.cube: 'CUBE', + functions.rollup: 'ROLLUP', + functions.grouping_sets: 'GROUPING SETS', } EXTRACT_MAP = { @@ -927,7 +927,12 @@ class SQLCompiler(Compiled): if disp: return disp(func, **kwargs) else: - name = FUNCTIONS.get(func.__class__, func.name + "%(expr)s") + name = FUNCTIONS.get(func.__class__, None) + if name: + if func._has_args: + name += "%(expr)s" + else: + name = func.name + "%(expr)s" return ".".join(list(func.packagenames) + [name]) % \ {'expr': self.function_argspec(func, **kwargs)} diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 5cea7750a..4b4d2d463 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -54,10 +54,13 @@ class FunctionElement(Executable, ColumnElement, FromClause): packagenames = () + _has_args = False + def __init__(self, *clauses, **kwargs): """Construct a :class:`.FunctionElement`. """ args = [_literal_as_binds(c, self.name) for c in clauses] + self._has_args = self._has_args or bool(args) self.clause_expr = ClauseList( operator=operators.comma_op, group_contents=True, *args).\ @@ -635,6 +638,7 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)): parsed_args = kwargs.pop('_parsed_args', None) if parsed_args is None: parsed_args = [_literal_as_binds(c, self.name) for c in args] + self._has_args = self._has_args or bool(parsed_args) self.packagenames = [] self._bind = kwargs.get('bind', None) self.clause_expr = ClauseList( @@ -671,8 +675,8 @@ class next_value(GenericFunction): class AnsiFunction(GenericFunction): - def __init__(self, **kwargs): - GenericFunction.__init__(self, **kwargs) + def __init__(self, *args, **kwargs): + GenericFunction.__init__(self, *args, **kwargs) class ReturnTypeFromArgs(GenericFunction): @@ -686,7 +690,7 @@ class ReturnTypeFromArgs(GenericFunction): class coalesce(ReturnTypeFromArgs): - pass + _has_args = True class max(ReturnTypeFromArgs): @@ -717,7 +721,7 @@ class char_length(GenericFunction): class random(GenericFunction): - pass + _has_args = True class count(GenericFunction): @@ -937,6 +941,7 @@ class cube(GenericFunction): .. versionadded:: 1.2 """ + _has_args = True class rollup(GenericFunction): @@ -952,6 +957,7 @@ class rollup(GenericFunction): .. versionadded:: 1.2 """ + _has_args = True class grouping_sets(GenericFunction): @@ -984,3 +990,4 @@ class grouping_sets(GenericFunction): .. versionadded:: 1.2 """ + _has_args = True diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index 48d5fc37f..ffc72e9ee 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -14,7 +14,7 @@ import decimal from sqlalchemy import testing from sqlalchemy.testing import fixtures, AssertsCompiledSQL, engines from sqlalchemy.dialects import sqlite, postgresql, mysql, oracle -from sqlalchemy.testing import assert_raises_message +from sqlalchemy.testing import assert_raises_message, assert_raises table1 = table('mytable', column('myid', Integer), @@ -133,12 +133,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): pass assert isinstance(func.myfunc(), myfunc) + self.assert_compile(func.myfunc(), "myfunc()") def test_custom_type(self): class myfunc(GenericFunction): type = DateTime assert isinstance(func.myfunc().type, DateTime) + self.assert_compile(func.myfunc(), "myfunc()") def test_custom_legacy_type(self): # in case someone was using this system @@ -228,24 +230,19 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): c = column('abc') self.assert_compile(func.count(c), 'count(abc)') - def test_constructor(self): - try: - func.current_timestamp('somearg') - assert False - except TypeError: - assert True - - try: - func.char_length('a', 'b') - assert False - except TypeError: - assert True + def test_ansi_functions_with_args(self): + ct = func.current_timestamp('somearg') + self.assert_compile(ct, "CURRENT_TIMESTAMP(:current_timestamp_1)") - try: - func.char_length() - assert False - except TypeError: - assert True + def test_char_length_fixed_args(self): + assert_raises( + TypeError, + func.char_length, 'a', 'b' + ) + assert_raises( + TypeError, + func.char_length + ) def test_return_type_detection(self): |
