summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/testing
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing')
-rw-r--r--lib/sqlalchemy/testing/config.py31
-rw-r--r--lib/sqlalchemy/testing/fixtures.py16
-rw-r--r--lib/sqlalchemy/testing/plugin/plugin_base.py3
-rw-r--r--lib/sqlalchemy/testing/requirements.py12
4 files changed, 54 insertions, 8 deletions
diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py
index 268a56421..c1ca670da 100644
--- a/lib/sqlalchemy/testing/config.py
+++ b/lib/sqlalchemy/testing/config.py
@@ -6,6 +6,11 @@
# the MIT License: https://www.opensource.org/licenses/mit-license.php
import collections
+import typing
+from typing import Any
+from typing import Iterable
+from typing import Tuple
+from typing import Union
from .. import util
@@ -20,10 +25,15 @@ any_async = False
_current = None
ident = "main"
-_fixture_functions = None # installed by plugin_base
+if typing.TYPE_CHECKING:
+ from .plugin.plugin_base import FixtureFunctions
+ _fixture_functions: FixtureFunctions
+else:
+ _fixture_functions = None # installed by plugin_base
-def combinations(*comb, **kw):
+
+def combinations(*comb: Union[Any, Tuple[Any, ...]], **kw: str):
r"""Deliver multiple versions of a test based on positional combinations.
This is a facade over pytest.mark.parametrize.
@@ -89,25 +99,32 @@ def combinations(*comb, **kw):
return _fixture_functions.combinations(*comb, **kw)
-def combinations_list(arg_iterable, **kw):
+def combinations_list(
+ arg_iterable: Iterable[
+ Tuple[
+ Any,
+ ]
+ ],
+ **kw,
+):
"As combination, but takes a single iterable"
return combinations(*arg_iterable, **kw)
-def fixture(*arg, **kw):
+def fixture(*arg: Any, **kw: Any) -> Any:
return _fixture_functions.fixture(*arg, **kw)
-def get_current_test_name():
+def get_current_test_name() -> str:
return _fixture_functions.get_current_test_name()
-def mark_base_test_class():
+def mark_base_test_class() -> Any:
return _fixture_functions.mark_base_test_class()
class _AddToMarker:
- def __getattr__(self, attr):
+ def __getattr__(self, attr: str) -> Any:
return getattr(_fixture_functions.add_to_marker, attr)
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
index ecc20f163..7228e5afe 100644
--- a/lib/sqlalchemy/testing/fixtures.py
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -20,6 +20,7 @@ from .util import drop_all_tables_from_metadata
from .. import event
from .. import util
from ..orm import declarative_base
+from ..orm import DeclarativeBase
from ..orm import registry
from ..schema import sort_tables_and_constraints
@@ -82,6 +83,21 @@ class TestBase:
yield reg
reg.dispose()
+ @config.fixture
+ def decl_base(self, metadata):
+ _md = metadata
+
+ class Base(DeclarativeBase):
+ metadata = _md
+ type_annotation_map = {
+ str: sa.String().with_variant(
+ sa.String(50), "mysql", "mariadb"
+ )
+ }
+
+ yield Base
+ Base.registry.dispose()
+
@config.fixture()
def future_connection(self, future_engine, connection):
# integrate the future_engine and connection fixtures so
diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py
index 0b4451b3c..52e42bb97 100644
--- a/lib/sqlalchemy/testing/plugin/plugin_base.py
+++ b/lib/sqlalchemy/testing/plugin/plugin_base.py
@@ -19,6 +19,7 @@ import logging
import os
import re
import sys
+from typing import Any
from sqlalchemy.testing import asyncio
@@ -738,7 +739,7 @@ class FixtureFunctions(abc.ABC):
raise NotImplementedError()
@abc.abstractmethod
- def mark_base_test_class(self):
+ def mark_base_test_class(self) -> Any:
raise NotImplementedError()
@abc.abstractproperty
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
index 410ab26ed..41e5d6772 100644
--- a/lib/sqlalchemy/testing/requirements.py
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -1326,6 +1326,18 @@ class SuiteRequirements(Requirements):
return exclusions.only_if(check)
@property
+ def no_sqlalchemy2_stubs(self):
+ def check(config):
+ try:
+ __import__("sqlalchemy-stubs.ext.mypy")
+ except ImportError:
+ return False
+ else:
+ return True
+
+ return exclusions.skip_if(check)
+
+ @property
def python38(self):
return exclusions.only_if(
lambda: util.py38, "Python 3.8 or above required"