diff options
Diffstat (limited to 'migrate')
83 files changed, 0 insertions, 9525 deletions
diff --git a/migrate/__init__.py b/migrate/__init__.py deleted file mode 100644 index 5e765d9..0000000 --- a/migrate/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" - SQLAlchemy migrate provides two APIs :mod:`migrate.versioning` for - database schema version and repository management and - :mod:`migrate.changeset` that allows to define database schema changes - using Python. -""" - -import pkg_resources - -from migrate.versioning import * -from migrate.changeset import * - -__version__ = pkg_resources.get_provider( - pkg_resources.Requirement.parse('sqlalchemy-migrate')).version diff --git a/migrate/changeset/__init__.py b/migrate/changeset/__init__.py deleted file mode 100644 index 507fa73..0000000 --- a/migrate/changeset/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -""" - This module extends SQLAlchemy and provides additional DDL [#]_ - support. - - .. [#] SQL Data Definition Language -""" -import re - -import sqlalchemy -from sqlalchemy import __version__ as _sa_version - -_sa_version = tuple(int(re.match("\d+", x).group(0)) for x in _sa_version.split(".")) -SQLA_07 = _sa_version >= (0, 7) -SQLA_08 = _sa_version >= (0, 8) -SQLA_09 = _sa_version >= (0, 9) -SQLA_10 = _sa_version >= (1, 0) - -del re -del _sa_version - -from migrate.changeset.schema import * -from migrate.changeset.constraint import * - -sqlalchemy.schema.Table.__bases__ += (ChangesetTable, ) -sqlalchemy.schema.Column.__bases__ += (ChangesetColumn, ) -sqlalchemy.schema.Index.__bases__ += (ChangesetIndex, ) - -sqlalchemy.schema.DefaultClause.__bases__ += (ChangesetDefaultClause, ) diff --git a/migrate/changeset/ansisql.py b/migrate/changeset/ansisql.py deleted file mode 100644 index 0a298a2..0000000 --- a/migrate/changeset/ansisql.py +++ /dev/null @@ -1,311 +0,0 @@ -""" - Extensions to SQLAlchemy for altering existing tables. - - At the moment, this isn't so much based off of ANSI as much as - things that just happen to work with multiple databases. -""" - -import sqlalchemy as sa -from sqlalchemy.schema import SchemaVisitor -from sqlalchemy.engine.default import DefaultDialect -from sqlalchemy.sql import ClauseElement -from sqlalchemy.schema import (ForeignKeyConstraint, - PrimaryKeyConstraint, - CheckConstraint, - UniqueConstraint, - Index) - -from migrate import exceptions -import sqlalchemy.sql.compiler -from migrate.changeset import constraint -from migrate.changeset import util -from six.moves import StringIO - -from sqlalchemy.schema import AddConstraint, DropConstraint -from sqlalchemy.sql.compiler import DDLCompiler -SchemaGenerator = SchemaDropper = DDLCompiler - - -class AlterTableVisitor(SchemaVisitor): - """Common operations for ``ALTER TABLE`` statements.""" - - # engine.Compiler looks for .statement - # when it spawns off a new compiler - statement = ClauseElement() - - def append(self, s): - """Append content to the SchemaIterator's query buffer.""" - - self.buffer.write(s) - - def execute(self): - """Execute the contents of the SchemaIterator's buffer.""" - try: - return self.connection.execute(self.buffer.getvalue()) - finally: - self.buffer.seek(0) - self.buffer.truncate() - - def __init__(self, dialect, connection, **kw): - self.connection = connection - self.buffer = StringIO() - self.preparer = dialect.identifier_preparer - self.dialect = dialect - - def traverse_single(self, elem): - ret = super(AlterTableVisitor, self).traverse_single(elem) - if ret: - # adapt to 0.6 which uses a string-returning - # object - self.append(" %s" % ret) - - def _to_table(self, param): - """Returns the table object for the given param object.""" - if isinstance(param, (sa.Column, sa.Index, sa.schema.Constraint)): - ret = param.table - else: - ret = param - return ret - - def start_alter_table(self, param): - """Returns the start of an ``ALTER TABLE`` SQL-Statement. - - Use the param object to determine the table name and use it - for building the SQL statement. - - :param param: object to determine the table from - :type param: :class:`sqlalchemy.Column`, :class:`sqlalchemy.Index`, - :class:`sqlalchemy.schema.Constraint`, :class:`sqlalchemy.Table`, - or string (table name) - """ - table = self._to_table(param) - self.append('\nALTER TABLE %s ' % self.preparer.format_table(table)) - return table - - -class ANSIColumnGenerator(AlterTableVisitor, SchemaGenerator): - """Extends ansisql generator for column creation (alter table add col)""" - - def visit_column(self, column): - """Create a column (table already exists). - - :param column: column object - :type column: :class:`sqlalchemy.Column` instance - """ - if column.default is not None: - self.traverse_single(column.default) - - table = self.start_alter_table(column) - self.append("ADD ") - self.append(self.get_column_specification(column)) - - for cons in column.constraints: - self.traverse_single(cons) - self.execute() - - # ALTER TABLE STATEMENTS - - # add indexes and unique constraints - if column.index_name: - Index(column.index_name,column).create() - elif column.unique_name: - constraint.UniqueConstraint(column, - name=column.unique_name).create() - - # SA bounds FK constraints to table, add manually - for fk in column.foreign_keys: - self.add_foreignkey(fk.constraint) - - # add primary key constraint if needed - if column.primary_key_name: - cons = constraint.PrimaryKeyConstraint(column, - name=column.primary_key_name) - cons.create() - - def add_foreignkey(self, fk): - self.connection.execute(AddConstraint(fk)) - -class ANSIColumnDropper(AlterTableVisitor, SchemaDropper): - """Extends ANSI SQL dropper for column dropping (``ALTER TABLE - DROP COLUMN``). - """ - - def visit_column(self, column): - """Drop a column from its table. - - :param column: the column object - :type column: :class:`sqlalchemy.Column` - """ - table = self.start_alter_table(column) - self.append('DROP COLUMN %s' % self.preparer.format_column(column)) - self.execute() - - -class ANSISchemaChanger(AlterTableVisitor, SchemaGenerator): - """Manages changes to existing schema elements. - - Note that columns are schema elements; ``ALTER TABLE ADD COLUMN`` - is in SchemaGenerator. - - All items may be renamed. Columns can also have many of their properties - - type, for example - changed. - - Each function is passed a tuple, containing (object, name); where - object is a type of object you'd expect for that function - (ie. table for visit_table) and name is the object's new - name. NONE means the name is unchanged. - """ - - def visit_table(self, table): - """Rename a table. Other ops aren't supported.""" - self.start_alter_table(table) - self.append("RENAME TO %s" % self.preparer.quote(table.new_name)) - self.execute() - - def visit_index(self, index): - """Rename an index""" - if hasattr(self, '_validate_identifier'): - # SA <= 0.6.3 - self.append("ALTER INDEX %s RENAME TO %s" % ( - self.preparer.quote( - self._validate_identifier( - index.name, True)), - self.preparer.quote( - self._validate_identifier( - index.new_name, True)))) - elif hasattr(self, '_index_identifier'): - # SA >= 0.6.5, < 0.8 - self.append("ALTER INDEX %s RENAME TO %s" % ( - self.preparer.quote( - self._index_identifier( - index.name)), - self.preparer.quote( - self._index_identifier( - index.new_name)))) - else: - # SA >= 0.8 - class NewName(object): - """Map obj.name -> obj.new_name""" - def __init__(self, index): - self.name = index.new_name - self._obj = index - - def __getattr__(self, attr): - if attr == 'name': - return getattr(self, attr) - return getattr(self._obj, attr) - - self.append("ALTER INDEX %s RENAME TO %s" % ( - self._prepared_index_name(index), - self._prepared_index_name(NewName(index)))) - - self.execute() - - def visit_column(self, delta): - """Rename/change a column.""" - # ALTER COLUMN is implemented as several ALTER statements - keys = delta.keys() - if 'type' in keys: - self._run_subvisit(delta, self._visit_column_type) - if 'nullable' in keys: - self._run_subvisit(delta, self._visit_column_nullable) - if 'server_default' in keys: - # Skip 'default': only handle server-side defaults, others - # are managed by the app, not the db. - self._run_subvisit(delta, self._visit_column_default) - if 'name' in keys: - self._run_subvisit(delta, self._visit_column_name, start_alter=False) - - def _run_subvisit(self, delta, func, start_alter=True): - """Runs visit method based on what needs to be changed on column""" - table = self._to_table(delta.table) - col_name = delta.current_name - if start_alter: - self.start_alter_column(table, col_name) - ret = func(table, delta.result_column, delta) - self.execute() - - def start_alter_column(self, table, col_name): - """Starts ALTER COLUMN""" - self.start_alter_table(table) - self.append("ALTER COLUMN %s " % self.preparer.quote(col_name)) - - def _visit_column_nullable(self, table, column, delta): - nullable = delta['nullable'] - if nullable: - self.append("DROP NOT NULL") - else: - self.append("SET NOT NULL") - - def _visit_column_default(self, table, column, delta): - default_text = self.get_column_default_string(column) - if default_text is not None: - self.append("SET DEFAULT %s" % default_text) - else: - self.append("DROP DEFAULT") - - def _visit_column_type(self, table, column, delta): - type_ = delta['type'] - type_text = str(type_.compile(dialect=self.dialect)) - self.append("TYPE %s" % type_text) - - def _visit_column_name(self, table, column, delta): - self.start_alter_table(table) - col_name = self.preparer.quote(delta.current_name) - new_name = self.preparer.format_column(delta.result_column) - self.append('RENAME COLUMN %s TO %s' % (col_name, new_name)) - - -class ANSIConstraintCommon(AlterTableVisitor): - """ - Migrate's constraints require a separate creation function from - SA's: Migrate's constraints are created independently of a table; - SA's are created at the same time as the table. - """ - - def get_constraint_name(self, cons): - """Gets a name for the given constraint. - - If the name is already set it will be used otherwise the - constraint's :meth:`autoname <migrate.changeset.constraint.ConstraintChangeset.autoname>` - method is used. - - :param cons: constraint object - """ - if cons.name is not None: - ret = cons.name - else: - ret = cons.name = cons.autoname() - return ret - - def visit_migrate_primary_key_constraint(self, *p, **k): - self._visit_constraint(*p, **k) - - def visit_migrate_foreign_key_constraint(self, *p, **k): - self._visit_constraint(*p, **k) - - def visit_migrate_check_constraint(self, *p, **k): - self._visit_constraint(*p, **k) - - def visit_migrate_unique_constraint(self, *p, **k): - self._visit_constraint(*p, **k) - -class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator): - def _visit_constraint(self, constraint): - constraint.name = self.get_constraint_name(constraint) - self.append(self.process(AddConstraint(constraint))) - self.execute() - -class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper): - def _visit_constraint(self, constraint): - constraint.name = self.get_constraint_name(constraint) - self.append(self.process(DropConstraint(constraint, cascade=constraint.cascade))) - self.execute() - - -class ANSIDialect(DefaultDialect): - columngenerator = ANSIColumnGenerator - columndropper = ANSIColumnDropper - schemachanger = ANSISchemaChanger - constraintgenerator = ANSIConstraintGenerator - constraintdropper = ANSIConstraintDropper diff --git a/migrate/changeset/constraint.py b/migrate/changeset/constraint.py deleted file mode 100644 index 96407bd..0000000 --- a/migrate/changeset/constraint.py +++ /dev/null @@ -1,199 +0,0 @@ -""" - This module defines standalone schema constraint classes. -""" -from sqlalchemy import schema - -from migrate.exceptions import * - -class ConstraintChangeset(object): - """Base class for Constraint classes.""" - - def _normalize_columns(self, cols, table_name=False): - """Given: column objects or names; return col names and - (maybe) a table""" - colnames = [] - table = None - for col in cols: - if isinstance(col, schema.Column): - if col.table is not None and table is None: - table = col.table - if table_name: - col = '.'.join((col.table.name, col.name)) - else: - col = col.name - colnames.append(col) - return colnames, table - - def __do_imports(self, visitor_name, *a, **kw): - engine = kw.pop('engine', self.table.bind) - from migrate.changeset.databases.visitor import (get_engine_visitor, - run_single_visitor) - visitorcallable = get_engine_visitor(engine, visitor_name) - run_single_visitor(engine, visitorcallable, self, *a, **kw) - - def create(self, *a, **kw): - """Create the constraint in the database. - - :param engine: the database engine to use. If this is \ - :keyword:`None` the instance's engine will be used - :type engine: :class:`sqlalchemy.engine.base.Engine` - :param connection: reuse connection istead of creating new one. - :type connection: :class:`sqlalchemy.engine.base.Connection` instance - """ - # TODO: set the parent here instead of in __init__ - self.__do_imports('constraintgenerator', *a, **kw) - - def drop(self, *a, **kw): - """Drop the constraint from the database. - - :param engine: the database engine to use. If this is - :keyword:`None` the instance's engine will be used - :param cascade: Issue CASCADE drop if database supports it - :type engine: :class:`sqlalchemy.engine.base.Engine` - :type cascade: bool - :param connection: reuse connection istead of creating new one. - :type connection: :class:`sqlalchemy.engine.base.Connection` instance - :returns: Instance with cleared columns - """ - self.cascade = kw.pop('cascade', False) - self.__do_imports('constraintdropper', *a, **kw) - # the spirit of Constraint objects is that they - # are immutable (just like in a DB. they're only ADDed - # or DROPped). - #self.columns.clear() - return self - - -class PrimaryKeyConstraint(ConstraintChangeset, schema.PrimaryKeyConstraint): - """Construct PrimaryKeyConstraint - - Migrate's additional parameters: - - :param cols: Columns in constraint. - :param table: If columns are passed as strings, this kw is required - :type table: Table instance - :type cols: strings or Column instances - """ - - __migrate_visit_name__ = 'migrate_primary_key_constraint' - - def __init__(self, *cols, **kwargs): - colnames, table = self._normalize_columns(cols) - table = kwargs.pop('table', table) - super(PrimaryKeyConstraint, self).__init__(*colnames, **kwargs) - if table is not None: - self._set_parent(table) - - - def autoname(self): - """Mimic the database's automatic constraint names""" - return "%s_pkey" % self.table.name - - -class ForeignKeyConstraint(ConstraintChangeset, schema.ForeignKeyConstraint): - """Construct ForeignKeyConstraint - - Migrate's additional parameters: - - :param columns: Columns in constraint - :param refcolumns: Columns that this FK reffers to in another table. - :param table: If columns are passed as strings, this kw is required - :type table: Table instance - :type columns: list of strings or Column instances - :type refcolumns: list of strings or Column instances - """ - - __migrate_visit_name__ = 'migrate_foreign_key_constraint' - - def __init__(self, columns, refcolumns, *args, **kwargs): - colnames, table = self._normalize_columns(columns) - table = kwargs.pop('table', table) - refcolnames, reftable = self._normalize_columns(refcolumns, - table_name=True) - super(ForeignKeyConstraint, self).__init__(colnames, refcolnames, *args, - **kwargs) - if table is not None: - self._set_parent(table) - - @property - def referenced(self): - return [e.column for e in self.elements] - - @property - def reftable(self): - return self.referenced[0].table - - def autoname(self): - """Mimic the database's automatic constraint names""" - if hasattr(self.columns, 'keys'): - # SA <= 0.5 - firstcol = self.columns[self.columns.keys()[0]] - ret = "%(table)s_%(firstcolumn)s_fkey" % dict( - table=firstcol.table.name, - firstcolumn=firstcol.name,) - else: - # SA >= 0.6 - ret = "%(table)s_%(firstcolumn)s_fkey" % dict( - table=self.table.name, - firstcolumn=self.columns[0],) - return ret - - -class CheckConstraint(ConstraintChangeset, schema.CheckConstraint): - """Construct CheckConstraint - - Migrate's additional parameters: - - :param sqltext: Plain SQL text to check condition - :param columns: If not name is applied, you must supply this kw\ - to autoname constraint - :param table: If columns are passed as strings, this kw is required - :type table: Table instance - :type columns: list of Columns instances - :type sqltext: string - """ - - __migrate_visit_name__ = 'migrate_check_constraint' - - def __init__(self, sqltext, *args, **kwargs): - cols = kwargs.pop('columns', []) - if not cols and not kwargs.get('name', False): - raise InvalidConstraintError('You must either set "name"' - 'parameter or "columns" to autogenarate it.') - colnames, table = self._normalize_columns(cols) - table = kwargs.pop('table', table) - schema.CheckConstraint.__init__(self, sqltext, *args, **kwargs) - if table is not None: - self._set_parent(table) - self.colnames = colnames - - def autoname(self): - return "%(table)s_%(cols)s_check" % \ - dict(table=self.table.name, cols="_".join(self.colnames)) - - -class UniqueConstraint(ConstraintChangeset, schema.UniqueConstraint): - """Construct UniqueConstraint - - Migrate's additional parameters: - - :param cols: Columns in constraint. - :param table: If columns are passed as strings, this kw is required - :type table: Table instance - :type cols: strings or Column instances - - .. versionadded:: 0.6.0 - """ - - __migrate_visit_name__ = 'migrate_unique_constraint' - - def __init__(self, *cols, **kwargs): - self.colnames, table = self._normalize_columns(cols) - table = kwargs.pop('table', table) - super(UniqueConstraint, self).__init__(*self.colnames, **kwargs) - if table is not None: - self._set_parent(table) - - def autoname(self): - """Mimic the database's automatic constraint names""" - return "%s_%s_key" % (self.table.name, self.colnames[0]) diff --git a/migrate/changeset/databases/__init__.py b/migrate/changeset/databases/__init__.py deleted file mode 100644 index 075a787..0000000 --- a/migrate/changeset/databases/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -""" - This module contains database dialect specific changeset - implementations. -""" -__all__ = [ - 'postgres', - 'sqlite', - 'mysql', - 'oracle', - 'ibmdb2', -] diff --git a/migrate/changeset/databases/firebird.py b/migrate/changeset/databases/firebird.py deleted file mode 100644 index 0f16b0a..0000000 --- a/migrate/changeset/databases/firebird.py +++ /dev/null @@ -1,93 +0,0 @@ -""" - Firebird database specific implementations of changeset classes. -""" -from sqlalchemy.databases import firebird as sa_base -from sqlalchemy.schema import PrimaryKeyConstraint -from migrate import exceptions -from migrate.changeset import ansisql - - -FBSchemaGenerator = sa_base.FBDDLCompiler - -class FBColumnGenerator(FBSchemaGenerator, ansisql.ANSIColumnGenerator): - """Firebird column generator implementation.""" - - -class FBColumnDropper(ansisql.ANSIColumnDropper): - """Firebird column dropper implementation.""" - - def visit_column(self, column): - """Firebird supports 'DROP col' instead of 'DROP COLUMN col' syntax - - Drop primary key and unique constraints if dropped column is referencing it.""" - if column.primary_key: - if column.table.primary_key.columns.contains_column(column): - column.table.primary_key.drop() - # TODO: recreate primary key if it references more than this column - - for index in column.table.indexes: - # "column in index.columns" causes problems as all - # column objects compare equal and return a SQL expression - if column.name in [col.name for col in index.columns]: - index.drop() - # TODO: recreate index if it references more than this column - - for cons in column.table.constraints: - if isinstance(cons,PrimaryKeyConstraint): - # will be deleted only when the column its on - # is deleted! - continue - - should_drop = column.name in cons.columns - if should_drop: - self.start_alter_table(column) - self.append("DROP CONSTRAINT ") - self.append(self.preparer.format_constraint(cons)) - self.execute() - # TODO: recreate unique constraint if it refenrences more than this column - - self.start_alter_table(column) - self.append('DROP %s' % self.preparer.format_column(column)) - self.execute() - - -class FBSchemaChanger(ansisql.ANSISchemaChanger): - """Firebird schema changer implementation.""" - - def visit_table(self, table): - """Rename table not supported""" - raise exceptions.NotSupportedError( - "Firebird does not support renaming tables.") - - def _visit_column_name(self, table, column, delta): - self.start_alter_table(table) - col_name = self.preparer.quote(delta.current_name) - new_name = self.preparer.format_column(delta.result_column) - self.append('ALTER COLUMN %s TO %s' % (col_name, new_name)) - - def _visit_column_nullable(self, table, column, delta): - """Changing NULL is not supported""" - # TODO: http://www.firebirdfaq.org/faq103/ - raise exceptions.NotSupportedError( - "Firebird does not support altering NULL bevahior.") - - -class FBConstraintGenerator(ansisql.ANSIConstraintGenerator): - """Firebird constraint generator implementation.""" - - -class FBConstraintDropper(ansisql.ANSIConstraintDropper): - """Firebird constaint dropper implementation.""" - - def cascade_constraint(self, constraint): - """Cascading constraints is not supported""" - raise exceptions.NotSupportedError( - "Firebird does not support cascading constraints") - - -class FBDialect(ansisql.ANSIDialect): - columngenerator = FBColumnGenerator - columndropper = FBColumnDropper - schemachanger = FBSchemaChanger - constraintgenerator = FBConstraintGenerator - constraintdropper = FBConstraintDropper diff --git a/migrate/changeset/databases/ibmdb2.py b/migrate/changeset/databases/ibmdb2.py deleted file mode 100644 index a12d73b..0000000 --- a/migrate/changeset/databases/ibmdb2.py +++ /dev/null @@ -1,337 +0,0 @@ -""" - DB2 database specific implementations of changeset classes. -""" - -import logging - -from ibm_db_sa import base -from sqlalchemy.schema import (AddConstraint, - CreateIndex, - DropConstraint) -from sqlalchemy.schema import (Index, - PrimaryKeyConstraint, - UniqueConstraint) - -from migrate.changeset import ansisql -from migrate.changeset import constraint -from migrate.changeset import util -from migrate import exceptions - - -LOG = logging.getLogger(__name__) - -IBMDBSchemaGenerator = base.IBM_DBDDLCompiler - - -def get_server_version_info(dialect): - """Returns the DB2 server major and minor version as a list of ints.""" - return [int(ver_token) for ver_token in dialect.dbms_ver.split('.')[0:2]] - - -def is_unique_constraint_with_null_columns_supported(dialect): - """Checks to see if the DB2 version is at least 10.5. - - This is needed for checking if unique constraints with null columns - are supported. - """ - return get_server_version_info(dialect) >= [10, 5] - - -class IBMDBColumnGenerator(IBMDBSchemaGenerator, - ansisql.ANSIColumnGenerator): - def visit_column(self, column): - nullable = True - if not column.nullable: - nullable = False - column.nullable = True - - table = self.start_alter_table(column) - self.append("ADD COLUMN ") - self.append(self.get_column_specification(column)) - - for cons in column.constraints: - self.traverse_single(cons) - if column.default is not None: - self.traverse_single(column.default) - self.execute() - - #ALTER TABLE STATEMENTS - if not nullable: - self.start_alter_table(column) - self.append("ALTER COLUMN %s SET NOT NULL" % - self.preparer.format_column(column)) - self.execute() - self.append("CALL SYSPROC.ADMIN_CMD('REORG TABLE %s')" % - self.preparer.format_table(table)) - self.execute() - - # add indexes and unique constraints - if column.index_name: - Index(column.index_name, column).create() - elif column.unique_name: - constraint.UniqueConstraint(column, - name=column.unique_name).create() - - # SA bounds FK constraints to table, add manually - for fk in column.foreign_keys: - self.add_foreignkey(fk.constraint) - - # add primary key constraint if needed - if column.primary_key_name: - pk = constraint.PrimaryKeyConstraint( - column, name=column.primary_key_name) - pk.create() - - self.append("COMMIT") - self.execute() - self.append("CALL SYSPROC.ADMIN_CMD('REORG TABLE %s')" % - self.preparer.format_table(table)) - self.execute() - - -class IBMDBColumnDropper(ansisql.ANSIColumnDropper): - def visit_column(self, column): - """Drop a column from its table. - - :param column: the column object - :type column: :class:`sqlalchemy.Column` - """ - #table = self.start_alter_table(column) - super(IBMDBColumnDropper, self).visit_column(column) - self.append("CALL SYSPROC.ADMIN_CMD('REORG TABLE %s')" % - self.preparer.format_table(column.table)) - self.execute() - - -class IBMDBSchemaChanger(IBMDBSchemaGenerator, ansisql.ANSISchemaChanger): - def visit_table(self, table): - """Rename a table; #38. Other ops aren't supported.""" - - self._rename_table(table) - self.append("TO %s" % self.preparer.quote(table.new_name)) - self.execute() - self.append("COMMIT") - self.execute() - - def _rename_table(self, table): - self.append("RENAME TABLE %s " % self.preparer.format_table(table)) - - def visit_index(self, index): - if hasattr(self, '_index_identifier'): - # SA >= 0.6.5, < 0.8 - old_name = self.preparer.quote( - self._index_identifier(index.name)) - new_name = self.preparer.quote( - self._index_identifier(index.new_name)) - else: - # SA >= 0.8 - class NewName(object): - """Map obj.name -> obj.new_name""" - def __init__(self, index): - self.name = index.new_name - self._obj = index - - def __getattr__(self, attr): - if attr == 'name': - return getattr(self, attr) - return getattr(self._obj, attr) - - old_name = self._prepared_index_name(index) - new_name = self._prepared_index_name(NewName(index)) - - self.append("RENAME INDEX %s TO %s" % (old_name, new_name)) - self.execute() - self.append("COMMIT") - self.execute() - - def _run_subvisit(self, delta, func, start_alter=True): - """Runs visit method based on what needs to be changed on column""" - table = delta.table - if start_alter: - self.start_alter_table(table) - ret = func(table, - self.preparer.quote(delta.current_name), - delta) - self.execute() - self._reorg_table(self.preparer.format_table(delta.table)) - - def _reorg_table(self, delta): - self.append("CALL SYSPROC.ADMIN_CMD('REORG TABLE %s')" % delta) - self.execute() - - def visit_column(self, delta): - keys = delta.keys() - tr = self.connection.begin() - column = delta.result_column.copy() - - if 'type' in keys: - try: - self._run_subvisit(delta, self._visit_column_change, False) - except Exception as e: - LOG.warn("Unable to change the column type. Error: %s" % e) - - if column.primary_key and 'primary_key' not in keys: - try: - self._run_subvisit(delta, self._visit_primary_key) - except Exception as e: - LOG.warn("Unable to add primary key. Error: %s" % e) - - if 'nullable' in keys: - self._run_subvisit(delta, self._visit_column_nullable) - - if 'server_default' in keys: - self._run_subvisit(delta, self._visit_column_default) - - if 'primary_key' in keys: - self._run_subvisit(delta, self._visit_primary_key) - self._run_subvisit(delta, self._visit_unique_constraint) - - if 'name' in keys: - try: - self._run_subvisit(delta, self._visit_column_name, False) - except Exception as e: - LOG.warn("Unable to change column %(name)s. Error: %(error)s" % - {'name': delta.current_name, 'error': e}) - - self._reorg_table(self.preparer.format_table(delta.table)) - self.append("COMMIT") - self.execute() - tr.commit() - - def _visit_unique_constraint(self, table, col_name, delta): - # Add primary key to the current column - self.append("ADD CONSTRAINT %s " % col_name) - self.append("UNIQUE (%s)" % col_name) - - def _visit_primary_key(self, table, col_name, delta): - # Add primary key to the current column - self.append("ADD PRIMARY KEY (%s)" % col_name) - - def _visit_column_name(self, table, col_name, delta): - column = delta.result_column.copy() - - # Delete the primary key before renaming the column - if column.primary_key: - try: - self.start_alter_table(table) - self.append("DROP PRIMARY KEY") - self.execute() - except Exception: - LOG.debug("Continue since Primary key does not exist.") - - self.start_alter_table(table) - new_name = self.preparer.format_column(delta.result_column) - self.append("RENAME COLUMN %s TO %s" % (col_name, new_name)) - - if column.primary_key: - # execute the rename before adding primary key back - self.execute() - self.start_alter_table(table) - self.append("ADD PRIMARY KEY (%s)" % new_name) - - def _visit_column_nullable(self, table, col_name, delta): - self.append("ALTER COLUMN %s " % col_name) - nullable = delta['nullable'] - if nullable: - self.append("DROP NOT NULL") - else: - self.append("SET NOT NULL") - - def _visit_column_default(self, table, col_name, delta): - default_text = self.get_column_default_string(delta.result_column) - self.append("ALTER COLUMN %s " % col_name) - if default_text is None: - self.append("DROP DEFAULT") - else: - self.append("SET WITH DEFAULT %s" % default_text) - - def _visit_column_change(self, table, col_name, delta): - column = delta.result_column.copy() - - # Delete the primary key before - if column.primary_key: - try: - self.start_alter_table(table) - self.append("DROP PRIMARY KEY") - self.execute() - except Exception: - LOG.debug("Continue since Primary key does not exist.") - # Delete the identity before - try: - self.start_alter_table(table) - self.append("ALTER COLUMN %s DROP IDENTITY" % col_name) - self.execute() - except Exception: - LOG.debug("Continue since identity does not exist.") - - column.default = None - if not column.table: - column.table = delta.table - self.start_alter_table(table) - self.append("ALTER COLUMN %s " % col_name) - self.append("SET DATA TYPE ") - type_text = self.dialect.type_compiler.process( - delta.result_column.type) - self.append(type_text) - - -class IBMDBConstraintGenerator(ansisql.ANSIConstraintGenerator): - def _visit_constraint(self, constraint): - constraint.name = self.get_constraint_name(constraint) - if (isinstance(constraint, UniqueConstraint) and - is_unique_constraint_with_null_columns_supported( - self.dialect)): - for column in constraint: - if column.nullable: - constraint.exclude_nulls = True - break - if getattr(constraint, 'exclude_nulls', None): - index = Index(constraint.name, - *(column for column in constraint), - unique=True) - sql = self.process(CreateIndex(index)) - sql += ' EXCLUDE NULL KEYS' - else: - sql = self.process(AddConstraint(constraint)) - self.append(sql) - self.execute() - - -class IBMDBConstraintDropper(ansisql.ANSIConstraintDropper, - ansisql.ANSIConstraintCommon): - def _visit_constraint(self, constraint): - constraint.name = self.get_constraint_name(constraint) - if (isinstance(constraint, UniqueConstraint) and - is_unique_constraint_with_null_columns_supported( - self.dialect)): - for column in constraint: - if column.nullable: - constraint.exclude_nulls = True - break - if getattr(constraint, 'exclude_nulls', None): - if hasattr(self, '_index_identifier'): - # SA >= 0.6.5, < 0.8 - index_name = self.preparer.quote( - self._index_identifier(constraint.name)) - else: - # SA >= 0.8 - index_name = self._prepared_index_name(constraint) - sql = 'DROP INDEX %s ' % index_name - else: - sql = self.process(DropConstraint(constraint, - cascade=constraint.cascade)) - self.append(sql) - self.execute() - - def visit_migrate_primary_key_constraint(self, constraint): - self.start_alter_table(constraint.table) - self.append("DROP PRIMARY KEY") - self.execute() - - -class IBMDBDialect(ansisql.ANSIDialect): - columngenerator = IBMDBColumnGenerator - columndropper = IBMDBColumnDropper - schemachanger = IBMDBSchemaChanger - constraintgenerator = IBMDBConstraintGenerator - constraintdropper = IBMDBConstraintDropper diff --git a/migrate/changeset/databases/mysql.py b/migrate/changeset/databases/mysql.py deleted file mode 100644 index 1c01706..0000000 --- a/migrate/changeset/databases/mysql.py +++ /dev/null @@ -1,68 +0,0 @@ -""" - MySQL database specific implementations of changeset classes. -""" - -import sqlalchemy -from sqlalchemy.databases import mysql as sa_base -from sqlalchemy import types as sqltypes - -from migrate import exceptions -from migrate.changeset import ansisql -from migrate.changeset import util - - - -MySQLSchemaGenerator = sa_base.MySQLDDLCompiler - -class MySQLColumnGenerator(MySQLSchemaGenerator, ansisql.ANSIColumnGenerator): - pass - - -class MySQLColumnDropper(ansisql.ANSIColumnDropper): - pass - - -class MySQLSchemaChanger(MySQLSchemaGenerator, ansisql.ANSISchemaChanger): - - def visit_column(self, delta): - table = delta.table - colspec = self.get_column_specification(delta.result_column) - if delta.result_column.autoincrement: - primary_keys = [c for c in table.primary_key.columns - if (c.autoincrement and - isinstance(c.type, sqltypes.Integer) and - not c.foreign_keys)] - - if primary_keys: - first = primary_keys.pop(0) - if first.name == delta.current_name: - colspec += " AUTO_INCREMENT" - old_col_name = self.preparer.quote(delta.current_name) - - self.start_alter_table(table) - - self.append("CHANGE COLUMN %s " % old_col_name) - self.append(colspec) - self.execute() - - def visit_index(self, param): - # If MySQL can do this, I can't find how - raise exceptions.NotSupportedError("MySQL cannot rename indexes") - - -class MySQLConstraintGenerator(ansisql.ANSIConstraintGenerator): - pass - - -class MySQLConstraintDropper(MySQLSchemaGenerator, ansisql.ANSIConstraintDropper): - def visit_migrate_check_constraint(self, *p, **k): - raise exceptions.NotSupportedError("MySQL does not support CHECK" - " constraints, use triggers instead.") - - -class MySQLDialect(ansisql.ANSIDialect): - columngenerator = MySQLColumnGenerator - columndropper = MySQLColumnDropper - schemachanger = MySQLSchemaChanger - constraintgenerator = MySQLConstraintGenerator - constraintdropper = MySQLConstraintDropper diff --git a/migrate/changeset/databases/oracle.py b/migrate/changeset/databases/oracle.py deleted file mode 100644 index 2f16b5b..0000000 --- a/migrate/changeset/databases/oracle.py +++ /dev/null @@ -1,108 +0,0 @@ -""" - Oracle database specific implementations of changeset classes. -""" -import sqlalchemy as sa -from sqlalchemy.databases import oracle as sa_base - -from migrate import exceptions -from migrate.changeset import ansisql - - -OracleSchemaGenerator = sa_base.OracleDDLCompiler - - -class OracleColumnGenerator(OracleSchemaGenerator, ansisql.ANSIColumnGenerator): - pass - - -class OracleColumnDropper(ansisql.ANSIColumnDropper): - pass - - -class OracleSchemaChanger(OracleSchemaGenerator, ansisql.ANSISchemaChanger): - - def get_column_specification(self, column, **kwargs): - # Ignore the NOT NULL generated - override_nullable = kwargs.pop('override_nullable', None) - if override_nullable: - orig = column.nullable - column.nullable = True - ret = super(OracleSchemaChanger, self).get_column_specification( - column, **kwargs) - if override_nullable: - column.nullable = orig - return ret - - def visit_column(self, delta): - keys = delta.keys() - - if 'name' in keys: - self._run_subvisit(delta, - self._visit_column_name, - start_alter=False) - - if len(set(('type', 'nullable', 'server_default')).intersection(keys)): - self._run_subvisit(delta, - self._visit_column_change, - start_alter=False) - - def _visit_column_change(self, table, column, delta): - # Oracle cannot drop a default once created, but it can set it - # to null. We'll do that if default=None - # http://forums.oracle.com/forums/message.jspa?messageID=1273234#1273234 - dropdefault_hack = (column.server_default is None \ - and 'server_default' in delta.keys()) - # Oracle apparently doesn't like it when we say "not null" if - # the column's already not null. Fudge it, so we don't need a - # new function - notnull_hack = ((not column.nullable) \ - and ('nullable' not in delta.keys())) - # We need to specify NULL if we're removing a NOT NULL - # constraint - null_hack = (column.nullable and ('nullable' in delta.keys())) - - if dropdefault_hack: - column.server_default = sa.PassiveDefault(sa.sql.null()) - if notnull_hack: - column.nullable = True - colspec = self.get_column_specification(column, - override_nullable=null_hack) - if null_hack: - colspec += ' NULL' - if notnull_hack: - column.nullable = False - if dropdefault_hack: - column.server_default = None - - self.start_alter_table(table) - self.append("MODIFY (") - self.append(colspec) - self.append(")") - - -class OracleConstraintCommon(object): - - def get_constraint_name(self, cons): - # Oracle constraints can't guess their name like other DBs - if not cons.name: - raise exceptions.NotSupportedError( - "Oracle constraint names must be explicitly stated") - return cons.name - - -class OracleConstraintGenerator(OracleConstraintCommon, - ansisql.ANSIConstraintGenerator): - pass - - -class OracleConstraintDropper(OracleConstraintCommon, - ansisql.ANSIConstraintDropper): - pass - - -class OracleDialect(ansisql.ANSIDialect): - columngenerator = OracleColumnGenerator - columndropper = OracleColumnDropper - schemachanger = OracleSchemaChanger - constraintgenerator = OracleConstraintGenerator - constraintdropper = OracleConstraintDropper diff --git a/migrate/changeset/databases/postgres.py b/migrate/changeset/databases/postgres.py deleted file mode 100644 index 10ea094..0000000 --- a/migrate/changeset/databases/postgres.py +++ /dev/null @@ -1,42 +0,0 @@ -""" - `PostgreSQL`_ database specific implementations of changeset classes. - - .. _`PostgreSQL`: http://www.postgresql.org/ -""" -from migrate.changeset import ansisql - -from sqlalchemy.databases import postgresql as sa_base -PGSchemaGenerator = sa_base.PGDDLCompiler - - -class PGColumnGenerator(PGSchemaGenerator, ansisql.ANSIColumnGenerator): - """PostgreSQL column generator implementation.""" - pass - - -class PGColumnDropper(ansisql.ANSIColumnDropper): - """PostgreSQL column dropper implementation.""" - pass - - -class PGSchemaChanger(ansisql.ANSISchemaChanger): - """PostgreSQL schema changer implementation.""" - pass - - -class PGConstraintGenerator(ansisql.ANSIConstraintGenerator): - """PostgreSQL constraint generator implementation.""" - pass - - -class PGConstraintDropper(ansisql.ANSIConstraintDropper): - """PostgreSQL constaint dropper implementation.""" - pass - - -class PGDialect(ansisql.ANSIDialect): - columngenerator = PGColumnGenerator - columndropper = PGColumnDropper - schemachanger = PGSchemaChanger - constraintgenerator = PGConstraintGenerator - constraintdropper = PGConstraintDropper diff --git a/migrate/changeset/databases/sqlite.py b/migrate/changeset/databases/sqlite.py deleted file mode 100644 index 908c800..0000000 --- a/migrate/changeset/databases/sqlite.py +++ /dev/null @@ -1,229 +0,0 @@ -""" - `SQLite`_ database specific implementations of changeset classes. - - .. _`SQLite`: http://www.sqlite.org/ -""" -try: # Python 3 - from collections.abc import MutableMapping as DictMixin -except ImportError: # Python 2 - from UserDict import DictMixin -from copy import copy -import re - -from sqlalchemy.databases import sqlite as sa_base -from sqlalchemy.schema import ForeignKeyConstraint -from sqlalchemy.schema import UniqueConstraint - -from migrate import exceptions -from migrate.changeset import ansisql - - -SQLiteSchemaGenerator = sa_base.SQLiteDDLCompiler - - -class SQLiteCommon(object): - - def _not_supported(self, op): - raise exceptions.NotSupportedError("SQLite does not support " - "%s; see http://www.sqlite.org/lang_altertable.html" % op) - - -class SQLiteHelper(SQLiteCommon): - - def _filter_columns(self, cols, table): - """Splits the string of columns and returns those only in the table. - - :param cols: comma-delimited string of table columns - :param table: the table to check - :return: list of columns in the table - """ - columns = [] - for c in cols.split(","): - if c in table.columns: - # There was a bug in reflection of SQLite columns with - # reserved identifiers as names (SQLite can return them - # wrapped with double quotes), so strip double quotes. - columns.extend(c.strip(' "')) - return columns - - def _get_constraints(self, table): - """Retrieve information about existing constraints of the table - - This feature is needed for recreate_table() to work properly. - """ - - data = table.metadata.bind.execute( - """SELECT sql - FROM sqlite_master - WHERE - type='table' AND - name=:table_name""", - table_name=table.name - ).fetchone()[0] - - UNIQUE_PATTERN = "CONSTRAINT (\w+) UNIQUE \(([^\)]+)\)" - constraints = [] - for name, cols in re.findall(UNIQUE_PATTERN, data): - # Filter out any columns that were dropped from the table. - columns = self._filter_columns(cols, table) - if columns: - constraints.extend(UniqueConstraint(*columns, name=name)) - - FKEY_PATTERN = "CONSTRAINT (\w+) FOREIGN KEY \(([^\)]+)\)" - for name, cols in re.findall(FKEY_PATTERN, data): - # Filter out any columns that were dropped from the table. - columns = self._filter_columns(cols, table) - if columns: - constraints.extend(ForeignKeyConstraint(*columns, name=name)) - - return constraints - - def recreate_table(self, table, column=None, delta=None, - omit_constraints=None): - table_name = self.preparer.format_table(table) - - # we remove all indexes so as not to have - # problems during copy and re-create - for index in table.indexes: - index.drop() - - # reflect existing constraints - for constraint in self._get_constraints(table): - table.append_constraint(constraint) - # omit given constraints when creating a new table if required - table.constraints = set([ - cons for cons in table.constraints - if omit_constraints is None or cons.name not in omit_constraints - ]) - - # Use "PRAGMA legacy_alter_table = ON" with sqlite >= 3.26 when - # using "ALTER TABLE RENAME TO migration_tmp" to maintain legacy - # behavior. See: https://www.sqlite.org/src/info/ae9638e9c0ad0c36 - if self.connection.engine.dialect.server_version_info >= (3, 26): - self.append('PRAGMA legacy_alter_table = ON') - self.execute() - self.append('ALTER TABLE %s RENAME TO migration_tmp' % table_name) - self.execute() - if self.connection.engine.dialect.server_version_info >= (3, 26): - self.append('PRAGMA legacy_alter_table = OFF') - self.execute() - - insertion_string = self._modify_table(table, column, delta) - - table.create(bind=self.connection) - self.append(insertion_string % {'table_name': table_name}) - self.execute() - self.append('DROP TABLE migration_tmp') - self.execute() - - def visit_column(self, delta): - if isinstance(delta, DictMixin): - column = delta.result_column - table = self._to_table(delta.table) - else: - column = delta - table = self._to_table(column.table) - self.recreate_table(table,column,delta) - -class SQLiteColumnGenerator(SQLiteSchemaGenerator, - ansisql.ANSIColumnGenerator, - # at the end so we get the normal - # visit_column by default - SQLiteHelper, - SQLiteCommon - ): - """SQLite ColumnGenerator""" - - def _modify_table(self, table, column, delta): - columns = ' ,'.join(map( - self.preparer.format_column, - [c for c in table.columns if c.name!=column.name])) - return ('INSERT INTO %%(table_name)s (%(cols)s) ' - 'SELECT %(cols)s from migration_tmp')%{'cols':columns} - - def visit_column(self,column): - if column.foreign_keys: - SQLiteHelper.visit_column(self,column) - else: - super(SQLiteColumnGenerator,self).visit_column(column) - -class SQLiteColumnDropper(SQLiteHelper, ansisql.ANSIColumnDropper): - """SQLite ColumnDropper""" - - def _modify_table(self, table, column, delta): - - columns = ' ,'.join(map(self.preparer.format_column, table.columns)) - return 'INSERT INTO %(table_name)s SELECT ' + columns + \ - ' from migration_tmp' - - def visit_column(self,column): - # For SQLite, we *have* to remove the column here so the table - # is re-created properly. - column.remove_from_table(column.table,unset_table=False) - super(SQLiteColumnDropper,self).visit_column(column) - - -class SQLiteSchemaChanger(SQLiteHelper, ansisql.ANSISchemaChanger): - """SQLite SchemaChanger""" - - def _modify_table(self, table, column, delta): - return 'INSERT INTO %(table_name)s SELECT * from migration_tmp' - - def visit_index(self, index): - """Does not support ALTER INDEX""" - self._not_supported('ALTER INDEX') - - -class SQLiteConstraintGenerator(ansisql.ANSIConstraintGenerator, SQLiteHelper, SQLiteCommon): - - def visit_migrate_primary_key_constraint(self, constraint): - tmpl = "CREATE UNIQUE INDEX %s ON %s ( %s )" - cols = ', '.join(map(self.preparer.format_column, constraint.columns)) - tname = self.preparer.format_table(constraint.table) - name = self.get_constraint_name(constraint) - msg = tmpl % (name, tname, cols) - self.append(msg) - self.execute() - - def _modify_table(self, table, column, delta): - return 'INSERT INTO %(table_name)s SELECT * from migration_tmp' - - def visit_migrate_foreign_key_constraint(self, *p, **k): - self.recreate_table(p[0].table) - - def visit_migrate_unique_constraint(self, *p, **k): - self.recreate_table(p[0].table) - - -class SQLiteConstraintDropper(ansisql.ANSIColumnDropper, - SQLiteHelper, - ansisql.ANSIConstraintCommon): - - def _modify_table(self, table, column, delta): - return 'INSERT INTO %(table_name)s SELECT * from migration_tmp' - - def visit_migrate_primary_key_constraint(self, constraint): - tmpl = "DROP INDEX %s " - name = self.get_constraint_name(constraint) - msg = tmpl % (name) - self.append(msg) - self.execute() - - def visit_migrate_foreign_key_constraint(self, *p, **k): - self.recreate_table(p[0].table, omit_constraints=[p[0].name]) - - def visit_migrate_check_constraint(self, *p, **k): - self._not_supported('ALTER TABLE DROP CONSTRAINT') - - def visit_migrate_unique_constraint(self, *p, **k): - self.recreate_table(p[0].table, omit_constraints=[p[0].name]) - - -# TODO: technically primary key is a NOT NULL + UNIQUE constraint, should add NOT NULL to index - -class SQLiteDialect(ansisql.ANSIDialect): - columngenerator = SQLiteColumnGenerator - columndropper = SQLiteColumnDropper - schemachanger = SQLiteSchemaChanger - constraintgenerator = SQLiteConstraintGenerator - constraintdropper = SQLiteConstraintDropper diff --git a/migrate/changeset/databases/visitor.py b/migrate/changeset/databases/visitor.py deleted file mode 100644 index c70aa6b..0000000 --- a/migrate/changeset/databases/visitor.py +++ /dev/null @@ -1,88 +0,0 @@ -""" - Module for visitor class mapping. -""" -import sqlalchemy as sa - -from migrate.changeset import ansisql -from migrate.changeset.databases import (sqlite, - postgres, - mysql, - oracle, - firebird) - - -# Map SA dialects to the corresponding Migrate extensions -DIALECTS = { - "default": ansisql.ANSIDialect, - "sqlite": sqlite.SQLiteDialect, - "postgres": postgres.PGDialect, - "postgresql": postgres.PGDialect, - "mysql": mysql.MySQLDialect, - "oracle": oracle.OracleDialect, - "firebird": firebird.FBDialect, -} - - -# NOTE(mriedem): We have to conditionally check for DB2 in case ibm_db_sa -# isn't available since ibm_db_sa is not packaged in sqlalchemy like the -# other dialects. -try: - from migrate.changeset.databases import ibmdb2 - DIALECTS["ibm_db_sa"] = ibmdb2.IBMDBDialect -except ImportError: - pass - - -def get_engine_visitor(engine, name): - """ - Get the visitor implementation for the given database engine. - - :param engine: SQLAlchemy Engine - :param name: Name of the visitor - :type name: string - :type engine: Engine - :returns: visitor - """ - # TODO: link to supported visitors - return get_dialect_visitor(engine.dialect, name) - - -def get_dialect_visitor(sa_dialect, name): - """ - Get the visitor implementation for the given dialect. - - Finds the visitor implementation based on the dialect class and - returns and instance initialized with the given name. - - Binds dialect specific preparer to visitor. - """ - - # map sa dialect to migrate dialect and return visitor - sa_dialect_name = getattr(sa_dialect, 'name', 'default') - migrate_dialect_cls = DIALECTS[sa_dialect_name] - visitor = getattr(migrate_dialect_cls, name) - - # bind preparer - visitor.preparer = sa_dialect.preparer(sa_dialect) - - return visitor - -def run_single_visitor(engine, visitorcallable, element, - connection=None, **kwargs): - """Taken from :meth:`sqlalchemy.engine.base.Engine._run_single_visitor` - with support for migrate visitors. - """ - if connection is None: - conn = engine.contextual_connect(close_with_result=False) - else: - conn = connection - visitor = visitorcallable(engine.dialect, conn) - try: - if hasattr(element, '__migrate_visit_name__'): - fn = getattr(visitor, 'visit_' + element.__migrate_visit_name__) - else: - fn = getattr(visitor, 'visit_' + element.__visit_name__) - fn(element, **kwargs) - finally: - if connection is None: - conn.close() diff --git a/migrate/changeset/schema.py b/migrate/changeset/schema.py deleted file mode 100644 index a33be4b..0000000 --- a/migrate/changeset/schema.py +++ /dev/null @@ -1,705 +0,0 @@ -""" - Schema module providing common schema operations. -""" -import abc -try: # Python 3 - from collections.abc import MutableMapping as DictMixin -except ImportError: # Python 2 - from UserDict import DictMixin -import warnings - -import six -import sqlalchemy - -from sqlalchemy.schema import ForeignKeyConstraint -from sqlalchemy.schema import UniqueConstraint - -from migrate.exceptions import * -from migrate.changeset import SQLA_07, SQLA_08 -from migrate.changeset import util -from migrate.changeset.databases.visitor import (get_engine_visitor, - run_single_visitor) - - -__all__ = [ - 'create_column', - 'drop_column', - 'alter_column', - 'rename_table', - 'rename_index', - 'ChangesetTable', - 'ChangesetColumn', - 'ChangesetIndex', - 'ChangesetDefaultClause', - 'ColumnDelta', -] - -def create_column(column, table=None, *p, **kw): - """Create a column, given the table. - - API to :meth:`ChangesetColumn.create`. - """ - if table is not None: - return table.create_column(column, *p, **kw) - return column.create(*p, **kw) - - -def drop_column(column, table=None, *p, **kw): - """Drop a column, given the table. - - API to :meth:`ChangesetColumn.drop`. - """ - if table is not None: - return table.drop_column(column, *p, **kw) - return column.drop(*p, **kw) - - -def rename_table(table, name, engine=None, **kw): - """Rename a table. - - If Table instance is given, engine is not used. - - API to :meth:`ChangesetTable.rename`. - - :param table: Table to be renamed. - :param name: New name for Table. - :param engine: Engine instance. - :type table: string or Table instance - :type name: string - :type engine: obj - """ - table = _to_table(table, engine) - table.rename(name, **kw) - - -def rename_index(index, name, table=None, engine=None, **kw): - """Rename an index. - - If Index instance is given, - table and engine are not used. - - API to :meth:`ChangesetIndex.rename`. - - :param index: Index to be renamed. - :param name: New name for index. - :param table: Table to which Index is reffered. - :param engine: Engine instance. - :type index: string or Index instance - :type name: string - :type table: string or Table instance - :type engine: obj - """ - index = _to_index(index, table, engine) - index.rename(name, **kw) - - -def alter_column(*p, **k): - """Alter a column. - - This is a helper function that creates a :class:`ColumnDelta` and - runs it. - - :argument column: - The name of the column to be altered or a - :class:`ChangesetColumn` column representing it. - - :param table: - A :class:`~sqlalchemy.schema.Table` or table name to - for the table where the column will be changed. - - :param engine: - The :class:`~sqlalchemy.engine.base.Engine` to use for table - reflection and schema alterations. - - :returns: A :class:`ColumnDelta` instance representing the change. - - - """ - - if 'table' not in k and isinstance(p[0], sqlalchemy.Column): - k['table'] = p[0].table - if 'engine' not in k: - k['engine'] = k['table'].bind - - # deprecation - if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column): - warnings.warn( - "Passing a Column object to alter_column is deprecated." - " Just pass in keyword parameters instead.", - MigrateDeprecationWarning - ) - engine = k['engine'] - - # enough tests seem to break when metadata is always altered - # that this crutch has to be left in until they can be sorted - # out - k['alter_metadata']=True - - delta = ColumnDelta(*p, **k) - - visitorcallable = get_engine_visitor(engine, 'schemachanger') - engine._run_visitor(visitorcallable, delta) - - return delta - - -def _to_table(table, engine=None): - """Return if instance of Table, else construct new with metadata""" - if isinstance(table, sqlalchemy.Table): - return table - - # Given: table name, maybe an engine - meta = sqlalchemy.MetaData() - if engine is not None: - meta.bind = engine - return sqlalchemy.Table(table, meta) - - -def _to_index(index, table=None, engine=None): - """Return if instance of Index, else construct new with metadata""" - if isinstance(index, sqlalchemy.Index): - return index - - # Given: index name; table name required - table = _to_table(table, engine) - ret = sqlalchemy.Index(index) - ret.table = table - return ret - - - -# Python3: if we just use: -# -# class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem): -# ... -# -# We get the following error: -# TypeError: metaclass conflict: the metaclass of a derived class must be a -# (non-strict) subclass of the metaclasses of all its bases. -# -# The complete inheritance/metaclass relationship list of ColumnDelta can be -# summarized by this following dot file: -# -# digraph test123 { -# ColumnDelta -> MutableMapping; -# MutableMapping -> Mapping; -# Mapping -> {Sized Iterable Container}; -# {Sized Iterable Container} -> ABCMeta[style=dashed]; -# -# ColumnDelta -> SchemaItem; -# SchemaItem -> {SchemaEventTarget Visitable}; -# SchemaEventTarget -> object; -# Visitable -> {VisitableType object} [style=dashed]; -# VisitableType -> type; -# } -# -# We need to use a metaclass that inherits from all the metaclasses of -# DictMixin and sqlalchemy.schema.SchemaItem. Let's call it "MyMeta". -class MyMeta(sqlalchemy.sql.visitors.VisitableType, abc.ABCMeta, object): - pass - - -class ColumnDelta(six.with_metaclass(MyMeta, DictMixin, sqlalchemy.schema.SchemaItem)): - """Extracts the differences between two columns/column-parameters - - May receive parameters arranged in several different ways: - - * **current_column, new_column, \*p, \*\*kw** - Additional parameters can be specified to override column - differences. - - * **current_column, \*p, \*\*kw** - Additional parameters alter current_column. Table name is extracted - from current_column object. - Name is changed to current_column.name from current_name, - if current_name is specified. - - * **current_col_name, \*p, \*\*kw** - Table kw must specified. - - :param table: Table at which current Column should be bound to.\ - If table name is given, reflection will be used. - :type table: string or Table instance - - :param metadata: A :class:`MetaData` instance to store - reflected table names - - :param engine: When reflecting tables, either engine or metadata must \ - be specified to acquire engine object. - :type engine: :class:`Engine` instance - :returns: :class:`ColumnDelta` instance provides interface for altered attributes to \ - `result_column` through :func:`dict` alike object. - - * :class:`ColumnDelta`.result_column is altered column with new attributes - - * :class:`ColumnDelta`.current_name is current name of column in db - - - """ - - # Column attributes that can be altered - diff_keys = ('name', 'type', 'primary_key', 'nullable', - 'server_onupdate', 'server_default', 'autoincrement') - diffs = dict() - __visit_name__ = 'column' - - def __init__(self, *p, **kw): - # 'alter_metadata' is not a public api. It exists purely - # as a crutch until the tests that fail when 'alter_metadata' - # behaviour always happens can be sorted out - self.alter_metadata = kw.pop("alter_metadata", False) - - self.meta = kw.pop("metadata", None) - self.engine = kw.pop("engine", None) - - # Things are initialized differently depending on how many column - # parameters are given. Figure out how many and call the appropriate - # method. - if len(p) >= 1 and isinstance(p[0], sqlalchemy.Column): - # At least one column specified - if len(p) >= 2 and isinstance(p[1], sqlalchemy.Column): - # Two columns specified - diffs = self.compare_2_columns(*p, **kw) - else: - # Exactly one column specified - diffs = self.compare_1_column(*p, **kw) - else: - # Zero columns specified - if not len(p) or not isinstance(p[0], six.string_types): - raise ValueError("First argument must be column name") - diffs = self.compare_parameters(*p, **kw) - - self.apply_diffs(diffs) - - def __repr__(self): - return '<ColumnDelta altermetadata=%r, %s>' % ( - self.alter_metadata, - super(ColumnDelta, self).__repr__() - ) - - def __getitem__(self, key): - if key not in self.keys(): - raise KeyError("No such diff key, available: %s" % self.diffs ) - return getattr(self.result_column, key) - - def __setitem__(self, key, value): - if key not in self.keys(): - raise KeyError("No such diff key, available: %s" % self.diffs ) - setattr(self.result_column, key, value) - - def __delitem__(self, key): - raise NotImplementedError - - def __len__(self): - raise NotImplementedError - - def __iter__(self): - raise NotImplementedError - - def keys(self): - return self.diffs.keys() - - def compare_parameters(self, current_name, *p, **k): - """Compares Column objects with reflection""" - self.table = k.pop('table') - self.result_column = self._table.c.get(current_name) - if len(p): - k = self._extract_parameters(p, k, self.result_column) - return k - - def compare_1_column(self, col, *p, **k): - """Compares one Column object""" - self.table = k.pop('table', None) - if self.table is None: - self.table = col.table - self.result_column = col - if len(p): - k = self._extract_parameters(p, k, self.result_column) - return k - - def compare_2_columns(self, old_col, new_col, *p, **k): - """Compares two Column objects""" - self.process_column(new_col) - self.table = k.pop('table', None) - # we cannot use bool() on table in SA06 - if self.table is None: - self.table = old_col.table - if self.table is None: - new_col.table - self.result_column = old_col - - # set differences - # leave out some stuff for later comp - for key in (set(self.diff_keys) - set(('type',))): - val = getattr(new_col, key, None) - if getattr(self.result_column, key, None) != val: - k.setdefault(key, val) - - # inspect types - if not self.are_column_types_eq(self.result_column.type, new_col.type): - k.setdefault('type', new_col.type) - - if len(p): - k = self._extract_parameters(p, k, self.result_column) - return k - - def apply_diffs(self, diffs): - """Populate dict and column object with new values""" - self.diffs = diffs - for key in self.diff_keys: - if key in diffs: - setattr(self.result_column, key, diffs[key]) - - self.process_column(self.result_column) - - # create an instance of class type if not yet - if 'type' in diffs: - if callable(self.result_column.type): - self.result_column.type = self.result_column.type() - if self.result_column.autoincrement and \ - not issubclass( - self.result_column.type._type_affinity, - sqlalchemy.Integer): - self.result_column.autoincrement = False - - # add column to the table - if self.table is not None and self.alter_metadata: - self.result_column.add_to_table(self.table) - - def are_column_types_eq(self, old_type, new_type): - """Compares two types to be equal""" - ret = old_type.__class__ == new_type.__class__ - - # String length is a special case - if ret and isinstance(new_type, sqlalchemy.types.String): - ret = (getattr(old_type, 'length', None) == \ - getattr(new_type, 'length', None)) - return ret - - def _extract_parameters(self, p, k, column): - """Extracts data from p and modifies diffs""" - p = list(p) - while len(p): - if isinstance(p[0], six.string_types): - k.setdefault('name', p.pop(0)) - elif isinstance(p[0], sqlalchemy.types.TypeEngine): - k.setdefault('type', p.pop(0)) - elif callable(p[0]): - p[0] = p[0]() - else: - break - - if len(p): - new_col = column.copy_fixed() - new_col._init_items(*p) - k = self.compare_2_columns(column, new_col, **k) - return k - - def process_column(self, column): - """Processes default values for column""" - # XXX: this is a snippet from SA processing of positional parameters - toinit = list() - - if column.server_default is not None: - if isinstance(column.server_default, sqlalchemy.FetchedValue): - toinit.append(column.server_default) - else: - toinit.append(sqlalchemy.DefaultClause(column.server_default)) - if column.server_onupdate is not None: - if isinstance(column.server_onupdate, FetchedValue): - toinit.append(column.server_default) - else: - toinit.append(sqlalchemy.DefaultClause(column.server_onupdate, - for_update=True)) - if toinit: - column._init_items(*toinit) - - def _get_table(self): - return getattr(self, '_table', None) - - def _set_table(self, table): - if isinstance(table, six.string_types): - if self.alter_metadata: - if not self.meta: - raise ValueError("metadata must be specified for table" - " reflection when using alter_metadata") - meta = self.meta - if self.engine: - meta.bind = self.engine - else: - if not self.engine and not self.meta: - raise ValueError("engine or metadata must be specified" - " to reflect tables") - if not self.engine: - self.engine = self.meta.bind - meta = sqlalchemy.MetaData(bind=self.engine) - self._table = sqlalchemy.Table(table, meta, autoload=True) - elif isinstance(table, sqlalchemy.Table): - self._table = table - if not self.alter_metadata: - self._table.meta = sqlalchemy.MetaData(bind=self._table.bind) - def _get_result_column(self): - return getattr(self, '_result_column', None) - - def _set_result_column(self, column): - """Set Column to Table based on alter_metadata evaluation.""" - self.process_column(column) - if not hasattr(self, 'current_name'): - self.current_name = column.name - if self.alter_metadata: - self._result_column = column - else: - self._result_column = column.copy_fixed() - - table = property(_get_table, _set_table) - result_column = property(_get_result_column, _set_result_column) - - -class ChangesetTable(object): - """Changeset extensions to SQLAlchemy tables.""" - - def create_column(self, column, *p, **kw): - """Creates a column. - - The column parameter may be a column definition or the name of - a column in this table. - - API to :meth:`ChangesetColumn.create` - - :param column: Column to be created - :type column: Column instance or string - """ - if not isinstance(column, sqlalchemy.Column): - # It's a column name - column = getattr(self.c, str(column)) - column.create(table=self, *p, **kw) - - def drop_column(self, column, *p, **kw): - """Drop a column, given its name or definition. - - API to :meth:`ChangesetColumn.drop` - - :param column: Column to be droped - :type column: Column instance or string - """ - if not isinstance(column, sqlalchemy.Column): - # It's a column name - try: - column = getattr(self.c, str(column)) - except AttributeError: - # That column isn't part of the table. We don't need - # its entire definition to drop the column, just its - # name, so create a dummy column with the same name. - column = sqlalchemy.Column(str(column), sqlalchemy.Integer()) - column.drop(table=self, *p, **kw) - - def rename(self, name, connection=None, **kwargs): - """Rename this table. - - :param name: New name of the table. - :type name: string - :param connection: reuse connection istead of creating new one. - :type connection: :class:`sqlalchemy.engine.base.Connection` instance - """ - engine = self.bind - self.new_name = name - visitorcallable = get_engine_visitor(engine, 'schemachanger') - run_single_visitor(engine, visitorcallable, self, connection, **kwargs) - - # Fix metadata registration - self.name = name - self.deregister() - self._set_parent(self.metadata) - - def _meta_key(self): - """Get the meta key for this table.""" - return sqlalchemy.schema._get_table_key(self.name, self.schema) - - def deregister(self): - """Remove this table from its metadata""" - if SQLA_07: - self.metadata._remove_table(self.name, self.schema) - else: - key = self._meta_key() - meta = self.metadata - if key in meta.tables: - del meta.tables[key] - - -class ChangesetColumn(object): - """Changeset extensions to SQLAlchemy columns.""" - - def alter(self, *p, **k): - """Makes a call to :func:`alter_column` for the column this - method is called on. - """ - if 'table' not in k: - k['table'] = self.table - if 'engine' not in k: - k['engine'] = k['table'].bind - return alter_column(self, *p, **k) - - def create(self, table=None, index_name=None, unique_name=None, - primary_key_name=None, populate_default=True, connection=None, **kwargs): - """Create this column in the database. - - Assumes the given table exists. ``ALTER TABLE ADD COLUMN``, - for most databases. - - :param table: Table instance to create on. - :param index_name: Creates :class:`ChangesetIndex` on this column. - :param unique_name: Creates :class:\ -`~migrate.changeset.constraint.UniqueConstraint` on this column. - :param primary_key_name: Creates :class:\ -`~migrate.changeset.constraint.PrimaryKeyConstraint` on this column. - :param populate_default: If True, created column will be \ -populated with defaults - :param connection: reuse connection istead of creating new one. - :type table: Table instance - :type index_name: string - :type unique_name: string - :type primary_key_name: string - :type populate_default: bool - :type connection: :class:`sqlalchemy.engine.base.Connection` instance - - :returns: self - """ - self.populate_default = populate_default - self.index_name = index_name - self.unique_name = unique_name - self.primary_key_name = primary_key_name - for cons in ('index_name', 'unique_name', 'primary_key_name'): - self._check_sanity_constraints(cons) - - self.add_to_table(table) - engine = self.table.bind - visitorcallable = get_engine_visitor(engine, 'columngenerator') - engine._run_visitor(visitorcallable, self, connection, **kwargs) - - # TODO: reuse existing connection - if self.populate_default and self.default is not None: - stmt = table.update().values({self: engine._execute_default(self.default)}) - engine.execute(stmt) - - return self - - def drop(self, table=None, connection=None, **kwargs): - """Drop this column from the database, leaving its table intact. - - ``ALTER TABLE DROP COLUMN``, for most databases. - - :param connection: reuse connection istead of creating new one. - :type connection: :class:`sqlalchemy.engine.base.Connection` instance - """ - if table is not None: - self.table = table - engine = self.table.bind - visitorcallable = get_engine_visitor(engine, 'columndropper') - engine._run_visitor(visitorcallable, self, connection, **kwargs) - self.remove_from_table(self.table, unset_table=False) - self.table = None - return self - - def add_to_table(self, table): - if table is not None and self.table is None: - if SQLA_07: - table.append_column(self) - else: - self._set_parent(table) - - def _col_name_in_constraint(self,cons,name): - return False - - def remove_from_table(self, table, unset_table=True): - # TODO: remove primary keys, constraints, etc - if unset_table: - self.table = None - - to_drop = set() - for index in table.indexes: - columns = [] - for col in index.columns: - if col.name!=self.name: - columns.append(col) - if columns: - index.columns = columns - if SQLA_08: - index.expressions = columns - else: - to_drop.add(index) - table.indexes = table.indexes - to_drop - - to_drop = set() - for cons in table.constraints: - # TODO: deal with other types of constraint - if isinstance(cons,(ForeignKeyConstraint, - UniqueConstraint)): - for col_name in cons.columns: - if not isinstance(col_name,six.string_types): - col_name = col_name.name - if self.name==col_name: - to_drop.add(cons) - table.constraints = table.constraints - to_drop - - if table.c.contains_column(self): - if SQLA_07: - table._columns.remove(self) - else: - table.c.remove(self) - - # TODO: this is fixed in 0.6 - def copy_fixed(self, **kw): - """Create a copy of this ``Column``, with all attributes.""" - return sqlalchemy.Column(self.name, self.type, self.default, - key=self.key, - primary_key=self.primary_key, - nullable=self.nullable, - index=self.index, - unique=self.unique, - onupdate=self.onupdate, - autoincrement=self.autoincrement, - server_default=self.server_default, - server_onupdate=self.server_onupdate, - *[c.copy(**kw) for c in self.constraints]) - - def _check_sanity_constraints(self, name): - """Check if constraints names are correct""" - obj = getattr(self, name) - if (getattr(self, name[:-5]) and not obj): - raise InvalidConstraintError("Column.create() accepts index_name," - " primary_key_name and unique_name to generate constraints") - if not isinstance(obj, six.string_types) and obj is not None: - raise InvalidConstraintError( - "%s argument for column must be constraint name" % name) - - -class ChangesetIndex(object): - """Changeset extensions to SQLAlchemy Indexes.""" - - __visit_name__ = 'index' - - def rename(self, name, connection=None, **kwargs): - """Change the name of an index. - - :param name: New name of the Index. - :type name: string - :param connection: reuse connection istead of creating new one. - :type connection: :class:`sqlalchemy.engine.base.Connection` instance - """ - engine = self.table.bind - self.new_name = name - visitorcallable = get_engine_visitor(engine, 'schemachanger') - engine._run_visitor(visitorcallable, self, connection, **kwargs) - self.name = name - - -class ChangesetDefaultClause(object): - """Implements comparison between :class:`DefaultClause` instances""" - - def __eq__(self, other): - if isinstance(other, self.__class__): - if self.arg == other.arg: - return True - - def __ne__(self, other): - return not self.__eq__(other) diff --git a/migrate/changeset/util.py b/migrate/changeset/util.py deleted file mode 100644 index 68b7609..0000000 --- a/migrate/changeset/util.py +++ /dev/null @@ -1,10 +0,0 @@ -from migrate.changeset import SQLA_10 - - -def fk_column_names(constraint): - if SQLA_10: - return [ - constraint.columns[key].name for key in constraint.column_keys] - else: - return [ - element.parent.name for element in constraint.elements] diff --git a/migrate/exceptions.py b/migrate/exceptions.py deleted file mode 100644 index 31c8cd9..0000000 --- a/migrate/exceptions.py +++ /dev/null @@ -1,91 +0,0 @@ -""" - Provide exception classes for :mod:`migrate` -""" - - -class Error(Exception): - """Error base class.""" - - -class ApiError(Error): - """Base class for API errors.""" - - -class KnownError(ApiError): - """A known error condition.""" - - -class UsageError(ApiError): - """A known error condition where help should be displayed.""" - - -class ControlledSchemaError(Error): - """Base class for controlled schema errors.""" - - -class InvalidVersionError(ControlledSchemaError): - """Invalid version number.""" - - -class VersionNotFoundError(KeyError): - """Specified version is not present.""" - - -class DatabaseNotControlledError(ControlledSchemaError): - """Database should be under version control, but it's not.""" - - -class DatabaseAlreadyControlledError(ControlledSchemaError): - """Database shouldn't be under version control, but it is""" - - -class WrongRepositoryError(ControlledSchemaError): - """This database is under version control by another repository.""" - - -class NoSuchTableError(ControlledSchemaError): - """The table does not exist.""" - - -class PathError(Error): - """Base class for path errors.""" - - -class PathNotFoundError(PathError): - """A path with no file was required; found a file.""" - - -class PathFoundError(PathError): - """A path with a file was required; found no file.""" - - -class RepositoryError(Error): - """Base class for repository errors.""" - - -class InvalidRepositoryError(RepositoryError): - """Invalid repository error.""" - - -class ScriptError(Error): - """Base class for script errors.""" - - -class InvalidScriptError(ScriptError): - """Invalid script error.""" - - -class InvalidVersionError(Error): - """Invalid version error.""" - -# migrate.changeset - -class NotSupportedError(Error): - """Not supported error""" - - -class InvalidConstraintError(Error): - """Invalid constraint error""" - -class MigrateDeprecationWarning(DeprecationWarning): - """Warning for deprecated features in Migrate""" diff --git a/migrate/tests/__init__.py b/migrate/tests/__init__.py deleted file mode 100644 index c03fbf4..0000000 --- a/migrate/tests/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# make this package available during imports as long as we support <python2.5 -import sys -import os -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - - -from unittest import TestCase -import migrate -import six - - -class TestVersionDefined(TestCase): - def test_version(self): - """Test for migrate.__version__""" - self.assertTrue(isinstance(migrate.__version__, six.string_types)) - self.assertTrue(len(migrate.__version__) > 0) diff --git a/migrate/tests/changeset/__init__.py b/migrate/tests/changeset/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/migrate/tests/changeset/__init__.py +++ /dev/null diff --git a/migrate/tests/changeset/databases/__init__.py b/migrate/tests/changeset/databases/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/migrate/tests/changeset/databases/__init__.py +++ /dev/null diff --git a/migrate/tests/changeset/databases/test_ibmdb2.py b/migrate/tests/changeset/databases/test_ibmdb2.py deleted file mode 100644 index 4b3f983..0000000 --- a/migrate/tests/changeset/databases/test_ibmdb2.py +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import mock - -import six - -from migrate.changeset.databases import ibmdb2 -from migrate.tests import fixture - - -class TestIBMDBDialect(fixture.Base): - """ - Test class for ibmdb2 dialect unit tests which do not require - a live backend database connection. - """ - - def test_is_unique_constraint_with_null_cols_supported(self): - test_values = { - '10.1': False, - '10.4.99': False, - '10.5': True, - '10.5.1': True - } - for version, supported in six.iteritems(test_values): - mock_dialect = mock.MagicMock() - mock_dialect.dbms_ver = version - self.assertEqual( - supported, - ibmdb2.is_unique_constraint_with_null_columns_supported( - mock_dialect), - 'Assertion failed on version: %s' % version) diff --git a/migrate/tests/changeset/test_changeset.py b/migrate/tests/changeset/test_changeset.py deleted file mode 100644 index c870c52..0000000 --- a/migrate/tests/changeset/test_changeset.py +++ /dev/null @@ -1,976 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -import sqlalchemy -import warnings - -from sqlalchemy import * - -from migrate import changeset, exceptions -from migrate.changeset import * -from migrate.changeset import constraint -from migrate.changeset.schema import ColumnDelta -from migrate.tests import fixture -from migrate.tests.fixture.warnings import catch_warnings -import six - -class TestAddDropColumn(fixture.DB): - """Test add/drop column through all possible interfaces - also test for constraints - """ - level = fixture.DB.CONNECT - table_name = 'tmp_adddropcol' - table_name_idx = 'tmp_adddropcol_idx' - table_int = 0 - - def _setup(self, url): - super(TestAddDropColumn, self)._setup(url) - self.meta = MetaData() - self.table = Table(self.table_name, self.meta, - Column('id', Integer, unique=True), - ) - self.table_idx = Table( - self.table_name_idx, - self.meta, - Column('id', Integer, primary_key=True), - Column('a', Integer), - Column('b', Integer), - Index('test_idx', 'a', 'b') - ) - self.meta.bind = self.engine - if self.engine.has_table(self.table.name): - self.table.drop() - if self.engine.has_table(self.table_idx.name): - self.table_idx.drop() - self.table.create() - self.table_idx.create() - - def _teardown(self): - if self.engine.has_table(self.table.name): - self.table.drop() - if self.engine.has_table(self.table_idx.name): - self.table_idx.drop() - self.meta.clear() - super(TestAddDropColumn,self)._teardown() - - def run_(self, create_column_func, drop_column_func, *col_p, **col_k): - col_name = 'data' - - def assert_numcols(num_of_expected_cols): - # number of cols should be correct in table object and in database - self.refresh_table(self.table_name) - result = len(self.table.c) - - self.assertEqual(result, num_of_expected_cols), - if col_k.get('primary_key', None): - # new primary key: check its length too - result = len(self.table.primary_key) - self.assertEqual(result, num_of_expected_cols) - - # we have 1 columns and there is no data column - assert_numcols(1) - self.assertTrue(getattr(self.table.c, 'data', None) is None) - if len(col_p) == 0: - col_p = [String(40)] - col = Column(col_name, *col_p, **col_k) - create_column_func(col) - assert_numcols(2) - # data column exists - self.assertTrue(self.table.c.data.type.length, 40) - - col2 = self.table.c.data - drop_column_func(col2) - assert_numcols(1) - - @fixture.usedb() - def test_undefined(self): - """Add/drop columns not yet defined in the table""" - def add_func(col): - return create_column(col, self.table) - def drop_func(col): - return drop_column(col, self.table) - return self.run_(add_func, drop_func) - - @fixture.usedb() - def test_defined(self): - """Add/drop columns already defined in the table""" - def add_func(col): - self.meta.clear() - self.table = Table(self.table_name, self.meta, - Column('id', Integer, primary_key=True), - col, - ) - return create_column(col) - def drop_func(col): - return drop_column(col) - return self.run_(add_func, drop_func) - - @fixture.usedb() - def test_method_bound(self): - """Add/drop columns via column methods; columns bound to a table - ie. no table parameter passed to function - """ - def add_func(col): - self.assertTrue(col.table is None, col.table) - self.table.append_column(col) - return col.create() - def drop_func(col): - #self.assertTrue(col.table is None,col.table) - #self.table.append_column(col) - return col.drop() - return self.run_(add_func, drop_func) - - @fixture.usedb() - def test_method_notbound(self): - """Add/drop columns via column methods; columns not bound to a table""" - def add_func(col): - return col.create(self.table) - def drop_func(col): - return col.drop(self.table) - return self.run_(add_func, drop_func) - - @fixture.usedb() - def test_tablemethod_obj(self): - """Add/drop columns via table methods; by column object""" - def add_func(col): - return self.table.create_column(col) - def drop_func(col): - return self.table.drop_column(col) - return self.run_(add_func, drop_func) - - @fixture.usedb() - def test_tablemethod_name(self): - """Add/drop columns via table methods; by column name""" - def add_func(col): - # must be bound to table - self.table.append_column(col) - return self.table.create_column(col.name) - def drop_func(col): - # Not necessarily bound to table - return self.table.drop_column(col.name) - return self.run_(add_func, drop_func) - - @fixture.usedb() - def test_byname(self): - """Add/drop columns via functions; by table object and column name""" - def add_func(col): - self.table.append_column(col) - return create_column(col.name, self.table) - def drop_func(col): - return drop_column(col.name, self.table) - return self.run_(add_func, drop_func) - - @fixture.usedb() - def test_drop_column_not_in_table(self): - """Drop column by name""" - def add_func(col): - return self.table.create_column(col) - def drop_func(col): - if SQLA_07: - self.table._columns.remove(col) - else: - self.table.c.remove(col) - return self.table.drop_column(col.name) - self.run_(add_func, drop_func) - - @fixture.usedb() - def test_fk(self): - """Can create columns with foreign keys""" - # create FK's target - reftable = Table('tmp_ref', self.meta, - Column('id', Integer, primary_key=True), - ) - if self.engine.has_table(reftable.name): - reftable.drop() - reftable.create() - - # create column with fk - col = Column('data', Integer, ForeignKey(reftable.c.id, name='testfk')) - col.create(self.table) - - # check if constraint is added - for cons in self.table.constraints: - if isinstance(cons, sqlalchemy.schema.ForeignKeyConstraint): - break - else: - self.fail('No constraint found') - - # TODO: test on db level if constraints work - - if SQLA_07: - self.assertEqual(reftable.c.id.name, - list(col.foreign_keys)[0].column.name) - else: - self.assertEqual(reftable.c.id.name, - col.foreign_keys[0].column.name) - - if self.engine.name == 'mysql': - constraint.ForeignKeyConstraint([self.table.c.data], - [reftable.c.id], - name='testfk').drop() - col.drop(self.table) - - if self.engine.has_table(reftable.name): - reftable.drop() - - @fixture.usedb(not_supported='sqlite') - def test_pk(self): - """Can create columns with primary key""" - col = Column('data', Integer, nullable=False) - self.assertRaises(exceptions.InvalidConstraintError, - col.create, self.table, primary_key_name=True) - col.create(self.table, primary_key_name='data_pkey') - - # check if constraint was added (cannot test on objects) - self.table.insert(values={'data': 4}).execute() - try: - self.table.insert(values={'data': 4}).execute() - except (sqlalchemy.exc.IntegrityError, - sqlalchemy.exc.ProgrammingError): - pass - else: - self.fail() - - col.drop() - - @fixture.usedb(not_supported=['mysql']) - def test_check(self): - """Can create columns with check constraint""" - col = Column('foo', - Integer, - sqlalchemy.schema.CheckConstraint('foo > 4')) - col.create(self.table) - - # check if constraint was added (cannot test on objects) - self.table.insert(values={'foo': 5}).execute() - try: - self.table.insert(values={'foo': 3}).execute() - except (sqlalchemy.exc.IntegrityError, - sqlalchemy.exc.ProgrammingError): - pass - else: - self.fail() - - col.drop() - - @fixture.usedb() - def test_unique_constraint(self): - self.assertRaises(exceptions.InvalidConstraintError, - Column('data', Integer, unique=True).create, self.table) - - col = Column('data', Integer) - col.create(self.table, unique_name='data_unique') - - # check if constraint was added (cannot test on objects) - self.table.insert(values={'data': 5}).execute() - try: - self.table.insert(values={'data': 5}).execute() - except (sqlalchemy.exc.IntegrityError, - sqlalchemy.exc.ProgrammingError): - pass - else: - self.fail() - - col.drop(self.table) - -# TODO: remove already attached columns with uniques, pks, fks .. - @fixture.usedb(not_supported=['ibm_db_sa', 'postgresql']) - def test_drop_column_of_composite_index(self): - # NOTE(rpodolyaka): postgresql automatically drops a composite index - # if one of its columns is dropped - # NOTE(mriedem): DB2 does the same. - self.table_idx.c.b.drop() - - reflected = Table(self.table_idx.name, MetaData(), autoload=True, - autoload_with=self.engine) - index = next(iter(reflected.indexes)) - self.assertEquals(['a'], [c.name for c in index.columns]) - - @fixture.usedb() - def test_drop_all_columns_of_composite_index(self): - self.table_idx.c.a.drop() - self.table_idx.c.b.drop() - - reflected = Table(self.table_idx.name, MetaData(), autoload=True, - autoload_with=self.engine) - self.assertEquals(0, len(reflected.indexes)) - - def _check_index(self,expected): - if 'mysql' in self.engine.name or 'postgres' in self.engine.name: - for index in tuple( - Table(self.table.name, MetaData(), - autoload=True, autoload_with=self.engine).indexes - ): - if index.name=='ix_data': - break - self.assertEqual(expected,index.unique) - - @fixture.usedb() - def test_index(self): - col = Column('data', Integer) - col.create(self.table, index_name='ix_data') - - self._check_index(False) - - col.drop() - - @fixture.usedb() - def test_index_unique(self): - # shows how to create a unique index - col = Column('data', Integer) - col.create(self.table) - Index('ix_data', col, unique=True).create(bind=self.engine) - - # check if index was added - self.table.insert(values={'data': 5}).execute() - try: - self.table.insert(values={'data': 5}).execute() - except (sqlalchemy.exc.IntegrityError, - sqlalchemy.exc.ProgrammingError): - pass - else: - self.fail() - - self._check_index(True) - - col.drop() - - @fixture.usedb() - def test_server_defaults(self): - """Can create columns with server_default values""" - col = Column('data', String(244), server_default='foobar') - col.create(self.table) - - self.table.insert(values={'id': 10}).execute() - row = self._select_row() - self.assertEqual(u'foobar', row['data']) - - col.drop() - - @fixture.usedb() - def test_populate_default(self): - """Test populate_default=True""" - def default(): - return 'foobar' - col = Column('data', String(244), default=default) - col.create(self.table, populate_default=True) - - self.table.insert(values={'id': 10}).execute() - row = self._select_row() - self.assertEqual(u'foobar', row['data']) - - col.drop() - - # TODO: test sequence - # TODO: test quoting - # TODO: test non-autoname constraints - - @fixture.usedb() - def test_drop_doesnt_delete_other_indexes(self): - # add two indexed columns - self.table.drop() - self.meta.clear() - self.table = Table( - self.table_name, self.meta, - Column('id', Integer, primary_key=True), - Column('d1', String(10), index=True), - Column('d2', String(10), index=True), - ) - self.table.create() - - # paranoid check - self.refresh_table() - self.assertEqual( - sorted([i.name for i in self.table.indexes]), - [u'ix_tmp_adddropcol_d1', u'ix_tmp_adddropcol_d2'] - ) - - # delete one - self.table.c.d2.drop() - - # ensure the other index is still there - self.refresh_table() - self.assertEqual( - sorted([i.name for i in self.table.indexes]), - [u'ix_tmp_adddropcol_d1'] - ) - - def _actual_foreign_keys(self): - from sqlalchemy.schema import ForeignKeyConstraint - result = [] - for cons in self.table.constraints: - if isinstance(cons,ForeignKeyConstraint): - col_names = [] - for col_name in cons.columns: - if not isinstance(col_name,six.string_types): - col_name = col_name.name - col_names.append(col_name) - result.append(col_names) - result.sort() - return result - - @fixture.usedb() - def test_drop_with_foreign_keys(self): - self.table.drop() - self.meta.clear() - - # create FK's target - reftable = Table('tmp_ref', self.meta, - Column('id', Integer, primary_key=True), - ) - if self.engine.has_table(reftable.name): - reftable.drop() - reftable.create() - - # add a table with two foreign key columns - self.table = Table( - self.table_name, self.meta, - Column('id', Integer, primary_key=True), - Column('r1', Integer, ForeignKey('tmp_ref.id', name='test_fk1')), - Column('r2', Integer, ForeignKey('tmp_ref.id', name='test_fk2')), - ) - self.table.create() - - # paranoid check - self.assertEqual([['r1'],['r2']], - self._actual_foreign_keys()) - - # delete one - if self.engine.name == 'mysql': - constraint.ForeignKeyConstraint([self.table.c.r2], [reftable.c.id], - name='test_fk2').drop() - self.table.c.r2.drop() - - # check remaining foreign key is there - self.assertEqual([['r1']], - self._actual_foreign_keys()) - - @fixture.usedb() - def test_drop_with_complex_foreign_keys(self): - from sqlalchemy.schema import ForeignKeyConstraint - from sqlalchemy.schema import UniqueConstraint - - self.table.drop() - self.meta.clear() - - # NOTE(mriedem): DB2 does not currently support unique constraints - # on nullable columns, so the columns that are used to create the - # foreign keys here need to be non-nullable for testing with DB2 - # to work. - - # create FK's target - reftable = Table('tmp_ref', self.meta, - Column('id', Integer, primary_key=True), - Column('jd', Integer, nullable=False), - UniqueConstraint('id','jd') - ) - if self.engine.has_table(reftable.name): - reftable.drop() - reftable.create() - - # add a table with a complex foreign key constraint - self.table = Table( - self.table_name, self.meta, - Column('id', Integer, primary_key=True), - Column('r1', Integer, nullable=False), - Column('r2', Integer, nullable=False), - ForeignKeyConstraint(['r1','r2'], - [reftable.c.id,reftable.c.jd], - name='test_fk') - ) - self.table.create() - - # paranoid check - self.assertEqual([['r1','r2']], - self._actual_foreign_keys()) - - # delete one - if self.engine.name == 'mysql': - constraint.ForeignKeyConstraint([self.table.c.r1, self.table.c.r2], - [reftable.c.id, reftable.c.jd], - name='test_fk').drop() - self.table.c.r2.drop() - - # check the constraint is gone, since part of it - # is no longer there - if people hit this, - # they may be confused, maybe we should raise an error - # and insist that the constraint is deleted first, separately? - self.assertEqual([], - self._actual_foreign_keys()) - -class TestRename(fixture.DB): - """Tests for table and index rename methods""" - level = fixture.DB.CONNECT - meta = MetaData() - - def _setup(self, url): - super(TestRename, self)._setup(url) - self.meta.bind = self.engine - - @fixture.usedb(not_supported='firebird') - def test_rename_table(self): - """Tables can be renamed""" - c_name = 'col_1' - table_name1 = 'name_one' - table_name2 = 'name_two' - index_name1 = 'x' + table_name1 - index_name2 = 'x' + table_name2 - - self.meta.clear() - self.column = Column(c_name, Integer) - self.table = Table(table_name1, self.meta, self.column) - self.index = Index(index_name1, self.column, unique=False) - - if self.engine.has_table(self.table.name): - self.table.drop() - if self.engine.has_table(table_name2): - tmp = Table(table_name2, self.meta, autoload=True) - tmp.drop() - tmp.deregister() - del tmp - self.table.create() - - def assert_table_name(expected, skip_object_check=False): - """Refresh a table via autoload - SA has changed some since this test was written; we now need to do - meta.clear() upon reloading a table - clear all rather than a - select few. So, this works only if we're working with one table at - a time (else, others will vanish too). - """ - if not skip_object_check: - # Table object check - self.assertEqual(self.table.name,expected) - newname = self.table.name - else: - # we know the object's name isn't consistent: just assign it - newname = expected - # Table DB check - self.meta.clear() - self.table = Table(newname, self.meta, autoload=True) - self.assertEqual(self.table.name, expected) - - def assert_index_name(expected, skip_object_check=False): - if not skip_object_check: - # Index object check - self.assertEqual(self.index.name, expected) - else: - # object is inconsistent - self.index.name = expected - # TODO: Index DB check - - def add_table_to_meta(name): - # trigger the case where table_name2 needs to be - # removed from the metadata in ChangesetTable.deregister() - tmp = Table(name, self.meta, Column(c_name, Integer)) - tmp.create() - tmp.drop() - - try: - # Table renames - assert_table_name(table_name1) - add_table_to_meta(table_name2) - rename_table(self.table, table_name2) - assert_table_name(table_name2) - self.table.rename(table_name1) - assert_table_name(table_name1) - - # test by just the string - rename_table(table_name1, table_name2, engine=self.engine) - assert_table_name(table_name2, True) # object not updated - - # Index renames - if self.url.startswith('sqlite') or self.url.startswith('mysql'): - self.assertRaises(exceptions.NotSupportedError, - self.index.rename, index_name2) - else: - assert_index_name(index_name1) - rename_index(self.index, index_name2, engine=self.engine) - assert_index_name(index_name2) - self.index.rename(index_name1) - assert_index_name(index_name1) - - # test by just the string - rename_index(index_name1, index_name2, engine=self.engine) - assert_index_name(index_name2, True) - - finally: - if self.table.exists(): - self.table.drop() - - -class TestColumnChange(fixture.DB): - level = fixture.DB.CONNECT - table_name = 'tmp_colchange' - - def _setup(self, url): - super(TestColumnChange, self)._setup(url) - self.meta = MetaData(self.engine) - self.table = Table(self.table_name, self.meta, - Column('id', Integer, primary_key=True), - Column('data', String(40), server_default=DefaultClause("tluafed"), - nullable=True), - ) - if self.table.exists(): - self.table.drop() - try: - self.table.create() - except sqlalchemy.exc.SQLError: - # SQLite: database schema has changed - if not self.url.startswith('sqlite://'): - raise - - def _teardown(self): - if self.table.exists(): - try: - self.table.drop(self.engine) - except sqlalchemy.exc.SQLError: - # SQLite: database schema has changed - if not self.url.startswith('sqlite://'): - raise - super(TestColumnChange, self)._teardown() - - @fixture.usedb() - def test_rename(self): - """Can rename a column""" - def num_rows(col, content): - return len(list(self.table.select(col == content).execute())) - # Table content should be preserved in changed columns - content = "fgsfds" - self.engine.execute(self.table.insert(), data=content, id=42) - self.assertEqual(num_rows(self.table.c.data, content), 1) - - # ...as a function, given a column object and the new name - alter_column('data', name='data2', table=self.table) - self.refresh_table() - alter_column(self.table.c.data2, name='atad') - self.refresh_table(self.table.name) - self.assertTrue('data' not in self.table.c.keys()) - self.assertTrue('atad' in self.table.c.keys()) - self.assertEqual(num_rows(self.table.c.atad, content), 1) - - # ...as a method, given a new name - self.table.c.atad.alter(name='data') - self.refresh_table(self.table.name) - self.assertTrue('atad' not in self.table.c.keys()) - self.table.c.data # Should not raise exception - self.assertEqual(num_rows(self.table.c.data, content), 1) - - # ...as a function, given a new object - alter_column(self.table.c.data, - name = 'atad', type=String(40), - server_default=self.table.c.data.server_default) - self.refresh_table(self.table.name) - self.assertTrue('data' not in self.table.c.keys()) - self.table.c.atad # Should not raise exception - self.assertEqual(num_rows(self.table.c.atad, content), 1) - - # ...as a method, given a new object - self.table.c.atad.alter( - name='data',type=String(40), - server_default=self.table.c.atad.server_default - ) - self.refresh_table(self.table.name) - self.assertTrue('atad' not in self.table.c.keys()) - self.table.c.data # Should not raise exception - self.assertEqual(num_rows(self.table.c.data,content), 1) - - @fixture.usedb() - def test_type(self): - # Test we can change a column's type - - # Just the new type - self.table.c.data.alter(type=String(43)) - self.refresh_table(self.table.name) - self.assertTrue(isinstance(self.table.c.data.type, String)) - self.assertEqual(self.table.c.data.type.length, 43) - - # Different type - self.assertTrue(isinstance(self.table.c.id.type, Integer)) - self.assertEqual(self.table.c.id.nullable, False) - - # SQLAlchemy 1.1 adds a third state to "autoincrement" called - # "auto". - self.assertTrue(self.table.c.id.autoincrement in ('auto', True)) - - if not self.engine.name == 'firebird': - self.table.c.id.alter(type=String(20)) - self.assertEqual(self.table.c.id.nullable, False) - - # a rule makes sure that autoincrement is set to False - # when we change off of Integer - self.assertEqual(self.table.c.id.autoincrement, False) - self.refresh_table(self.table.name) - self.assertTrue(isinstance(self.table.c.id.type, String)) - - # note that after reflection, "autoincrement" is likely - # to change back to a database-generated value. Should be - # False or "auto". if True, it's a bug; at least one of these - # exists prior to SQLAlchemy 1.1.3 - - @fixture.usedb() - def test_default(self): - """Can change a column's server_default value (DefaultClauses only) - Only DefaultClauses are changed here: others are managed by the - application / by SA - """ - self.assertEqual(self.table.c.data.server_default.arg, 'tluafed') - - # Just the new default - default = 'my_default' - self.table.c.data.alter(server_default=DefaultClause(default)) - self.refresh_table(self.table.name) - #self.assertEqual(self.table.c.data.server_default.arg,default) - # TextClause returned by autoload - self.assertTrue(default in str(self.table.c.data.server_default.arg)) - self.engine.execute(self.table.insert(), id=12) - row = self._select_row() - self.assertEqual(row['data'], default) - - # Column object - default = 'your_default' - self.table.c.data.alter(type=String(40), server_default=DefaultClause(default)) - self.refresh_table(self.table.name) - self.assertTrue(default in str(self.table.c.data.server_default.arg)) - - # Drop/remove default - self.table.c.data.alter(server_default=None) - self.assertEqual(self.table.c.data.server_default, None) - - self.refresh_table(self.table.name) - # server_default isn't necessarily None for Oracle - #self.assertTrue(self.table.c.data.server_default is None,self.table.c.data.server_default) - self.engine.execute(self.table.insert(), id=11) - row = self.table.select(self.table.c.id == 11).execution_options(autocommit=True).execute().fetchone() - self.assertTrue(row['data'] is None, row['data']) - - @fixture.usedb(not_supported='firebird') - def test_null(self): - """Can change a column's null constraint""" - self.assertEqual(self.table.c.data.nullable, True) - - # Full column - self.table.c.data.alter(type=String(40), nullable=False) - self.table.nullable = None - self.refresh_table(self.table.name) - self.assertEqual(self.table.c.data.nullable, False) - - # Just the new status - self.table.c.data.alter(nullable=True) - self.refresh_table(self.table.name) - self.assertEqual(self.table.c.data.nullable, True) - - @fixture.usedb() - def test_alter_deprecated(self): - try: - # py 2.4 compatibility :-/ - cw = catch_warnings(record=True) - w = cw.__enter__() - - warnings.simplefilter("always") - self.table.c.data.alter(Column('data', String(100))) - - self.assertEqual(len(w),1) - self.assertTrue(issubclass(w[-1].category, - MigrateDeprecationWarning)) - self.assertEqual( - 'Passing a Column object to alter_column is deprecated. ' - 'Just pass in keyword parameters instead.', - str(w[-1].message)) - finally: - cw.__exit__() - - @fixture.usedb() - def test_alter_returns_delta(self): - """Test if alter constructs return delta""" - - delta = self.table.c.data.alter(type=String(100)) - self.assertTrue('type' in delta) - - @fixture.usedb() - def test_alter_all(self): - """Tests all alter changes at one time""" - # test for each db separately - # since currently some dont support everything - - # test pre settings - self.assertEqual(self.table.c.data.nullable, True) - self.assertEqual(self.table.c.data.server_default.arg, 'tluafed') - self.assertEqual(self.table.c.data.name, 'data') - self.assertTrue(isinstance(self.table.c.data.type, String)) - self.assertTrue(self.table.c.data.type.length, 40) - - kw = dict(nullable=False, - server_default='foobar', - name='data_new', - type=String(50)) - if self.engine.name == 'firebird': - del kw['nullable'] - self.table.c.data.alter(**kw) - - # test altered objects - self.assertEqual(self.table.c.data.server_default.arg, 'foobar') - if not self.engine.name == 'firebird': - self.assertEqual(self.table.c.data.nullable, False) - self.assertEqual(self.table.c.data.name, 'data_new') - self.assertEqual(self.table.c.data.type.length, 50) - - self.refresh_table(self.table.name) - - # test post settings - if not self.engine.name == 'firebird': - self.assertEqual(self.table.c.data_new.nullable, False) - self.assertEqual(self.table.c.data_new.name, 'data_new') - self.assertTrue(isinstance(self.table.c.data_new.type, String)) - self.assertTrue(self.table.c.data_new.type.length, 50) - - # insert data and assert default - self.table.insert(values={'id': 10}).execute() - row = self._select_row() - self.assertEqual(u'foobar', row['data_new']) - - -class TestColumnDelta(fixture.DB): - """Tests ColumnDelta class""" - - level = fixture.DB.CONNECT - table_name = 'tmp_coldelta' - table_int = 0 - - def _setup(self, url): - super(TestColumnDelta, self)._setup(url) - self.meta = MetaData() - self.table = Table(self.table_name, self.meta, - Column('ids', String(10)), - ) - self.meta.bind = self.engine - if self.engine.has_table(self.table.name): - self.table.drop() - self.table.create() - - def _teardown(self): - if self.engine.has_table(self.table.name): - self.table.drop() - self.meta.clear() - super(TestColumnDelta,self)._teardown() - - def mkcol(self, name='id', type=String, *p, **k): - return Column(name, type, *p, **k) - - def verify(self, expected, original, *p, **k): - self.delta = ColumnDelta(original, *p, **k) - result = list(self.delta.keys()) - result.sort() - self.assertEqual(expected, result) - return self.delta - - def test_deltas_two_columns(self): - """Testing ColumnDelta with two columns""" - col_orig = self.mkcol(primary_key=True) - col_new = self.mkcol(name='ids', primary_key=True) - self.verify([], col_orig, col_orig) - self.verify(['name'], col_orig, col_orig, 'ids') - self.verify(['name'], col_orig, col_orig, name='ids') - self.verify(['name'], col_orig, col_new) - self.verify(['name', 'type'], col_orig, col_new, type=String) - - # Type comparisons - self.verify([], self.mkcol(type=String), self.mkcol(type=String)) - self.verify(['type'], self.mkcol(type=String), self.mkcol(type=Integer)) - self.verify(['type'], self.mkcol(type=String), self.mkcol(type=String(42))) - self.verify([], self.mkcol(type=String(42)), self.mkcol(type=String(42))) - self.verify(['type'], self.mkcol(type=String(24)), self.mkcol(type=String(42))) - self.verify(['type'], self.mkcol(type=String(24)), self.mkcol(type=Text(24))) - - # Other comparisons - self.verify(['primary_key'], self.mkcol(nullable=False), self.mkcol(primary_key=True)) - - # PK implies nullable=False - self.verify(['nullable', 'primary_key'], self.mkcol(nullable=True), self.mkcol(primary_key=True)) - self.verify([], self.mkcol(primary_key=True), self.mkcol(primary_key=True)) - self.verify(['nullable'], self.mkcol(nullable=True), self.mkcol(nullable=False)) - self.verify([], self.mkcol(nullable=True), self.mkcol(nullable=True)) - self.verify([], self.mkcol(server_default=None), self.mkcol(server_default=None)) - self.verify([], self.mkcol(server_default='42'), self.mkcol(server_default='42')) - - # test server default - delta = self.verify(['server_default'], self.mkcol(), self.mkcol('id', String, DefaultClause('foobar'))) - self.assertEqual(delta['server_default'].arg, 'foobar') - - self.verify([], self.mkcol(server_default='foobar'), self.mkcol('id', String, DefaultClause('foobar'))) - self.verify(['type'], self.mkcol(server_default='foobar'), self.mkcol('id', Text, DefaultClause('foobar'))) - - col = self.mkcol(server_default='foobar') - self.verify(['type'], col, self.mkcol('id', Text, DefaultClause('foobar')), alter_metadata=True) - self.assertTrue(isinstance(col.type, Text)) - - col = self.mkcol() - self.verify(['name', 'server_default', 'type'], col, self.mkcol('beep', Text, DefaultClause('foobar')), - alter_metadata=True) - self.assertTrue(isinstance(col.type, Text)) - self.assertEqual(col.name, 'beep') - self.assertEqual(col.server_default.arg, 'foobar') - - @fixture.usedb() - def test_deltas_zero_columns(self): - """Testing ColumnDelta with zero columns""" - - self.verify(['name'], 'ids', table=self.table, name='hey') - - # test reflection - self.verify(['type'], 'ids', table=self.table.name, type=String(80), engine=self.engine) - self.verify(['type'], 'ids', table=self.table.name, type=String(80), metadata=self.meta) - - self.meta.clear() - delta = self.verify(['type'], 'ids', table=self.table.name, type=String(80), metadata=self.meta, - alter_metadata=True) - self.assertTrue(self.table.name in self.meta) - self.assertEqual(delta.result_column.type.length, 80) - self.assertEqual(self.meta.tables.get(self.table.name).c.ids.type.length, 80) - - # test defaults - self.meta.clear() - self.verify(['server_default'], 'ids', table=self.table.name, server_default='foobar', - metadata=self.meta, - alter_metadata=True) - self.meta.tables.get(self.table.name).c.ids.server_default.arg == 'foobar' - - # test missing parameters - self.assertRaises(ValueError, ColumnDelta, table=self.table.name) - self.assertRaises(ValueError, ColumnDelta, 'ids', table=self.table.name, alter_metadata=True) - self.assertRaises(ValueError, ColumnDelta, 'ids', table=self.table.name, alter_metadata=False) - - def test_deltas_one_column(self): - """Testing ColumnDelta with one column""" - col_orig = self.mkcol(primary_key=True) - - self.verify([], col_orig) - self.verify(['name'], col_orig, 'ids') - # Parameters are always executed, even if they're 'unchanged' - # (We can't assume given column is up-to-date) - self.verify(['name', 'primary_key', 'type'], col_orig, 'id', Integer, primary_key=True) - self.verify(['name', 'primary_key', 'type'], col_orig, name='id', type=Integer, primary_key=True) - - # Change name, given an up-to-date definition and the current name - delta = self.verify(['name'], col_orig, name='blah') - self.assertEqual(delta.get('name'), 'blah') - self.assertEqual(delta.current_name, 'id') - - col_orig = self.mkcol(primary_key=True) - self.verify(['name', 'type'], col_orig, name='id12', type=Text, alter_metadata=True) - self.assertTrue(isinstance(col_orig.type, Text)) - self.assertEqual(col_orig.name, 'id12') - - # test server default - col_orig = self.mkcol(primary_key=True) - delta = self.verify(['server_default'], col_orig, DefaultClause('foobar')) - self.assertEqual(delta['server_default'].arg, 'foobar') - - delta = self.verify(['server_default'], col_orig, server_default=DefaultClause('foobar')) - self.assertEqual(delta['server_default'].arg, 'foobar') - - # no change - col_orig = self.mkcol(server_default=DefaultClause('foobar')) - delta = self.verify(['type'], col_orig, DefaultClause('foobar'), type=PickleType) - self.assertTrue(isinstance(delta.result_column.type, PickleType)) - - # TODO: test server on update - # TODO: test bind metadata diff --git a/migrate/tests/changeset/test_constraint.py b/migrate/tests/changeset/test_constraint.py deleted file mode 100644 index 325b3c0..0000000 --- a/migrate/tests/changeset/test_constraint.py +++ /dev/null @@ -1,299 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -from sqlalchemy import * -from sqlalchemy.util import * -from sqlalchemy.exc import * - -from migrate.changeset.util import fk_column_names -from migrate.exceptions import * -from migrate.changeset import * - -from migrate.tests import fixture - - -class CommonTestConstraint(fixture.DB): - """helper functions to test constraints. - - we just create a fresh new table and make sure everything is - as required. - """ - - def _setup(self, url): - super(CommonTestConstraint, self)._setup(url) - self._create_table() - - def _teardown(self): - if hasattr(self, 'table') and self.engine.has_table(self.table.name): - self.table.drop() - super(CommonTestConstraint, self)._teardown() - - def _create_table(self): - self._connect(self.url) - self.meta = MetaData(self.engine) - self.tablename = 'mytable' - self.table = Table(self.tablename, self.meta, - Column(u'id', Integer, nullable=False), - Column(u'fkey', Integer, nullable=False), - mysql_engine='InnoDB') - if self.engine.has_table(self.table.name): - self.table.drop() - self.table.create() - - # make sure we start at zero - self.assertEqual(len(self.table.primary_key), 0) - self.assertTrue(isinstance(self.table.primary_key, - schema.PrimaryKeyConstraint), self.table.primary_key.__class__) - - -class TestConstraint(CommonTestConstraint): - level = fixture.DB.CONNECT - - def _define_pk(self, *cols): - # Add a pk by creating a PK constraint - if (self.engine.name in ('oracle', 'firebird')): - # Can't drop Oracle PKs without an explicit name - pk = PrimaryKeyConstraint(table=self.table, name='temp_pk_key', *cols) - else: - pk = PrimaryKeyConstraint(table=self.table, *cols) - self.compare_columns_equal(pk.columns, cols) - pk.create() - self.refresh_table() - if not self.url.startswith('sqlite'): - self.compare_columns_equal(self.table.primary_key, cols, ['type', 'autoincrement']) - - # Drop the PK constraint - #if (self.engine.name in ('oracle', 'firebird')): - # # Apparently Oracle PK names aren't introspected - # pk.name = self.table.primary_key.name - pk.drop() - self.refresh_table() - self.assertEqual(len(self.table.primary_key), 0) - self.assertTrue(isinstance(self.table.primary_key, schema.PrimaryKeyConstraint)) - return pk - - @fixture.usedb() - def test_define_fk(self): - """FK constraints can be defined, created, and dropped""" - # FK target must be unique - pk = PrimaryKeyConstraint(self.table.c.id, table=self.table, name="pkid") - pk.create() - - # Add a FK by creating a FK constraint - if SQLA_07: - self.assertEqual(list(self.table.c.fkey.foreign_keys), []) - else: - self.assertEqual(self.table.c.fkey.foreign_keys._list, []) - fk = ForeignKeyConstraint([self.table.c.fkey], - [self.table.c.id], - name="fk_id_fkey", - ondelete="CASCADE") - if SQLA_07: - self.assertTrue(list(self.table.c.fkey.foreign_keys) is not []) - else: - self.assertTrue(self.table.c.fkey.foreign_keys._list is not []) - for key in fk_column_names(fk): - self.assertEqual(key, self.table.c.fkey.name) - self.assertEqual([e.column for e in fk.elements], [self.table.c.id]) - self.assertEqual(list(fk.referenced), [self.table.c.id]) - - if self.url.startswith('mysql'): - # MySQL FKs need an index - index = Index('index_name', self.table.c.fkey) - index.create() - fk.create() - - # test for ondelete/onupdate - if SQLA_07: - fkey = list(self.table.c.fkey.foreign_keys)[0] - else: - fkey = self.table.c.fkey.foreign_keys._list[0] - self.assertEqual(fkey.ondelete, "CASCADE") - # TODO: test on real db if it was set - - self.refresh_table() - if SQLA_07: - self.assertTrue(list(self.table.c.fkey.foreign_keys) is not []) - else: - self.assertTrue(self.table.c.fkey.foreign_keys._list is not []) - - fk.drop() - self.refresh_table() - if SQLA_07: - self.assertEqual(list(self.table.c.fkey.foreign_keys), []) - else: - self.assertEqual(self.table.c.fkey.foreign_keys._list, []) - - @fixture.usedb() - def test_define_pk(self): - """PK constraints can be defined, created, and dropped""" - self._define_pk(self.table.c.fkey) - - @fixture.usedb() - def test_define_pk_multi(self): - """Multicolumn PK constraints can be defined, created, and dropped""" - self._define_pk(self.table.c.id, self.table.c.fkey) - - @fixture.usedb(not_supported=['firebird']) - def test_drop_cascade(self): - """Drop constraint cascaded""" - pk = PrimaryKeyConstraint('fkey', table=self.table, name="id_pkey") - pk.create() - self.refresh_table() - - # Drop the PK constraint forcing cascade - pk.drop(cascade=True) - - # TODO: add real assertion if it was added - - @fixture.usedb(supported=['mysql']) - def test_fail_mysql_check_constraints(self): - """Check constraints raise NotSupported for mysql on drop""" - cons = CheckConstraint('id > 3', name="id_check", table=self.table) - cons.create() - self.refresh_table() - - try: - cons.drop() - except NotSupportedError: - pass - else: - self.fail() - - @fixture.usedb(not_supported=['sqlite', 'mysql']) - def test_named_check_constraints(self): - """Check constraints can be defined, created, and dropped""" - self.assertRaises(InvalidConstraintError, CheckConstraint, 'id > 3') - cons = CheckConstraint('id > 3', name="id_check", table=self.table) - cons.create() - self.refresh_table() - - self.table.insert(values={'id': 4, 'fkey': 1}).execute() - try: - self.table.insert(values={'id': 1, 'fkey': 1}).execute() - except (IntegrityError, ProgrammingError): - pass - else: - self.fail() - - # Remove the name, drop the constraint; it should succeed - cons.drop() - self.refresh_table() - self.table.insert(values={'id': 2, 'fkey': 2}).execute() - self.table.insert(values={'id': 1, 'fkey': 2}).execute() - - -class TestAutoname(CommonTestConstraint): - """Every method tests for a type of constraint wether it can autoname - itself and if you can pass object instance and names to classes. - """ - level = fixture.DB.CONNECT - - @fixture.usedb(not_supported=['oracle', 'firebird']) - def test_autoname_pk(self): - """PrimaryKeyConstraints can guess their name if None is given""" - # Don't supply a name; it should create one - cons = PrimaryKeyConstraint(self.table.c.id) - cons.create() - self.refresh_table() - if not self.url.startswith('sqlite'): - # TODO: test for index for sqlite - self.compare_columns_equal(cons.columns, self.table.primary_key, ['autoincrement', 'type']) - - # Remove the name, drop the constraint; it should succeed - cons.name = None - cons.drop() - self.refresh_table() - self.assertEqual(list(), list(self.table.primary_key)) - - # test string names - cons = PrimaryKeyConstraint('id', table=self.table) - cons.create() - self.refresh_table() - if not self.url.startswith('sqlite'): - # TODO: test for index for sqlite - self.compare_columns_equal(cons.columns, self.table.primary_key) - cons.name = None - cons.drop() - - @fixture.usedb(not_supported=['oracle', 'sqlite', 'firebird']) - def test_autoname_fk(self): - """ForeignKeyConstraints can guess their name if None is given""" - cons = PrimaryKeyConstraint(self.table.c.id) - cons.create() - - cons = ForeignKeyConstraint([self.table.c.fkey], [self.table.c.id]) - cons.create() - self.refresh_table() - if SQLA_07: - list(self.table.c.fkey.foreign_keys)[0].column is self.table.c.id - else: - self.table.c.fkey.foreign_keys[0].column is self.table.c.id - - # Remove the name, drop the constraint; it should succeed - cons.name = None - cons.drop() - self.refresh_table() - if SQLA_07: - self.assertEqual(list(self.table.c.fkey.foreign_keys), list()) - else: - self.assertEqual(self.table.c.fkey.foreign_keys._list, list()) - - # test string names - cons = ForeignKeyConstraint(['fkey'], ['%s.id' % self.tablename], table=self.table) - cons.create() - self.refresh_table() - if SQLA_07: - list(self.table.c.fkey.foreign_keys)[0].column is self.table.c.id - else: - self.table.c.fkey.foreign_keys[0].column is self.table.c.id - - # Remove the name, drop the constraint; it should succeed - cons.name = None - cons.drop() - - @fixture.usedb(not_supported=['oracle', 'sqlite', 'mysql']) - def test_autoname_check(self): - """CheckConstraints can guess their name if None is given""" - cons = CheckConstraint('id > 3', columns=[self.table.c.id]) - cons.create() - self.refresh_table() - - if not self.engine.name == 'mysql': - self.table.insert(values={'id': 4, 'fkey': 1}).execute() - try: - self.table.insert(values={'id': 1, 'fkey': 2}).execute() - except (IntegrityError, ProgrammingError): - pass - else: - self.fail() - - # Remove the name, drop the constraint; it should succeed - cons.name = None - cons.drop() - self.refresh_table() - self.table.insert(values={'id': 2, 'fkey': 2}).execute() - self.table.insert(values={'id': 1, 'fkey': 3}).execute() - - @fixture.usedb(not_supported=['oracle']) - def test_autoname_unique(self): - """UniqueConstraints can guess their name if None is given""" - cons = UniqueConstraint(self.table.c.fkey) - cons.create() - self.refresh_table() - - self.table.insert(values={'fkey': 4, 'id': 1}).execute() - try: - self.table.insert(values={'fkey': 4, 'id': 2}).execute() - except (sqlalchemy.exc.IntegrityError, - sqlalchemy.exc.ProgrammingError): - pass - else: - self.fail() - - # Remove the name, drop the constraint; it should succeed - cons.name = None - cons.drop() - self.refresh_table() - self.table.insert(values={'fkey': 4, 'id': 2}).execute() - self.table.insert(values={'fkey': 4, 'id': 1}).execute() diff --git a/migrate/tests/fixture/__init__.py b/migrate/tests/fixture/__init__.py deleted file mode 100644 index 6b8bc48..0000000 --- a/migrate/tests/fixture/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import testtools - -def main(imports=None): - if imports: - global suite - suite = suite(imports) - defaultTest='fixture.suite' - else: - defaultTest=None - return testtools.TestProgram(defaultTest=defaultTest) - -from .base import Base -from .pathed import Pathed -from .shell import Shell -from .database import DB,usedb diff --git a/migrate/tests/fixture/base.py b/migrate/tests/fixture/base.py deleted file mode 100644 index 38c91af..0000000 --- a/migrate/tests/fixture/base.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import re -import testtools - -class Base(testtools.TestCase): - - def assertEqualIgnoreWhitespace(self, v1, v2): - """Compares two strings that should be\ - identical except for whitespace - """ - def strip_whitespace(s): - return re.sub(r'\s', '', s) - - line1 = strip_whitespace(v1) - line2 = strip_whitespace(v2) - - self.assertEqual(line1, line2, "%s != %s" % (v1, v2)) - - def ignoreErrors(self, func, *p,**k): - """Call a function, ignoring any exceptions""" - try: - func(*p,**k) - except: - pass diff --git a/migrate/tests/fixture/database.py b/migrate/tests/fixture/database.py deleted file mode 100644 index 93bd69b..0000000 --- a/migrate/tests/fixture/database.py +++ /dev/null @@ -1,203 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import os -import logging -import sys - -import six -from decorator import decorator - -from sqlalchemy import create_engine, Table, MetaData -from sqlalchemy import exc as sa_exc -from sqlalchemy.orm import create_session -from sqlalchemy.pool import StaticPool - -from migrate.changeset.schema import ColumnDelta -from migrate.versioning.util import Memoize - -from migrate.tests.fixture.base import Base -from migrate.tests.fixture.pathed import Pathed - - -log = logging.getLogger(__name__) - -@Memoize -def readurls(): - """read URLs from config file return a list""" - # TODO: remove tmpfile since sqlite can store db in memory - filename = 'test_db.cfg' if six.PY2 else "test_db_py3.cfg" - ret = list() - tmpfile = Pathed.tmp() - fullpath = os.path.join(os.curdir, filename) - - try: - fd = open(fullpath) - except IOError: - raise IOError("""You must specify the databases to use for testing! -Copy %(filename)s.tmpl to %(filename)s and edit your database URLs.""" % locals()) - - for line in fd: - if line.startswith('#'): - continue - line = line.replace('__tmp__', tmpfile).strip() - ret.append(line) - fd.close() - return ret - -def is_supported(url, supported, not_supported): - db = url.split(':', 1)[0] - - if supported is not None: - if isinstance(supported, six.string_types): - return supported == db - else: - return db in supported - elif not_supported is not None: - if isinstance(not_supported, six.string_types): - return not_supported != db - else: - return not (db in not_supported) - return True - - -def usedb(supported=None, not_supported=None): - """Decorates tests to be run with a database connection - These tests are run once for each available database - - @param supported: run tests for ONLY these databases - @param not_supported: run tests for all databases EXCEPT these - - If both supported and not_supported are empty, all dbs are assumed - to be supported - """ - if supported is not None and not_supported is not None: - raise AssertionError("Can't specify both supported and not_supported in fixture.db()") - - urls = readurls() - my_urls = [url for url in urls if is_supported(url, supported, not_supported)] - - @decorator - def dec(f, self, *a, **kw): - failed_for = [] - fail = False - for url in my_urls: - try: - log.debug("Running test with engine %s", url) - try: - self._setup(url) - except sa_exc.OperationalError: - log.info('Backend %s is not available, skip it', url) - continue - except Exception as e: - raise RuntimeError('Exception during _setup(): %r' % e) - - try: - f(self, *a, **kw) - finally: - try: - self._teardown() - except Exception as e: - raise RuntimeError('Exception during _teardown(): %r' % e) - except Exception: - failed_for.append(url) - fail = sys.exc_info() - for url in failed_for: - log.error('Failed for %s', url) - if fail: - # cause the failure :-) - six.reraise(*fail) - return dec - - -class DB(Base): - # Constants: connection level - NONE = 0 # No connection; just set self.url - CONNECT = 1 # Connect; no transaction - TXN = 2 # Everything in a transaction - - level = TXN - - def _engineInfo(self, url=None): - if url is None: - url = self.url - return url - - def _setup(self, url): - self._connect(url) - # make sure there are no tables lying around - meta = MetaData(self.engine) - meta.reflect() - meta.drop_all() - - def _teardown(self): - self._disconnect() - - def _connect(self, url): - self.url = url - # TODO: seems like 0.5.x branch does not work with engine.dispose and staticpool - #self.engine = create_engine(url, echo=True, poolclass=StaticPool) - self.engine = create_engine(url, echo=True) - # silence the logger added by SA, nose adds its own! - logging.getLogger('sqlalchemy').handlers=[] - self.meta = MetaData(bind=self.engine) - if self.level < self.CONNECT: - return - #self.session = create_session(bind=self.engine) - if self.level < self.TXN: - return - #self.txn = self.session.begin() - - def _disconnect(self): - if hasattr(self, 'txn'): - self.txn.rollback() - if hasattr(self, 'session'): - self.session.close() - #if hasattr(self,'conn'): - # self.conn.close() - self.engine.dispose() - - def _supported(self, url): - db = url.split(':',1)[0] - func = getattr(self, self._TestCase__testMethodName) - if hasattr(func, 'supported'): - return db in func.supported - if hasattr(func, 'not_supported'): - return not (db in func.not_supported) - # Neither list assigned; assume all are supported - return True - - def _not_supported(self, url): - return not self._supported(url) - - def _select_row(self): - """Select rows, used in multiple tests""" - return self.table.select().execution_options( - autocommit=True).execute().fetchone() - - def refresh_table(self, name=None): - """Reload the table from the database - Assumes we're working with only a single table, self.table, and - metadata self.meta - - Working w/ multiple tables is not possible, as tables can only be - reloaded with meta.clear() - """ - if name is None: - name = self.table.name - self.meta.clear() - self.table = Table(name, self.meta, autoload=True) - - def compare_columns_equal(self, columns1, columns2, ignore=None): - """Loop through all columns and compare them""" - def key(column): - return column.name - for c1, c2 in zip(sorted(columns1, key=key), sorted(columns2, key=key)): - diffs = ColumnDelta(c1, c2).diffs - if ignore: - for key in ignore: - diffs.pop(key, None) - if diffs: - self.fail("Comparing %s to %s failed: %s" % (columns1, columns2, diffs)) - -# TODO: document engine.dispose and write tests diff --git a/migrate/tests/fixture/models.py b/migrate/tests/fixture/models.py deleted file mode 100644 index ee76429..0000000 --- a/migrate/tests/fixture/models.py +++ /dev/null @@ -1,14 +0,0 @@ -from sqlalchemy import * - -# test rundiffs in shell -meta_old_rundiffs = MetaData() -meta_rundiffs = MetaData() -meta = MetaData() - -tmp_account_rundiffs = Table('tmp_account_rundiffs', meta_rundiffs, - Column('id', Integer, primary_key=True), - Column('login', Text()), - Column('passwd', Text()), -) - -tmp_sql_table = Table('tmp_sql_table', meta, Column('id', Integer)) diff --git a/migrate/tests/fixture/pathed.py b/migrate/tests/fixture/pathed.py deleted file mode 100644 index 78cf4cd..0000000 --- a/migrate/tests/fixture/pathed.py +++ /dev/null @@ -1,77 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import os -import sys -import shutil -import tempfile - -from migrate.tests.fixture import base - - -class Pathed(base.Base): - # Temporary files - - _tmpdir = tempfile.mkdtemp() - - def setUp(self): - super(Pathed, self).setUp() - self.temp_usable_dir = tempfile.mkdtemp() - sys.path.append(self.temp_usable_dir) - - def tearDown(self): - super(Pathed, self).tearDown() - try: - sys.path.remove(self.temp_usable_dir) - except: - pass # w00t? - Pathed.purge(self.temp_usable_dir) - - @classmethod - def _tmp(cls, prefix='', suffix=''): - """Generate a temporary file name that doesn't exist - All filenames are generated inside a temporary directory created by - tempfile.mkdtemp(); only the creating user has access to this directory. - It should be secure to return a nonexistant temp filename in this - directory, unless the user is messing with their own files. - """ - file, ret = tempfile.mkstemp(suffix,prefix,cls._tmpdir) - os.close(file) - os.remove(ret) - return ret - - @classmethod - def tmp(cls, *p, **k): - return cls._tmp(*p, **k) - - @classmethod - def tmp_py(cls, *p, **k): - return cls._tmp(suffix='.py', *p, **k) - - @classmethod - def tmp_sql(cls, *p, **k): - return cls._tmp(suffix='.sql', *p, **k) - - @classmethod - def tmp_named(cls, name): - return os.path.join(cls._tmpdir, name) - - @classmethod - def tmp_repos(cls, *p, **k): - return cls._tmp(*p, **k) - - @classmethod - def purge(cls, path): - """Removes this path if it exists, in preparation for tests - Careful - all tests should take place in /tmp. - We don't want to accidentally wipe stuff out... - """ - if os.path.exists(path): - if os.path.isdir(path): - shutil.rmtree(path) - else: - os.remove(path) - if path.endswith('.py'): - pyc = path + 'c' - if os.path.exists(pyc): - os.remove(pyc) diff --git a/migrate/tests/fixture/shell.py b/migrate/tests/fixture/shell.py deleted file mode 100644 index 566d250..0000000 --- a/migrate/tests/fixture/shell.py +++ /dev/null @@ -1,33 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import os -import sys -import logging - -from scripttest import TestFileEnvironment - -from migrate.tests.fixture.pathed import * - - -log = logging.getLogger(__name__) - -class Shell(Pathed): - """Base class for command line tests""" - - def setUp(self): - super(Shell, self).setUp() - migrate_path = os.path.dirname(sys.executable) - # PATH to migrate development script folder - log.debug('PATH for ScriptTest: %s', migrate_path) - self.env = TestFileEnvironment( - base_path=os.path.join(self.temp_usable_dir, 'env'), - ) - - def run_version(self, repos_path): - result = self.env.run('migrate version %s' % repos_path) - return int(result.stdout.strip()) - - def run_db_version(self, url, repos_path): - result = self.env.run('migrate db_version %s %s' % (url, repos_path)) - return int(result.stdout.strip()) diff --git a/migrate/tests/fixture/warnings.py b/migrate/tests/fixture/warnings.py deleted file mode 100644 index 8d99c0f..0000000 --- a/migrate/tests/fixture/warnings.py +++ /dev/null @@ -1,88 +0,0 @@ -# lifted from Python 2.6, so we can use it in Python 2.5 -import sys - -class WarningMessage(object): - - """Holds the result of a single showwarning() call.""" - - _WARNING_DETAILS = ("message", "category", "filename", "lineno", "file", - "line") - - def __init__(self, message, category, filename, lineno, file=None, - line=None): - local_values = locals() - for attr in self._WARNING_DETAILS: - setattr(self, attr, local_values[attr]) - if category: - self._category_name = category.__name__ - else: - self._category_name = None - - def __str__(self): - return ("{message : %r, category : %r, filename : %r, lineno : %s, " - "line : %r}" % (self.message, self._category_name, - self.filename, self.lineno, self.line)) - - -class catch_warnings(object): - - """A context manager that copies and restores the warnings filter upon - exiting the context. - - The 'record' argument specifies whether warnings should be captured by a - custom implementation of warnings.showwarning() and be appended to a list - returned by the context manager. Otherwise None is returned by the context - manager. The objects appended to the list are arguments whose attributes - mirror the arguments to showwarning(). - - The 'module' argument is to specify an alternative module to the module - named 'warnings' and imported under that name. This argument is only useful - when testing the warnings module itself. - - """ - - def __init__(self, record=False, module=None): - """Specify whether to record warnings and if an alternative module - should be used other than sys.modules['warnings']. - - For compatibility with Python 3.0, please consider all arguments to be - keyword-only. - - """ - self._record = record - if module is None: - self._module = sys.modules['warnings'] - else: - self._module = module - self._entered = False - - def __repr__(self): - args = [] - if self._record: - args.append("record=True") - if self._module is not sys.modules['warnings']: - args.append("module=%r" % self._module) - name = type(self).__name__ - return "%s(%s)" % (name, ", ".join(args)) - - def __enter__(self): - if self._entered: - raise RuntimeError("Cannot enter %r twice" % self) - self._entered = True - self._filters = self._module.filters - self._module.filters = self._filters[:] - self._showwarning = self._module.showwarning - if self._record: - log = [] - def showwarning(*args, **kwargs): - log.append(WarningMessage(*args, **kwargs)) - self._module.showwarning = showwarning - return log - else: - return None - - def __exit__(self, *exc_info): - if not self._entered: - raise RuntimeError("Cannot exit %r without entering first" % self) - self._module.filters = self._filters - self._module.showwarning = self._showwarning diff --git a/migrate/tests/integrated/__init__.py b/migrate/tests/integrated/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/migrate/tests/integrated/__init__.py +++ /dev/null diff --git a/migrate/tests/integrated/test_docs.py b/migrate/tests/integrated/test_docs.py deleted file mode 100644 index 8e35427..0000000 --- a/migrate/tests/integrated/test_docs.py +++ /dev/null @@ -1,18 +0,0 @@ -import doctest -import os - - -from migrate.tests import fixture - -# Collect tests for all handwritten docs: doc/*.rst - -dir = ('..','..','..','doc','source') -absdir = (os.path.dirname(os.path.abspath(__file__)),)+dir -dirpath = os.path.join(*absdir) -files = [f for f in os.listdir(dirpath) if f.endswith('.rst')] -paths = [os.path.join(*(dir+(f,))) for f in files] -assert len(paths) > 0 -suite = doctest.DocFileSuite(*paths) - -def test_docs(): - suite.debug() diff --git a/migrate/tests/versioning/__init__.py b/migrate/tests/versioning/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/migrate/tests/versioning/__init__.py +++ /dev/null diff --git a/migrate/tests/versioning/test_api.py b/migrate/tests/versioning/test_api.py deleted file mode 100644 index bc4b29d..0000000 --- a/migrate/tests/versioning/test_api.py +++ /dev/null @@ -1,128 +0,0 @@ -#!/usr/bin/python -# -*- coding: utf-8 -*- - -import six - -from migrate.exceptions import * -from migrate.versioning import api - -from migrate.tests.fixture.pathed import * -from migrate.tests.fixture import models -from migrate.tests import fixture - - -class TestAPI(Pathed): - - def test_help(self): - self.assertTrue(isinstance(api.help('help'), six.string_types)) - self.assertRaises(UsageError, api.help) - self.assertRaises(UsageError, api.help, 'foobar') - self.assertTrue(isinstance(api.help('create'), str)) - - # test that all commands return some text - for cmd in api.__all__: - content = api.help(cmd) - self.assertTrue(content) - - def test_create(self): - tmprepo = self.tmp_repos() - api.create(tmprepo, 'temp') - - # repository already exists - self.assertRaises(KnownError, api.create, tmprepo, 'temp') - - def test_script(self): - repo = self.tmp_repos() - api.create(repo, 'temp') - api.script('first version', repo) - - def test_script_sql(self): - repo = self.tmp_repos() - api.create(repo, 'temp') - api.script_sql('postgres', 'desc', repo) - - def test_version(self): - repo = self.tmp_repos() - api.create(repo, 'temp') - api.version(repo) - - def test_version_control(self): - repo = self.tmp_repos() - api.create(repo, 'temp') - api.version_control('sqlite:///', repo) - api.version_control('sqlite:///', six.text_type(repo)) - - def test_source(self): - repo = self.tmp_repos() - api.create(repo, 'temp') - api.script('first version', repo) - api.script_sql('default', 'desc', repo) - - # no repository - self.assertRaises(UsageError, api.source, 1) - - # stdout - out = api.source(1, dest=None, repository=repo) - self.assertTrue(out) - - # file - out = api.source(1, dest=self.tmp_repos(), repository=repo) - self.assertFalse(out) - - def test_manage(self): - output = api.manage(os.path.join(self.temp_usable_dir, 'manage.py')) - - -class TestSchemaAPI(fixture.DB, Pathed): - - def _setup(self, url): - super(TestSchemaAPI, self)._setup(url) - self.repo = self.tmp_repos() - api.create(self.repo, 'temp') - self.schema = api.version_control(url, self.repo) - - def _teardown(self): - self.schema = api.drop_version_control(self.url, self.repo) - super(TestSchemaAPI, self)._teardown() - - @fixture.usedb() - def test_workflow(self): - self.assertEqual(api.db_version(self.url, self.repo), 0) - api.script('First Version', self.repo) - self.assertEqual(api.db_version(self.url, self.repo), 0) - api.upgrade(self.url, self.repo, 1) - self.assertEqual(api.db_version(self.url, self.repo), 1) - api.downgrade(self.url, self.repo, 0) - self.assertEqual(api.db_version(self.url, self.repo), 0) - api.test(self.url, self.repo) - self.assertEqual(api.db_version(self.url, self.repo), 0) - - # preview - # TODO: test output - out = api.upgrade(self.url, self.repo, preview_py=True) - out = api.upgrade(self.url, self.repo, preview_sql=True) - - api.upgrade(self.url, self.repo, 1) - api.script_sql('default', 'desc', self.repo) - self.assertRaises(UsageError, api.upgrade, self.url, self.repo, 2, preview_py=True) - out = api.upgrade(self.url, self.repo, 2, preview_sql=True) - - # cant upgrade to version 1, already at version 1 - self.assertEqual(api.db_version(self.url, self.repo), 1) - self.assertRaises(KnownError, api.upgrade, self.url, self.repo, 0) - - @fixture.usedb() - def test_compare_model_to_db(self): - diff = api.compare_model_to_db(self.url, self.repo, models.meta) - - @fixture.usedb() - def test_create_model(self): - model = api.create_model(self.url, self.repo) - - @fixture.usedb() - def test_make_update_script_for_model(self): - model = api.make_update_script_for_model(self.url, self.repo, models.meta_old_rundiffs, models.meta_rundiffs) - - @fixture.usedb() - def test_update_db_from_model(self): - model = api.update_db_from_model(self.url, self.repo, models.meta_rundiffs) diff --git a/migrate/tests/versioning/test_cfgparse.py b/migrate/tests/versioning/test_cfgparse.py deleted file mode 100644 index a31273e..0000000 --- a/migrate/tests/versioning/test_cfgparse.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/python -# -*- coding: utf-8 -*- - -from migrate.versioning import cfgparse -from migrate.versioning.repository import * -from migrate.versioning.template import Template -from migrate.tests import fixture - - -class TestConfigParser(fixture.Base): - - def test_to_dict(self): - """Correctly interpret config results as dictionaries""" - parser = cfgparse.Parser(dict(default_value=42)) - self.assertTrue(len(parser.sections()) == 0) - parser.add_section('section') - parser.set('section','option','value') - self.assertEqual(parser.get('section', 'option'), 'value') - self.assertEqual(parser.to_dict()['section']['option'], 'value') - - def test_table_config(self): - """We should be able to specify the table to be used with a repository""" - default_text = Repository.prepare_config(Template().get_repository(), - 'repository_name', {}) - specified_text = Repository.prepare_config(Template().get_repository(), - 'repository_name', {'version_table': '_other_table'}) - self.assertNotEqual(default_text, specified_text) diff --git a/migrate/tests/versioning/test_database.py b/migrate/tests/versioning/test_database.py deleted file mode 100644 index 8291c6b..0000000 --- a/migrate/tests/versioning/test_database.py +++ /dev/null @@ -1,13 +0,0 @@ -from sqlalchemy import select, text -from migrate.tests import fixture - -class TestConnect(fixture.DB): - level=fixture.DB.TXN - - @fixture.usedb() - def test_connect(self): - """Connect to the database successfully""" - # Connection is done in fixture.DB setup; make sure we can do stuff - self.engine.execute( - select([text('42')]) - ) diff --git a/migrate/tests/versioning/test_genmodel.py b/migrate/tests/versioning/test_genmodel.py deleted file mode 100644 index f800826..0000000 --- a/migrate/tests/versioning/test_genmodel.py +++ /dev/null @@ -1,214 +0,0 @@ -# -*- coding: utf-8 -*- - -import os - -import six -import sqlalchemy -from sqlalchemy import * - -from migrate.versioning import genmodel, schemadiff -from migrate.changeset import schema - -from migrate.tests import fixture - - -class TestSchemaDiff(fixture.DB): - table_name = 'tmp_schemadiff' - level = fixture.DB.CONNECT - - def _setup(self, url): - super(TestSchemaDiff, self)._setup(url) - self.meta = MetaData(self.engine) - self.meta.reflect() - self.meta.drop_all() # in case junk tables are lying around in the test database - self.meta = MetaData(self.engine) - self.meta.reflect() # needed if we just deleted some tables - self.table = Table(self.table_name, self.meta, - Column('id',Integer(), primary_key=True), - Column('name', UnicodeText()), - Column('data', UnicodeText()), - ) - - def _teardown(self): - if self.table.exists(): - self.meta = MetaData(self.engine) - self.meta.reflect() - self.meta.drop_all() - super(TestSchemaDiff, self)._teardown() - - def _applyLatestModel(self): - diff = schemadiff.getDiffOfModelAgainstDatabase(self.meta, self.engine, excludeTables=['migrate_version']) - genmodel.ModelGenerator(diff,self.engine).runB2A() - - # NOTE(mriedem): DB2 handles UnicodeText as LONG VARGRAPHIC - # so the schema diffs on the columns don't work with this test. - @fixture.usedb(not_supported='ibm_db_sa') - def test_functional(self): - def assertDiff(isDiff, tablesMissingInDatabase, tablesMissingInModel, tablesWithDiff): - diff = schemadiff.getDiffOfModelAgainstDatabase(self.meta, self.engine, excludeTables=['migrate_version']) - self.assertEqual( - (diff.tables_missing_from_B, - diff.tables_missing_from_A, - list(diff.tables_different.keys()), - bool(diff)), - (tablesMissingInDatabase, - tablesMissingInModel, - tablesWithDiff, - isDiff) - ) - - # Model is defined but database is empty. - assertDiff(True, [self.table_name], [], []) - - # Check Python upgrade and downgrade of database from updated model. - diff = schemadiff.getDiffOfModelAgainstDatabase(self.meta, self.engine, excludeTables=['migrate_version']) - decls, upgradeCommands, downgradeCommands = genmodel.ModelGenerator(diff,self.engine).genB2AMigration() - - # Feature test for a recent SQLa feature; - # expect different output in that case. - if repr(String()) == 'String()': - self.assertEqualIgnoreWhitespace(decls, ''' - from migrate.changeset import schema - pre_meta = MetaData() - post_meta = MetaData() - tmp_schemadiff = Table('tmp_schemadiff', post_meta, - Column('id', Integer, primary_key=True, nullable=False), - Column('name', UnicodeText), - Column('data', UnicodeText), - ) - ''') - else: - self.assertEqualIgnoreWhitespace(decls, ''' - from migrate.changeset import schema - pre_meta = MetaData() - post_meta = MetaData() - tmp_schemadiff = Table('tmp_schemadiff', post_meta, - Column('id', Integer, primary_key=True, nullable=False), - Column('name', UnicodeText(length=None)), - Column('data', UnicodeText(length=None)), - ) - ''') - - # Create table in database, now model should match database. - self._applyLatestModel() - assertDiff(False, [], [], []) - - # Check Python code gen from database. - diff = schemadiff.getDiffOfModelAgainstDatabase(MetaData(), self.engine, excludeTables=['migrate_version']) - src = genmodel.ModelGenerator(diff,self.engine).genBDefinition() - - namespace = {} - six.exec_(src, namespace) - - c1 = Table('tmp_schemadiff', self.meta, autoload=True).c - c2 = namespace['tmp_schemadiff'].c - self.compare_columns_equal(c1, c2, ['type']) - # TODO: get rid of ignoring type - - if not self.engine.name == 'oracle': - # Add data, later we'll make sure it's still present. - result = self.engine.execute(self.table.insert(), id=1, name=u'mydata') - dataId = result.inserted_primary_key[0] - - # Modify table in model (by removing it and adding it back to model) - # Drop column data, add columns data2 and data3. - self.meta.remove(self.table) - self.table = Table(self.table_name,self.meta, - Column('id',Integer(),primary_key=True), - Column('name',UnicodeText(length=None)), - Column('data2',Integer(),nullable=True), - Column('data3',Integer(),nullable=True), - ) - assertDiff(True, [], [], [self.table_name]) - - # Apply latest model changes and find no more diffs. - self._applyLatestModel() - assertDiff(False, [], [], []) - - # Drop column data3, add data4 - self.meta.remove(self.table) - self.table = Table(self.table_name,self.meta, - Column('id',Integer(),primary_key=True), - Column('name',UnicodeText(length=None)), - Column('data2',Integer(),nullable=True), - Column('data4',Float(),nullable=True), - ) - assertDiff(True, [], [], [self.table_name]) - - diff = schemadiff.getDiffOfModelAgainstDatabase( - self.meta, self.engine, excludeTables=['migrate_version']) - decls, upgradeCommands, downgradeCommands = genmodel.ModelGenerator(diff,self.engine).genB2AMigration(indent='') - - # decls have changed since genBDefinition - six.exec_(decls, namespace) - # migration commands expect a namespace containing migrate_engine - namespace['migrate_engine'] = self.engine - # run the migration up and down - six.exec_(upgradeCommands, namespace) - assertDiff(False, [], [], []) - - six.exec_(decls, namespace) - six.exec_(downgradeCommands, namespace) - assertDiff(True, [], [], [self.table_name]) - - six.exec_(decls, namespace) - six.exec_(upgradeCommands, namespace) - assertDiff(False, [], [], []) - - if not self.engine.name == 'oracle': - # Make sure data is still present. - result = self.engine.execute(self.table.select(self.table.c.id==dataId)) - rows = result.fetchall() - self.assertEqual(len(rows), 1) - self.assertEqual(rows[0].name, 'mydata') - - # Add data, later we'll make sure it's still present. - result = self.engine.execute(self.table.insert(), id=2, name=u'mydata2', data2=123) - dataId2 = result.inserted_primary_key[0] - - # Change column type in model. - self.meta.remove(self.table) - self.table = Table(self.table_name,self.meta, - Column('id',Integer(),primary_key=True), - Column('name',UnicodeText(length=None)), - Column('data2',String(255),nullable=True), - ) - - # XXX test type diff - return - - assertDiff(True, [], [], [self.table_name]) - - # Apply latest model changes and find no more diffs. - self._applyLatestModel() - assertDiff(False, [], [], []) - - if not self.engine.name == 'oracle': - # Make sure data is still present. - result = self.engine.execute(self.table.select(self.table.c.id==dataId2)) - rows = result.fetchall() - self.assertEqual(len(rows), 1) - self.assertEqual(rows[0].name, 'mydata2') - self.assertEqual(rows[0].data2, '123') - - # Delete data, since we're about to make a required column. - # Not even using sqlalchemy.PassiveDefault helps because we're doing explicit column select. - self.engine.execute(self.table.delete(), id=dataId) - - if not self.engine.name == 'firebird': - # Change column nullable in model. - self.meta.remove(self.table) - self.table = Table(self.table_name,self.meta, - Column('id',Integer(),primary_key=True), - Column('name',UnicodeText(length=None)), - Column('data2',String(255),nullable=False), - ) - assertDiff(True, [], [], [self.table_name]) # TODO test nullable diff - - # Apply latest model changes and find no more diffs. - self._applyLatestModel() - assertDiff(False, [], [], []) - - # Remove table from model. - self.meta.remove(self.table) - assertDiff(True, [], [self.table_name], []) diff --git a/migrate/tests/versioning/test_keyedinstance.py b/migrate/tests/versioning/test_keyedinstance.py deleted file mode 100644 index 485cbbb..0000000 --- a/migrate/tests/versioning/test_keyedinstance.py +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/python -# -*- coding: utf-8 -*- - -from migrate.tests import fixture -from migrate.versioning.util.keyedinstance import * - -class TestKeydInstance(fixture.Base): - def test_unique(self): - """UniqueInstance should produce unique object instances""" - class Uniq1(KeyedInstance): - @classmethod - def _key(cls,key): - return str(key) - def __init__(self,value): - self.value=value - class Uniq2(KeyedInstance): - @classmethod - def _key(cls,key): - return str(key) - def __init__(self,value): - self.value=value - - a10 = Uniq1('a') - - # Different key: different instance - b10 = Uniq1('b') - self.assertTrue(a10 is not b10) - - # Different class: different instance - a20 = Uniq2('a') - self.assertTrue(a10 is not a20) - - # Same key/class: same instance - a11 = Uniq1('a') - self.assertTrue(a10 is a11) - - # __init__ is called - self.assertEqual(a10.value,'a') - - # clear() causes us to forget all existing instances - Uniq1.clear() - a12 = Uniq1('a') - self.assertTrue(a10 is not a12) - - self.assertRaises(NotImplementedError, KeyedInstance._key) diff --git a/migrate/tests/versioning/test_pathed.py b/migrate/tests/versioning/test_pathed.py deleted file mode 100644 index 53f0b47..0000000 --- a/migrate/tests/versioning/test_pathed.py +++ /dev/null @@ -1,51 +0,0 @@ -from migrate.tests import fixture -from migrate.versioning.pathed import * - -class TestPathed(fixture.Base): - def test_parent_path(self): - """Default parent_path should behave correctly""" - filepath='/fgsfds/moot.py' - dirpath='/fgsfds/moot' - sdirpath='/fgsfds/moot/' - - result='/fgsfds' - self.assertTrue(result==Pathed._parent_path(filepath)) - self.assertTrue(result==Pathed._parent_path(dirpath)) - self.assertTrue(result==Pathed._parent_path(sdirpath)) - - def test_new(self): - """Pathed(path) shouldn't create duplicate objects of the same path""" - path='/fgsfds' - class Test(Pathed): - attr=None - o1=Test(path) - o2=Test(path) - self.assertTrue(isinstance(o1,Test)) - self.assertTrue(o1.path==path) - self.assertTrue(o1 is o2) - o1.attr='herring' - self.assertTrue(o2.attr=='herring') - o2.attr='shrubbery' - self.assertTrue(o1.attr=='shrubbery') - - def test_parent(self): - """Parents should be fetched correctly""" - class Parent(Pathed): - parent=None - children=0 - def _init_child(self,child,path): - """Keep a tally of children. - (A real class might do something more interesting here) - """ - self.__class__.children+=1 - - class Child(Pathed): - parent=Parent - - path='/fgsfds/moot.py' - parent_path='/fgsfds' - object=Child(path) - self.assertTrue(isinstance(object,Child)) - self.assertTrue(isinstance(object.parent,Parent)) - self.assertTrue(object.path==path) - self.assertTrue(object.parent.path==parent_path) diff --git a/migrate/tests/versioning/test_repository.py b/migrate/tests/versioning/test_repository.py deleted file mode 100644 index 6e87c02..0000000 --- a/migrate/tests/versioning/test_repository.py +++ /dev/null @@ -1,216 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import os -import shutil - -from migrate import exceptions -from migrate.versioning.repository import * -from migrate.versioning.script import * - -from migrate.tests import fixture -from datetime import datetime - - -class TestRepository(fixture.Pathed): - def test_create(self): - """Repositories are created successfully""" - path = self.tmp_repos() - name = 'repository_name' - - # Creating a repository that doesn't exist should succeed - repo = Repository.create(path, name) - config_path = repo.config.path - manage_path = os.path.join(repo.path, 'manage.py') - self.assertTrue(repo) - - # Files should actually be created - self.assertTrue(os.path.exists(path)) - self.assertTrue(os.path.exists(config_path)) - self.assertTrue(os.path.exists(manage_path)) - - # Can't create it again: it already exists - self.assertRaises(exceptions.PathFoundError, Repository.create, path, name) - return path - - def test_load(self): - """We should be able to load information about an existing repository""" - # Create a repository to load - path = self.test_create() - repos = Repository(path) - - self.assertTrue(repos) - self.assertTrue(repos.config) - self.assertTrue(repos.config.get('db_settings', 'version_table')) - - # version_table's default isn't none - self.assertNotEqual(repos.config.get('db_settings', 'version_table'), 'None') - - def test_load_notfound(self): - """Nonexistant repositories shouldn't be loaded""" - path = self.tmp_repos() - self.assertTrue(not os.path.exists(path)) - self.assertRaises(exceptions.InvalidRepositoryError, Repository, path) - - def test_load_invalid(self): - """Invalid repos shouldn't be loaded""" - # Here, invalid=empty directory. There may be other conditions too, - # but we shouldn't need to test all of them - path = self.tmp_repos() - os.mkdir(path) - self.assertRaises(exceptions.InvalidRepositoryError, Repository, path) - - -class TestVersionedRepository(fixture.Pathed): - """Tests on an existing repository with a single python script""" - - def setUp(self): - super(TestVersionedRepository, self).setUp() - Repository.clear() - self.path_repos = self.tmp_repos() - Repository.create(self.path_repos, 'repository_name') - - def test_version(self): - """We should correctly detect the version of a repository""" - repos = Repository(self.path_repos) - - # Get latest version, or detect if a specified version exists - self.assertEqual(repos.latest, 0) - # repos.latest isn't an integer, but a VerNum - # (so we can't just assume the following tests are correct) - self.assertTrue(repos.latest >= 0) - self.assertTrue(repos.latest < 1) - - # Create a script and test again - repos.create_script('') - self.assertEqual(repos.latest, 1) - self.assertTrue(repos.latest >= 0) - self.assertTrue(repos.latest >= 1) - self.assertTrue(repos.latest < 2) - - # Create a new script and test again - repos.create_script('') - self.assertEqual(repos.latest, 2) - self.assertTrue(repos.latest >= 0) - self.assertTrue(repos.latest >= 1) - self.assertTrue(repos.latest >= 2) - self.assertTrue(repos.latest < 3) - - - def test_timestmap_numbering_version(self): - repos = Repository(self.path_repos) - repos.config.set('db_settings', 'use_timestamp_numbering', 'True') - - # Get latest version, or detect if a specified version exists - self.assertEqual(repos.latest, 0) - # repos.latest isn't an integer, but a VerNum - # (so we can't just assume the following tests are correct) - self.assertTrue(repos.latest >= 0) - self.assertTrue(repos.latest < 1) - - # Create a script and test again - now = int(datetime.utcnow().strftime('%Y%m%d%H%M%S')) - repos.create_script('') - self.assertEqual(repos.latest, now) - - def test_source(self): - """Get a script object by version number and view its source""" - # Load repository and commit script - repo = Repository(self.path_repos) - repo.create_script('') - repo.create_script_sql('postgres', 'foo bar') - - # Source is valid: script must have an upgrade function - # (not a very thorough test, but should be plenty) - source = repo.version(1).script().source() - self.assertTrue(source.find('def upgrade') >= 0) - - import pprint; pprint.pprint(repo.version(2).sql) - source = repo.version(2).script('postgres', 'upgrade').source() - self.assertEqual(source.strip(), '') - - def test_latestversion(self): - """Repository.version() (no params) returns the latest version""" - repos = Repository(self.path_repos) - repos.create_script('') - self.assertTrue(repos.version(repos.latest) is repos.version()) - self.assertTrue(repos.version() is not None) - - def test_changeset(self): - """Repositories can create changesets properly""" - # Create a nonzero-version repository of empty scripts - repos = Repository(self.path_repos) - for i in range(10): - repos.create_script('') - - def check_changeset(params, length): - """Creates and verifies a changeset""" - changeset = repos.changeset('postgres', *params) - self.assertEqual(len(changeset), length) - self.assertTrue(isinstance(changeset, Changeset)) - uniq = list() - # Changesets are iterable - for version, change in changeset: - self.assertTrue(isinstance(change, BaseScript)) - # Changes aren't identical - self.assertTrue(id(change) not in uniq) - uniq.append(id(change)) - return changeset - - # Upgrade to a specified version... - cs = check_changeset((0, 10), 10) - self.assertEqual(cs.keys().pop(0),0 ) # 0 -> 1: index is starting version - self.assertEqual(cs.keys().pop(), 9) # 9 -> 10: index is starting version - self.assertEqual(cs.start, 0) # starting version - self.assertEqual(cs.end, 10) # ending version - check_changeset((0, 1), 1) - check_changeset((0, 5), 5) - check_changeset((0, 0), 0) - check_changeset((5, 5), 0) - check_changeset((10, 10), 0) - check_changeset((5, 10), 5) - - # Can't request a changeset of higher version than this repository - self.assertRaises(Exception, repos.changeset, 'postgres', 5, 11) - self.assertRaises(Exception, repos.changeset, 'postgres', -1, 5) - - # Upgrade to the latest version... - cs = check_changeset((0,), 10) - self.assertEqual(cs.keys().pop(0), 0) - self.assertEqual(cs.keys().pop(), 9) - self.assertEqual(cs.start, 0) - self.assertEqual(cs.end, 10) - check_changeset((1,), 9) - check_changeset((5,), 5) - check_changeset((9,), 1) - check_changeset((10,), 0) - - # run changes - cs.run('postgres', 'upgrade') - - # Can't request a changeset of higher/lower version than this repository - self.assertRaises(Exception, repos.changeset, 'postgres', 11) - self.assertRaises(Exception, repos.changeset, 'postgres', -1) - - # Downgrade - cs = check_changeset((10, 0),10) - self.assertEqual(cs.keys().pop(0), 10) # 10 -> 9 - self.assertEqual(cs.keys().pop(), 1) # 1 -> 0 - self.assertEqual(cs.start, 10) - self.assertEqual(cs.end, 0) - check_changeset((10, 5), 5) - check_changeset((5, 0), 5) - - def test_many_versions(self): - """Test what happens when lots of versions are created""" - repos = Repository(self.path_repos) - for i in range(1001): - repos.create_script('') - - # since we normally create 3 digit ones, let's see if we blow up - self.assertTrue(os.path.exists('%s/versions/1000.py' % self.path_repos)) - self.assertTrue(os.path.exists('%s/versions/1001.py' % self.path_repos)) - - -# TODO: test manage file -# TODO: test changeset diff --git a/migrate/tests/versioning/test_runchangeset.py b/migrate/tests/versioning/test_runchangeset.py deleted file mode 100644 index 12bc77c..0000000 --- a/migrate/tests/versioning/test_runchangeset.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import os,shutil - -from migrate.tests import fixture -from migrate.versioning.schema import * -from migrate.versioning import script - - -class TestRunChangeset(fixture.Pathed,fixture.DB): - level=fixture.DB.CONNECT - def _setup(self, url): - super(TestRunChangeset, self)._setup(url) - Repository.clear() - self.path_repos=self.tmp_repos() - # Create repository, script - Repository.create(self.path_repos,'repository_name') - - @fixture.usedb() - def test_changeset_run(self): - """Running a changeset against a repository gives expected results""" - repos=Repository(self.path_repos) - for i in range(10): - repos.create_script('') - try: - ControlledSchema(self.engine,repos).drop() - except: - pass - db=ControlledSchema.create(self.engine,repos) - - # Scripts are empty; we'll check version # correctness. - # (Correct application of their content is checked elsewhere) - self.assertEqual(db.version,0) - db.upgrade(1) - self.assertEqual(db.version,1) - db.upgrade(5) - self.assertEqual(db.version,5) - db.upgrade(5) - self.assertEqual(db.version,5) - db.upgrade(None) # Latest is implied - self.assertEqual(db.version,10) - self.assertRaises(Exception,db.upgrade,11) - self.assertEqual(db.version,10) - db.upgrade(9) - self.assertEqual(db.version,9) - db.upgrade(0) - self.assertEqual(db.version,0) - self.assertRaises(Exception,db.upgrade,-1) - self.assertEqual(db.version,0) - #changeset = repos.changeset(self.url,0) - db.drop() diff --git a/migrate/tests/versioning/test_schema.py b/migrate/tests/versioning/test_schema.py deleted file mode 100644 index 5396d9d..0000000 --- a/migrate/tests/versioning/test_schema.py +++ /dev/null @@ -1,205 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import os -import shutil - -import six - -from migrate import exceptions -from migrate.versioning.schema import * -from migrate.versioning import script, schemadiff - -from sqlalchemy import * - -from migrate.tests import fixture - - -class TestControlledSchema(fixture.Pathed, fixture.DB): - # Transactions break postgres in this test; we'll clean up after ourselves - level = fixture.DB.CONNECT - - def setUp(self): - super(TestControlledSchema, self).setUp() - self.path_repos = self.temp_usable_dir + '/repo/' - self.repos = Repository.create(self.path_repos, 'repo_name') - - def _setup(self, url): - self.setUp() - super(TestControlledSchema, self)._setup(url) - self.cleanup() - - def _teardown(self): - super(TestControlledSchema, self)._teardown() - self.cleanup() - self.tearDown() - - def cleanup(self): - # drop existing version table if necessary - try: - ControlledSchema(self.engine, self.repos).drop() - except: - # No table to drop; that's fine, be silent - pass - - def tearDown(self): - self.cleanup() - super(TestControlledSchema, self).tearDown() - - @fixture.usedb() - def test_version_control(self): - """Establish version control on a particular database""" - # Establish version control on this database - dbcontrol = ControlledSchema.create(self.engine, self.repos) - - # Trying to create another DB this way fails: table exists - self.assertRaises(exceptions.DatabaseAlreadyControlledError, - ControlledSchema.create, self.engine, self.repos) - - # We can load a controlled DB this way, too - dbcontrol0 = ControlledSchema(self.engine, self.repos) - self.assertEqual(dbcontrol, dbcontrol0) - - # We can also use a repository path, instead of a repository - dbcontrol0 = ControlledSchema(self.engine, self.repos.path) - self.assertEqual(dbcontrol, dbcontrol0) - - # We don't have to use the same connection - engine = create_engine(self.url) - dbcontrol0 = ControlledSchema(engine, self.repos.path) - self.assertEqual(dbcontrol, dbcontrol0) - - # Clean up: - dbcontrol.drop() - - # Attempting to drop vc from a db without it should fail - self.assertRaises(exceptions.DatabaseNotControlledError, dbcontrol.drop) - - # No table defined should raise error - self.assertRaises(exceptions.DatabaseNotControlledError, - ControlledSchema, self.engine, self.repos) - - @fixture.usedb() - def test_version_control_specified(self): - """Establish version control with a specified version""" - # Establish version control on this database - version = 0 - dbcontrol = ControlledSchema.create(self.engine, self.repos, version) - self.assertEqual(dbcontrol.version, version) - - # Correct when we load it, too - dbcontrol = ControlledSchema(self.engine, self.repos) - self.assertEqual(dbcontrol.version, version) - - dbcontrol.drop() - - # Now try it with a nonzero value - version = 10 - for i in range(version): - self.repos.create_script('') - self.assertEqual(self.repos.latest, version) - - # Test with some mid-range value - dbcontrol = ControlledSchema.create(self.engine,self.repos, 5) - self.assertEqual(dbcontrol.version, 5) - dbcontrol.drop() - - # Test with max value - dbcontrol = ControlledSchema.create(self.engine, self.repos, version) - self.assertEqual(dbcontrol.version, version) - dbcontrol.drop() - - @fixture.usedb() - def test_version_control_invalid(self): - """Try to establish version control with an invalid version""" - versions = ('Thirteen', '-1', -1, '' , 13) - # A fresh repository doesn't go up to version 13 yet - for version in versions: - #self.assertRaises(ControlledSchema.InvalidVersionError, - # Can't have custom errors with assertRaises... - try: - ControlledSchema.create(self.engine, self.repos, version) - self.assertTrue(False, repr(version)) - except exceptions.InvalidVersionError: - pass - - @fixture.usedb() - def test_changeset(self): - """Create changeset from controlled schema""" - dbschema = ControlledSchema.create(self.engine, self.repos) - - # empty schema doesn't have changesets - cs = dbschema.changeset() - self.assertEqual(cs, {}) - - for i in range(5): - self.repos.create_script('') - self.assertEqual(self.repos.latest, 5) - - cs = dbschema.changeset(5) - self.assertEqual(len(cs), 5) - - # cleanup - dbschema.drop() - - @fixture.usedb() - def test_upgrade_runchange(self): - dbschema = ControlledSchema.create(self.engine, self.repos) - - for i in range(10): - self.repos.create_script('') - - self.assertEqual(self.repos.latest, 10) - - dbschema.upgrade(10) - - self.assertRaises(ValueError, dbschema.upgrade, 'a') - self.assertRaises(exceptions.InvalidVersionError, dbschema.runchange, 20, '', 1) - - # TODO: test for table version in db - - # cleanup - dbschema.drop() - - @fixture.usedb() - def test_create_model(self): - """Test workflow to generate create_model""" - model = ControlledSchema.create_model(self.engine, self.repos, declarative=False) - self.assertTrue(isinstance(model, six.string_types)) - - model = ControlledSchema.create_model(self.engine, self.repos.path, declarative=True) - self.assertTrue(isinstance(model, six.string_types)) - - @fixture.usedb() - def test_compare_model_to_db(self): - meta = self.construct_model() - - diff = ControlledSchema.compare_model_to_db(self.engine, meta, self.repos) - self.assertTrue(isinstance(diff, schemadiff.SchemaDiff)) - - diff = ControlledSchema.compare_model_to_db(self.engine, meta, self.repos.path) - self.assertTrue(isinstance(diff, schemadiff.SchemaDiff)) - meta.drop_all(self.engine) - - @fixture.usedb() - def test_update_db_from_model(self): - dbschema = ControlledSchema.create(self.engine, self.repos) - - meta = self.construct_model() - - dbschema.update_db_from_model(meta) - - # TODO: test for table version in db - - # cleanup - dbschema.drop() - meta.drop_all(self.engine) - - def construct_model(self): - meta = MetaData() - - user = Table('temp_model_schema', meta, Column('id', Integer), Column('user', String(245))) - - return meta - - # TODO: test how are tables populated in db diff --git a/migrate/tests/versioning/test_schemadiff.py b/migrate/tests/versioning/test_schemadiff.py deleted file mode 100644 index f45a012..0000000 --- a/migrate/tests/versioning/test_schemadiff.py +++ /dev/null @@ -1,227 +0,0 @@ -# -*- coding: utf-8 -*- - -import os - -from sqlalchemy import * - -from migrate.versioning import schemadiff - -from migrate.tests import fixture - -class SchemaDiffBase(fixture.DB): - - level = fixture.DB.CONNECT - def _make_table(self,*cols,**kw): - self.table = Table('xtable', self.meta, - Column('id',Integer(), primary_key=True), - *cols - ) - if kw.get('create',True): - self.table.create() - - def _assert_diff(self,col_A,col_B): - self._make_table(col_A) - self.meta.clear() - self._make_table(col_B,create=False) - diff = self._run_diff() - # print diff - self.assertTrue(diff) - self.assertEqual(1,len(diff.tables_different)) - td = list(diff.tables_different.values())[0] - self.assertEqual(1,len(td.columns_different)) - cd = list(td.columns_different.values())[0] - label_width = max(len(self.name1), len(self.name2)) - self.assertEqual(('Schema diffs:\n' - ' table with differences: xtable\n' - ' column with differences: data\n' - ' %*s: %r\n' - ' %*s: %r')%( - label_width, - self.name1, - cd.col_A, - label_width, - self.name2, - cd.col_B - ),str(diff)) - -class Test_getDiffOfModelAgainstDatabase(SchemaDiffBase): - name1 = 'model' - name2 = 'database' - - def _run_diff(self,**kw): - return schemadiff.getDiffOfModelAgainstDatabase( - self.meta, self.engine, **kw - ) - - @fixture.usedb() - def test_table_missing_in_db(self): - self._make_table(create=False) - diff = self._run_diff() - self.assertTrue(diff) - self.assertEqual('Schema diffs:\n tables missing from %s: xtable' % self.name2, - str(diff)) - - @fixture.usedb() - def test_table_missing_in_model(self): - self._make_table() - self.meta.clear() - diff = self._run_diff() - self.assertTrue(diff) - self.assertEqual('Schema diffs:\n tables missing from %s: xtable' % self.name1, - str(diff)) - - @fixture.usedb() - def test_column_missing_in_db(self): - # db - Table('xtable', self.meta, - Column('id',Integer(), primary_key=True), - ).create() - self.meta.clear() - # model - self._make_table( - Column('xcol',Integer()), - create=False - ) - # run diff - diff = self._run_diff() - self.assertTrue(diff) - self.assertEqual('Schema diffs:\n' - ' table with differences: xtable\n' - ' %s missing these columns: xcol' % self.name2, - str(diff)) - - @fixture.usedb() - def test_column_missing_in_model(self): - # db - self._make_table( - Column('xcol',Integer()), - ) - self.meta.clear() - # model - self._make_table( - create=False - ) - # run diff - diff = self._run_diff() - self.assertTrue(diff) - self.assertEqual('Schema diffs:\n' - ' table with differences: xtable\n' - ' %s missing these columns: xcol' % self.name1, - str(diff)) - - @fixture.usedb() - def test_exclude_tables(self): - # db - Table('ytable', self.meta, - Column('id',Integer(), primary_key=True), - ).create() - Table('ztable', self.meta, - Column('id',Integer(), primary_key=True), - ).create() - self.meta.clear() - # model - self._make_table( - create=False - ) - Table('ztable', self.meta, - Column('id',Integer(), primary_key=True), - ) - # run diff - diff = self._run_diff(excludeTables=('xtable','ytable')) - # ytable only in database - # xtable only in model - # ztable identical on both - # ...so we expect no diff! - self.assertFalse(diff) - self.assertEqual('No schema diffs',str(diff)) - - @fixture.usedb() - def test_identical_just_pk(self): - self._make_table() - diff = self._run_diff() - self.assertFalse(diff) - self.assertEqual('No schema diffs',str(diff)) - - - @fixture.usedb() - def test_different_type(self): - self._assert_diff( - Column('data', String(10)), - Column('data', Integer()), - ) - - @fixture.usedb() - def test_int_vs_float(self): - self._assert_diff( - Column('data', Integer()), - Column('data', Float()), - ) - - # NOTE(mriedem): The ibm_db_sa driver handles the Float() as a DOUBLE() - # which extends Numeric() but isn't defined in sqlalchemy.types, so we - # can't check for it as a special case like is done in schemadiff.ColDiff. - @fixture.usedb(not_supported='ibm_db_sa') - def test_float_vs_numeric(self): - self._assert_diff( - Column('data', Float()), - Column('data', Numeric()), - ) - - @fixture.usedb() - def test_numeric_precision(self): - self._assert_diff( - Column('data', Numeric(precision=5)), - Column('data', Numeric(precision=6)), - ) - - @fixture.usedb() - def test_numeric_scale(self): - self._assert_diff( - Column('data', Numeric(precision=6,scale=0)), - Column('data', Numeric(precision=6,scale=1)), - ) - - @fixture.usedb() - def test_string_length(self): - self._assert_diff( - Column('data', String(10)), - Column('data', String(20)), - ) - - @fixture.usedb() - def test_integer_identical(self): - self._make_table( - Column('data', Integer()), - ) - diff = self._run_diff() - self.assertEqual('No schema diffs',str(diff)) - self.assertFalse(diff) - - @fixture.usedb() - def test_string_identical(self): - self._make_table( - Column('data', String(10)), - ) - diff = self._run_diff() - self.assertEqual('No schema diffs',str(diff)) - self.assertFalse(diff) - - @fixture.usedb() - def test_text_identical(self): - self._make_table( - Column('data', Text), - ) - diff = self._run_diff() - self.assertEqual('No schema diffs',str(diff)) - self.assertFalse(diff) - -class Test_getDiffOfModelAgainstModel(Test_getDiffOfModelAgainstDatabase): - name1 = 'metadataA' - name2 = 'metadataB' - - def _run_diff(self,**kw): - db_meta= MetaData() - db_meta.reflect(self.engine) - return schemadiff.getDiffOfModelAgainstModel( - self.meta, db_meta, **kw - ) diff --git a/migrate/tests/versioning/test_script.py b/migrate/tests/versioning/test_script.py deleted file mode 100644 index 20e6af0..0000000 --- a/migrate/tests/versioning/test_script.py +++ /dev/null @@ -1,305 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import imp -import os -import sys -import shutil - -import six -from migrate import exceptions -from migrate.versioning import version, repository -from migrate.versioning.script import * -from migrate.versioning.util import * - -from migrate.tests import fixture -from migrate.tests.fixture.models import tmp_sql_table - - -class TestBaseScript(fixture.Pathed): - - def test_all(self): - """Testing all basic BaseScript operations""" - # verify / source / run - src = self.tmp() - open(src, 'w').close() - bscript = BaseScript(src) - BaseScript.verify(src) - self.assertEqual(bscript.source(), '') - self.assertRaises(NotImplementedError, bscript.run, 'foobar') - - -class TestPyScript(fixture.Pathed, fixture.DB): - cls = PythonScript - def test_create(self): - """We can create a migration script""" - path = self.tmp_py() - # Creating a file that doesn't exist should succeed - self.cls.create(path) - self.assertTrue(os.path.exists(path)) - # Created file should be a valid script (If not, raises an error) - self.cls.verify(path) - # Can't create it again: it already exists - self.assertRaises(exceptions.PathFoundError,self.cls.create,path) - - @fixture.usedb(supported='sqlite') - def test_run(self): - script_path = self.tmp_py() - pyscript = PythonScript.create(script_path) - pyscript.run(self.engine, 1) - pyscript.run(self.engine, -1) - - self.assertRaises(exceptions.ScriptError, pyscript.run, self.engine, 0) - self.assertRaises(exceptions.ScriptError, pyscript._func, 'foobar') - - # clean pyc file - if six.PY3: - os.remove(imp.cache_from_source(script_path)) - else: - os.remove(script_path + 'c') - - # test deprecated upgrade/downgrade with no arguments - contents = open(script_path, 'r').read() - f = open(script_path, 'w') - f.write(contents.replace("upgrade(migrate_engine)", "upgrade()")) - f.close() - - pyscript = PythonScript(script_path) - pyscript._module = None - try: - pyscript.run(self.engine, 1) - pyscript.run(self.engine, -1) - except exceptions.ScriptError: - pass - else: - self.fail() - - def test_verify_notfound(self): - """Correctly verify a python migration script: nonexistant file""" - path = self.tmp_py() - self.assertFalse(os.path.exists(path)) - # Fails on empty path - self.assertRaises(exceptions.InvalidScriptError,self.cls.verify,path) - self.assertRaises(exceptions.InvalidScriptError,self.cls,path) - - def test_verify_invalidpy(self): - """Correctly verify a python migration script: invalid python file""" - path=self.tmp_py() - # Create empty file - f = open(path,'w') - f.write("def fail") - f.close() - self.assertRaises(Exception,self.cls.verify_module,path) - # script isn't verified on creation, but on module reference - py = self.cls(path) - self.assertRaises(Exception,(lambda x: x.module),py) - - def test_verify_nofuncs(self): - """Correctly verify a python migration script: valid python file; no upgrade func""" - path = self.tmp_py() - # Create empty file - f = open(path, 'w') - f.write("def zergling():\n\tprint('rush')") - f.close() - self.assertRaises(exceptions.InvalidScriptError, self.cls.verify_module, path) - # script isn't verified on creation, but on module reference - py = self.cls(path) - self.assertRaises(exceptions.InvalidScriptError,(lambda x: x.module),py) - - @fixture.usedb(supported='sqlite') - def test_preview_sql(self): - """Preview SQL abstract from ORM layer (sqlite)""" - path = self.tmp_py() - - f = open(path, 'w') - content = ''' -from migrate import * -from sqlalchemy import * - -metadata = MetaData() - -UserGroup = Table('Link', metadata, - Column('link1ID', Integer), - Column('link2ID', Integer), - UniqueConstraint('link1ID', 'link2ID')) - -def upgrade(migrate_engine): - metadata.create_all(migrate_engine) - ''' - f.write(content) - f.close() - - pyscript = self.cls(path) - SQL = pyscript.preview_sql(self.url, 1) - self.assertEqualIgnoreWhitespace(""" - CREATE TABLE "Link" - ("link1ID" INTEGER, - "link2ID" INTEGER, - UNIQUE ("link1ID", "link2ID")) - """, SQL) - # TODO: test: No SQL should be executed! - - def test_verify_success(self): - """Correctly verify a python migration script: success""" - path = self.tmp_py() - # Succeeds after creating - self.cls.create(path) - self.cls.verify(path) - - # test for PythonScript.make_update_script_for_model - - @fixture.usedb() - def test_make_update_script_for_model(self): - """Construct script source from differences of two models""" - - self.setup_model_params() - self.write_file(self.first_model_path, self.base_source) - self.write_file(self.second_model_path, self.base_source + self.model_source) - - source_script = self.pyscript.make_update_script_for_model( - engine=self.engine, - oldmodel=load_model('testmodel_first:meta'), - model=load_model('testmodel_second:meta'), - repository=self.repo_path, - ) - - self.assertTrue("['User'].create()" in source_script) - self.assertTrue("['User'].drop()" in source_script) - - @fixture.usedb() - def test_make_update_script_for_equal_models(self): - """Try to make update script from two identical models""" - - self.setup_model_params() - self.write_file(self.first_model_path, self.base_source + self.model_source) - self.write_file(self.second_model_path, self.base_source + self.model_source) - - source_script = self.pyscript.make_update_script_for_model( - engine=self.engine, - oldmodel=load_model('testmodel_first:meta'), - model=load_model('testmodel_second:meta'), - repository=self.repo_path, - ) - - self.assertFalse('User.create()' in source_script) - self.assertFalse('User.drop()' in source_script) - - @fixture.usedb() - def test_make_update_script_direction(self): - """Check update scripts go in the right direction""" - - self.setup_model_params() - self.write_file(self.first_model_path, self.base_source) - self.write_file(self.second_model_path, self.base_source + self.model_source) - - source_script = self.pyscript.make_update_script_for_model( - engine=self.engine, - oldmodel=load_model('testmodel_first:meta'), - model=load_model('testmodel_second:meta'), - repository=self.repo_path, - ) - - self.assertTrue(0 - < source_script.find('upgrade') - < source_script.find("['User'].create()") - < source_script.find('downgrade') - < source_script.find("['User'].drop()")) - - def setup_model_params(self): - self.script_path = self.tmp_py() - self.repo_path = self.tmp() - self.first_model_path = os.path.join(self.temp_usable_dir, 'testmodel_first.py') - self.second_model_path = os.path.join(self.temp_usable_dir, 'testmodel_second.py') - - self.base_source = """from sqlalchemy import *\nmeta = MetaData()\n""" - self.model_source = """ -User = Table('User', meta, - Column('id', Integer, primary_key=True), - Column('login', Unicode(40)), - Column('passwd', String(40)), -)""" - - self.repo = repository.Repository.create(self.repo_path, 'repo') - self.pyscript = PythonScript.create(self.script_path) - sys.modules.pop('testmodel_first', None) - sys.modules.pop('testmodel_second', None) - - def write_file(self, path, contents): - f = open(path, 'w') - f.write(contents) - f.close() - - -class TestSqlScript(fixture.Pathed, fixture.DB): - - @fixture.usedb() - def test_error(self): - """Test if exception is raised on wrong script source""" - src = self.tmp() - - f = open(src, 'w') - f.write("""foobar""") - f.close() - - sqls = SqlScript(src) - self.assertRaises(Exception, sqls.run, self.engine) - - @fixture.usedb() - def test_success(self): - """Test sucessful SQL execution""" - # cleanup and prepare python script - tmp_sql_table.metadata.drop_all(self.engine, checkfirst=True) - script_path = self.tmp_py() - pyscript = PythonScript.create(script_path) - - # populate python script - contents = open(script_path, 'r').read() - contents = contents.replace("pass", "tmp_sql_table.create(migrate_engine)") - contents = 'from migrate.tests.fixture.models import tmp_sql_table\n' + contents - f = open(script_path, 'w') - f.write(contents) - f.close() - - # write SQL script from python script preview - pyscript = PythonScript(script_path) - src = self.tmp() - f = open(src, 'w') - f.write(pyscript.preview_sql(self.url, 1)) - f.close() - - # run the change - sqls = SqlScript(src) - sqls.run(self.engine) - tmp_sql_table.metadata.drop_all(self.engine, checkfirst=True) - - @fixture.usedb() - def test_transaction_management_statements(self): - """ - Test that we can successfully execute SQL scripts with transaction - management statements. - """ - for script_pattern in ( - "BEGIN TRANSACTION; %s; COMMIT;", - "BEGIN; %s; END TRANSACTION;", - "/* comment */BEGIN TRANSACTION; %s; /* comment */COMMIT;", - "/* comment */ BEGIN TRANSACTION; %s; /* comment */ COMMIT;", - """ --- comment -BEGIN TRANSACTION; - -%s; - --- comment -COMMIT;""", - ): - - test_statement = ("CREATE TABLE TEST1 (field1 int); " - "DROP TABLE TEST1") - script = script_pattern % test_statement - src = self.tmp() - - with open(src, 'wt') as f: - f.write(script) - - sqls = SqlScript(src) - sqls.run(self.engine) diff --git a/migrate/tests/versioning/test_shell.py b/migrate/tests/versioning/test_shell.py deleted file mode 100644 index 001efcf..0000000 --- a/migrate/tests/versioning/test_shell.py +++ /dev/null @@ -1,574 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import os -import sys -import tempfile - -import six -from six.moves import cStringIO -from sqlalchemy import MetaData, Table - -from migrate.exceptions import * -from migrate.versioning.repository import Repository -from migrate.versioning import genmodel, shell, api -from migrate.tests.fixture import Shell, DB, usedb -from migrate.tests.fixture import models - - -class TestShellCommands(Shell): - """Tests migrate.py commands""" - - def test_help(self): - """Displays default help dialog""" - self.assertEqual(self.env.run('migrate -h').returncode, 0) - self.assertEqual(self.env.run('migrate --help').returncode, 0) - self.assertEqual(self.env.run('migrate help').returncode, 0) - - def test_help_commands(self): - """Display help on a specific command""" - # we can only test that we get some output - for cmd in api.__all__: - result = self.env.run('migrate help %s' % cmd) - self.assertTrue(isinstance(result.stdout, six.string_types)) - self.assertTrue(result.stdout) - self.assertFalse(result.stderr) - - def test_shutdown_logging(self): - """Try to shutdown logging output""" - repos = self.tmp_repos() - result = self.env.run('migrate create %s repository_name' % repos) - result = self.env.run('migrate version %s --disable_logging' % repos) - self.assertEqual(result.stdout, '') - result = self.env.run('migrate version %s -q' % repos) - self.assertEqual(result.stdout, '') - - # TODO: assert logging messages to 0 - shell.main(['version', repos], logging=False) - - def test_main_with_runpy(self): - if sys.version_info[:2] == (2, 4): - self.skipTest("runpy is not part of python2.4") - from runpy import run_module - try: - original = sys.argv - sys.argv=['X','--help'] - - run_module('migrate.versioning.shell', run_name='__main__') - - finally: - sys.argv = original - - def _check_error(self,args,code,expected,**kw): - original = sys.stderr - try: - actual = cStringIO() - sys.stderr = actual - try: - shell.main(args,**kw) - except SystemExit as e: - self.assertEqual(code,e.args[0]) - else: - self.fail('No exception raised') - finally: - sys.stderr = original - actual = actual.getvalue() - self.assertTrue(expected in actual,'%r not in:\n"""\n%s\n"""'%(expected,actual)) - - def test_main(self): - """Test main() function""" - repos = self.tmp_repos() - shell.main(['help']) - shell.main(['help', 'create']) - shell.main(['create', 'repo_name', '--preview_sql'], repository=repos) - shell.main(['version', '--', '--repository=%s' % repos]) - shell.main(['version', '-d', '--repository=%s' % repos, '--version=2']) - - self._check_error(['foobar'],2,'error: Invalid command foobar') - self._check_error(['create', 'f', 'o', 'o'],2,'error: Too many arguments for command create: o') - self._check_error(['create'],2,'error: Not enough arguments for command create: name, repository not specified') - self._check_error(['create', 'repo_name'],2,'already exists', repository=repos) - - def test_create(self): - """Repositories are created successfully""" - repos = self.tmp_repos() - - # Creating a file that doesn't exist should succeed - result = self.env.run('migrate create %s repository_name' % repos) - - # Files should actually be created - self.assertTrue(os.path.exists(repos)) - - # The default table should not be None - repos_ = Repository(repos) - self.assertNotEqual(repos_.config.get('db_settings', 'version_table'), 'None') - - # Can't create it again: it already exists - result = self.env.run('migrate create %s repository_name' % repos, - expect_error=True) - self.assertEqual(result.returncode, 2) - - def test_script(self): - """We can create a migration script via the command line""" - repos = self.tmp_repos() - result = self.env.run('migrate create %s repository_name' % repos) - - result = self.env.run('migrate script --repository=%s Desc' % repos) - self.assertTrue(os.path.exists('%s/versions/001_Desc.py' % repos)) - - result = self.env.run('migrate script More %s' % repos) - self.assertTrue(os.path.exists('%s/versions/002_More.py' % repos)) - - result = self.env.run('migrate script "Some Random name" %s' % repos) - self.assertTrue(os.path.exists('%s/versions/003_Some_Random_name.py' % repos)) - - def test_script_sql(self): - """We can create a migration sql script via the command line""" - repos = self.tmp_repos() - result = self.env.run('migrate create %s repository_name' % repos) - - result = self.env.run('migrate script_sql mydb foo %s' % repos) - self.assertTrue(os.path.exists('%s/versions/001_foo_mydb_upgrade.sql' % repos)) - self.assertTrue(os.path.exists('%s/versions/001_foo_mydb_downgrade.sql' % repos)) - - # Test creating a second - result = self.env.run('migrate script_sql postgres foo --repository=%s' % repos) - self.assertTrue(os.path.exists('%s/versions/002_foo_postgres_upgrade.sql' % repos)) - self.assertTrue(os.path.exists('%s/versions/002_foo_postgres_downgrade.sql' % repos)) - - # TODO: test --previews - - def test_manage(self): - """Create a project management script""" - script = self.tmp_py() - self.assertTrue(not os.path.exists(script)) - - # No attempt is made to verify correctness of the repository path here - result = self.env.run('migrate manage %s --repository=/bla/' % script) - self.assertTrue(os.path.exists(script)) - - -class TestShellRepository(Shell): - """Shell commands on an existing repository/python script""" - - def setUp(self): - """Create repository, python change script""" - super(TestShellRepository, self).setUp() - self.path_repos = self.tmp_repos() - result = self.env.run('migrate create %s repository_name' % self.path_repos) - - def test_version(self): - """Correctly detect repository version""" - # Version: 0 (no scripts yet); successful execution - result = self.env.run('migrate version --repository=%s' % self.path_repos) - self.assertEqual(result.stdout.strip(), "0") - - # Also works as a positional param - result = self.env.run('migrate version %s' % self.path_repos) - self.assertEqual(result.stdout.strip(), "0") - - # Create a script and version should increment - result = self.env.run('migrate script Desc %s' % self.path_repos) - result = self.env.run('migrate version %s' % self.path_repos) - self.assertEqual(result.stdout.strip(), "1") - - def test_source(self): - """Correctly fetch a script's source""" - result = self.env.run('migrate script Desc --repository=%s' % self.path_repos) - - filename = '%s/versions/001_Desc.py' % self.path_repos - source = open(filename).read() - self.assertTrue(source.find('def upgrade') >= 0) - - # Version is now 1 - result = self.env.run('migrate version %s' % self.path_repos) - self.assertEqual(result.stdout.strip(), "1") - - # Output/verify the source of version 1 - result = self.env.run('migrate source 1 --repository=%s' % self.path_repos) - self.assertEqual(result.stdout.strip(), source.strip()) - - # We can also send the source to a file... test that too - result = self.env.run('migrate source 1 %s --repository=%s' % - (filename, self.path_repos)) - self.assertTrue(os.path.exists(filename)) - fd = open(filename) - result = fd.read() - self.assertTrue(result.strip() == source.strip()) - - -class TestShellDatabase(Shell, DB): - """Commands associated with a particular database""" - # We'll need to clean up after ourself, since the shell creates its own txn; - # we need to connect to the DB to see if things worked - - level = DB.CONNECT - - @usedb() - def test_version_control(self): - """Ensure we can set version control on a database""" - path_repos = repos = self.tmp_repos() - url = self.url - result = self.env.run('migrate create %s repository_name' % repos) - - result = self.env.run('migrate drop_version_control %(url)s %(repos)s'\ - % locals(), expect_error=True) - self.assertEqual(result.returncode, 1) - result = self.env.run('migrate version_control %(url)s %(repos)s' % locals()) - - # Clean up - result = self.env.run('migrate drop_version_control %(url)s %(repos)s' % locals()) - # Attempting to drop vc from a database without it should fail - result = self.env.run('migrate drop_version_control %(url)s %(repos)s'\ - % locals(), expect_error=True) - self.assertEqual(result.returncode, 1) - - @usedb() - def test_wrapped_kwargs(self): - """Commands with default arguments set by manage.py""" - path_repos = repos = self.tmp_repos() - url = self.url - result = self.env.run('migrate create --name=repository_name %s' % repos) - result = self.env.run('migrate drop_version_control %(url)s %(repos)s' % locals(), expect_error=True) - self.assertEqual(result.returncode, 1) - result = self.env.run('migrate version_control %(url)s %(repos)s' % locals()) - - result = self.env.run('migrate drop_version_control %(url)s %(repos)s' % locals()) - - @usedb() - def test_version_control_specified(self): - """Ensure we can set version control to a particular version""" - path_repos = self.tmp_repos() - url = self.url - result = self.env.run('migrate create --name=repository_name %s' % path_repos) - result = self.env.run('migrate drop_version_control %(url)s %(path_repos)s' % locals(), expect_error=True) - self.assertEqual(result.returncode, 1) - - # Fill the repository - path_script = self.tmp_py() - version = 2 - for i in range(version): - result = self.env.run('migrate script Desc --repository=%s' % path_repos) - - # Repository version is correct - result = self.env.run('migrate version %s' % path_repos) - self.assertEqual(result.stdout.strip(), str(version)) - - # Apply versioning to DB - result = self.env.run('migrate version_control %(url)s %(path_repos)s %(version)s' % locals()) - - # Test db version number (should start at 2) - result = self.env.run('migrate db_version %(url)s %(path_repos)s' % locals()) - self.assertEqual(result.stdout.strip(), str(version)) - - # Clean up - result = self.env.run('migrate drop_version_control %(url)s %(path_repos)s' % locals()) - - @usedb() - def test_upgrade(self): - """Can upgrade a versioned database""" - # Create a repository - repos_name = 'repos_name' - repos_path = self.tmp() - result = self.env.run('migrate create %(repos_path)s %(repos_name)s' % locals()) - self.assertEqual(self.run_version(repos_path), 0) - - # Version the DB - result = self.env.run('migrate drop_version_control %s %s' % (self.url, repos_path), expect_error=True) - result = self.env.run('migrate version_control %s %s' % (self.url, repos_path)) - - # Upgrades with latest version == 0 - self.assertEqual(self.run_db_version(self.url, repos_path), 0) - result = self.env.run('migrate upgrade %s %s' % (self.url, repos_path)) - self.assertEqual(self.run_db_version(self.url, repos_path), 0) - result = self.env.run('migrate upgrade %s %s' % (self.url, repos_path)) - self.assertEqual(self.run_db_version(self.url, repos_path), 0) - result = self.env.run('migrate upgrade %s %s 1' % (self.url, repos_path), expect_error=True) - self.assertEqual(result.returncode, 1) - result = self.env.run('migrate upgrade %s %s -1' % (self.url, repos_path), expect_error=True) - self.assertEqual(result.returncode, 2) - - # Add a script to the repository; upgrade the db - result = self.env.run('migrate script Desc --repository=%s' % (repos_path)) - self.assertEqual(self.run_version(repos_path), 1) - self.assertEqual(self.run_db_version(self.url, repos_path), 0) - - # Test preview - result = self.env.run('migrate upgrade %s %s 0 --preview_sql' % (self.url, repos_path)) - result = self.env.run('migrate upgrade %s %s 0 --preview_py' % (self.url, repos_path)) - - result = self.env.run('migrate upgrade %s %s' % (self.url, repos_path)) - self.assertEqual(self.run_db_version(self.url, repos_path), 1) - - # Downgrade must have a valid version specified - result = self.env.run('migrate downgrade %s %s' % (self.url, repos_path), expect_error=True) - self.assertEqual(result.returncode, 2) - result = self.env.run('migrate downgrade %s %s -1' % (self.url, repos_path), expect_error=True) - self.assertEqual(result.returncode, 2) - result = self.env.run('migrate downgrade %s %s 2' % (self.url, repos_path), expect_error=True) - self.assertEqual(result.returncode, 2) - self.assertEqual(self.run_db_version(self.url, repos_path), 1) - - result = self.env.run('migrate downgrade %s %s 0' % (self.url, repos_path)) - self.assertEqual(self.run_db_version(self.url, repos_path), 0) - - result = self.env.run('migrate downgrade %s %s 1' % (self.url, repos_path), expect_error=True) - self.assertEqual(result.returncode, 2) - self.assertEqual(self.run_db_version(self.url, repos_path), 0) - - result = self.env.run('migrate drop_version_control %s %s' % (self.url, repos_path)) - - def _run_test_sqlfile(self, upgrade_script, downgrade_script): - # TODO: add test script that checks if db really changed - repos_path = self.tmp() - repos_name = 'repos' - - result = self.env.run('migrate create %s %s' % (repos_path, repos_name)) - result = self.env.run('migrate drop_version_control %s %s' % (self.url, repos_path), expect_error=True) - result = self.env.run('migrate version_control %s %s' % (self.url, repos_path)) - self.assertEqual(self.run_version(repos_path), 0) - self.assertEqual(self.run_db_version(self.url, repos_path), 0) - - beforeCount = len(os.listdir(os.path.join(repos_path, 'versions'))) # hmm, this number changes sometimes based on running from svn - result = self.env.run('migrate script_sql %s --repository=%s' % ('postgres', repos_path)) - self.assertEqual(self.run_version(repos_path), 1) - self.assertEqual(len(os.listdir(os.path.join(repos_path, 'versions'))), beforeCount + 2) - - open('%s/versions/001_postgres_upgrade.sql' % repos_path, 'a').write(upgrade_script) - open('%s/versions/001_postgres_downgrade.sql' % repos_path, 'a').write(downgrade_script) - - self.assertEqual(self.run_db_version(self.url, repos_path), 0) - self.assertRaises(Exception, self.engine.text('select * from t_table').execute) - - result = self.env.run('migrate upgrade %s %s' % (self.url, repos_path)) - self.assertEqual(self.run_db_version(self.url, repos_path), 1) - self.engine.text('select * from t_table').execute() - - result = self.env.run('migrate downgrade %s %s 0' % (self.url, repos_path)) - self.assertEqual(self.run_db_version(self.url, repos_path), 0) - self.assertRaises(Exception, self.engine.text('select * from t_table').execute) - - # The tests below are written with some postgres syntax, but the stuff - # being tested (.sql files) ought to work with any db. - @usedb(supported='postgres') - def test_sqlfile(self): - upgrade_script = """ - create table t_table ( - id serial, - primary key(id) - ); - """ - downgrade_script = """ - drop table t_table; - """ - self.meta.drop_all() - self._run_test_sqlfile(upgrade_script, downgrade_script) - - @usedb(supported='postgres') - def test_sqlfile_comment(self): - upgrade_script = """ - -- Comments in SQL break postgres autocommit - create table t_table ( - id serial, - primary key(id) - ); - """ - downgrade_script = """ - -- Comments in SQL break postgres autocommit - drop table t_table; - """ - self._run_test_sqlfile(upgrade_script, downgrade_script) - - @usedb() - def test_command_test(self): - repos_name = 'repos_name' - repos_path = self.tmp() - - result = self.env.run('migrate create repository_name --repository=%s' % repos_path) - result = self.env.run('migrate drop_version_control %s %s' % (self.url, repos_path), expect_error=True) - result = self.env.run('migrate version_control %s %s' % (self.url, repos_path)) - self.assertEqual(self.run_version(repos_path), 0) - self.assertEqual(self.run_db_version(self.url, repos_path), 0) - - # Empty script should succeed - result = self.env.run('migrate script Desc %s' % repos_path) - result = self.env.run('migrate test %s %s' % (self.url, repos_path)) - self.assertEqual(self.run_version(repos_path), 1) - self.assertEqual(self.run_db_version(self.url, repos_path), 0) - - # Error script should fail - script_path = self.tmp_py() - script_text=''' - from sqlalchemy import * - from migrate import * - - def upgrade(): - print 'fgsfds' - raise Exception() - - def downgrade(): - print 'sdfsgf' - raise Exception() - '''.replace("\n ", "\n") - file = open(script_path, 'w') - file.write(script_text) - file.close() - - result = self.env.run('migrate test %s %s bla' % (self.url, repos_path), expect_error=True) - self.assertEqual(result.returncode, 2) - self.assertEqual(self.run_version(repos_path), 1) - self.assertEqual(self.run_db_version(self.url, repos_path), 0) - - # Nonempty script using migrate_engine should succeed - script_path = self.tmp_py() - script_text = ''' - from sqlalchemy import * - from migrate import * - - from migrate.changeset import schema - - meta = MetaData(migrate_engine) - account = Table('account', meta, - Column('id', Integer, primary_key=True), - Column('login', Text), - Column('passwd', Text), - ) - def upgrade(): - # Upgrade operations go here. Don't create your own engine; use the engine - # named 'migrate_engine' imported from migrate. - meta.create_all() - - def downgrade(): - # Operations to reverse the above upgrade go here. - meta.drop_all() - '''.replace("\n ", "\n") - file = open(script_path, 'w') - file.write(script_text) - file.close() - result = self.env.run('migrate test %s %s' % (self.url, repos_path)) - self.assertEqual(self.run_version(repos_path), 1) - self.assertEqual(self.run_db_version(self.url, repos_path), 0) - - @usedb() - def test_rundiffs_in_shell(self): - # This is a variant of the test_schemadiff tests but run through the shell level. - # These shell tests are hard to debug (since they keep forking processes) - # so they shouldn't replace the lower-level tests. - repos_name = 'repos_name' - repos_path = self.tmp() - script_path = self.tmp_py() - model_module = 'migrate.tests.fixture.models:meta_rundiffs' - old_model_module = 'migrate.tests.fixture.models:meta_old_rundiffs' - - # Create empty repository. - self.meta = MetaData(self.engine) - self.meta.reflect() - self.meta.drop_all() # in case junk tables are lying around in the test database - - result = self.env.run( - 'migrate create %s %s' % (repos_path, repos_name), - expect_stderr=True) - result = self.env.run( - 'migrate drop_version_control %s %s' % (self.url, repos_path), - expect_stderr=True, expect_error=True) - result = self.env.run( - 'migrate version_control %s %s' % (self.url, repos_path), - expect_stderr=True) - self.assertEqual(self.run_version(repos_path), 0) - self.assertEqual(self.run_db_version(self.url, repos_path), 0) - - # Setup helper script. - result = self.env.run( - 'migrate manage %s --repository=%s --url=%s --model=%s'\ - % (script_path, repos_path, self.url, model_module), - expect_stderr=True) - self.assertTrue(os.path.exists(script_path)) - - # Model is defined but database is empty. - result = self.env.run('migrate compare_model_to_db %s %s --model=%s' \ - % (self.url, repos_path, model_module), expect_stderr=True) - self.assertTrue( - "tables missing from database: tmp_account_rundiffs" - in result.stdout) - - # Test Deprecation - result = self.env.run('migrate compare_model_to_db %s %s --model=%s' \ - % (self.url, repos_path, model_module.replace(":", ".")), - expect_stderr=True, expect_error=True) - self.assertEqual(result.returncode, 0) - self.assertTrue( - "tables missing from database: tmp_account_rundiffs" - in result.stdout) - - # Update db to latest model. - result = self.env.run('migrate update_db_from_model %s %s %s'\ - % (self.url, repos_path, model_module), expect_stderr=True) - self.assertEqual(self.run_version(repos_path), 0) - self.assertEqual(self.run_db_version(self.url, repos_path), 0) # version did not get bumped yet because new version not yet created - - result = self.env.run('migrate compare_model_to_db %s %s %s'\ - % (self.url, repos_path, model_module), expect_stderr=True) - self.assertTrue("No schema diffs" in result.stdout) - - result = self.env.run( - 'migrate drop_version_control %s %s' % (self.url, repos_path), - expect_stderr=True, expect_error=True) - result = self.env.run( - 'migrate version_control %s %s' % (self.url, repos_path), - expect_stderr=True) - - result = self.env.run( - 'migrate create_model %s %s' % (self.url, repos_path), - expect_stderr=True) - temp_dict = dict() - six.exec_(result.stdout, temp_dict) - - # TODO: breaks on SA06 and SA05 - in need of total refactor - use different approach - - # TODO: compare whole table - self.compare_columns_equal(models.tmp_account_rundiffs.c, temp_dict['tmp_account_rundiffs'].c, ['type']) - ##self.assertTrue("""tmp_account_rundiffs = Table('tmp_account_rundiffs', meta, - ##Column('id', Integer(), primary_key=True, nullable=False), - ##Column('login', String(length=None, convert_unicode=False, assert_unicode=None)), - ##Column('passwd', String(length=None, convert_unicode=False, assert_unicode=None))""" in result.stdout) - - ## We're happy with db changes, make first db upgrade script to go from version 0 -> 1. - #result = self.env.run('migrate make_update_script_for_model', expect_error=True, expect_stderr=True) - #self.assertTrue('Not enough arguments' in result.stderr) - - #result_script = self.env.run('migrate make_update_script_for_model %s %s %s %s'\ - #% (self.url, repos_path, old_model_module, model_module)) - #self.assertEqualIgnoreWhitespace(result_script.stdout, - #'''from sqlalchemy import * - #from migrate import * - - #from migrate.changeset import schema - - #meta = MetaData() - #tmp_account_rundiffs = Table('tmp_account_rundiffs', meta, - #Column('id', Integer(), primary_key=True, nullable=False), - #Column('login', Text(length=None, convert_unicode=False, assert_unicode=None, unicode_error=None, _warn_on_bytestring=False)), - #Column('passwd', Text(length=None, convert_unicode=False, assert_unicode=None, unicode_error=None, _warn_on_bytestring=False)), - #) - - #def upgrade(migrate_engine): - ## Upgrade operations go here. Don't create your own engine; bind migrate_engine - ## to your metadata - #meta.bind = migrate_engine - #tmp_account_rundiffs.create() - - #def downgrade(migrate_engine): - ## Operations to reverse the above upgrade go here. - #meta.bind = migrate_engine - #tmp_account_rundiffs.drop()''') - - ## Save the upgrade script. - #result = self.env.run('migrate script Desc %s' % repos_path) - #upgrade_script_path = '%s/versions/001_Desc.py' % repos_path - #open(upgrade_script_path, 'w').write(result_script.stdout) - - #result = self.env.run('migrate compare_model_to_db %s %s %s'\ - #% (self.url, repos_path, model_module)) - #self.assertTrue("No schema diffs" in result.stdout) - - self.meta.drop_all() # in case junk tables are lying around in the test database diff --git a/migrate/tests/versioning/test_template.py b/migrate/tests/versioning/test_template.py deleted file mode 100644 index a079d8b..0000000 --- a/migrate/tests/versioning/test_template.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/python -# -*- coding: utf-8 -*- - -import os -import shutil - -import migrate.versioning.templates -from migrate.versioning.template import * -from migrate.versioning import api - -from migrate.tests import fixture - - -class TestTemplate(fixture.Pathed): - def test_templates(self): - """We can find the path to all repository templates""" - path = str(Template()) - self.assertTrue(os.path.exists(path)) - - def test_repository(self): - """We can find the path to the default repository""" - path = Template().get_repository() - self.assertTrue(os.path.exists(path)) - - def test_script(self): - """We can find the path to the default migration script""" - path = Template().get_script() - self.assertTrue(os.path.exists(path)) - - def test_custom_templates_and_themes(self): - """Users can define their own templates with themes""" - new_templates_dir = os.path.join(self.temp_usable_dir, 'templates') - manage_tmpl_file = os.path.join(new_templates_dir, 'manage/custom.py_tmpl') - repository_tmpl_file = os.path.join(new_templates_dir, 'repository/custom/README') - script_tmpl_file = os.path.join(new_templates_dir, 'script/custom.py_tmpl') - sql_script_tmpl_file = os.path.join(new_templates_dir, 'sql_script/custom.py_tmpl') - - MANAGE_CONTENTS = 'print "manage.py"' - README_CONTENTS = 'MIGRATE README!' - SCRIPT_FILE_CONTENTS = 'print "script.py"' - new_repo_dest = self.tmp_repos() - new_manage_dest = self.tmp_py() - - # make new templates dir - shutil.copytree(migrate.versioning.templates.__path__[0], new_templates_dir) - shutil.copytree(os.path.join(new_templates_dir, 'repository/default'), - os.path.join(new_templates_dir, 'repository/custom')) - - # edit templates - f = open(manage_tmpl_file, 'w').write(MANAGE_CONTENTS) - f = open(repository_tmpl_file, 'w').write(README_CONTENTS) - f = open(script_tmpl_file, 'w').write(SCRIPT_FILE_CONTENTS) - f = open(sql_script_tmpl_file, 'w').write(SCRIPT_FILE_CONTENTS) - - # create repository, manage file and python script - kw = {} - kw['templates_path'] = new_templates_dir - kw['templates_theme'] = 'custom' - api.create(new_repo_dest, 'repo_name', **kw) - api.script('test', new_repo_dest, **kw) - api.script_sql('postgres', 'foo', new_repo_dest, **kw) - api.manage(new_manage_dest, **kw) - - # assert changes - self.assertEqual(open(new_manage_dest).read(), MANAGE_CONTENTS) - self.assertEqual(open(os.path.join(new_repo_dest, 'manage.py')).read(), MANAGE_CONTENTS) - self.assertEqual(open(os.path.join(new_repo_dest, 'README')).read(), README_CONTENTS) - self.assertEqual(open(os.path.join(new_repo_dest, 'versions/001_test.py')).read(), SCRIPT_FILE_CONTENTS) - self.assertEqual(open(os.path.join(new_repo_dest, 'versions/002_foo_postgres_downgrade.sql')).read(), SCRIPT_FILE_CONTENTS) - self.assertEqual(open(os.path.join(new_repo_dest, 'versions/002_foo_postgres_upgrade.sql')).read(), SCRIPT_FILE_CONTENTS) diff --git a/migrate/tests/versioning/test_util.py b/migrate/tests/versioning/test_util.py deleted file mode 100644 index 21e3f27..0000000 --- a/migrate/tests/versioning/test_util.py +++ /dev/null @@ -1,139 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import os - -from sqlalchemy import * - -from migrate.exceptions import MigrateDeprecationWarning -from migrate.tests import fixture -from migrate.tests.fixture.warnings import catch_warnings -from migrate.versioning.util import * -from migrate.versioning import api - -import warnings - -class TestUtil(fixture.Pathed): - - def test_construct_engine(self): - """Construct engine the smart way""" - url = 'sqlite://' - - engine = construct_engine(url) - self.assertTrue(engine.name == 'sqlite') - - # keyword arg - engine = construct_engine(url, engine_arg_encoding='utf-8') - self.assertEqual(engine.dialect.encoding, 'utf-8') - - # dict - engine = construct_engine(url, engine_dict={'encoding': 'utf-8'}) - self.assertEqual(engine.dialect.encoding, 'utf-8') - - # engine parameter - engine_orig = create_engine('sqlite://') - engine = construct_engine(engine_orig) - self.assertEqual(engine, engine_orig) - - # test precedance - engine = construct_engine(url, engine_dict={'encoding': 'iso-8859-1'}, - engine_arg_encoding='utf-8') - self.assertEqual(engine.dialect.encoding, 'utf-8') - - # deprecated echo=True parameter - try: - # py 2.4 compatibility :-/ - cw = catch_warnings(record=True) - w = cw.__enter__() - - warnings.simplefilter("always") - engine = construct_engine(url, echo='True') - self.assertTrue(engine.echo) - - self.assertEqual(len(w),1) - self.assertTrue(issubclass(w[-1].category, - MigrateDeprecationWarning)) - self.assertEqual( - 'echo=True parameter is deprecated, pass ' - 'engine_arg_echo=True or engine_dict={"echo": True}', - str(w[-1].message)) - - finally: - cw.__exit__() - - # unsupported argument - self.assertRaises(ValueError, construct_engine, 1) - - def test_passing_engine(self): - repo = self.tmp_repos() - api.create(repo, 'temp') - api.script('First Version', repo) - engine = construct_engine('sqlite:///:memory:') - - api.version_control(engine, repo) - api.upgrade(engine, repo) - - def test_asbool(self): - """test asbool parsing""" - result = asbool(True) - self.assertEqual(result, True) - - result = asbool(False) - self.assertEqual(result, False) - - result = asbool('y') - self.assertEqual(result, True) - - result = asbool('n') - self.assertEqual(result, False) - - self.assertRaises(ValueError, asbool, 'test') - self.assertRaises(ValueError, asbool, object) - - - def test_load_model(self): - """load model from dotted name""" - model_path = os.path.join(self.temp_usable_dir, 'test_load_model.py') - - f = open(model_path, 'w') - f.write("class FakeFloat(int): pass") - f.close() - - try: - # py 2.4 compatibility :-/ - cw = catch_warnings(record=True) - w = cw.__enter__() - - warnings.simplefilter("always") - - # deprecated spelling - FakeFloat = load_model('test_load_model.FakeFloat') - self.assertTrue(isinstance(FakeFloat(), int)) - - self.assertEqual(len(w),1) - self.assertTrue(issubclass(w[-1].category, - MigrateDeprecationWarning)) - self.assertEqual( - 'model should be in form of module.model:User ' - 'and not module.model.User', - str(w[-1].message)) - - finally: - cw.__exit__() - - FakeFloat = load_model('test_load_model:FakeFloat') - self.assertTrue(isinstance(FakeFloat(), int)) - - FakeFloat = load_model(FakeFloat) - self.assertTrue(isinstance(FakeFloat(), int)) - - def test_guess_obj_type(self): - """guess object type from string""" - result = guess_obj_type('7') - self.assertEqual(result, 7) - - result = guess_obj_type('y') - self.assertEqual(result, True) - - result = guess_obj_type('test') - self.assertEqual(result, 'test') diff --git a/migrate/tests/versioning/test_version.py b/migrate/tests/versioning/test_version.py deleted file mode 100644 index 286dd59..0000000 --- a/migrate/tests/versioning/test_version.py +++ /dev/null @@ -1,186 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -from migrate.exceptions import * -from migrate.versioning.version import * - -from migrate.tests import fixture - - -class TestVerNum(fixture.Base): - def test_invalid(self): - """Disallow invalid version numbers""" - versions = ('-1', -1, 'Thirteen', '') - for version in versions: - self.assertRaises(ValueError, VerNum, version) - - def test_str(self): - """Test str and repr version numbers""" - self.assertEqual(str(VerNum(2)), '2') - self.assertEqual(repr(VerNum(2)), '<VerNum(2)>') - - def test_is(self): - """Two version with the same number should be equal""" - a = VerNum(1) - b = VerNum(1) - self.assertTrue(a is b) - - self.assertEqual(VerNum(VerNum(2)), VerNum(2)) - - def test_add(self): - self.assertEqual(VerNum(1) + VerNum(1), VerNum(2)) - self.assertEqual(VerNum(1) + 1, 2) - self.assertEqual(VerNum(1) + 1, '2') - self.assertTrue(isinstance(VerNum(1) + 1, VerNum)) - - def test_sub(self): - self.assertEqual(VerNum(1) - 1, 0) - self.assertTrue(isinstance(VerNum(1) - 1, VerNum)) - self.assertRaises(ValueError, lambda: VerNum(0) - 1) - - def test_eq(self): - """Two versions are equal""" - self.assertEqual(VerNum(1), VerNum('1')) - self.assertEqual(VerNum(1), 1) - self.assertEqual(VerNum(1), '1') - self.assertNotEqual(VerNum(1), 2) - - def test_ne(self): - self.assertTrue(VerNum(1) != 2) - self.assertFalse(VerNum(1) != 1) - - def test_lt(self): - self.assertFalse(VerNum(1) < 1) - self.assertTrue(VerNum(1) < 2) - self.assertFalse(VerNum(2) < 1) - - def test_le(self): - self.assertTrue(VerNum(1) <= 1) - self.assertTrue(VerNum(1) <= 2) - self.assertFalse(VerNum(2) <= 1) - - def test_gt(self): - self.assertFalse(VerNum(1) > 1) - self.assertFalse(VerNum(1) > 2) - self.assertTrue(VerNum(2) > 1) - - def test_ge(self): - self.assertTrue(VerNum(1) >= 1) - self.assertTrue(VerNum(2) >= 1) - self.assertFalse(VerNum(1) >= 2) - - def test_int_cast(self): - ver = VerNum(3) - # test __int__ - self.assertEqual(int(ver), 3) - # test __index__: range() doesn't call __int__ - self.assertEqual(list(range(ver, ver)), []) - - -class TestVersion(fixture.Pathed): - - def setUp(self): - super(TestVersion, self).setUp() - - def test_str_to_filename(self): - self.assertEqual(str_to_filename(''), '') - self.assertEqual(str_to_filename('__'), '_') - self.assertEqual(str_to_filename('a'), 'a') - self.assertEqual(str_to_filename('Abc Def'), 'Abc_Def') - self.assertEqual(str_to_filename('Abc "D" Ef'), 'Abc_D_Ef') - self.assertEqual(str_to_filename("Abc's Stuff"), 'Abc_s_Stuff') - self.assertEqual(str_to_filename("a b"), 'a_b') - self.assertEqual(str_to_filename("a.b to c"), 'a_b_to_c') - - def test_collection(self): - """Let's see how we handle versions collection""" - coll = Collection(self.temp_usable_dir) - coll.create_new_python_version("foo bar") - coll.create_new_sql_version("postgres", "foo bar") - coll.create_new_sql_version("sqlite", "foo bar") - coll.create_new_python_version("") - - self.assertEqual(coll.latest, 4) - self.assertEqual(len(coll.versions), 4) - self.assertEqual(coll.version(4), coll.version(coll.latest)) - # Check for non-existing version - self.assertRaises(VersionNotFoundError, coll.version, 5) - # Check for the current version - self.assertEqual('4', coll.version(4).version) - - coll2 = Collection(self.temp_usable_dir) - self.assertEqual(coll.versions, coll2.versions) - - Collection.clear() - - def test_old_repository(self): - open(os.path.join(self.temp_usable_dir, '1'), 'w') - self.assertRaises(Exception, Collection, self.temp_usable_dir) - - #TODO: def test_collection_unicode(self): - # pass - - def test_create_new_python_version(self): - coll = Collection(self.temp_usable_dir) - coll.create_new_python_version("'") - - ver = coll.version() - self.assertTrue(ver.script().source()) - - def test_create_new_sql_version(self): - coll = Collection(self.temp_usable_dir) - coll.create_new_sql_version("sqlite", "foo bar") - - ver = coll.version() - ver_up = ver.script('sqlite', 'upgrade') - ver_down = ver.script('sqlite', 'downgrade') - ver_up.source() - ver_down.source() - - def test_selection(self): - """Verify right sql script is selected""" - - # Create empty directory. - path = self.tmp_repos() - os.mkdir(path) - - # Create files -- files must be present or you'll get an exception later. - python_file = '001_initial_.py' - sqlite_upgrade_file = '001_sqlite_upgrade.sql' - default_upgrade_file = '001_default_upgrade.sql' - for file_ in [sqlite_upgrade_file, default_upgrade_file, python_file]: - filepath = '%s/%s' % (path, file_) - open(filepath, 'w').close() - - ver = Version(1, path, [sqlite_upgrade_file]) - self.assertEqual(os.path.basename(ver.script('sqlite', 'upgrade').path), sqlite_upgrade_file) - - ver = Version(1, path, [default_upgrade_file]) - self.assertEqual(os.path.basename(ver.script('default', 'upgrade').path), default_upgrade_file) - - ver = Version(1, path, [sqlite_upgrade_file, default_upgrade_file]) - self.assertEqual(os.path.basename(ver.script('sqlite', 'upgrade').path), sqlite_upgrade_file) - - ver = Version(1, path, [sqlite_upgrade_file, default_upgrade_file, python_file]) - self.assertEqual(os.path.basename(ver.script('postgres', 'upgrade').path), default_upgrade_file) - - ver = Version(1, path, [sqlite_upgrade_file, python_file]) - self.assertEqual(os.path.basename(ver.script('postgres', 'upgrade').path), python_file) - - def test_bad_version(self): - ver = Version(1, self.temp_usable_dir, []) - self.assertRaises(ScriptError, ver.add_script, '123.sql') - - # tests bad ibm_db_sa filename - ver = Version(123, self.temp_usable_dir, []) - self.assertRaises(ScriptError, ver.add_script, - '123_ibm_db_sa_upgrade.sql') - - # tests that the name is ok but the script doesn't exist - self.assertRaises(InvalidScriptError, ver.add_script, - '123_test_ibm_db_sa_upgrade.sql') - - pyscript = os.path.join(self.temp_usable_dir, 'bla.py') - open(pyscript, 'w') - ver.add_script(pyscript) - self.assertRaises(ScriptError, ver.add_script, 'bla.py') diff --git a/migrate/versioning/__init__.py b/migrate/versioning/__init__.py deleted file mode 100644 index 8b5a736..0000000 --- a/migrate/versioning/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -""" - This package provides functionality to create and manage - repositories of database schema changesets and to apply these - changesets to databases. -""" diff --git a/migrate/versioning/api.py b/migrate/versioning/api.py deleted file mode 100644 index 570dc08..0000000 --- a/migrate/versioning/api.py +++ /dev/null @@ -1,384 +0,0 @@ -""" - This module provides an external API to the versioning system. - - .. versionchanged:: 0.6.0 - :func:`migrate.versioning.api.test` and schema diff functions - changed order of positional arguments so all accept `url` and `repository` - as first arguments. - - .. versionchanged:: 0.5.4 - ``--preview_sql`` displays source file when using SQL scripts. - If Python script is used, it runs the action with mocked engine and - returns captured SQL statements. - - .. versionchanged:: 0.5.4 - Deprecated ``--echo`` parameter in favour of new - :func:`migrate.versioning.util.construct_engine` behavior. -""" - -# Dear migrate developers, -# -# please do not comment this module using sphinx syntax because its -# docstrings are presented as user help and most users cannot -# interpret sphinx annotated ReStructuredText. -# -# Thanks, -# Jan Dittberner - -import sys -import inspect -import logging - -from migrate import exceptions -from migrate.versioning import (repository, schema, version, - script as script_) # command name conflict -from migrate.versioning.util import catch_known_errors, with_engine - - -log = logging.getLogger(__name__) -command_desc = { - 'help': 'displays help on a given command', - 'create': 'create an empty repository at the specified path', - 'script': 'create an empty change Python script', - 'script_sql': 'create empty change SQL scripts for given database', - 'version': 'display the latest version available in a repository', - 'db_version': 'show the current version of the repository under version control', - 'source': 'display the Python code for a particular version in this repository', - 'version_control': 'mark a database as under this repository\'s version control', - 'upgrade': 'upgrade a database to a later version', - 'downgrade': 'downgrade a database to an earlier version', - 'drop_version_control': 'removes version control from a database', - 'manage': 'creates a Python script that runs Migrate with a set of default values', - 'test': 'performs the upgrade and downgrade command on the given database', - 'compare_model_to_db': 'compare MetaData against the current database state', - 'create_model': 'dump the current database as a Python model to stdout', - 'make_update_script_for_model': 'create a script changing the old MetaData to the new (current) MetaData', - 'update_db_from_model': 'modify the database to match the structure of the current MetaData', -} -__all__ = command_desc.keys() - -Repository = repository.Repository -ControlledSchema = schema.ControlledSchema -VerNum = version.VerNum -PythonScript = script_.PythonScript -SqlScript = script_.SqlScript - - -# deprecated -def help(cmd=None, **opts): - """%prog help COMMAND - - Displays help on a given command. - """ - if cmd is None: - raise exceptions.UsageError(None) - try: - func = globals()[cmd] - except: - raise exceptions.UsageError( - "'%s' isn't a valid command. Try 'help COMMAND'" % cmd) - ret = func.__doc__ - if sys.argv[0]: - ret = ret.replace('%prog', sys.argv[0]) - return ret - -@catch_known_errors -def create(repository, name, **opts): - """%prog create REPOSITORY_PATH NAME [--table=TABLE] - - Create an empty repository at the specified path. - - You can specify the version_table to be used; by default, it is - 'migrate_version'. This table is created in all version-controlled - databases. - """ - repo_path = Repository.create(repository, name, **opts) - - -@catch_known_errors -def script(description, repository, **opts): - """%prog script DESCRIPTION REPOSITORY_PATH - - Create an empty change script using the next unused version number - appended with the given description. - - For instance, manage.py script "Add initial tables" creates: - repository/versions/001_Add_initial_tables.py - """ - repo = Repository(repository) - repo.create_script(description, **opts) - - -@catch_known_errors -def script_sql(database, description, repository, **opts): - """%prog script_sql DATABASE DESCRIPTION REPOSITORY_PATH - - Create empty change SQL scripts for given DATABASE, where DATABASE - is either specific ('postgresql', 'mysql', 'oracle', 'sqlite', etc.) - or generic ('default'). - - For instance, manage.py script_sql postgresql description creates: - repository/versions/001_description_postgresql_upgrade.sql and - repository/versions/001_description_postgresql_downgrade.sql - """ - repo = Repository(repository) - repo.create_script_sql(database, description, **opts) - - -def version(repository, **opts): - """%prog version REPOSITORY_PATH - - Display the latest version available in a repository. - """ - repo = Repository(repository) - return repo.latest - - -@with_engine -def db_version(url, repository, **opts): - """%prog db_version URL REPOSITORY_PATH - - Show the current version of the repository with the given - connection string, under version control of the specified - repository. - - The url should be any valid SQLAlchemy connection string. - """ - engine = opts.pop('engine') - schema = ControlledSchema(engine, repository) - return schema.version - - -def source(version, dest=None, repository=None, **opts): - """%prog source VERSION [DESTINATION] --repository=REPOSITORY_PATH - - Display the Python code for a particular version in this - repository. Save it to the file at DESTINATION or, if omitted, - send to stdout. - """ - if repository is None: - raise exceptions.UsageError("A repository must be specified") - repo = Repository(repository) - ret = repo.version(version).script().source() - if dest is not None: - dest = open(dest, 'w') - dest.write(ret) - dest.close() - ret = None - return ret - - -def upgrade(url, repository, version=None, **opts): - """%prog upgrade URL REPOSITORY_PATH [VERSION] [--preview_py|--preview_sql] - - Upgrade a database to a later version. - - This runs the upgrade() function defined in your change scripts. - - By default, the database is updated to the latest available - version. You may specify a version instead, if you wish. - - You may preview the Python or SQL code to be executed, rather than - actually executing it, using the appropriate 'preview' option. - """ - err = "Cannot upgrade a database of version %s to version %s. "\ - "Try 'downgrade' instead." - return _migrate(url, repository, version, upgrade=True, err=err, **opts) - - -def downgrade(url, repository, version, **opts): - """%prog downgrade URL REPOSITORY_PATH VERSION [--preview_py|--preview_sql] - - Downgrade a database to an earlier version. - - This is the reverse of upgrade; this runs the downgrade() function - defined in your change scripts. - - You may preview the Python or SQL code to be executed, rather than - actually executing it, using the appropriate 'preview' option. - """ - err = "Cannot downgrade a database of version %s to version %s. "\ - "Try 'upgrade' instead." - return _migrate(url, repository, version, upgrade=False, err=err, **opts) - -@with_engine -def test(url, repository, **opts): - """%prog test URL REPOSITORY_PATH [VERSION] - - Performs the upgrade and downgrade option on the given - database. This is not a real test and may leave the database in a - bad state. You should therefore better run the test on a copy of - your database. - """ - engine = opts.pop('engine') - repos = Repository(repository) - - # Upgrade - log.info("Upgrading...") - script = repos.version(None).script(engine.name, 'upgrade') - script.run(engine, 1) - log.info("done") - - log.info("Downgrading...") - script = repos.version(None).script(engine.name, 'downgrade') - script.run(engine, -1) - log.info("done") - log.info("Success") - - -@with_engine -def version_control(url, repository, version=None, **opts): - """%prog version_control URL REPOSITORY_PATH [VERSION] - - Mark a database as under this repository's version control. - - Once a database is under version control, schema changes should - only be done via change scripts in this repository. - - This creates the table version_table in the database. - - The url should be any valid SQLAlchemy connection string. - - By default, the database begins at version 0 and is assumed to be - empty. If the database is not empty, you may specify a version at - which to begin instead. No attempt is made to verify this - version's correctness - the database schema is expected to be - identical to what it would be if the database were created from - scratch. - """ - engine = opts.pop('engine') - ControlledSchema.create(engine, repository, version) - - -@with_engine -def drop_version_control(url, repository, **opts): - """%prog drop_version_control URL REPOSITORY_PATH - - Removes version control from a database. - """ - engine = opts.pop('engine') - schema = ControlledSchema(engine, repository) - schema.drop() - - -def manage(file, **opts): - """%prog manage FILENAME [VARIABLES...] - - Creates a script that runs Migrate with a set of default values. - - For example:: - - %prog manage manage.py --repository=/path/to/repository \ ---url=sqlite:///project.db - - would create the script manage.py. The following two commands - would then have exactly the same results:: - - python manage.py version - %prog version --repository=/path/to/repository - """ - Repository.create_manage_file(file, **opts) - - -@with_engine -def compare_model_to_db(url, repository, model, **opts): - """%prog compare_model_to_db URL REPOSITORY_PATH MODEL - - Compare the current model (assumed to be a module level variable - of type sqlalchemy.MetaData) against the current database. - - NOTE: This is EXPERIMENTAL. - """ # TODO: get rid of EXPERIMENTAL label - engine = opts.pop('engine') - return ControlledSchema.compare_model_to_db(engine, model, repository) - - -@with_engine -def create_model(url, repository, **opts): - """%prog create_model URL REPOSITORY_PATH [DECLERATIVE=True] - - Dump the current database as a Python model to stdout. - - NOTE: This is EXPERIMENTAL. - """ # TODO: get rid of EXPERIMENTAL label - engine = opts.pop('engine') - declarative = opts.get('declarative', False) - return ControlledSchema.create_model(engine, repository, declarative) - - -@catch_known_errors -@with_engine -def make_update_script_for_model(url, repository, oldmodel, model, **opts): - """%prog make_update_script_for_model URL OLDMODEL MODEL REPOSITORY_PATH - - Create a script changing the old Python model to the new (current) - Python model, sending to stdout. - - NOTE: This is EXPERIMENTAL. - """ # TODO: get rid of EXPERIMENTAL label - engine = opts.pop('engine') - return PythonScript.make_update_script_for_model( - engine, oldmodel, model, repository, **opts) - - -@with_engine -def update_db_from_model(url, repository, model, **opts): - """%prog update_db_from_model URL REPOSITORY_PATH MODEL - - Modify the database to match the structure of the current Python - model. This also sets the db_version number to the latest in the - repository. - - NOTE: This is EXPERIMENTAL. - """ # TODO: get rid of EXPERIMENTAL label - engine = opts.pop('engine') - schema = ControlledSchema(engine, repository) - schema.update_db_from_model(model) - -@with_engine -def _migrate(url, repository, version, upgrade, err, **opts): - engine = opts.pop('engine') - url = str(engine.url) - schema = ControlledSchema(engine, repository) - version = _migrate_version(schema, version, upgrade, err) - - changeset = schema.changeset(version) - for ver, change in changeset: - nextver = ver + changeset.step - log.info('%s -> %s... ', ver, nextver) - - if opts.get('preview_sql'): - if isinstance(change, PythonScript): - log.info(change.preview_sql(url, changeset.step, **opts)) - elif isinstance(change, SqlScript): - log.info(change.source()) - - elif opts.get('preview_py'): - if not isinstance(change, PythonScript): - raise exceptions.UsageError("Python source can be only displayed" - " for python migration files") - source_ver = max(ver, nextver) - module = schema.repository.version(source_ver).script().module - funcname = upgrade and "upgrade" or "downgrade" - func = getattr(module, funcname) - log.info(inspect.getsource(func)) - else: - schema.runchange(ver, change, changeset.step) - log.info('done') - - -def _migrate_version(schema, version, upgrade, err): - if version is None: - return version - # Version is specified: ensure we're upgrading in the right direction - # (current version < target version for upgrading; reverse for down) - version = VerNum(version) - cur = schema.version - if upgrade is not None: - if upgrade: - direction = cur <= version - else: - direction = cur >= version - if not direction: - raise exceptions.KnownError(err % (cur, version)) - return version diff --git a/migrate/versioning/cfgparse.py b/migrate/versioning/cfgparse.py deleted file mode 100644 index 8f1ccf9..0000000 --- a/migrate/versioning/cfgparse.py +++ /dev/null @@ -1,27 +0,0 @@ -""" - Configuration parser module. -""" - -from six.moves.configparser import ConfigParser - -from migrate.versioning.config import * -from migrate.versioning import pathed - - -class Parser(ConfigParser): - """A project configuration file.""" - - def to_dict(self, sections=None): - """It's easier to access config values like dictionaries""" - return self._sections - - -class Config(pathed.Pathed, Parser): - """Configuration class.""" - - def __init__(self, path, *p, **k): - """Confirm the config file exists; read it.""" - self.require_found(path) - pathed.Pathed.__init__(self, path) - Parser.__init__(self, *p, **k) - self.read(path) diff --git a/migrate/versioning/config.py b/migrate/versioning/config.py deleted file mode 100644 index 2429fd8..0000000 --- a/migrate/versioning/config.py +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/python -# -*- coding: utf-8 -*- - -from sqlalchemy.util import OrderedDict - - -__all__ = ['databases', 'operations'] - -databases = ('sqlite', 'postgres', 'mysql', 'oracle', 'mssql', 'firebird') - -# Map operation names to function names -operations = OrderedDict() -operations['upgrade'] = 'upgrade' -operations['downgrade'] = 'downgrade' diff --git a/migrate/versioning/genmodel.py b/migrate/versioning/genmodel.py deleted file mode 100644 index 4d9cd12..0000000 --- a/migrate/versioning/genmodel.py +++ /dev/null @@ -1,306 +0,0 @@ -""" -Code to generate a Python model from a database or differences -between a model and database. - -Some of this is borrowed heavily from the AutoCode project at: -http://code.google.com/p/sqlautocode/ -""" - -import sys -import logging - -import six -import sqlalchemy - -import migrate -import migrate.changeset - - -log = logging.getLogger(__name__) -HEADER = """ -## File autogenerated by genmodel.py - -from sqlalchemy import * -""" - -META_DEFINITION = "meta = MetaData()" - -DECLARATIVE_DEFINITION = """ -from sqlalchemy.ext import declarative - -Base = declarative.declarative_base() -""" - - -class ModelGenerator(object): - """Various transformations from an A, B diff. - - In the implementation, A tends to be called the model and B - the database (although this is not true of all diffs). - The diff is directionless, but transformations apply the diff - in a particular direction, described in the method name. - """ - - def __init__(self, diff, engine, declarative=False): - self.diff = diff - self.engine = engine - self.declarative = declarative - - def column_repr(self, col): - kwarg = [] - if col.key != col.name: - kwarg.append('key') - if col.primary_key: - col.primary_key = True # otherwise it dumps it as 1 - kwarg.append('primary_key') - if not col.nullable: - kwarg.append('nullable') - if col.onupdate: - kwarg.append('onupdate') - if col.default: - if col.primary_key: - # I found that PostgreSQL automatically creates a - # default value for the sequence, but let's not show - # that. - pass - else: - kwarg.append('default') - args = ['%s=%r' % (k, getattr(col, k)) for k in kwarg] - - # crs: not sure if this is good idea, but it gets rid of extra - # u'' - if six.PY3: - name = col.name - else: - name = col.name.encode('utf8') - - type_ = col.type - for cls in col.type.__class__.__mro__: - if cls.__module__ == 'sqlalchemy.types' and \ - not cls.__name__.isupper(): - if cls is not type_.__class__: - type_ = cls() - break - - type_repr = repr(type_) - if type_repr.endswith('()'): - type_repr = type_repr[:-2] - - constraints = [repr(cn) for cn in col.constraints] - - data = { - 'name': name, - 'commonStuff': ', '.join([type_repr] + constraints + args), - } - - if self.declarative: - return """%(name)s = Column(%(commonStuff)s)""" % data - else: - return """Column(%(name)r, %(commonStuff)s)""" % data - - def _getTableDefn(self, table, metaName='meta'): - out = [] - tableName = table.name - if self.declarative: - out.append("class %(table)s(Base):" % {'table': tableName}) - out.append(" __tablename__ = '%(table)s'\n" % - {'table': tableName}) - for col in table.columns: - out.append(" %s" % self.column_repr(col)) - out.append('\n') - else: - out.append("%(table)s = Table('%(table)s', %(meta)s," % - {'table': tableName, 'meta': metaName}) - for col in table.columns: - out.append(" %s," % self.column_repr(col)) - out.append(")\n") - return out - - def _get_tables(self,missingA=False,missingB=False,modified=False): - to_process = [] - for bool_,names,metadata in ( - (missingA,self.diff.tables_missing_from_A,self.diff.metadataB), - (missingB,self.diff.tables_missing_from_B,self.diff.metadataA), - (modified,self.diff.tables_different,self.diff.metadataA), - ): - if bool_: - for name in names: - yield metadata.tables.get(name) - - def _genModelHeader(self, tables): - out = [] - import_index = [] - - out.append(HEADER) - - for table in tables: - for col in table.columns: - if "dialects" in col.type.__module__ and \ - col.type.__class__ not in import_index: - out.append("from " + col.type.__module__ + - " import " + col.type.__class__.__name__) - import_index.append(col.type.__class__) - - out.append("") - - if self.declarative: - out.append(DECLARATIVE_DEFINITION) - else: - out.append(META_DEFINITION) - out.append("") - - return out - - def genBDefinition(self): - """Generates the source code for a definition of B. - - Assumes a diff where A is empty. - - Was: toPython. Assume database (B) is current and model (A) is empty. - """ - - out = [] - out.extend(self._genModelHeader(self._get_tables(missingA=True))) - for table in self._get_tables(missingA=True): - out.extend(self._getTableDefn(table)) - return '\n'.join(out) - - def genB2AMigration(self, indent=' '): - '''Generate a migration from B to A. - - Was: toUpgradeDowngradePython - Assume model (A) is most current and database (B) is out-of-date. - ''' - - decls = ['from migrate.changeset import schema', - 'pre_meta = MetaData()', - 'post_meta = MetaData()', - ] - upgradeCommands = ['pre_meta.bind = migrate_engine', - 'post_meta.bind = migrate_engine'] - downgradeCommands = list(upgradeCommands) - - for tn in self.diff.tables_missing_from_A: - pre_table = self.diff.metadataB.tables[tn] - decls.extend(self._getTableDefn(pre_table, metaName='pre_meta')) - upgradeCommands.append( - "pre_meta.tables[%(table)r].drop()" % {'table': tn}) - downgradeCommands.append( - "pre_meta.tables[%(table)r].create()" % {'table': tn}) - - for tn in self.diff.tables_missing_from_B: - post_table = self.diff.metadataA.tables[tn] - decls.extend(self._getTableDefn(post_table, metaName='post_meta')) - upgradeCommands.append( - "post_meta.tables[%(table)r].create()" % {'table': tn}) - downgradeCommands.append( - "post_meta.tables[%(table)r].drop()" % {'table': tn}) - - for (tn, td) in six.iteritems(self.diff.tables_different): - if td.columns_missing_from_A or td.columns_different: - pre_table = self.diff.metadataB.tables[tn] - decls.extend(self._getTableDefn( - pre_table, metaName='pre_meta')) - if td.columns_missing_from_B or td.columns_different: - post_table = self.diff.metadataA.tables[tn] - decls.extend(self._getTableDefn( - post_table, metaName='post_meta')) - - for col in td.columns_missing_from_A: - upgradeCommands.append( - 'pre_meta.tables[%r].columns[%r].drop()' % (tn, col)) - downgradeCommands.append( - 'pre_meta.tables[%r].columns[%r].create()' % (tn, col)) - for col in td.columns_missing_from_B: - upgradeCommands.append( - 'post_meta.tables[%r].columns[%r].create()' % (tn, col)) - downgradeCommands.append( - 'post_meta.tables[%r].columns[%r].drop()' % (tn, col)) - for modelCol, databaseCol, modelDecl, databaseDecl in td.columns_different: - upgradeCommands.append( - 'assert False, "Can\'t alter columns: %s:%s=>%s"' % ( - tn, modelCol.name, databaseCol.name)) - downgradeCommands.append( - 'assert False, "Can\'t alter columns: %s:%s=>%s"' % ( - tn, modelCol.name, databaseCol.name)) - - return ( - '\n'.join(decls), - '\n'.join('%s%s' % (indent, line) for line in upgradeCommands), - '\n'.join('%s%s' % (indent, line) for line in downgradeCommands)) - - def _db_can_handle_this_change(self,td): - """Check if the database can handle going from B to A.""" - - if (td.columns_missing_from_B - and not td.columns_missing_from_A - and not td.columns_different): - # Even sqlite can handle column additions. - return True - else: - return not self.engine.url.drivername.startswith('sqlite') - - def runB2A(self): - """Goes from B to A. - - Was: applyModel. Apply model (A) to current database (B). - """ - - meta = sqlalchemy.MetaData(self.engine) - - for table in self._get_tables(missingA=True): - table = table.tometadata(meta) - table.drop() - for table in self._get_tables(missingB=True): - table = table.tometadata(meta) - table.create() - for modelTable in self._get_tables(modified=True): - tableName = modelTable.name - modelTable = modelTable.tometadata(meta) - dbTable = self.diff.metadataB.tables[tableName] - - td = self.diff.tables_different[tableName] - - if self._db_can_handle_this_change(td): - - for col in td.columns_missing_from_B: - modelTable.columns[col].create() - for col in td.columns_missing_from_A: - dbTable.columns[col].drop() - # XXX handle column changes here. - else: - # Sqlite doesn't support drop column, so you have to - # do more: create temp table, copy data to it, drop - # old table, create new table, copy data back. - # - # I wonder if this is guaranteed to be unique? - tempName = '_temp_%s' % modelTable.name - - def getCopyStatement(): - preparer = self.engine.dialect.preparer - commonCols = [] - for modelCol in modelTable.columns: - if modelCol.name in dbTable.columns: - commonCols.append(modelCol.name) - commonColsStr = ', '.join(commonCols) - return 'INSERT INTO %s (%s) SELECT %s FROM %s' % \ - (tableName, commonColsStr, commonColsStr, tempName) - - # Move the data in one transaction, so that we don't - # leave the database in a nasty state. - connection = self.engine.connect() - trans = connection.begin() - try: - connection.execute( - 'CREATE TEMPORARY TABLE %s as SELECT * from %s' % \ - (tempName, modelTable.name)) - # make sure the drop takes place inside our - # transaction with the bind parameter - modelTable.drop(bind=connection) - modelTable.create(bind=connection) - connection.execute(getCopyStatement()) - connection.execute('DROP TABLE %s' % tempName) - trans.commit() - except: - trans.rollback() - raise diff --git a/migrate/versioning/migrate_repository.py b/migrate/versioning/migrate_repository.py deleted file mode 100644 index 22bba47..0000000 --- a/migrate/versioning/migrate_repository.py +++ /dev/null @@ -1,96 +0,0 @@ -""" - Script to migrate repository from sqlalchemy <= 0.4.4 to the new - repository schema. This shouldn't use any other migrate modules, so - that it can work in any version. -""" - -import os -import sys -import logging - -log = logging.getLogger(__name__) - - -def usage(): - """Gives usage information.""" - print("Usage: %s repository-to-migrate" % sys.argv[0]) - print("Upgrade your repository to the new flat format.") - print("NOTE: You should probably make a backup before running this.") - sys.exit(1) - - -def delete_file(filepath): - """Deletes a file and prints a message.""" - log.info('Deleting file: %s' % filepath) - os.remove(filepath) - - -def move_file(src, tgt): - """Moves a file and prints a message.""" - log.info('Moving file %s to %s' % (src, tgt)) - if os.path.exists(tgt): - raise Exception( - 'Cannot move file %s because target %s already exists' % \ - (src, tgt)) - os.rename(src, tgt) - - -def delete_directory(dirpath): - """Delete a directory and print a message.""" - log.info('Deleting directory: %s' % dirpath) - os.rmdir(dirpath) - - -def migrate_repository(repos): - """Does the actual migration to the new repository format.""" - log.info('Migrating repository at: %s to new format' % repos) - versions = '%s/versions' % repos - dirs = os.listdir(versions) - # Only use int's in list. - numdirs = [int(dirname) for dirname in dirs if dirname.isdigit()] - numdirs.sort() # Sort list. - for dirname in numdirs: - origdir = '%s/%s' % (versions, dirname) - log.info('Working on directory: %s' % origdir) - files = os.listdir(origdir) - files.sort() - for filename in files: - # Delete compiled Python files. - if filename.endswith('.pyc') or filename.endswith('.pyo'): - delete_file('%s/%s' % (origdir, filename)) - - # Delete empty __init__.py files. - origfile = '%s/__init__.py' % origdir - if os.path.exists(origfile) and len(open(origfile).read()) == 0: - delete_file(origfile) - - # Move sql upgrade scripts. - if filename.endswith('.sql'): - version, dbms, operation = filename.split('.', 3)[0:3] - origfile = '%s/%s' % (origdir, filename) - # For instance: 2.postgres.upgrade.sql -> - # 002_postgres_upgrade.sql - tgtfile = '%s/%03d_%s_%s.sql' % ( - versions, int(version), dbms, operation) - move_file(origfile, tgtfile) - - # Move Python upgrade script. - pyfile = '%s.py' % dirname - pyfilepath = '%s/%s' % (origdir, pyfile) - if os.path.exists(pyfilepath): - tgtfile = '%s/%03d.py' % (versions, int(dirname)) - move_file(pyfilepath, tgtfile) - - # Try to remove directory. Will fail if it's not empty. - delete_directory(origdir) - - -def main(): - """Main function to be called when using this script.""" - if len(sys.argv) != 2: - usage() - migrate_repository(sys.argv[1]) - - -if __name__ == '__main__': - main() diff --git a/migrate/versioning/pathed.py b/migrate/versioning/pathed.py deleted file mode 100644 index fbee0e4..0000000 --- a/migrate/versioning/pathed.py +++ /dev/null @@ -1,75 +0,0 @@ -""" - A path/directory class. -""" - -import os -import shutil -import logging - -from migrate import exceptions -from migrate.versioning.config import * -from migrate.versioning.util import KeyedInstance - - -log = logging.getLogger(__name__) - -class Pathed(KeyedInstance): - """ - A class associated with a path/directory tree. - - Only one instance of this class may exist for a particular file; - __new__ will return an existing instance if possible - """ - parent = None - - @classmethod - def _key(cls, path): - return str(path) - - def __init__(self, path): - self.path = path - if self.__class__.parent is not None: - self._init_parent(path) - - def _init_parent(self, path): - """Try to initialize this object's parent, if it has one""" - parent_path = self.__class__._parent_path(path) - self.parent = self.__class__.parent(parent_path) - log.debug("Getting parent %r:%r" % (self.__class__.parent, parent_path)) - self.parent._init_child(path, self) - - def _init_child(self, child, path): - """Run when a child of this object is initialized. - - Parameters: the child object; the path to this object (its - parent) - """ - - @classmethod - def _parent_path(cls, path): - """ - Fetch the path of this object's parent from this object's path. - """ - # os.path.dirname(), but strip directories like files (like - # unix basename) - # - # Treat directories like files... - if path[-1] == '/': - path = path[:-1] - ret = os.path.dirname(path) - return ret - - @classmethod - def require_notfound(cls, path): - """Ensures a given path does not already exist""" - if os.path.exists(path): - raise exceptions.PathFoundError(path) - - @classmethod - def require_found(cls, path): - """Ensures a given path already exists""" - if not os.path.exists(path): - raise exceptions.PathNotFoundError(path) - - def __str__(self): - return self.path diff --git a/migrate/versioning/repository.py b/migrate/versioning/repository.py deleted file mode 100644 index 8c8cd3b..0000000 --- a/migrate/versioning/repository.py +++ /dev/null @@ -1,242 +0,0 @@ -""" - SQLAlchemy migrate repository management. -""" -import os -import shutil -import string -import logging - -from pkg_resources import resource_filename -from tempita import Template as TempitaTemplate - -from migrate import exceptions -from migrate.versioning import version, pathed, cfgparse -from migrate.versioning.template import Template -from migrate.versioning.config import * - - -log = logging.getLogger(__name__) - -class Changeset(dict): - """A collection of changes to be applied to a database. - - Changesets are bound to a repository and manage a set of - scripts from that repository. - - Behaves like a dict, for the most part. Keys are ordered based on step value. - """ - - def __init__(self, start, *changes, **k): - """ - Give a start version; step must be explicitly stated. - """ - self.step = k.pop('step', 1) - self.start = version.VerNum(start) - self.end = self.start - for change in changes: - self.add(change) - - def __iter__(self): - return iter(self.items()) - - def keys(self): - """ - In a series of upgrades x -> y, keys are version x. Sorted. - """ - ret = list(super(Changeset, self).keys()) - # Reverse order if downgrading - ret.sort(reverse=(self.step < 1)) - return ret - - def values(self): - return [self[k] for k in self.keys()] - - def items(self): - return zip(self.keys(), self.values()) - - def add(self, change): - """Add new change to changeset""" - key = self.end - self.end += self.step - self[key] = change - - def run(self, *p, **k): - """Run the changeset scripts""" - for ver, script in self: - script.run(*p, **k) - - -class Repository(pathed.Pathed): - """A project's change script repository""" - - _config = 'migrate.cfg' - _versions = 'versions' - - def __init__(self, path): - log.debug('Loading repository %s...' % path) - self.verify(path) - super(Repository, self).__init__(path) - self.config = cfgparse.Config(os.path.join(self.path, self._config)) - self.versions = version.Collection(os.path.join(self.path, - self._versions)) - log.debug('Repository %s loaded successfully' % path) - log.debug('Config: %r' % self.config.to_dict()) - - @classmethod - def verify(cls, path): - """ - Ensure the target path is a valid repository. - - :raises: :exc:`InvalidRepositoryError <migrate.exceptions.InvalidRepositoryError>` - """ - # Ensure the existence of required files - try: - cls.require_found(path) - cls.require_found(os.path.join(path, cls._config)) - cls.require_found(os.path.join(path, cls._versions)) - except exceptions.PathNotFoundError: - raise exceptions.InvalidRepositoryError(path) - - @classmethod - def prepare_config(cls, tmpl_dir, name, options=None): - """ - Prepare a project configuration file for a new project. - - :param tmpl_dir: Path to Repository template - :param config_file: Name of the config file in Repository template - :param name: Repository name - :type tmpl_dir: string - :type config_file: string - :type name: string - :returns: Populated config file - """ - if options is None: - options = {} - options.setdefault('version_table', 'migrate_version') - options.setdefault('repository_id', name) - options.setdefault('required_dbs', []) - options.setdefault('use_timestamp_numbering', False) - - tmpl = open(os.path.join(tmpl_dir, cls._config)).read() - ret = TempitaTemplate(tmpl).substitute(options) - - # cleanup - del options['__template_name__'] - - return ret - - @classmethod - def create(cls, path, name, **opts): - """Create a repository at a specified path""" - cls.require_notfound(path) - theme = opts.pop('templates_theme', None) - t_path = opts.pop('templates_path', None) - - # Create repository - tmpl_dir = Template(t_path).get_repository(theme=theme) - shutil.copytree(tmpl_dir, path) - - # Edit config defaults - config_text = cls.prepare_config(tmpl_dir, name, options=opts) - fd = open(os.path.join(path, cls._config), 'w') - fd.write(config_text) - fd.close() - - opts['repository_name'] = name - - # Create a management script - manager = os.path.join(path, 'manage.py') - Repository.create_manage_file(manager, templates_theme=theme, - templates_path=t_path, **opts) - - return cls(path) - - def create_script(self, description, **k): - """API to :meth:`migrate.versioning.version.Collection.create_new_python_version`""" - - k['use_timestamp_numbering'] = self.use_timestamp_numbering - self.versions.create_new_python_version(description, **k) - - def create_script_sql(self, database, description, **k): - """API to :meth:`migrate.versioning.version.Collection.create_new_sql_version`""" - k['use_timestamp_numbering'] = self.use_timestamp_numbering - self.versions.create_new_sql_version(database, description, **k) - - @property - def latest(self): - """API to :attr:`migrate.versioning.version.Collection.latest`""" - return self.versions.latest - - @property - def version_table(self): - """Returns version_table name specified in config""" - return self.config.get('db_settings', 'version_table') - - @property - def id(self): - """Returns repository id specified in config""" - return self.config.get('db_settings', 'repository_id') - - @property - def use_timestamp_numbering(self): - """Returns use_timestamp_numbering specified in config""" - if self.config.has_option('db_settings', 'use_timestamp_numbering'): - return self.config.getboolean('db_settings', 'use_timestamp_numbering') - return False - - def version(self, *p, **k): - """API to :attr:`migrate.versioning.version.Collection.version`""" - return self.versions.version(*p, **k) - - @classmethod - def clear(cls): - # TODO: deletes repo - super(Repository, cls).clear() - version.Collection.clear() - - def changeset(self, database, start, end=None): - """Create a changeset to migrate this database from ver. start to end/latest. - - :param database: name of database to generate changeset - :param start: version to start at - :param end: version to end at (latest if None given) - :type database: string - :type start: int - :type end: int - :returns: :class:`Changeset instance <migration.versioning.repository.Changeset>` - """ - start = version.VerNum(start) - - if end is None: - end = self.latest - else: - end = version.VerNum(end) - - if start <= end: - step = 1 - range_mod = 1 - op = 'upgrade' - else: - step = -1 - range_mod = 0 - op = 'downgrade' - - versions = range(int(start) + range_mod, int(end) + range_mod, step) - changes = [self.version(v).script(database, op) for v in versions] - ret = Changeset(start, step=step, *changes) - return ret - - @classmethod - def create_manage_file(cls, file_, **opts): - """Create a project management script (manage.py) - - :param file_: Destination file to be written - :param opts: Options that are passed to :func:`migrate.versioning.shell.main` - """ - mng_file = Template(opts.pop('templates_path', None))\ - .get_manage(theme=opts.pop('templates_theme', None)) - - tmpl = open(mng_file).read() - fd = open(file_, 'w') - fd.write(TempitaTemplate(tmpl).substitute(opts)) - fd.close() diff --git a/migrate/versioning/schema.py b/migrate/versioning/schema.py deleted file mode 100644 index b525cef..0000000 --- a/migrate/versioning/schema.py +++ /dev/null @@ -1,222 +0,0 @@ -""" - Database schema version management. -""" -import sys -import logging - -import six -from sqlalchemy import (Table, Column, MetaData, String, Text, Integer, - create_engine) -from sqlalchemy.sql import and_ -from sqlalchemy import exc as sa_exceptions -from sqlalchemy.sql import bindparam - -from migrate import exceptions -from migrate.changeset import SQLA_07 -from migrate.versioning import genmodel, schemadiff -from migrate.versioning.repository import Repository -from migrate.versioning.util import load_model -from migrate.versioning.version import VerNum - - -log = logging.getLogger(__name__) - -class ControlledSchema(object): - """A database under version control""" - - def __init__(self, engine, repository): - if isinstance(repository, six.string_types): - repository = Repository(repository) - self.engine = engine - self.repository = repository - self.meta = MetaData(engine) - self.load() - - def __eq__(self, other): - """Compare two schemas by repositories and versions""" - return (self.repository is other.repository \ - and self.version == other.version) - - def load(self): - """Load controlled schema version info from DB""" - tname = self.repository.version_table - try: - if not hasattr(self, 'table') or self.table is None: - self.table = Table(tname, self.meta, autoload=True) - - result = self.engine.execute(self.table.select( - self.table.c.repository_id == str(self.repository.id))) - - data = list(result)[0] - except: - cls, exc, tb = sys.exc_info() - six.reraise(exceptions.DatabaseNotControlledError, - exceptions.DatabaseNotControlledError(str(exc)), tb) - - self.version = data['version'] - return data - - def drop(self): - """ - Remove version control from a database. - """ - if SQLA_07: - try: - self.table.drop() - except sa_exceptions.DatabaseError: - raise exceptions.DatabaseNotControlledError(str(self.table)) - else: - try: - self.table.drop() - except (sa_exceptions.SQLError): - raise exceptions.DatabaseNotControlledError(str(self.table)) - - def changeset(self, version=None): - """API to Changeset creation. - - Uses self.version for start version and engine.name - to get database name. - """ - database = self.engine.name - start_ver = self.version - changeset = self.repository.changeset(database, start_ver, version) - return changeset - - def runchange(self, ver, change, step): - startver = ver - endver = ver + step - # Current database version must be correct! Don't run if corrupt! - if self.version != startver: - raise exceptions.InvalidVersionError("%s is not %s" % \ - (self.version, startver)) - # Run the change - change.run(self.engine, step) - - # Update/refresh database version - self.update_repository_table(startver, endver) - self.load() - - def update_repository_table(self, startver, endver): - """Update version_table with new information""" - update = self.table.update(and_(self.table.c.version == int(startver), - self.table.c.repository_id == str(self.repository.id))) - self.engine.execute(update, version=int(endver)) - - def upgrade(self, version=None): - """ - Upgrade (or downgrade) to a specified version, or latest version. - """ - changeset = self.changeset(version) - for ver, change in changeset: - self.runchange(ver, change, changeset.step) - - def update_db_from_model(self, model): - """ - Modify the database to match the structure of the current Python model. - """ - model = load_model(model) - - diff = schemadiff.getDiffOfModelAgainstDatabase( - model, self.engine, excludeTables=[self.repository.version_table] - ) - genmodel.ModelGenerator(diff,self.engine).runB2A() - - self.update_repository_table(self.version, int(self.repository.latest)) - - self.load() - - @classmethod - def create(cls, engine, repository, version=None): - """ - Declare a database to be under a repository's version control. - - :raises: :exc:`DatabaseAlreadyControlledError` - :returns: :class:`ControlledSchema` - """ - # Confirm that the version # is valid: positive, integer, - # exists in repos - if isinstance(repository, six.string_types): - repository = Repository(repository) - version = cls._validate_version(repository, version) - table = cls._create_table_version(engine, repository, version) - # TODO: history table - # Load repository information and return - return cls(engine, repository) - - @classmethod - def _validate_version(cls, repository, version): - """ - Ensures this is a valid version number for this repository. - - :raises: :exc:`InvalidVersionError` if invalid - :return: valid version number - """ - if version is None: - version = 0 - try: - version = VerNum(version) # raises valueerror - if version < 0 or version > repository.latest: - raise ValueError() - except ValueError: - raise exceptions.InvalidVersionError(version) - return version - - @classmethod - def _create_table_version(cls, engine, repository, version): - """ - Creates the versioning table in a database. - - :raises: :exc:`DatabaseAlreadyControlledError` - """ - # Create tables - tname = repository.version_table - meta = MetaData(engine) - - table = Table( - tname, meta, - Column('repository_id', String(250), primary_key=True), - Column('repository_path', Text), - Column('version', Integer), ) - - # there can be multiple repositories/schemas in the same db - if not table.exists(): - table.create() - - # test for existing repository_id - s = table.select(table.c.repository_id == bindparam("repository_id")) - result = engine.execute(s, repository_id=repository.id) - if result.fetchone(): - raise exceptions.DatabaseAlreadyControlledError - - # Insert data - engine.execute(table.insert().values( - repository_id=repository.id, - repository_path=repository.path, - version=int(version))) - return table - - @classmethod - def compare_model_to_db(cls, engine, model, repository): - """ - Compare the current model against the current database. - """ - if isinstance(repository, six.string_types): - repository = Repository(repository) - model = load_model(model) - - diff = schemadiff.getDiffOfModelAgainstDatabase( - model, engine, excludeTables=[repository.version_table]) - return diff - - @classmethod - def create_model(cls, engine, repository, declarative=False): - """ - Dump the current database as a Python model. - """ - if isinstance(repository, six.string_types): - repository = Repository(repository) - - diff = schemadiff.getDiffOfModelAgainstDatabase( - MetaData(), engine, excludeTables=[repository.version_table] - ) - return genmodel.ModelGenerator(diff, engine, declarative).genBDefinition() diff --git a/migrate/versioning/schemadiff.py b/migrate/versioning/schemadiff.py deleted file mode 100644 index d9477bf..0000000 --- a/migrate/versioning/schemadiff.py +++ /dev/null @@ -1,298 +0,0 @@ -""" - Schema differencing support. -""" - -import logging -import sqlalchemy - -from sqlalchemy.types import Float - -log = logging.getLogger(__name__) - -def getDiffOfModelAgainstDatabase(metadata, engine, excludeTables=None): - """ - Return differences of model against database. - - :return: object which will evaluate to :keyword:`True` if there \ - are differences else :keyword:`False`. - """ - db_metadata = sqlalchemy.MetaData(engine) - db_metadata.reflect() - - # sqlite will include a dynamically generated 'sqlite_sequence' table if - # there are autoincrement sequences in the database; this should not be - # compared. - if engine.dialect.name == 'sqlite': - if 'sqlite_sequence' in db_metadata.tables: - db_metadata.remove(db_metadata.tables['sqlite_sequence']) - - return SchemaDiff(metadata, db_metadata, - labelA='model', - labelB='database', - excludeTables=excludeTables) - - -def getDiffOfModelAgainstModel(metadataA, metadataB, excludeTables=None): - """ - Return differences of model against another model. - - :return: object which will evaluate to :keyword:`True` if there \ - are differences else :keyword:`False`. - """ - return SchemaDiff(metadataA, metadataB, excludeTables=excludeTables) - - -class ColDiff(object): - """ - Container for differences in one :class:`~sqlalchemy.schema.Column` - between two :class:`~sqlalchemy.schema.Table` instances, ``A`` - and ``B``. - - .. attribute:: col_A - - The :class:`~sqlalchemy.schema.Column` object for A. - - .. attribute:: col_B - - The :class:`~sqlalchemy.schema.Column` object for B. - - .. attribute:: type_A - - The most generic type of the :class:`~sqlalchemy.schema.Column` - object in A. - - .. attribute:: type_B - - The most generic type of the :class:`~sqlalchemy.schema.Column` - object in A. - - """ - - diff = False - - def __init__(self,col_A,col_B): - self.col_A = col_A - self.col_B = col_B - - self.type_A = col_A.type - self.type_B = col_B.type - - self.affinity_A = self.type_A._type_affinity - self.affinity_B = self.type_B._type_affinity - - if self.affinity_A is not self.affinity_B: - self.diff = True - return - - if isinstance(self.type_A,Float) or isinstance(self.type_B,Float): - if not (isinstance(self.type_A,Float) and isinstance(self.type_B,Float)): - self.diff=True - return - - for attr in ('precision','scale','length'): - A = getattr(self.type_A,attr,None) - B = getattr(self.type_B,attr,None) - if not (A is None or B is None) and A!=B: - self.diff=True - return - - def __nonzero__(self): - return self.diff - - __bool__ = __nonzero__ - - -class TableDiff(object): - """ - Container for differences in one :class:`~sqlalchemy.schema.Table` - between two :class:`~sqlalchemy.schema.MetaData` instances, ``A`` - and ``B``. - - .. attribute:: columns_missing_from_A - - A sequence of column names that were found in B but weren't in - A. - - .. attribute:: columns_missing_from_B - - A sequence of column names that were found in A but weren't in - B. - - .. attribute:: columns_different - - A dictionary containing information about columns that were - found to be different. - It maps column names to a :class:`ColDiff` objects describing the - differences found. - """ - __slots__ = ( - 'columns_missing_from_A', - 'columns_missing_from_B', - 'columns_different', - ) - - def __nonzero__(self): - return bool( - self.columns_missing_from_A or - self.columns_missing_from_B or - self.columns_different - ) - - __bool__ = __nonzero__ - -class SchemaDiff(object): - """ - Compute the difference between two :class:`~sqlalchemy.schema.MetaData` - objects. - - The string representation of a :class:`SchemaDiff` will summarise - the changes found between the two - :class:`~sqlalchemy.schema.MetaData` objects. - - The length of a :class:`SchemaDiff` will give the number of - changes found, enabling it to be used much like a boolean in - expressions. - - :param metadataA: - First :class:`~sqlalchemy.schema.MetaData` to compare. - - :param metadataB: - Second :class:`~sqlalchemy.schema.MetaData` to compare. - - :param labelA: - The label to use in messages about the first - :class:`~sqlalchemy.schema.MetaData`. - - :param labelB: - The label to use in messages about the second - :class:`~sqlalchemy.schema.MetaData`. - - :param excludeTables: - A sequence of table names to exclude. - - .. attribute:: tables_missing_from_A - - A sequence of table names that were found in B but weren't in - A. - - .. attribute:: tables_missing_from_B - - A sequence of table names that were found in A but weren't in - B. - - .. attribute:: tables_different - - A dictionary containing information about tables that were found - to be different. - It maps table names to a :class:`TableDiff` objects describing the - differences found. - """ - - def __init__(self, - metadataA, metadataB, - labelA='metadataA', - labelB='metadataB', - excludeTables=None): - - self.metadataA, self.metadataB = metadataA, metadataB - self.labelA, self.labelB = labelA, labelB - self.label_width = max(len(labelA),len(labelB)) - excludeTables = set(excludeTables or []) - - A_table_names = set(metadataA.tables.keys()) - B_table_names = set(metadataB.tables.keys()) - - self.tables_missing_from_A = sorted( - B_table_names - A_table_names - excludeTables - ) - self.tables_missing_from_B = sorted( - A_table_names - B_table_names - excludeTables - ) - - self.tables_different = {} - for table_name in A_table_names.intersection(B_table_names): - - td = TableDiff() - - A_table = metadataA.tables[table_name] - B_table = metadataB.tables[table_name] - - A_column_names = set(A_table.columns.keys()) - B_column_names = set(B_table.columns.keys()) - - td.columns_missing_from_A = sorted( - B_column_names - A_column_names - ) - - td.columns_missing_from_B = sorted( - A_column_names - B_column_names - ) - - td.columns_different = {} - - for col_name in A_column_names.intersection(B_column_names): - - cd = ColDiff( - A_table.columns.get(col_name), - B_table.columns.get(col_name) - ) - - if cd: - td.columns_different[col_name]=cd - - # XXX - index and constraint differences should - # be checked for here - - if td: - self.tables_different[table_name]=td - - def __str__(self): - ''' Summarize differences. ''' - out = [] - column_template =' %%%is: %%r' % self.label_width - - for names,label in ( - (self.tables_missing_from_A,self.labelA), - (self.tables_missing_from_B,self.labelB), - ): - if names: - out.append( - ' tables missing from %s: %s' % ( - label,', '.join(sorted(names)) - ) - ) - - for name,td in sorted(self.tables_different.items()): - out.append( - ' table with differences: %s' % name - ) - for names,label in ( - (td.columns_missing_from_A,self.labelA), - (td.columns_missing_from_B,self.labelB), - ): - if names: - out.append( - ' %s missing these columns: %s' % ( - label,', '.join(sorted(names)) - ) - ) - for name,cd in td.columns_different.items(): - out.append(' column with differences: %s' % name) - out.append(column_template % (self.labelA,cd.col_A)) - out.append(column_template % (self.labelB,cd.col_B)) - - if out: - out.insert(0, 'Schema diffs:') - return '\n'.join(out) - else: - return 'No schema diffs' - - def __len__(self): - """ - Used in bool evaluation, return of 0 means no diffs. - """ - return ( - len(self.tables_missing_from_A) + - len(self.tables_missing_from_B) + - len(self.tables_different) - ) diff --git a/migrate/versioning/script/__init__.py b/migrate/versioning/script/__init__.py deleted file mode 100644 index c788eda..0000000 --- a/migrate/versioning/script/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -from migrate.versioning.script.base import BaseScript -from migrate.versioning.script.py import PythonScript -from migrate.versioning.script.sql import SqlScript diff --git a/migrate/versioning/script/base.py b/migrate/versioning/script/base.py deleted file mode 100644 index 22ca7b4..0000000 --- a/migrate/versioning/script/base.py +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -import logging - -from migrate import exceptions -from migrate.versioning.config import operations -from migrate.versioning import pathed - - -log = logging.getLogger(__name__) - -class BaseScript(pathed.Pathed): - """Base class for other types of scripts. - All scripts have the following properties: - - source (script.source()) - The source code of the script - version (script.version()) - The version number of the script - operations (script.operations()) - The operations defined by the script: upgrade(), downgrade() or both. - Returns a tuple of operations. - Can also check for an operation with ex. script.operation(Script.ops.up) - """ # TODO: sphinxfy this and implement it correctly - - def __init__(self, path): - log.debug('Loading script %s...' % path) - self.verify(path) - super(BaseScript, self).__init__(path) - log.debug('Script %s loaded successfully' % path) - - @classmethod - def verify(cls, path): - """Ensure this is a valid script - This version simply ensures the script file's existence - - :raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>` - """ - try: - cls.require_found(path) - except: - raise exceptions.InvalidScriptError(path) - - def source(self): - """:returns: source code of the script. - :rtype: string - """ - fd = open(self.path) - ret = fd.read() - fd.close() - return ret - - def run(self, engine): - """Core of each BaseScript subclass. - This method executes the script. - """ - raise NotImplementedError() diff --git a/migrate/versioning/script/py.py b/migrate/versioning/script/py.py deleted file mode 100644 index 92a8f6b..0000000 --- a/migrate/versioning/script/py.py +++ /dev/null @@ -1,163 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import shutil -import warnings -import logging -import inspect - -import migrate -from migrate.versioning import genmodel, schemadiff -from migrate.versioning.config import operations -from migrate.versioning.template import Template -from migrate.versioning.script import base -from migrate.versioning.util import import_path, load_model, with_engine -from migrate.exceptions import MigrateDeprecationWarning, InvalidScriptError, ScriptError -import six -from six.moves import StringIO - -log = logging.getLogger(__name__) -__all__ = ['PythonScript'] - - -class PythonScript(base.BaseScript): - """Base for Python scripts""" - - @classmethod - def create(cls, path, **opts): - """Create an empty migration script at specified path - - :returns: :class:`PythonScript instance <migrate.versioning.script.py.PythonScript>`""" - cls.require_notfound(path) - - src = Template(opts.pop('templates_path', None)).get_script(theme=opts.pop('templates_theme', None)) - shutil.copy(src, path) - - return cls(path) - - @classmethod - def make_update_script_for_model(cls, engine, oldmodel, - model, repository, **opts): - """Create a migration script based on difference between two SA models. - - :param repository: path to migrate repository - :param oldmodel: dotted.module.name:SAClass or SAClass object - :param model: dotted.module.name:SAClass or SAClass object - :param engine: SQLAlchemy engine - :type repository: string or :class:`Repository instance <migrate.versioning.repository.Repository>` - :type oldmodel: string or Class - :type model: string or Class - :type engine: Engine instance - :returns: Upgrade / Downgrade script - :rtype: string - """ - - if isinstance(repository, six.string_types): - # oh dear, an import cycle! - from migrate.versioning.repository import Repository - repository = Repository(repository) - - oldmodel = load_model(oldmodel) - model = load_model(model) - - # Compute differences. - diff = schemadiff.getDiffOfModelAgainstModel( - model, - oldmodel, - excludeTables=[repository.version_table]) - # TODO: diff can be False (there is no difference?) - decls, upgradeCommands, downgradeCommands = \ - genmodel.ModelGenerator(diff,engine).genB2AMigration() - - # Store differences into file. - src = Template(opts.pop('templates_path', None)).get_script(opts.pop('templates_theme', None)) - f = open(src) - contents = f.read() - f.close() - - # generate source - search = 'def upgrade(migrate_engine):' - contents = contents.replace(search, '\n\n'.join((decls, search)), 1) - if upgradeCommands: - contents = contents.replace(' pass', upgradeCommands, 1) - if downgradeCommands: - contents = contents.replace(' pass', downgradeCommands, 1) - return contents - - @classmethod - def verify_module(cls, path): - """Ensure path is a valid script - - :param path: Script location - :type path: string - :raises: :exc:`InvalidScriptError <migrate.exceptions.InvalidScriptError>` - :returns: Python module - """ - # Try to import and get the upgrade() func - module = import_path(path) - try: - assert callable(module.upgrade) - except Exception as e: - raise InvalidScriptError(path + ': %s' % str(e)) - return module - - def preview_sql(self, url, step, **args): - """Mocks SQLAlchemy Engine to store all executed calls in a string - and runs :meth:`PythonScript.run <migrate.versioning.script.py.PythonScript.run>` - - :returns: SQL file - """ - buf = StringIO() - args['engine_arg_strategy'] = 'mock' - args['engine_arg_executor'] = lambda s, p = '': buf.write(str(s) + p) - - @with_engine - def go(url, step, **kw): - engine = kw.pop('engine') - self.run(engine, step) - return buf.getvalue() - - return go(url, step, **args) - - def run(self, engine, step): - """Core method of Script file. - Exectues :func:`update` or :func:`downgrade` functions - - :param engine: SQLAlchemy Engine - :param step: Operation to run - :type engine: string - :type step: int - """ - if step in ('downgrade', 'upgrade'): - op = step - elif step > 0: - op = 'upgrade' - elif step < 0: - op = 'downgrade' - else: - raise ScriptError("%d is not a valid step" % step) - - funcname = base.operations[op] - script_func = self._func(funcname) - - # check for old way of using engine - if not inspect.getargspec(script_func)[0]: - raise TypeError("upgrade/downgrade functions must accept engine" - " parameter (since version 0.5.4)") - - script_func(engine) - - @property - def module(self): - """Calls :meth:`migrate.versioning.script.py.verify_module` - and returns it. - """ - if not hasattr(self, '_module'): - self._module = self.verify_module(self.path) - return self._module - - def _func(self, funcname): - if not hasattr(self.module, funcname): - msg = "Function '%s' is not defined in this script" - raise ScriptError(msg % funcname) - return getattr(self.module, funcname) diff --git a/migrate/versioning/script/sql.py b/migrate/versioning/script/sql.py deleted file mode 100644 index 862bc9f..0000000 --- a/migrate/versioning/script/sql.py +++ /dev/null @@ -1,70 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -import logging -import re -import shutil - -import sqlparse - -from migrate.versioning.script import base -from migrate.versioning.template import Template - - -log = logging.getLogger(__name__) - -class SqlScript(base.BaseScript): - """A file containing plain SQL statements.""" - - @classmethod - def create(cls, path, **opts): - """Create an empty migration script at specified path - - :returns: :class:`SqlScript instance <migrate.versioning.script.sql.SqlScript>`""" - cls.require_notfound(path) - - src = Template(opts.pop('templates_path', None)).get_sql_script(theme=opts.pop('templates_theme', None)) - shutil.copy(src, path) - return cls(path) - - # TODO: why is step parameter even here? - def run(self, engine, step=None): - """Runs SQL script through raw dbapi execute call""" - text = self.source() - # Don't rely on SA's autocommit here - # (SA uses .startswith to check if a commit is needed. What if script - # starts with a comment?) - conn = engine.connect() - try: - trans = conn.begin() - try: - # ignore transaction management statements that are - # redundant in SQL script context and result in - # operational error being returned. - # - # Note: we don't ignore ROLLBACK in migration scripts - # since its usage would be insane anyway, and we're - # better to fail on its occurance instead of ignoring it - # (and committing transaction, which is contradictory to - # the whole idea of ROLLBACK) - ignored_statements = ('BEGIN', 'END', 'COMMIT') - ignored_regex = re.compile('^\s*(%s).*;?$' % '|'.join(ignored_statements), - re.IGNORECASE) - - # NOTE(ihrachys): script may contain multiple statements, and - # not all drivers reliably handle multistatement queries or - # commands passed to .execute(), so split them and execute one - # by one - text = sqlparse.format(text, strip_comments=True, strip_whitespace=True) - for statement in sqlparse.split(text): - if statement: - if re.match(ignored_regex, statement): - log.warning('"%s" found in SQL script; ignoring' % statement) - else: - conn.execute(statement) - trans.commit() - except Exception as e: - log.error("SQL script %s failed: %s", self.path, e) - trans.rollback() - raise - finally: - conn.close() diff --git a/migrate/versioning/shell.py b/migrate/versioning/shell.py deleted file mode 100644 index 5fb86b1..0000000 --- a/migrate/versioning/shell.py +++ /dev/null @@ -1,216 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -"""The migrate command-line tool.""" - -import sys -import inspect -import logging -from optparse import OptionParser, BadOptionError - -from migrate import exceptions -from migrate.versioning import api -from migrate.versioning.config import * -from migrate.versioning.util import asbool -import six - - -alias = dict( - s=api.script, - vc=api.version_control, - dbv=api.db_version, - v=api.version, -) - -def alias_setup(): - global alias - for key, val in six.iteritems(alias): - setattr(api, key, val) -alias_setup() - - -class PassiveOptionParser(OptionParser): - - def _process_args(self, largs, rargs, values): - """little hack to support all --some_option=value parameters""" - - while rargs: - arg = rargs[0] - if arg == "--": - del rargs[0] - return - elif arg[0:2] == "--": - # if parser does not know about the option - # pass it along (make it anonymous) - try: - opt = arg.split('=', 1)[0] - self._match_long_opt(opt) - except BadOptionError: - largs.append(arg) - del rargs[0] - else: - self._process_long_opt(rargs, values) - elif arg[:1] == "-" and len(arg) > 1: - self._process_short_opts(rargs, values) - elif self.allow_interspersed_args: - largs.append(arg) - del rargs[0] - -def main(argv=None, **kwargs): - """Shell interface to :mod:`migrate.versioning.api`. - - kwargs are default options that can be overriden with passing - --some_option as command line option - - :param disable_logging: Let migrate configure logging - :type disable_logging: bool - """ - if argv is not None: - argv = argv - else: - argv = list(sys.argv[1:]) - commands = list(api.__all__) - commands.sort() - - usage = """%%prog COMMAND ... - - Available commands: - %s - - Enter "%%prog help COMMAND" for information on a particular command. - """ % '\n\t'.join(["%s - %s" % (command.ljust(28), api.command_desc.get(command)) for command in commands]) - - parser = PassiveOptionParser(usage=usage) - parser.add_option("-d", "--debug", - action="store_true", - dest="debug", - default=False, - help="Shortcut to turn on DEBUG mode for logging") - parser.add_option("-q", "--disable_logging", - action="store_true", - dest="disable_logging", - default=False, - help="Use this option to disable logging configuration") - help_commands = ['help', '-h', '--help'] - HELP = False - - try: - command = argv.pop(0) - if command in help_commands: - HELP = True - command = argv.pop(0) - except IndexError: - parser.print_help() - return - - command_func = getattr(api, command, None) - if command_func is None or command.startswith('_'): - parser.error("Invalid command %s" % command) - - parser.set_usage(inspect.getdoc(command_func)) - f_args, f_varargs, f_kwargs, f_defaults = inspect.getargspec(command_func) - for arg in f_args: - parser.add_option( - "--%s" % arg, - dest=arg, - action='store', - type="string") - - # display help of the current command - if HELP: - parser.print_help() - return - - options, args = parser.parse_args(argv) - - # override kwargs with anonymous parameters - override_kwargs = dict() - for arg in list(args): - if arg.startswith('--'): - args.remove(arg) - if '=' in arg: - opt, value = arg[2:].split('=', 1) - else: - opt = arg[2:] - value = True - override_kwargs[opt] = value - - # override kwargs with options if user is overwriting - for key, value in six.iteritems(options.__dict__): - if value is not None: - override_kwargs[key] = value - - # arguments that function accepts without passed kwargs - f_required = list(f_args) - candidates = dict(kwargs) - candidates.update(override_kwargs) - for key, value in six.iteritems(candidates): - if key in f_args: - f_required.remove(key) - - # map function arguments to parsed arguments - for arg in args: - try: - kw = f_required.pop(0) - except IndexError: - parser.error("Too many arguments for command %s: %s" % (command, - arg)) - kwargs[kw] = arg - - # apply overrides - kwargs.update(override_kwargs) - - # configure options - for key, value in six.iteritems(options.__dict__): - kwargs.setdefault(key, value) - - # configure logging - if not asbool(kwargs.pop('disable_logging', False)): - # filter to log =< INFO into stdout and rest to stderr - class SingleLevelFilter(logging.Filter): - def __init__(self, min=None, max=None): - self.min = min or 0 - self.max = max or 100 - - def filter(self, record): - return self.min <= record.levelno <= self.max - - logger = logging.getLogger() - h1 = logging.StreamHandler(sys.stdout) - f1 = SingleLevelFilter(max=logging.INFO) - h1.addFilter(f1) - h2 = logging.StreamHandler(sys.stderr) - f2 = SingleLevelFilter(min=logging.WARN) - h2.addFilter(f2) - logger.addHandler(h1) - logger.addHandler(h2) - - if options.debug: - logger.setLevel(logging.DEBUG) - else: - logger.setLevel(logging.INFO) - - log = logging.getLogger(__name__) - - # check if all args are given - try: - num_defaults = len(f_defaults) - except TypeError: - num_defaults = 0 - f_args_default = f_args[len(f_args) - num_defaults:] - required = list(set(f_required) - set(f_args_default)) - required.sort() - if required: - parser.error("Not enough arguments for command %s: %s not specified" \ - % (command, ', '.join(required))) - - # handle command - try: - ret = command_func(**kwargs) - if ret is not None: - log.info(ret) - except (exceptions.UsageError, exceptions.KnownError) as e: - parser.error(e.args[0]) - -if __name__ == "__main__": - main() diff --git a/migrate/versioning/template.py b/migrate/versioning/template.py deleted file mode 100644 index 8182e6b..0000000 --- a/migrate/versioning/template.py +++ /dev/null @@ -1,93 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import os -import shutil -import sys - -from pkg_resources import resource_filename - -from migrate.versioning.config import * -from migrate.versioning import pathed - - -class Collection(pathed.Pathed): - """A collection of templates of a specific type""" - _mask = None - - def get_path(self, file): - return os.path.join(self.path, str(file)) - - -class RepositoryCollection(Collection): - _mask = '%s' - -class ScriptCollection(Collection): - _mask = '%s.py_tmpl' - -class ManageCollection(Collection): - _mask = '%s.py_tmpl' - -class SQLScriptCollection(Collection): - _mask = '%s.py_tmpl' - -class Template(pathed.Pathed): - """Finds the paths/packages of various Migrate templates. - - :param path: Templates are loaded from migrate package - if `path` is not provided. - """ - pkg = 'migrate.versioning.templates' - - def __new__(cls, path=None): - if path is None: - path = cls._find_path(cls.pkg) - return super(Template, cls).__new__(cls, path) - - def __init__(self, path=None): - if path is None: - path = Template._find_path(self.pkg) - super(Template, self).__init__(path) - self.repository = RepositoryCollection(os.path.join(path, 'repository')) - self.script = ScriptCollection(os.path.join(path, 'script')) - self.manage = ManageCollection(os.path.join(path, 'manage')) - self.sql_script = SQLScriptCollection(os.path.join(path, 'sql_script')) - - @classmethod - def _find_path(cls, pkg): - """Returns absolute path to dotted python package.""" - tmp_pkg = pkg.rsplit('.', 1) - - if len(tmp_pkg) != 1: - return resource_filename(tmp_pkg[0], tmp_pkg[1]) - else: - return resource_filename(tmp_pkg[0], '') - - def _get_item(self, collection, theme=None): - """Locates and returns collection. - - :param collection: name of collection to locate - :param type_: type of subfolder in collection (defaults to "_default") - :returns: (package, source) - :rtype: str, str - """ - item = getattr(self, collection) - theme_mask = getattr(item, '_mask') - theme = theme_mask % (theme or 'default') - return item.get_path(theme) - - def get_repository(self, *a, **kw): - """Calls self._get_item('repository', *a, **kw)""" - return self._get_item('repository', *a, **kw) - - def get_script(self, *a, **kw): - """Calls self._get_item('script', *a, **kw)""" - return self._get_item('script', *a, **kw) - - def get_sql_script(self, *a, **kw): - """Calls self._get_item('sql_script', *a, **kw)""" - return self._get_item('sql_script', *a, **kw) - - def get_manage(self, *a, **kw): - """Calls self._get_item('manage', *a, **kw)""" - return self._get_item('manage', *a, **kw) diff --git a/migrate/versioning/templates/__init__.py b/migrate/versioning/templates/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/migrate/versioning/templates/__init__.py +++ /dev/null diff --git a/migrate/versioning/templates/manage/default.py_tmpl b/migrate/versioning/templates/manage/default.py_tmpl deleted file mode 100644 index 971c70f..0000000 --- a/migrate/versioning/templates/manage/default.py_tmpl +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env python -from migrate.versioning.shell import main - -{{py: -import six -_vars = locals().copy() -del _vars['__template_name__'] -del _vars['six'] -_vars.pop('repository_name', None) -defaults = ", ".join(["%s='%s'" % var for var in six.iteritems(_vars)]) -}} - -if __name__ == '__main__': - main({{ defaults }}) diff --git a/migrate/versioning/templates/manage/pylons.py_tmpl b/migrate/versioning/templates/manage/pylons.py_tmpl deleted file mode 100644 index 0d6c32c..0000000 --- a/migrate/versioning/templates/manage/pylons.py_tmpl +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/python -# -*- coding: utf-8 -*- -import sys - -from sqlalchemy import engine_from_config -from paste.deploy.loadwsgi import ConfigLoader - -from migrate.versioning.shell import main -from {{ locals().pop('repository_name') }}.model import migrations - - -if '-c' in sys.argv: - pos = sys.argv.index('-c') - conf_path = sys.argv[pos + 1] - del sys.argv[pos:pos + 2] -else: - conf_path = 'development.ini' - -{{py: -import six -_vars = locals().copy() -del _vars['__template_name__'] -del _vars['six'] -defaults = ", ".join(["%s='%s'" % var for var in six.iteritems(_vars)]) -}} - -conf_dict = ConfigLoader(conf_path).parser._sections['app:main'] - -# migrate supports passing url as an existing Engine instance (since 0.6.0) -# usage: migrate -c path/to/config.ini COMMANDS -if __name__ == '__main__': - main(url=engine_from_config(conf_dict), repository=migrations.__path__[0],{{ defaults }}) diff --git a/migrate/versioning/templates/repository/__init__.py b/migrate/versioning/templates/repository/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/migrate/versioning/templates/repository/__init__.py +++ /dev/null diff --git a/migrate/versioning/templates/repository/default/README b/migrate/versioning/templates/repository/default/README deleted file mode 100644 index 6218f8c..0000000 --- a/migrate/versioning/templates/repository/default/README +++ /dev/null @@ -1,4 +0,0 @@ -This is a database migration repository. - -More information at -http://code.google.com/p/sqlalchemy-migrate/ diff --git a/migrate/versioning/templates/repository/default/__init__.py b/migrate/versioning/templates/repository/default/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/migrate/versioning/templates/repository/default/__init__.py +++ /dev/null diff --git a/migrate/versioning/templates/repository/default/migrate.cfg b/migrate/versioning/templates/repository/default/migrate.cfg deleted file mode 100644 index bcc33a7..0000000 --- a/migrate/versioning/templates/repository/default/migrate.cfg +++ /dev/null @@ -1,25 +0,0 @@ -[db_settings] -# Used to identify which repository this database is versioned under. -# You can use the name of your project. -repository_id={{ locals().pop('repository_id') }} - -# The name of the database table used to track the schema version. -# This name shouldn't already be used by your project. -# If this is changed once a database is under version control, you'll need to -# change the table name in each database too. -version_table={{ locals().pop('version_table') }} - -# When committing a change script, Migrate will attempt to generate the -# sql for all supported databases; normally, if one of them fails - probably -# because you don't have that database installed - it is ignored and the -# commit continues, perhaps ending successfully. -# Databases in this list MUST compile successfully during a commit, or the -# entire commit will fail. List the databases your application will actually -# be using to ensure your updates to that database work properly. -# This must be a list; example: ['postgres','sqlite'] -required_dbs={{ locals().pop('required_dbs') }} - -# When creating new change scripts, Migrate will stamp the new script with -# a version number. By default this is latest_version + 1. You can set this -# to 'true' to tell Migrate to use the UTC timestamp instead. -use_timestamp_numbering={{ locals().pop('use_timestamp_numbering') }} diff --git a/migrate/versioning/templates/repository/default/versions/__init__.py b/migrate/versioning/templates/repository/default/versions/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/migrate/versioning/templates/repository/default/versions/__init__.py +++ /dev/null diff --git a/migrate/versioning/templates/repository/pylons/README b/migrate/versioning/templates/repository/pylons/README deleted file mode 100644 index 6218f8c..0000000 --- a/migrate/versioning/templates/repository/pylons/README +++ /dev/null @@ -1,4 +0,0 @@ -This is a database migration repository. - -More information at -http://code.google.com/p/sqlalchemy-migrate/ diff --git a/migrate/versioning/templates/repository/pylons/__init__.py b/migrate/versioning/templates/repository/pylons/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/migrate/versioning/templates/repository/pylons/__init__.py +++ /dev/null diff --git a/migrate/versioning/templates/repository/pylons/migrate.cfg b/migrate/versioning/templates/repository/pylons/migrate.cfg deleted file mode 100644 index bcc33a7..0000000 --- a/migrate/versioning/templates/repository/pylons/migrate.cfg +++ /dev/null @@ -1,25 +0,0 @@ -[db_settings] -# Used to identify which repository this database is versioned under. -# You can use the name of your project. -repository_id={{ locals().pop('repository_id') }} - -# The name of the database table used to track the schema version. -# This name shouldn't already be used by your project. -# If this is changed once a database is under version control, you'll need to -# change the table name in each database too. -version_table={{ locals().pop('version_table') }} - -# When committing a change script, Migrate will attempt to generate the -# sql for all supported databases; normally, if one of them fails - probably -# because you don't have that database installed - it is ignored and the -# commit continues, perhaps ending successfully. -# Databases in this list MUST compile successfully during a commit, or the -# entire commit will fail. List the databases your application will actually -# be using to ensure your updates to that database work properly. -# This must be a list; example: ['postgres','sqlite'] -required_dbs={{ locals().pop('required_dbs') }} - -# When creating new change scripts, Migrate will stamp the new script with -# a version number. By default this is latest_version + 1. You can set this -# to 'true' to tell Migrate to use the UTC timestamp instead. -use_timestamp_numbering={{ locals().pop('use_timestamp_numbering') }} diff --git a/migrate/versioning/templates/repository/pylons/versions/__init__.py b/migrate/versioning/templates/repository/pylons/versions/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/migrate/versioning/templates/repository/pylons/versions/__init__.py +++ /dev/null diff --git a/migrate/versioning/templates/script/__init__.py b/migrate/versioning/templates/script/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/migrate/versioning/templates/script/__init__.py +++ /dev/null diff --git a/migrate/versioning/templates/script/default.py_tmpl b/migrate/versioning/templates/script/default.py_tmpl deleted file mode 100644 index 58d874b..0000000 --- a/migrate/versioning/templates/script/default.py_tmpl +++ /dev/null @@ -1,13 +0,0 @@ -from sqlalchemy import * -from migrate import * - - -def upgrade(migrate_engine): - # Upgrade operations go here. Don't create your own engine; bind - # migrate_engine to your metadata - pass - - -def downgrade(migrate_engine): - # Operations to reverse the above upgrade go here. - pass diff --git a/migrate/versioning/templates/script/pylons.py_tmpl b/migrate/versioning/templates/script/pylons.py_tmpl deleted file mode 100644 index 58d874b..0000000 --- a/migrate/versioning/templates/script/pylons.py_tmpl +++ /dev/null @@ -1,13 +0,0 @@ -from sqlalchemy import * -from migrate import * - - -def upgrade(migrate_engine): - # Upgrade operations go here. Don't create your own engine; bind - # migrate_engine to your metadata - pass - - -def downgrade(migrate_engine): - # Operations to reverse the above upgrade go here. - pass diff --git a/migrate/versioning/templates/sql_script/default.py_tmpl b/migrate/versioning/templates/sql_script/default.py_tmpl deleted file mode 100644 index e69de29..0000000 --- a/migrate/versioning/templates/sql_script/default.py_tmpl +++ /dev/null diff --git a/migrate/versioning/templates/sql_script/pylons.py_tmpl b/migrate/versioning/templates/sql_script/pylons.py_tmpl deleted file mode 100644 index e69de29..0000000 --- a/migrate/versioning/templates/sql_script/pylons.py_tmpl +++ /dev/null diff --git a/migrate/versioning/util/__init__.py b/migrate/versioning/util/__init__.py deleted file mode 100644 index 55c72c9..0000000 --- a/migrate/versioning/util/__init__.py +++ /dev/null @@ -1,187 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""".. currentmodule:: migrate.versioning.util""" - -import warnings -import logging -from decorator import decorator -from pkg_resources import EntryPoint - -import six -from sqlalchemy import create_engine -from sqlalchemy.engine import Engine -from sqlalchemy.pool import StaticPool - -from migrate import exceptions -from migrate.versioning.util.keyedinstance import KeyedInstance -from migrate.versioning.util.importpath import import_path - - -log = logging.getLogger(__name__) - -def load_model(dotted_name): - """Import module and use module-level variable". - - :param dotted_name: path to model in form of string: ``some.python.module:Class`` - - .. versionchanged:: 0.5.4 - - """ - if isinstance(dotted_name, six.string_types): - if ':' not in dotted_name: - # backwards compatibility - warnings.warn('model should be in form of module.model:User ' - 'and not module.model.User', exceptions.MigrateDeprecationWarning) - dotted_name = ':'.join(dotted_name.rsplit('.', 1)) - - ep = EntryPoint.parse('x=%s' % dotted_name) - if hasattr(ep, 'resolve'): - # this is available on setuptools >= 10.2 - return ep.resolve() - else: - # this causes a DeprecationWarning on setuptools >= 11.3 - return ep.load(False) - else: - # Assume it's already loaded. - return dotted_name - -def asbool(obj): - """Do everything to use object as bool""" - if isinstance(obj, six.string_types): - obj = obj.strip().lower() - if obj in ['true', 'yes', 'on', 'y', 't', '1']: - return True - elif obj in ['false', 'no', 'off', 'n', 'f', '0']: - return False - else: - raise ValueError("String is not true/false: %r" % obj) - if obj in (True, False): - return bool(obj) - else: - raise ValueError("String is not true/false: %r" % obj) - -def guess_obj_type(obj): - """Do everything to guess object type from string - - Tries to convert to `int`, `bool` and finally returns if not succeded. - - .. versionadded: 0.5.4 - """ - - result = None - - try: - result = int(obj) - except: - pass - - if result is None: - try: - result = asbool(obj) - except: - pass - - if result is not None: - return result - else: - return obj - -@decorator -def catch_known_errors(f, *a, **kw): - """Decorator that catches known api errors - - .. versionadded: 0.5.4 - """ - - try: - return f(*a, **kw) - except exceptions.PathFoundError as e: - raise exceptions.KnownError("The path %s already exists" % e.args[0]) - -def construct_engine(engine, **opts): - """.. versionadded:: 0.5.4 - - Constructs and returns SQLAlchemy engine. - - Currently, there are 2 ways to pass create_engine options to :mod:`migrate.versioning.api` functions: - - :param engine: connection string or a existing engine - :param engine_dict: python dictionary of options to pass to `create_engine` - :param engine_arg_*: keyword parameters to pass to `create_engine` (evaluated with :func:`migrate.versioning.util.guess_obj_type`) - :type engine_dict: dict - :type engine: string or Engine instance - :type engine_arg_*: string - :returns: SQLAlchemy Engine - - .. note:: - - keyword parameters override ``engine_dict`` values. - - """ - if isinstance(engine, Engine): - return engine - elif not isinstance(engine, six.string_types): - raise ValueError("you need to pass either an existing engine or a database uri") - - # get options for create_engine - if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict): - kwargs = opts['engine_dict'] - else: - kwargs = dict() - - # DEPRECATED: handle echo the old way - echo = asbool(opts.get('echo', False)) - if echo: - warnings.warn('echo=True parameter is deprecated, pass ' - 'engine_arg_echo=True or engine_dict={"echo": True}', - exceptions.MigrateDeprecationWarning) - kwargs['echo'] = echo - - # parse keyword arguments - for key, value in six.iteritems(opts): - if key.startswith('engine_arg_'): - kwargs[key[11:]] = guess_obj_type(value) - - log.debug('Constructing engine') - # TODO: return create_engine(engine, poolclass=StaticPool, **kwargs) - # seems like 0.5.x branch does not work with engine.dispose and staticpool - return create_engine(engine, **kwargs) - -@decorator -def with_engine(f, *a, **kw): - """Decorator for :mod:`migrate.versioning.api` functions - to safely close resources after function usage. - - Passes engine parameters to :func:`construct_engine` and - resulting parameter is available as kw['engine']. - - Engine is disposed after wrapped function is executed. - - .. versionadded: 0.6.0 - """ - url = a[0] - engine = construct_engine(url, **kw) - - try: - kw['engine'] = engine - return f(*a, **kw) - finally: - if isinstance(engine, Engine) and engine is not url: - log.debug('Disposing SQLAlchemy engine %s', engine) - engine.dispose() - - -class Memoize(object): - """Memoize(fn) - an instance which acts like fn but memoizes its arguments - Will only work on functions with non-mutable arguments - - ActiveState Code 52201 - """ - def __init__(self, fn): - self.fn = fn - self.memo = {} - - def __call__(self, *args): - if args not in self.memo: - self.memo[args] = self.fn(*args) - return self.memo[args] diff --git a/migrate/versioning/util/importpath.py b/migrate/versioning/util/importpath.py deleted file mode 100644 index 529be89..0000000 --- a/migrate/versioning/util/importpath.py +++ /dev/null @@ -1,30 +0,0 @@ -import os -import sys - -PY33 = sys.version_info >= (3, 3) - -if PY33: - from importlib import machinery -else: - from six.moves import reload_module as reload - - -def import_path(fullpath): - """ Import a file with full path specification. Allows one to - import from anywhere, something __import__ does not do. - """ - if PY33: - name = os.path.splitext(os.path.basename(fullpath))[0] - return machinery.SourceFileLoader( - name, fullpath).load_module(name) - else: - # http://zephyrfalcon.org/weblog/arch_d7_2002_08_31.html - path, filename = os.path.split(fullpath) - filename, ext = os.path.splitext(filename) - sys.path.append(path) - try: - module = __import__(filename) - reload(module) # Might be out of date during tests - return module - finally: - del sys.path[-1] diff --git a/migrate/versioning/util/keyedinstance.py b/migrate/versioning/util/keyedinstance.py deleted file mode 100644 index a692e08..0000000 --- a/migrate/versioning/util/keyedinstance.py +++ /dev/null @@ -1,36 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -class KeyedInstance(object): - """A class whose instances have a unique identifier of some sort - No two instances with the same unique ID should exist - if we try to create - a second instance, the first should be returned. - """ - - _instances = dict() - - def __new__(cls, *p, **k): - instances = cls._instances - clskey = str(cls) - if clskey not in instances: - instances[clskey] = dict() - instances = instances[clskey] - - key = cls._key(*p, **k) - if key not in instances: - instances[key] = super(KeyedInstance, cls).__new__(cls) - return instances[key] - - @classmethod - def _key(cls, *p, **k): - """Given a unique identifier, return a dictionary key - This should be overridden by child classes, to specify which parameters - should determine an object's uniqueness - """ - raise NotImplementedError() - - @classmethod - def clear(cls): - # Allow cls.clear() as well as uniqueInstance.clear(cls) - if str(cls) in cls._instances: - del cls._instances[str(cls)] diff --git a/migrate/versioning/version.py b/migrate/versioning/version.py deleted file mode 100644 index 0633e1b..0000000 --- a/migrate/versioning/version.py +++ /dev/null @@ -1,282 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import os -import re -import shutil -import logging - -from migrate import exceptions -from migrate.versioning import pathed, script -from datetime import datetime -import six - - -log = logging.getLogger(__name__) - -class VerNum(object): - """A version number that behaves like a string and int at the same time""" - - _instances = dict() - - def __new__(cls, value): - val = str(value) - if val not in cls._instances: - cls._instances[val] = super(VerNum, cls).__new__(cls) - ret = cls._instances[val] - return ret - - def __init__(self,value): - self.value = str(int(value)) - if self < 0: - raise ValueError("Version number cannot be negative") - - def __add__(self, value): - ret = int(self) + int(value) - return VerNum(ret) - - def __sub__(self, value): - return self + (int(value) * -1) - - def __eq__(self, value): - return int(self) == int(value) - - def __ne__(self, value): - return int(self) != int(value) - - def __lt__(self, value): - return int(self) < int(value) - - def __gt__(self, value): - return int(self) > int(value) - - def __ge__(self, value): - return int(self) >= int(value) - - def __le__(self, value): - return int(self) <= int(value) - - def __repr__(self): - return "<VerNum(%s)>" % self.value - - def __str__(self): - return str(self.value) - - def __int__(self): - return int(self.value) - - def __index__(self): - return int(self.value) - - if six.PY3: - def __hash__(self): - return hash(self.value) - - -class Collection(pathed.Pathed): - """A collection of versioning scripts in a repository""" - - FILENAME_WITH_VERSION = re.compile(r'^(\d{3,}).*') - - def __init__(self, path): - """Collect current version scripts in repository - and store them in self.versions - """ - super(Collection, self).__init__(path) - - # Create temporary list of files, allowing skipped version numbers. - files = os.listdir(path) - if '1' in files: - # deprecation - raise Exception('It looks like you have a repository in the old ' - 'format (with directories for each version). ' - 'Please convert repository before proceeding.') - - tempVersions = dict() - for filename in files: - match = self.FILENAME_WITH_VERSION.match(filename) - if match: - num = int(match.group(1)) - tempVersions.setdefault(num, []).append(filename) - else: - pass # Must be a helper file or something, let's ignore it. - - # Create the versions member where the keys - # are VerNum's and the values are Version's. - self.versions = dict() - for num, files in tempVersions.items(): - self.versions[VerNum(num)] = Version(num, path, files) - - @property - def latest(self): - """:returns: Latest version in Collection""" - return max([VerNum(0)] + list(self.versions.keys())) - - def _next_ver_num(self, use_timestamp_numbering): - if use_timestamp_numbering == True: - return VerNum(int(datetime.utcnow().strftime('%Y%m%d%H%M%S'))) - else: - return self.latest + 1 - - def create_new_python_version(self, description, **k): - """Create Python files for new version""" - ver = self._next_ver_num(k.pop('use_timestamp_numbering', False)) - extra = str_to_filename(description) - - if extra: - if extra == '_': - extra = '' - elif not extra.startswith('_'): - extra = '_%s' % extra - - filename = '%03d%s.py' % (ver, extra) - filepath = self._version_path(filename) - - script.PythonScript.create(filepath, **k) - self.versions[ver] = Version(ver, self.path, [filename]) - - def create_new_sql_version(self, database, description, **k): - """Create SQL files for new version""" - ver = self._next_ver_num(k.pop('use_timestamp_numbering', False)) - self.versions[ver] = Version(ver, self.path, []) - - extra = str_to_filename(description) - - if extra: - if extra == '_': - extra = '' - elif not extra.startswith('_'): - extra = '_%s' % extra - - # Create new files. - for op in ('upgrade', 'downgrade'): - filename = '%03d%s_%s_%s.sql' % (ver, extra, database, op) - filepath = self._version_path(filename) - script.SqlScript.create(filepath, **k) - self.versions[ver].add_script(filepath) - - def version(self, vernum=None): - """Returns required version. - - If vernum is not given latest version will be returned otherwise - required version will be returned. - :raises: : exceptions.VersionNotFoundError if respective migration - script file of version is not present in the migration repository. - """ - if vernum is None: - vernum = self.latest - - try: - return self.versions[VerNum(vernum)] - except KeyError: - raise exceptions.VersionNotFoundError( - ("Database schema file with version %(args)s doesn't " - "exist.") % {'args': VerNum(vernum)}) - - @classmethod - def clear(cls): - super(Collection, cls).clear() - - def _version_path(self, ver): - """Returns path of file in versions repository""" - return os.path.join(self.path, str(ver)) - - -class Version(object): - """A single version in a collection - :param vernum: Version Number - :param path: Path to script files - :param filelist: List of scripts - :type vernum: int, VerNum - :type path: string - :type filelist: list - """ - - def __init__(self, vernum, path, filelist): - self.version = VerNum(vernum) - - # Collect scripts in this folder - self.sql = dict() - self.python = None - - for script in filelist: - self.add_script(os.path.join(path, script)) - - def script(self, database=None, operation=None): - """Returns SQL or Python Script""" - for db in (database, 'default'): - # Try to return a .sql script first - try: - return self.sql[db][operation] - except KeyError: - continue # No .sql script exists - - # TODO: maybe add force Python parameter? - ret = self.python - - assert ret is not None, \ - "There is no script for %d version" % self.version - return ret - - def add_script(self, path): - """Add script to Collection/Version""" - if path.endswith(Extensions.py): - self._add_script_py(path) - elif path.endswith(Extensions.sql): - self._add_script_sql(path) - - SQL_FILENAME = re.compile(r'^.*\.sql') - - def _add_script_sql(self, path): - basename = os.path.basename(path) - match = self.SQL_FILENAME.match(basename) - - if match: - basename = basename.replace('.sql', '') - parts = basename.split('_') - if len(parts) < 3: - raise exceptions.ScriptError( - "Invalid SQL script name %s " % basename + \ - "(needs to be ###_description_database_operation.sql)") - version = parts[0] - op = parts[-1] - # NOTE(mriedem): check for ibm_db_sa as the database in the name - if 'ibm_db_sa' in basename: - if len(parts) == 6: - dbms = '_'.join(parts[-4: -1]) - else: - raise exceptions.ScriptError( - "Invalid ibm_db_sa SQL script name '%s'; " - "(needs to be " - "###_description_ibm_db_sa_operation.sql)" % basename) - else: - dbms = parts[-2] - else: - raise exceptions.ScriptError( - "Invalid SQL script name %s " % basename + \ - "(needs to be ###_description_database_operation.sql)") - - # File the script into a dictionary - self.sql.setdefault(dbms, {})[op] = script.SqlScript(path) - - def _add_script_py(self, path): - if self.python is not None: - raise exceptions.ScriptError('You can only have one Python script ' - 'per version, but you have: %s and %s' % (self.python, path)) - self.python = script.PythonScript(path) - - -class Extensions(object): - """A namespace for file extensions""" - py = 'py' - sql = 'sql' - -def str_to_filename(s): - """Replaces spaces, (double and single) quotes - and double underscores to underscores - """ - - s = s.replace(' ', '_').replace('"', '_').replace("'", '_').replace(".", "_") - while '__' in s: - s = s.replace('__', '_') - return s |