diff options
Diffstat (limited to 'alembic/ddl/postgresql.py')
-rw-r--r-- | alembic/ddl/postgresql.py | 73 |
1 files changed, 59 insertions, 14 deletions
diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py index 4ffc2eb..247838b 100644 --- a/alembic/ddl/postgresql.py +++ b/alembic/ddl/postgresql.py @@ -21,8 +21,10 @@ from sqlalchemy.dialects.postgresql import BIGINT from sqlalchemy.dialects.postgresql import ExcludeConstraint from sqlalchemy.dialects.postgresql import INTEGER from sqlalchemy.schema import CreateIndex +from sqlalchemy.sql import operators from sqlalchemy.sql.elements import ColumnClause from sqlalchemy.sql.elements import TextClause +from sqlalchemy.sql.elements import UnaryExpression from sqlalchemy.types import NULLTYPE from .base import alter_column @@ -53,6 +55,7 @@ if TYPE_CHECKING: from sqlalchemy.dialects.postgresql.json import JSON from sqlalchemy.dialects.postgresql.json import JSONB from sqlalchemy.sql.elements import BinaryExpression + from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.schema import Table @@ -248,11 +251,14 @@ class PostgresqlImpl(DefaultImpl): if not sqla_compat.sqla_2: self._skip_functional_indexes(metadata_indexes, conn_indexes) - def _cleanup_index_expr(self, index: Index, expr: str) -> str: + def _cleanup_index_expr( + self, index: Index, expr: str, remove_suffix: str + ) -> str: # start = expr expr = expr.lower() expr = expr.replace('"', "") if index.table is not None: + # should not be needed, since include_table=False is in compile expr = expr.replace(f"{index.table.name.lower()}.", "") while expr and expr[0] == "(" and expr[-1] == ")": @@ -261,25 +267,64 @@ class PostgresqlImpl(DefaultImpl): # strip :: cast. types can have spaces in them expr = re.sub(r"(::[\w ]+\w)", "", expr) + if remove_suffix and expr.endswith(remove_suffix): + expr = expr[: -len(remove_suffix)] + # print(f"START: {start} END: {expr}") return expr + def _default_modifiers(self, exp: ClauseElement) -> str: + to_remove = "" + while isinstance(exp, UnaryExpression): + if exp.modifier is None: + exp = exp.element + else: + op = exp.modifier + if isinstance(exp.element, UnaryExpression): + inner_op = exp.element.modifier + else: + inner_op = None + if inner_op is None: + if op == operators.asc_op: + # default is asc + to_remove = " asc" + elif op == operators.nullslast_op: + # default is nulls last + to_remove = " nulls last" + else: + if ( + inner_op == operators.asc_op + and op == operators.nullslast_op + ): + # default is asc nulls last + to_remove = " asc nulls last" + elif ( + inner_op == operators.desc_op + and op == operators.nullsfirst_op + ): + # default for desc is nulls first + to_remove = " nulls first" + break + return to_remove + def create_index_sig(self, index: Index) -> Tuple[Any, ...]: - if sqla_compat.is_expression_index(index): - return tuple( - self._cleanup_index_expr( - index, - e + return tuple( + self._cleanup_index_expr( + index, + *( + (e, "") if isinstance(e, str) - else e.compile( - dialect=self.dialect, - compile_kwargs={"literal_binds": True}, - ).string, - ) - for e in index.expressions + else (self._compile_element(e), self._default_modifiers(e)) + ), ) - else: - return super().create_index_sig(index) + for e in index.expressions + ) + + def _compile_element(self, element: ClauseElement) -> str: + return element.compile( + dialect=self.dialect, + compile_kwargs={"literal_binds": True, "include_table": False}, + ).string def render_type( self, type_: TypeEngine, autogen_context: AutogenContext |