summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/util.py
blob: b39e2f6168e8ae7da6c18adb784cf9277310b0f0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# mapper/util.py
# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

import sqlalchemy.util as util
import sqlalchemy.sql as sql

class CascadeOptions(object):
    """keeps track of the options sent to relation().cascade"""
    def __init__(self, arg=""):
        values = util.Set([c.strip() for c in arg.split(',')])
        self.delete_orphan = "delete-orphan" in values
        self.delete = "delete" in values or self.delete_orphan or "all" in values
        self.save_update = "save-update" in values or "all" in values
        self.merge = "merge" in values or "all" in values
        self.expunge = "expunge" in values or "all" in values
        # refresh_expire not really implemented as of yet
        #self.refresh_expire = "refresh-expire" in values or "all" in values
    def __contains__(self, item):
        return getattr(self, item.replace("-", "_"), False)
    

def polymorphic_union(table_map, typecolname, aliasname='p_union'):
    colnames = util.Set()
    colnamemaps = {}
    types = {}
    for key in table_map.keys():
        table = table_map[key]

        # mysql doesnt like selecting from a select; make it an alias of the select
        if isinstance(table, sql.Select):
            table = table.alias()
            table_map[key] = table

        m = {}
        for c in table.c:
            colnames.add(c.name)
            m[c.name] = c
            types[c.name] = c.type
        colnamemaps[table] = m
        
    def col(name, table):
        try:
            return colnamemaps[table][name]
        except KeyError:
            return sql.cast(sql.null(), types[name]).label(name)

    result = []
    for type, table in table_map.iteritems():
        if typecolname is not None:
            result.append(sql.select([col(name, table) for name in colnames] + [sql.column("'%s'" % type).label(typecolname)], from_obj=[table]))
        else:
            result.append(sql.select([col(name, table) for name in colnames], from_obj=[table]))
    return sql.union_all(*result).alias(aliasname)

class TranslatingDict(dict):
    """a dictionary that stores ColumnElement objects as keys.  incoming ColumnElement
    keys are translated against those of an underling FromClause for all operations.
    This way the columns from any Selectable that is derived from or underlying this
    TranslatingDict's selectable can be used as keys."""
    def __init__(self, selectable):
        super(TranslatingDict, self).__init__()
        self.selectable = selectable
    def __translate_col(self, col):
        ourcol = self.selectable.corresponding_column(col, keys_ok=False, raiseerr=False)
        #if col is not ourcol:
        #    print "TD TRANSLATING ", col, "TO", ourcol
        if ourcol is None:
            return col
        else:
            return ourcol
    def __getitem__(self, col):
        return super(TranslatingDict, self).__getitem__(self.__translate_col(col))
    def has_key(self, col):
        return super(TranslatingDict, self).has_key(self.__translate_col(col))
    def __setitem__(self, col, value):
        return super(TranslatingDict, self).__setitem__(self.__translate_col(col), value)
    def __contains__(self, col):
        return self.has_key(col)
    def setdefault(self, col, value):
        return super(TranslatingDict, self).setdefault(self.__translate_col(col), value)

class BinaryVisitor(sql.ClauseVisitor):
    def __init__(self, func):
        self.func = func
    def visit_binary(self, binary):
        self.func(binary)

def instance_str(instance):
    """return a string describing an instance"""
    return instance.__class__.__name__ + "@" + hex(id(instance))
    
def attribute_str(instance, attribute):
    return instance_str(instance) + "." + attribute