diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-08-22 13:54:13 -0400 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2012-08-22 13:54:13 -0400 |
commit | 0e9ad8076086bdf22705c2a82b6316b35a7daaa5 (patch) | |
tree | 1a932f621e9db76507ea11e3f3bd0e024520c1d8 /test/sql/test_functions.py | |
parent | ffab937e4c44e8ff27e92cab26efcb42e7aca4ab (diff) | |
download | sqlalchemy-0e9ad8076086bdf22705c2a82b6316b35a7daaa5.tar.gz |
- [feature] Enhanced GenericFunction and func.*
to allow for user-defined GenericFunction
subclasses to be available via the func.*
namespace automatically by classname,
optionally using a package name as well.
Diffstat (limited to 'test/sql/test_functions.py')
-rw-r--r-- | test/sql/test_functions.py | 147 |
1 files changed, 105 insertions, 42 deletions
diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index 2f9c6f908..5769e4a1a 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -2,31 +2,36 @@ from test.lib.testing import eq_ import datetime from sqlalchemy import * from sqlalchemy.sql import table, column -from sqlalchemy import databases, sql, util +from sqlalchemy import sql, util from sqlalchemy.sql.compiler import BIND_TEMPLATES -from sqlalchemy.engine import default from test.lib.engines import all_dialects from sqlalchemy import types as sqltypes -from test.lib import * +from sqlalchemy.sql import functions from sqlalchemy.sql.functions import GenericFunction -from test.lib.testing import eq_ from sqlalchemy.util.compat import decimal -from test.lib import testing -from sqlalchemy.databases import * +from test.lib import testing, fixtures, AssertsCompiledSQL, engines +from sqlalchemy.dialects import sqlite, postgresql, mysql, oracle class CompileTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = 'default' + def tear_down(self): + functions._registry.clear() + def test_compile(self): - for dialect in all_dialects(exclude=('sybase', 'access', 'informix', 'maxdb')): + for dialect in all_dialects(exclude=('sybase', 'access', + 'informix', 'maxdb')): bindtemplate = BIND_TEMPLATES[dialect.paramstyle] - self.assert_compile(func.current_timestamp(), "CURRENT_TIMESTAMP", dialect=dialect) + self.assert_compile(func.current_timestamp(), + "CURRENT_TIMESTAMP", dialect=dialect) self.assert_compile(func.localtime(), "LOCALTIME", dialect=dialect) - if isinstance(dialect, (firebird.dialect, maxdb.dialect)): - self.assert_compile(func.nosuchfunction(), "nosuchfunction", dialect=dialect) + if dialect.name in ('firebird', 'maxdb'): + self.assert_compile(func.nosuchfunction(), + "nosuchfunction", dialect=dialect) else: - self.assert_compile(func.nosuchfunction(), "nosuchfunction()", dialect=dialect) + self.assert_compile(func.nosuchfunction(), + "nosuchfunction()", dialect=dialect) # test generic function compile class fake_func(GenericFunction): @@ -38,7 +43,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( fake_func('foo'), "fake_func(%s)" % - bindtemplate % {'name':'param_1', 'position':1}, + bindtemplate % {'name': 'param_1', 'position': 1}, dialect=dialect) def test_use_labels(self): @@ -71,6 +76,44 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ]: self.assert_compile(func.random(), ret, dialect=dialect) + def test_custom_default_namespace(self): + class myfunc(GenericFunction): + pass + + assert isinstance(func.myfunc(), myfunc) + + def test_custom_type(self): + class myfunc(GenericFunction): + type = DateTime + + assert isinstance(func.myfunc().type, DateTime) + + def test_custom_legacy_type(self): + # in case someone was using this system + class myfunc(GenericFunction): + __return_type__ = DateTime + + assert isinstance(func.myfunc().type, DateTime) + + def test_custom_w_custom_name(self): + class myfunc(GenericFunction): + name = "notmyfunc" + + assert isinstance(func.notmyfunc(), myfunc) + assert not isinstance(func.myfunc(), myfunc) + + def test_custom_package_namespace(self): + def cls1(pk_name): + class myfunc(GenericFunction): + package = pk_name + return myfunc + + f1 = cls1("mypackage") + f2 = cls1("myotherpackage") + + assert isinstance(func.mypackage.myfunc(), f1) + assert isinstance(func.myotherpackage.myfunc(), f2) + def test_namespacing_conflicts(self): self.assert_compile(func.text('foo'), 'text(:text_1)') @@ -108,12 +151,15 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ((datetime.date(2007, 10, 5), datetime.date(2005, 10, 15)), sqltypes.Date), ((3, 5), sqltypes.Integer), - ((decimal.Decimal(3), decimal.Decimal(5)), sqltypes.Numeric), + ((decimal.Decimal(3), decimal.Decimal(5)), + sqltypes.Numeric), (("foo", "bar"), sqltypes.String), ((datetime.datetime(2007, 10, 5, 8, 3, 34), - datetime.datetime(2005, 10, 15, 14, 45, 33)), sqltypes.DateTime) + datetime.datetime(2005, 10, 15, 14, 45, 33)), + sqltypes.DateTime) ]: - assert isinstance(fn(*args).type, type_), "%s / %s" % (fn(), type_) + assert isinstance(fn(*args).type, type_), \ + "%s / %s" % (fn(), type_) assert isinstance(func.concat("foo", "bar").type, sqltypes.String) @@ -129,8 +175,10 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ) # test an expression with a function - self.assert_compile(func.lala(3, 4, literal("five"), table1.c.myid) * table2.c.otherid, - "lala(:lala_1, :lala_2, :param_1, mytable.myid) * myothertable.otherid") + self.assert_compile(func.lala(3, 4, literal("five"), + table1.c.myid) * table2.c.otherid, + "lala(:lala_1, :lala_2, :param_1, mytable.myid) * " + "myothertable.otherid") # test it in a SELECT self.assert_compile(select([func.count(table1.c.myid)]), @@ -140,8 +188,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile(select([func.foo.bar.lala(table1.c.myid)]), "SELECT foo.bar.lala(mytable.myid) AS lala_1 FROM mytable") - # test the bind parameter name with a "dotted" function name is only the name - # (limits the length of the bind param name) + # test the bind parameter name with a "dotted" function name is + # only the name (limits the length of the bind param name) self.assert_compile(select([func.foo.bar.lala(12)]), "SELECT foo.bar.lala(:lala_2) AS lala_1") @@ -149,16 +197,17 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile(func.lala.hoho(7), "lala.hoho(:hoho_1)") # test None becomes NULL - self.assert_compile(func.my_func(1,2,None,3), + self.assert_compile(func.my_func(1, 2, None, 3), "my_func(:my_func_1, :my_func_2, NULL, :my_func_3)") # test pickling self.assert_compile( - util.pickle.loads(util.pickle.dumps(func.my_func(1, 2, None, 3))), + util.pickle.loads(util.pickle.dumps( + func.my_func(1, 2, None, 3))), "my_func(:my_func_1, :my_func_2, NULL, :my_func_3)") - # assert func raises AttributeError for __bases__ attribute, since its not a class - # fixes pydoc + # assert func raises AttributeError for __bases__ attribute, since + # its not a class fixes pydoc try: func.__bases__ assert False @@ -186,8 +235,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "FROM users, (SELECT q, z, r " "FROM calculate(:x_1, :y_1)) AS c1, (SELECT q, z, r " "FROM calculate(:x_2, :y_2)) AS c2 " - "WHERE users.id BETWEEN c1.z AND c2.z" - , checkparams={'y_1': 45, 'x_1': 17, 'y_2': 12, 'x_2': 5}) + "WHERE users.id BETWEEN c1.z AND c2.z", + checkparams={'y_1': 45, 'x_1': 17, 'y_2': 12, 'x_2': 5}) class ExecuteTest(fixtures.TestBase): @@ -233,12 +282,12 @@ class ExecuteTest(fixtures.TestBase): eq_(f._execution_options, {}) f = f.execution_options(foo='bar') - eq_(f._execution_options, {'foo':'bar'}) + eq_(f._execution_options, {'foo': 'bar'}) s = f.select() - eq_(s._execution_options, {'foo':'bar'}) + eq_(s._execution_options, {'foo': 'bar'}) ret = testing.db.execute(func.now().execution_options(foo='bar')) - eq_(ret.context.execution_options, {'foo':'bar'}) + eq_(ret.context.execution_options, {'foo': 'bar'}) ret.close() @@ -252,11 +301,13 @@ class ExecuteTest(fixtures.TestBase): meta = MetaData(testing.db) t = Table('t1', meta, - Column('id', Integer, Sequence('t1idseq', optional=True), primary_key=True), + Column('id', Integer, Sequence('t1idseq', optional=True), + primary_key=True), Column('value', Integer) ) t2 = Table('t2', meta, - Column('id', Integer, Sequence('t2idseq', optional=True), primary_key=True), + Column('id', Integer, Sequence('t2idseq', optional=True), + primary_key=True), Column('value', Integer, default=7), Column('stuff', String(20), onupdate="thisisstuff") ) @@ -269,20 +320,23 @@ class ExecuteTest(fixtures.TestBase): r = t.insert(values=dict(value=func.length("sfsaafsda"))).execute() id = r.inserted_primary_key[0] - assert t.select(t.c.id==id).execute().first()['value'] == 9 - t.update(values={t.c.value:func.length("asdf")}).execute() + assert t.select(t.c.id == id).execute().first()['value'] == 9 + t.update(values={t.c.value: func.length("asdf")}).execute() assert t.select().execute().first()['value'] == 4 print "--------------------------" t2.insert().execute() t2.insert(values=dict(value=func.length("one"))).execute() - t2.insert(values=dict(value=func.length("asfda") + -19)).execute(stuff="hi") + t2.insert(values=dict(value=func.length("asfda") + -19)).\ + execute(stuff="hi") res = exec_sorted(select([t2.c.value, t2.c.stuff])) eq_(res, [(-14, 'hi'), (3, None), (7, None)]) - t2.update(values=dict(value=func.length("asdsafasd"))).execute(stuff="some stuff") + t2.update(values=dict(value=func.length("asdsafasd"))).\ + execute(stuff="some stuff") assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == \ - [(9,"some stuff"), (9,"some stuff"), (9,"some stuff")] + [(9, "some stuff"), (9, "some stuff"), + (9, "some stuff")] t2.delete().execute() @@ -290,11 +344,17 @@ class ExecuteTest(fixtures.TestBase): assert t2.select().execute().first()['value'] == 11 t2.update(values=dict(value=func.length("asfda"))).execute() - assert select([t2.c.value, t2.c.stuff]).execute().first() == (5, "thisisstuff") + eq_( + select([t2.c.value, t2.c.stuff]).execute().first(), + (5, "thisisstuff") + ) - t2.update(values={t2.c.value:func.length("asfdaasdf"), t2.c.stuff:"foo"}).execute() + t2.update(values={t2.c.value: func.length("asfdaasdf"), + t2.c.stuff: "foo"}).execute() print "HI", select([t2.c.value, t2.c.stuff]).execute().first() - assert select([t2.c.value, t2.c.stuff]).execute().first() == (9, "foo") + eq_(select([t2.c.value, t2.c.stuff]).execute().first(), + (9, "foo") + ) finally: meta.drop_all() @@ -304,10 +364,13 @@ class ExecuteTest(fixtures.TestBase): x = func.current_date(bind=testing.db).execute().scalar() y = func.current_date(bind=testing.db).select().execute().scalar() z = func.current_date(bind=testing.db).scalar() - w = select(['*'], from_obj=[func.current_date(bind=testing.db)]).scalar() + w = select(['*'], from_obj=[func.current_date(bind=testing.db)]).\ + scalar() - # construct a column-based FROM object out of a function, like in [ticket:172] - s = select([sql.column('date', type_=DateTime)], from_obj=[func.current_date(bind=testing.db)]) + # construct a column-based FROM object out of a function, + # like in [ticket:172] + s = select([sql.column('date', type_=DateTime)], + from_obj=[func.current_date(bind=testing.db)]) q = s.execute().first()[s.c.date] r = s.alias('datequery').select().scalar() @@ -340,7 +403,7 @@ class ExecuteTest(fixtures.TestBase): try: table.insert().execute( {'dt': datetime.datetime(2010, 5, 1, 12, 11, 10), - 'd': datetime.date(2010, 5, 1) }) + 'd': datetime.date(2010, 5, 1)}) rs = select([extract('year', table.c.dt), extract('month', table.c.d)]).execute() row = rs.first() |