sts-sponsors team mailing list archive
-
sts-sponsors team
-
Mailing list archive
-
Message #06986
[Merge] ~ack/maas-site-manager:count-queries into maas-site-manager:main
Alberto Donato has proposed merging ~ack/maas-site-manager:count-queries into maas-site-manager:main.
Commit message:
get counts for tokens and sites with separate queries
Requested reviews:
MAAS Committers (maas-committers)
For more details, see:
https://code.launchpad.net/~ack/maas-site-manager/+git/site-manager/+merge/440968
--
Your team MAAS Committers is requested to review the proposed merge of ~ack/maas-site-manager:count-queries into maas-site-manager:main.
diff --git a/backend/msm/db/queries.py b/backend/msm/db/queries.py
index 3a35472..57ee03c 100644
--- a/backend/msm/db/queries.py
+++ b/backend/msm/db/queries.py
@@ -3,19 +3,12 @@ from datetime import (
datetime,
timedelta,
)
-from typing import (
- Any,
- Sequence,
- Type,
- TypeVar,
-)
from uuid import UUID
from sqlalchemy import (
case,
func,
Operators,
- Row,
select,
Table,
)
@@ -69,25 +62,6 @@ def filters_from_arguments(
yield clause
-DerivedSchema = TypeVar("DerivedSchema", SiteSchema, TokenSchema)
-
-
-def extract_count_and_results(
- schema: Type[DerivedSchema],
- db_results: Sequence[Row[tuple[Any, Any, Any, Any, Any, Any]]],
-) -> tuple[int, Iterable[DerivedSchema]]:
- """
- Extract the count and result from a paginated query using a window function
- """
- schema_objects = (schema(**row._asdict()) for row in db_results)
- count: int = 0
- try:
- count = db_results[0][0]
- except IndexError:
- count = 0
- return count, list(schema_objects)
-
-
async def get_filtered_sites(
session: AsyncSession,
offset: int = 0,
@@ -101,22 +75,28 @@ async def get_filtered_sites(
timezone: list[str] | None = [],
url: list[str] | None = [],
) -> tuple[int, Iterable[SiteSchema]]:
- filters = filters_from_arguments(
- Site,
- city=city,
- country=country,
- name=name,
- note=note,
- region=region,
- street=street,
- timezone=timezone,
- url=url,
+ filters = list(
+ filters_from_arguments(
+ Site,
+ city=city,
+ country=country,
+ name=name,
+ note=note,
+ region=region,
+ street=street,
+ timezone=timezone,
+ url=url,
+ )
)
+ count = (
+ await session.execute(
+ select(func.count())
+ .select_from(Site)
+ .where(*filters) # type: ignore[arg-type]
+ )
+ ).scalar() or 0
stmt = (
select(
- # use a window function to get the count before limit
- # will be added as the first item of every results
- func.count().over(), # type: ignore[no-untyped-call]
Site.c.id,
Site.c.name,
Site.c.city,
@@ -150,14 +130,12 @@ async def get_filtered_sites(
.select_from(
Site.join(SiteData, SiteData.c.site_id == Site.c.id, isouter=True)
)
+ .where(*filters) # type: ignore[arg-type]
.limit(limit)
.offset(offset)
)
- for clause in filters:
- stmt = stmt.where(clause) # type: ignore[arg-type]
-
result = await session.execute(stmt)
- return extract_count_and_results(SiteSchema, result.all())
+ return count, [SiteSchema(**row._asdict()) for row in result.all()]
async def get_tokens(
@@ -165,9 +143,11 @@ async def get_tokens(
offset: int = 0,
limit: int = MAX_PAGE_SIZE,
) -> tuple[int, Iterable[TokenSchema]]:
+ count = (
+ await session.execute(select(func.count()).select_from(Token))
+ ).scalar() or 0
result = await session.execute(
select(
- func.count().over(), # type: ignore[no-untyped-call]
Token.c.id,
Token.c.site_id,
Token.c.value,
@@ -178,7 +158,7 @@ async def get_tokens(
.offset(offset)
.limit(limit)
)
- return extract_count_and_results(TokenSchema, result.all())
+ return count, [TokenSchema(**row._asdict()) for row in result.all()]
async def create_tokens(
diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py
index f5c9b21..e28f267 100644
--- a/backend/tests/conftest.py
+++ b/backend/tests/conftest.py
@@ -1,3 +1,5 @@
+import pytest
+
from .fixtures.app import (
user_app,
user_app_client,
@@ -9,3 +11,11 @@ from .fixtures.db import (
)
__all__ = ["db", "db_setup", "user_app", "user_app_client", "fixture"]
+
+
+def pytest_addoption(parser: pytest.Parser) -> None:
+ parser.addoption(
+ "--sqlalchemy-debug",
+ help="print out SQLALchemy queries",
+ action="store_true",
+ )
diff --git a/backend/tests/fixtures/app.py b/backend/tests/fixtures/app.py
index 1a379d0..d325e81 100644
--- a/backend/tests/fixtures/app.py
+++ b/backend/tests/fixtures/app.py
@@ -12,9 +12,13 @@ from msm.user_api import create_app
@pytest.fixture
-def user_app(db: Database) -> Iterable[FastAPI]:
+def user_app(
+ request: pytest.FixtureRequest, db: Database
+) -> Iterable[FastAPI]:
"""The API for users."""
- yield create_app(db.dsn)
+ app = create_app(db.dsn)
+ app.state.db._engine.echo = request.config.getoption("sqlalchemy_debug")
+ yield app
@pytest.fixture
diff --git a/backend/tests/fixtures/db.py b/backend/tests/fixtures/db.py
index 30848bc..7e326a3 100644
--- a/backend/tests/fixtures/db.py
+++ b/backend/tests/fixtures/db.py
@@ -55,9 +55,12 @@ def db_setup(postgresql_proc: PostgreSQLExecutor) -> Iterator[TestDSN]:
@pytest.fixture
-def db(db_setup: TestDSN) -> Iterator[Database]:
+def db(
+ request: pytest.FixtureRequest, db_setup: TestDSN
+) -> Iterator[Database]:
"""Set up the database schema."""
- engine = create_engine(db_setup.sync_dsn)
+ echo = request.config.getoption("sqlalchemy_debug")
+ engine = create_engine(db_setup.sync_dsn, echo=echo)
with engine.connect() as conn:
with conn.begin():
METADATA.create_all(conn)
Follow ups