summaryrefslogtreecommitdiff
path: root/test/sql/test_functions.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2012-08-22 13:54:13 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2012-08-22 13:54:13 -0400
commit0e9ad8076086bdf22705c2a82b6316b35a7daaa5 (patch)
tree1a932f621e9db76507ea11e3f3bd0e024520c1d8 /test/sql/test_functions.py
parentffab937e4c44e8ff27e92cab26efcb42e7aca4ab (diff)
downloadsqlalchemy-0e9ad8076086bdf22705c2a82b6316b35a7daaa5.tar.gz
- [feature] Enhanced GenericFunction and func.*
to allow for user-defined GenericFunction subclasses to be available via the func.* namespace automatically by classname, optionally using a package name as well.
Diffstat (limited to 'test/sql/test_functions.py')
-rw-r--r--test/sql/test_functions.py147
1 files changed, 105 insertions, 42 deletions
diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py
index 2f9c6f908..5769e4a1a 100644
--- a/test/sql/test_functions.py
+++ b/test/sql/test_functions.py
@@ -2,31 +2,36 @@ from test.lib.testing import eq_
import datetime
from sqlalchemy import *
from sqlalchemy.sql import table, column
-from sqlalchemy import databases, sql, util
+from sqlalchemy import sql, util
from sqlalchemy.sql.compiler import BIND_TEMPLATES
-from sqlalchemy.engine import default
from test.lib.engines import all_dialects
from sqlalchemy import types as sqltypes
-from test.lib import *
+from sqlalchemy.sql import functions
from sqlalchemy.sql.functions import GenericFunction
-from test.lib.testing import eq_
from sqlalchemy.util.compat import decimal
-from test.lib import testing
-from sqlalchemy.databases import *
+from test.lib import testing, fixtures, AssertsCompiledSQL, engines
+from sqlalchemy.dialects import sqlite, postgresql, mysql, oracle
class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = 'default'
+ def tear_down(self):
+ functions._registry.clear()
+
def test_compile(self):
- for dialect in all_dialects(exclude=('sybase', 'access', 'informix', 'maxdb')):
+ for dialect in all_dialects(exclude=('sybase', 'access',
+ 'informix', 'maxdb')):
bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
- self.assert_compile(func.current_timestamp(), "CURRENT_TIMESTAMP", dialect=dialect)
+ self.assert_compile(func.current_timestamp(),
+ "CURRENT_TIMESTAMP", dialect=dialect)
self.assert_compile(func.localtime(), "LOCALTIME", dialect=dialect)
- if isinstance(dialect, (firebird.dialect, maxdb.dialect)):
- self.assert_compile(func.nosuchfunction(), "nosuchfunction", dialect=dialect)
+ if dialect.name in ('firebird', 'maxdb'):
+ self.assert_compile(func.nosuchfunction(),
+ "nosuchfunction", dialect=dialect)
else:
- self.assert_compile(func.nosuchfunction(), "nosuchfunction()", dialect=dialect)
+ self.assert_compile(func.nosuchfunction(),
+ "nosuchfunction()", dialect=dialect)
# test generic function compile
class fake_func(GenericFunction):
@@ -38,7 +43,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
self.assert_compile(
fake_func('foo'),
"fake_func(%s)" %
- bindtemplate % {'name':'param_1', 'position':1},
+ bindtemplate % {'name': 'param_1', 'position': 1},
dialect=dialect)
def test_use_labels(self):
@@ -71,6 +76,44 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
]:
self.assert_compile(func.random(), ret, dialect=dialect)
+ def test_custom_default_namespace(self):
+ class myfunc(GenericFunction):
+ pass
+
+ assert isinstance(func.myfunc(), myfunc)
+
+ def test_custom_type(self):
+ class myfunc(GenericFunction):
+ type = DateTime
+
+ assert isinstance(func.myfunc().type, DateTime)
+
+ def test_custom_legacy_type(self):
+ # in case someone was using this system
+ class myfunc(GenericFunction):
+ __return_type__ = DateTime
+
+ assert isinstance(func.myfunc().type, DateTime)
+
+ def test_custom_w_custom_name(self):
+ class myfunc(GenericFunction):
+ name = "notmyfunc"
+
+ assert isinstance(func.notmyfunc(), myfunc)
+ assert not isinstance(func.myfunc(), myfunc)
+
+ def test_custom_package_namespace(self):
+ def cls1(pk_name):
+ class myfunc(GenericFunction):
+ package = pk_name
+ return myfunc
+
+ f1 = cls1("mypackage")
+ f2 = cls1("myotherpackage")
+
+ assert isinstance(func.mypackage.myfunc(), f1)
+ assert isinstance(func.myotherpackage.myfunc(), f2)
+
def test_namespacing_conflicts(self):
self.assert_compile(func.text('foo'), 'text(:text_1)')
@@ -108,12 +151,15 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
((datetime.date(2007, 10, 5),
datetime.date(2005, 10, 15)), sqltypes.Date),
((3, 5), sqltypes.Integer),
- ((decimal.Decimal(3), decimal.Decimal(5)), sqltypes.Numeric),
+ ((decimal.Decimal(3), decimal.Decimal(5)),
+ sqltypes.Numeric),
(("foo", "bar"), sqltypes.String),
((datetime.datetime(2007, 10, 5, 8, 3, 34),
- datetime.datetime(2005, 10, 15, 14, 45, 33)), sqltypes.DateTime)
+ datetime.datetime(2005, 10, 15, 14, 45, 33)),
+ sqltypes.DateTime)
]:
- assert isinstance(fn(*args).type, type_), "%s / %s" % (fn(), type_)
+ assert isinstance(fn(*args).type, type_), \
+ "%s / %s" % (fn(), type_)
assert isinstance(func.concat("foo", "bar").type, sqltypes.String)
@@ -129,8 +175,10 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
)
# test an expression with a function
- self.assert_compile(func.lala(3, 4, literal("five"), table1.c.myid) * table2.c.otherid,
- "lala(:lala_1, :lala_2, :param_1, mytable.myid) * myothertable.otherid")
+ self.assert_compile(func.lala(3, 4, literal("five"),
+ table1.c.myid) * table2.c.otherid,
+ "lala(:lala_1, :lala_2, :param_1, mytable.myid) * "
+ "myothertable.otherid")
# test it in a SELECT
self.assert_compile(select([func.count(table1.c.myid)]),
@@ -140,8 +188,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
self.assert_compile(select([func.foo.bar.lala(table1.c.myid)]),
"SELECT foo.bar.lala(mytable.myid) AS lala_1 FROM mytable")
- # test the bind parameter name with a "dotted" function name is only the name
- # (limits the length of the bind param name)
+ # test the bind parameter name with a "dotted" function name is
+ # only the name (limits the length of the bind param name)
self.assert_compile(select([func.foo.bar.lala(12)]),
"SELECT foo.bar.lala(:lala_2) AS lala_1")
@@ -149,16 +197,17 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
self.assert_compile(func.lala.hoho(7), "lala.hoho(:hoho_1)")
# test None becomes NULL
- self.assert_compile(func.my_func(1,2,None,3),
+ self.assert_compile(func.my_func(1, 2, None, 3),
"my_func(:my_func_1, :my_func_2, NULL, :my_func_3)")
# test pickling
self.assert_compile(
- util.pickle.loads(util.pickle.dumps(func.my_func(1, 2, None, 3))),
+ util.pickle.loads(util.pickle.dumps(
+ func.my_func(1, 2, None, 3))),
"my_func(:my_func_1, :my_func_2, NULL, :my_func_3)")
- # assert func raises AttributeError for __bases__ attribute, since its not a class
- # fixes pydoc
+ # assert func raises AttributeError for __bases__ attribute, since
+ # its not a class fixes pydoc
try:
func.__bases__
assert False
@@ -186,8 +235,8 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
"FROM users, (SELECT q, z, r "
"FROM calculate(:x_1, :y_1)) AS c1, (SELECT q, z, r "
"FROM calculate(:x_2, :y_2)) AS c2 "
- "WHERE users.id BETWEEN c1.z AND c2.z"
- , checkparams={'y_1': 45, 'x_1': 17, 'y_2': 12, 'x_2': 5})
+ "WHERE users.id BETWEEN c1.z AND c2.z",
+ checkparams={'y_1': 45, 'x_1': 17, 'y_2': 12, 'x_2': 5})
class ExecuteTest(fixtures.TestBase):
@@ -233,12 +282,12 @@ class ExecuteTest(fixtures.TestBase):
eq_(f._execution_options, {})
f = f.execution_options(foo='bar')
- eq_(f._execution_options, {'foo':'bar'})
+ eq_(f._execution_options, {'foo': 'bar'})
s = f.select()
- eq_(s._execution_options, {'foo':'bar'})
+ eq_(s._execution_options, {'foo': 'bar'})
ret = testing.db.execute(func.now().execution_options(foo='bar'))
- eq_(ret.context.execution_options, {'foo':'bar'})
+ eq_(ret.context.execution_options, {'foo': 'bar'})
ret.close()
@@ -252,11 +301,13 @@ class ExecuteTest(fixtures.TestBase):
meta = MetaData(testing.db)
t = Table('t1', meta,
- Column('id', Integer, Sequence('t1idseq', optional=True), primary_key=True),
+ Column('id', Integer, Sequence('t1idseq', optional=True),
+ primary_key=True),
Column('value', Integer)
)
t2 = Table('t2', meta,
- Column('id', Integer, Sequence('t2idseq', optional=True), primary_key=True),
+ Column('id', Integer, Sequence('t2idseq', optional=True),
+ primary_key=True),
Column('value', Integer, default=7),
Column('stuff', String(20), onupdate="thisisstuff")
)
@@ -269,20 +320,23 @@ class ExecuteTest(fixtures.TestBase):
r = t.insert(values=dict(value=func.length("sfsaafsda"))).execute()
id = r.inserted_primary_key[0]
- assert t.select(t.c.id==id).execute().first()['value'] == 9
- t.update(values={t.c.value:func.length("asdf")}).execute()
+ assert t.select(t.c.id == id).execute().first()['value'] == 9
+ t.update(values={t.c.value: func.length("asdf")}).execute()
assert t.select().execute().first()['value'] == 4
print "--------------------------"
t2.insert().execute()
t2.insert(values=dict(value=func.length("one"))).execute()
- t2.insert(values=dict(value=func.length("asfda") + -19)).execute(stuff="hi")
+ t2.insert(values=dict(value=func.length("asfda") + -19)).\
+ execute(stuff="hi")
res = exec_sorted(select([t2.c.value, t2.c.stuff]))
eq_(res, [(-14, 'hi'), (3, None), (7, None)])
- t2.update(values=dict(value=func.length("asdsafasd"))).execute(stuff="some stuff")
+ t2.update(values=dict(value=func.length("asdsafasd"))).\
+ execute(stuff="some stuff")
assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == \
- [(9,"some stuff"), (9,"some stuff"), (9,"some stuff")]
+ [(9, "some stuff"), (9, "some stuff"),
+ (9, "some stuff")]
t2.delete().execute()
@@ -290,11 +344,17 @@ class ExecuteTest(fixtures.TestBase):
assert t2.select().execute().first()['value'] == 11
t2.update(values=dict(value=func.length("asfda"))).execute()
- assert select([t2.c.value, t2.c.stuff]).execute().first() == (5, "thisisstuff")
+ eq_(
+ select([t2.c.value, t2.c.stuff]).execute().first(),
+ (5, "thisisstuff")
+ )
- t2.update(values={t2.c.value:func.length("asfdaasdf"), t2.c.stuff:"foo"}).execute()
+ t2.update(values={t2.c.value: func.length("asfdaasdf"),
+ t2.c.stuff: "foo"}).execute()
print "HI", select([t2.c.value, t2.c.stuff]).execute().first()
- assert select([t2.c.value, t2.c.stuff]).execute().first() == (9, "foo")
+ eq_(select([t2.c.value, t2.c.stuff]).execute().first(),
+ (9, "foo")
+ )
finally:
meta.drop_all()
@@ -304,10 +364,13 @@ class ExecuteTest(fixtures.TestBase):
x = func.current_date(bind=testing.db).execute().scalar()
y = func.current_date(bind=testing.db).select().execute().scalar()
z = func.current_date(bind=testing.db).scalar()
- w = select(['*'], from_obj=[func.current_date(bind=testing.db)]).scalar()
+ w = select(['*'], from_obj=[func.current_date(bind=testing.db)]).\
+ scalar()
- # construct a column-based FROM object out of a function, like in [ticket:172]
- s = select([sql.column('date', type_=DateTime)], from_obj=[func.current_date(bind=testing.db)])
+ # construct a column-based FROM object out of a function,
+ # like in [ticket:172]
+ s = select([sql.column('date', type_=DateTime)],
+ from_obj=[func.current_date(bind=testing.db)])
q = s.execute().first()[s.c.date]
r = s.alias('datequery').select().scalar()
@@ -340,7 +403,7 @@ class ExecuteTest(fixtures.TestBase):
try:
table.insert().execute(
{'dt': datetime.datetime(2010, 5, 1, 12, 11, 10),
- 'd': datetime.date(2010, 5, 1) })
+ 'd': datetime.date(2010, 5, 1)})
rs = select([extract('year', table.c.dt),
extract('month', table.c.d)]).execute()
row = rs.first()