summaryrefslogtreecommitdiff
path: root/alembic/ddl/postgresql.py
diff options
context:
space:
mode:
Diffstat (limited to 'alembic/ddl/postgresql.py')
-rw-r--r--alembic/ddl/postgresql.py73
1 files changed, 59 insertions, 14 deletions
diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py
index 4ffc2eb..247838b 100644
--- a/alembic/ddl/postgresql.py
+++ b/alembic/ddl/postgresql.py
@@ -21,8 +21,10 @@ from sqlalchemy.dialects.postgresql import BIGINT
from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import INTEGER
from sqlalchemy.schema import CreateIndex
+from sqlalchemy.sql import operators
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.elements import TextClause
+from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.types import NULLTYPE
from .base import alter_column
@@ -53,6 +55,7 @@ if TYPE_CHECKING:
from sqlalchemy.dialects.postgresql.json import JSON
from sqlalchemy.dialects.postgresql.json import JSONB
from sqlalchemy.sql.elements import BinaryExpression
+ from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import Table
@@ -248,11 +251,14 @@ class PostgresqlImpl(DefaultImpl):
if not sqla_compat.sqla_2:
self._skip_functional_indexes(metadata_indexes, conn_indexes)
- def _cleanup_index_expr(self, index: Index, expr: str) -> str:
+ def _cleanup_index_expr(
+ self, index: Index, expr: str, remove_suffix: str
+ ) -> str:
# start = expr
expr = expr.lower()
expr = expr.replace('"', "")
if index.table is not None:
+ # should not be needed, since include_table=False is in compile
expr = expr.replace(f"{index.table.name.lower()}.", "")
while expr and expr[0] == "(" and expr[-1] == ")":
@@ -261,25 +267,64 @@ class PostgresqlImpl(DefaultImpl):
# strip :: cast. types can have spaces in them
expr = re.sub(r"(::[\w ]+\w)", "", expr)
+ if remove_suffix and expr.endswith(remove_suffix):
+ expr = expr[: -len(remove_suffix)]
+
# print(f"START: {start} END: {expr}")
return expr
+ def _default_modifiers(self, exp: ClauseElement) -> str:
+ to_remove = ""
+ while isinstance(exp, UnaryExpression):
+ if exp.modifier is None:
+ exp = exp.element
+ else:
+ op = exp.modifier
+ if isinstance(exp.element, UnaryExpression):
+ inner_op = exp.element.modifier
+ else:
+ inner_op = None
+ if inner_op is None:
+ if op == operators.asc_op:
+ # default is asc
+ to_remove = " asc"
+ elif op == operators.nullslast_op:
+ # default is nulls last
+ to_remove = " nulls last"
+ else:
+ if (
+ inner_op == operators.asc_op
+ and op == operators.nullslast_op
+ ):
+ # default is asc nulls last
+ to_remove = " asc nulls last"
+ elif (
+ inner_op == operators.desc_op
+ and op == operators.nullsfirst_op
+ ):
+ # default for desc is nulls first
+ to_remove = " nulls first"
+ break
+ return to_remove
+
def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
- if sqla_compat.is_expression_index(index):
- return tuple(
- self._cleanup_index_expr(
- index,
- e
+ return tuple(
+ self._cleanup_index_expr(
+ index,
+ *(
+ (e, "")
if isinstance(e, str)
- else e.compile(
- dialect=self.dialect,
- compile_kwargs={"literal_binds": True},
- ).string,
- )
- for e in index.expressions
+ else (self._compile_element(e), self._default_modifiers(e))
+ ),
)
- else:
- return super().create_index_sig(index)
+ for e in index.expressions
+ )
+
+ def _compile_element(self, element: ClauseElement) -> str:
+ return element.compile(
+ dialect=self.dialect,
+ compile_kwargs={"literal_binds": True, "include_table": False},
+ ).string
def render_type(
self, type_: TypeEngine, autogen_context: AutogenContext