from test.lib.testing import eq_ import datetime from sqlalchemy import * from sqlalchemy.sql import table, column from sqlalchemy import databases, sql, util from sqlalchemy.sql.compiler import BIND_TEMPLATES from sqlalchemy.engine import default from test.lib.engines import all_dialects from sqlalchemy import types as sqltypes from test.lib import * from sqlalchemy.sql.functions import GenericFunction from test.lib.testing import eq_ from sqlalchemy.util.compat import decimal from test.lib import testing from sqlalchemy.databases import * class CompileTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = 'default' def test_compile(self): for dialect in all_dialects(exclude=('sybase', 'access', 'informix', 'maxdb')): bindtemplate = BIND_TEMPLATES[dialect.paramstyle] self.assert_compile(func.current_timestamp(), "CURRENT_TIMESTAMP", dialect=dialect) self.assert_compile(func.localtime(), "LOCALTIME", dialect=dialect) if isinstance(dialect, (firebird.dialect, maxdb.dialect)): self.assert_compile(func.nosuchfunction(), "nosuchfunction", dialect=dialect) else: self.assert_compile(func.nosuchfunction(), "nosuchfunction()", dialect=dialect) # test generic function compile class fake_func(GenericFunction): __return_type__ = sqltypes.Integer def __init__(self, arg, **kwargs): GenericFunction.__init__(self, args=[arg], **kwargs) self.assert_compile( fake_func('foo'), "fake_func(%s)" % bindtemplate % {'name':'param_1', 'position':1}, dialect=dialect) def test_use_labels(self): self.assert_compile(select([func.foo()], use_labels=True), "SELECT foo() AS foo_1" ) def test_underscores(self): self.assert_compile(func.if_(), "if()") def test_generic_now(self): assert isinstance(func.now().type, sqltypes.DateTime) for ret, dialect in [ ('CURRENT_TIMESTAMP', sqlite.dialect()), ('now()', postgresql.dialect()), ('now()', mysql.dialect()), ('CURRENT_TIMESTAMP', oracle.dialect()) ]: self.assert_compile(func.now(), ret, dialect=dialect) def test_generic_random(self): assert func.random().type == sqltypes.NULLTYPE assert isinstance(func.random(type_=Integer).type, Integer) for ret, dialect in [ ('random()', sqlite.dialect()), ('random()', postgresql.dialect()), ('rand()', mysql.dialect()), ('random()', oracle.dialect()) ]: self.assert_compile(func.random(), ret, dialect=dialect) def test_namespacing_conflicts(self): self.assert_compile(func.text('foo'), 'text(:text_1)') def test_generic_count(self): assert isinstance(func.count().type, sqltypes.Integer) self.assert_compile(func.count(), 'count(*)') self.assert_compile(func.count(1), 'count(:param_1)') c = column('abc') self.assert_compile(func.count(c), 'count(abc)') def test_constructor(self): try: func.current_timestamp('somearg') assert False except TypeError: assert True try: func.char_length('a', 'b') assert False except TypeError: assert True try: func.char_length() assert False except TypeError: assert True def test_return_type_detection(self): for fn in [func.coalesce, func.max, func.min, func.sum]: for args, type_ in [ ((datetime.date(2007, 10, 5), datetime.date(2005, 10, 15)), sqltypes.Date), ((3, 5), sqltypes.Integer), ((decimal.Decimal(3), decimal.Decimal(5)), sqltypes.Numeric), (("foo", "bar"), sqltypes.String), ((datetime.datetime(2007, 10, 5, 8, 3, 34), datetime.datetime(2005, 10, 15, 14, 45, 33)), sqltypes.DateTime) ]: assert isinstance(fn(*args).type, type_), "%s / %s" % (fn(), type_) assert isinstance(func.concat("foo", "bar").type, sqltypes.String) def test_assorted(self): table1 = table('mytable', column('myid', Integer), ) table2 = table( 'myothertable', column('otherid', Integer), ) # test an expression with a function self.assert_compile(func.lala(3, 4, literal("five"), table1.c.myid) * table2.c.otherid, "lala(:lala_1, :lala_2, :param_1, mytable.myid) * myothertable.otherid") # test it in a SELECT self.assert_compile(select([func.count(table1.c.myid)]), "SELECT count(mytable.myid) AS count_1 FROM mytable") # test a "dotted" function name self.assert_compile(select([func.foo.bar.lala(table1.c.myid)]), "SELECT foo.bar.lala(mytable.myid) AS lala_1 FROM mytable") # test the bind parameter name with a "dotted" function name is only the name # (limits the length of the bind param name) self.assert_compile(select([func.foo.bar.lala(12)]), "SELECT foo.bar.lala(:lala_2) AS lala_1") # test a dotted func off the engine itself self.assert_compile(func.lala.hoho(7), "lala.hoho(:hoho_1)") # test None becomes NULL self.assert_compile(func.my_func(1,2,None,3), "my_func(:my_func_1, :my_func_2, NULL, :my_func_3)") # test pickling self.assert_compile( util.pickle.loads(util.pickle.dumps(func.my_func(1, 2, None, 3))), "my_func(:my_func_1, :my_func_2, NULL, :my_func_3)") # assert func raises AttributeError for __bases__ attribute, since its not a class # fixes pydoc try: func.__bases__ assert False except AttributeError: assert True def test_functions_with_cols(self): users = table('users', column('id'), column('name'), column('fullname')) calculate = select([column('q'), column('z'), column('r')], from_obj=[func.calculate(bindparam('x'), bindparam('y'))]) self.assert_compile(select([users], users.c.id > calculate.c.z), "SELECT users.id, users.name, users.fullname " "FROM users, (SELECT q, z, r " "FROM calculate(:x, :y)) " "WHERE users.id > z" ) s = select([users], users.c.id.between( calculate.alias('c1').unique_params(x=17, y=45).c.z, calculate.alias('c2').unique_params(x=5, y=12).c.z)) self.assert_compile(s, "SELECT users.id, users.name, users.fullname " "FROM users, (SELECT q, z, r " "FROM calculate(:x_1, :y_1)) AS c1, (SELECT q, z, r " "FROM calculate(:x_2, :y_2)) AS c2 " "WHERE users.id BETWEEN c1.z AND c2.z" , checkparams={'y_1': 45, 'x_1': 17, 'y_2': 12, 'x_2': 5}) class ExecuteTest(fixtures.TestBase): @engines.close_first def tearDown(self): pass @testing.uses_deprecated def test_standalone_execute(self): x = testing.db.func.current_date().execute().scalar() y = testing.db.func.current_date().select().execute().scalar() z = testing.db.func.current_date().scalar() assert (x == y == z) is True # ansi func x = testing.db.func.current_date() assert isinstance(x.type, Date) assert isinstance(x.execute().scalar(), datetime.date) def test_conn_execute(self): from sqlalchemy.sql.expression import FunctionElement from sqlalchemy.ext.compiler import compiles class myfunc(FunctionElement): type = Date() @compiles(myfunc) def compile(elem, compiler, **kw): return compiler.process(func.current_date()) conn = testing.db.connect() try: x = conn.execute(func.current_date()).scalar() y = conn.execute(func.current_date().select()).scalar() z = conn.scalar(func.current_date()) q = conn.scalar(myfunc()) finally: conn.close() assert (x == y == z == q) is True def test_exec_options(self): f = func.foo() eq_(f._execution_options, {}) f = f.execution_options(foo='bar') eq_(f._execution_options, {'foo':'bar'}) s = f.select() eq_(s._execution_options, {'foo':'bar'}) ret = testing.db.execute(func.now().execution_options(foo='bar')) eq_(ret.context.execution_options, {'foo':'bar'}) ret.close() @engines.close_first def test_update(self): """ Tests sending functions and SQL expressions to the VALUES and SET clauses of INSERT/UPDATE instances, and that column-level defaults get overridden. """ meta = MetaData(testing.db) t = Table('t1', meta, Column('id', Integer, Sequence('t1idseq', optional=True), primary_key=True), Column('value', Integer) ) t2 = Table('t2', meta, Column('id', Integer, Sequence('t2idseq', optional=True), primary_key=True), Column('value', Integer, default=7), Column('stuff', String(20), onupdate="thisisstuff") ) meta.create_all() try: t.insert(values=dict(value=func.length("one"))).execute() assert t.select().execute().first()['value'] == 3 t.update(values=dict(value=func.length("asfda"))).execute() assert t.select().execute().first()['value'] == 5 r = t.insert(values=dict(value=func.length("sfsaafsda"))).execute() id = r.inserted_primary_key[0] assert t.select(t.c.id==id).execute().first()['value'] == 9 t.update(values={t.c.value:func.length("asdf")}).execute() assert t.select().execute().first()['value'] == 4 print "--------------------------" t2.insert().execute() t2.insert(values=dict(value=func.length("one"))).execute() t2.insert(values=dict(value=func.length("asfda") + -19)).execute(stuff="hi") res = exec_sorted(select([t2.c.value, t2.c.stuff])) eq_(res, [(-14, 'hi'), (3, None), (7, None)]) t2.update(values=dict(value=func.length("asdsafasd"))).execute(stuff="some stuff") assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == \ [(9,"some stuff"), (9,"some stuff"), (9,"some stuff")] t2.delete().execute() t2.insert(values=dict(value=func.length("one") + 8)).execute() assert t2.select().execute().first()['value'] == 11 t2.update(values=dict(value=func.length("asfda"))).execute() assert select([t2.c.value, t2.c.stuff]).execute().first() == (5, "thisisstuff") t2.update(values={t2.c.value:func.length("asfdaasdf"), t2.c.stuff:"foo"}).execute() print "HI", select([t2.c.value, t2.c.stuff]).execute().first() assert select([t2.c.value, t2.c.stuff]).execute().first() == (9, "foo") finally: meta.drop_all() @testing.fails_on_everything_except('postgresql') def test_as_from(self): # TODO: shouldnt this work on oracle too ? x = func.current_date(bind=testing.db).execute().scalar() y = func.current_date(bind=testing.db).select().execute().scalar() z = func.current_date(bind=testing.db).scalar() w = select(['*'], from_obj=[func.current_date(bind=testing.db)]).scalar() # construct a column-based FROM object out of a function, like in [ticket:172] s = select([sql.column('date', type_=DateTime)], from_obj=[func.current_date(bind=testing.db)]) q = s.execute().first()[s.c.date] r = s.alias('datequery').select().scalar() assert x == y == z == w == q == r def test_extract_bind(self): """Basic common denominator execution tests for extract()""" date = datetime.date(2010, 5, 1) def execute(field): return testing.db.execute(select([extract(field, date)])).scalar() assert execute('year') == 2010 assert execute('month') == 5 assert execute('day') == 1 date = datetime.datetime(2010, 5, 1, 12, 11, 10) assert execute('year') == 2010 assert execute('month') == 5 assert execute('day') == 1 def test_extract_expression(self): meta = MetaData(testing.db) table = Table('test', meta, Column('dt', DateTime), Column('d', Date)) meta.create_all() try: table.insert().execute( {'dt': datetime.datetime(2010, 5, 1, 12, 11, 10), 'd': datetime.date(2010, 5, 1) }) rs = select([extract('year', table.c.dt), extract('month', table.c.d)]).execute() row = rs.first() assert row[0] == 2010 assert row[1] == 5 rs.close() finally: meta.drop_all() def exec_sorted(statement, *args, **kw): """Executes a statement and returns a sorted list plain tuple rows.""" return sorted([tuple(row) for row in statement.execute(*args, **kw).fetchall()])