← Back to team overview

sts-sponsors team mailing list archive

[Merge] ~ack/maas-site-manager:count-queries into maas-site-manager:main

 

Alberto Donato has proposed merging ~ack/maas-site-manager:count-queries into maas-site-manager:main.

Commit message:
get counts for tokens and sites with separate queries

Requested reviews:
  MAAS Committers (maas-committers)

For more details, see:
https://code.launchpad.net/~ack/maas-site-manager/+git/site-manager/+merge/440968
-- 
Your team MAAS Committers is requested to review the proposed merge of ~ack/maas-site-manager:count-queries into maas-site-manager:main.
diff --git a/backend/msm/db/queries.py b/backend/msm/db/queries.py
index 3a35472..57ee03c 100644
--- a/backend/msm/db/queries.py
+++ b/backend/msm/db/queries.py
@@ -3,19 +3,12 @@ from datetime import (
     datetime,
     timedelta,
 )
-from typing import (
-    Any,
-    Sequence,
-    Type,
-    TypeVar,
-)
 from uuid import UUID
 
 from sqlalchemy import (
     case,
     func,
     Operators,
-    Row,
     select,
     Table,
 )
@@ -69,25 +62,6 @@ def filters_from_arguments(
                 yield clause
 
 
-DerivedSchema = TypeVar("DerivedSchema", SiteSchema, TokenSchema)
-
-
-def extract_count_and_results(
-    schema: Type[DerivedSchema],
-    db_results: Sequence[Row[tuple[Any, Any, Any, Any, Any, Any]]],
-) -> tuple[int, Iterable[DerivedSchema]]:
-    """
-    Extract the count and result from a paginated query using a window function
-    """
-    schema_objects = (schema(**row._asdict()) for row in db_results)
-    count: int = 0
-    try:
-        count = db_results[0][0]
-    except IndexError:
-        count = 0
-    return count, list(schema_objects)
-
-
 async def get_filtered_sites(
     session: AsyncSession,
     offset: int = 0,
@@ -101,22 +75,28 @@ async def get_filtered_sites(
     timezone: list[str] | None = [],
     url: list[str] | None = [],
 ) -> tuple[int, Iterable[SiteSchema]]:
-    filters = filters_from_arguments(
-        Site,
-        city=city,
-        country=country,
-        name=name,
-        note=note,
-        region=region,
-        street=street,
-        timezone=timezone,
-        url=url,
+    filters = list(
+        filters_from_arguments(
+            Site,
+            city=city,
+            country=country,
+            name=name,
+            note=note,
+            region=region,
+            street=street,
+            timezone=timezone,
+            url=url,
+        )
     )
+    count = (
+        await session.execute(
+            select(func.count())
+            .select_from(Site)
+            .where(*filters)  # type: ignore[arg-type]
+        )
+    ).scalar() or 0
     stmt = (
         select(
-            # use a window function to get the count before limit
-            # will be added as the first item of every results
-            func.count().over(),  # type: ignore[no-untyped-call]
             Site.c.id,
             Site.c.name,
             Site.c.city,
@@ -150,14 +130,12 @@ async def get_filtered_sites(
         .select_from(
             Site.join(SiteData, SiteData.c.site_id == Site.c.id, isouter=True)
         )
+        .where(*filters)  # type: ignore[arg-type]
         .limit(limit)
         .offset(offset)
     )
-    for clause in filters:
-        stmt = stmt.where(clause)  # type: ignore[arg-type]
-
     result = await session.execute(stmt)
-    return extract_count_and_results(SiteSchema, result.all())
+    return count, [SiteSchema(**row._asdict()) for row in result.all()]
 
 
 async def get_tokens(
@@ -165,9 +143,11 @@ async def get_tokens(
     offset: int = 0,
     limit: int = MAX_PAGE_SIZE,
 ) -> tuple[int, Iterable[TokenSchema]]:
+    count = (
+        await session.execute(select(func.count()).select_from(Token))
+    ).scalar() or 0
     result = await session.execute(
         select(
-            func.count().over(),  # type: ignore[no-untyped-call]
             Token.c.id,
             Token.c.site_id,
             Token.c.value,
@@ -178,7 +158,7 @@ async def get_tokens(
         .offset(offset)
         .limit(limit)
     )
-    return extract_count_and_results(TokenSchema, result.all())
+    return count, [TokenSchema(**row._asdict()) for row in result.all()]
 
 
 async def create_tokens(
diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py
index f5c9b21..e28f267 100644
--- a/backend/tests/conftest.py
+++ b/backend/tests/conftest.py
@@ -1,3 +1,5 @@
+import pytest
+
 from .fixtures.app import (
     user_app,
     user_app_client,
@@ -9,3 +11,11 @@ from .fixtures.db import (
 )
 
 __all__ = ["db", "db_setup", "user_app", "user_app_client", "fixture"]
+
+
+def pytest_addoption(parser: pytest.Parser) -> None:
+    parser.addoption(
+        "--sqlalchemy-debug",
+        help="print out SQLALchemy queries",
+        action="store_true",
+    )
diff --git a/backend/tests/fixtures/app.py b/backend/tests/fixtures/app.py
index 1a379d0..d325e81 100644
--- a/backend/tests/fixtures/app.py
+++ b/backend/tests/fixtures/app.py
@@ -12,9 +12,13 @@ from msm.user_api import create_app
 
 
 @pytest.fixture
-def user_app(db: Database) -> Iterable[FastAPI]:
+def user_app(
+    request: pytest.FixtureRequest, db: Database
+) -> Iterable[FastAPI]:
     """The API for users."""
-    yield create_app(db.dsn)
+    app = create_app(db.dsn)
+    app.state.db._engine.echo = request.config.getoption("sqlalchemy_debug")
+    yield app
 
 
 @pytest.fixture
diff --git a/backend/tests/fixtures/db.py b/backend/tests/fixtures/db.py
index 30848bc..7e326a3 100644
--- a/backend/tests/fixtures/db.py
+++ b/backend/tests/fixtures/db.py
@@ -55,9 +55,12 @@ def db_setup(postgresql_proc: PostgreSQLExecutor) -> Iterator[TestDSN]:
 
 
 @pytest.fixture
-def db(db_setup: TestDSN) -> Iterator[Database]:
+def db(
+    request: pytest.FixtureRequest, db_setup: TestDSN
+) -> Iterator[Database]:
     """Set up the database schema."""
-    engine = create_engine(db_setup.sync_dsn)
+    echo = request.config.getoption("sqlalchemy_debug")
+    engine = create_engine(db_setup.sync_dsn, echo=echo)
     with engine.connect() as conn:
         with conn.begin():
             METADATA.create_all(conn)

Follow ups