← Back to team overview

sts-sponsors team mailing list archive

[Merge] ~ack/maas-site-manager:row-count-helper into maas-site-manager:main

 

Alberto Donato has proposed merging ~ack/maas-site-manager:row-count-helper into maas-site-manager:main.

Commit message:
add a row_count helper for counting entries



Requested reviews:
  MAAS Lander (maas-lander): unittests
  MAAS Committers (maas-committers)

For more details, see:
https://code.launchpad.net/~ack/maas-site-manager/+git/site-manager/+merge/442474
-- 
Your team MAAS Committers is requested to review the proposed merge of ~ack/maas-site-manager:row-count-helper into maas-site-manager:main.
diff --git a/backend/msm/db/queries.py b/backend/msm/db/queries.py
index b46f105..803aa81 100644
--- a/backend/msm/db/queries.py
+++ b/backend/msm/db/queries.py
@@ -8,7 +8,6 @@ from operator import or_
 from typing import Any
 from uuid import UUID
 
-# from passlib.context import CryptContext
 from sqlalchemy import (
     case,
     ColumnOperators,
@@ -21,6 +20,7 @@ from sqlalchemy import (
     update,
 )
 from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.sql.expression import FromClause
 
 from ..schema import MAX_PAGE_SIZE
 from ._tables import (
@@ -45,6 +45,18 @@ class InvalidPendingSites(Exception):
         super().__init__("Unknown pending sites")
 
 
+async def row_count(
+    session: AsyncSession, what: FromClause, *filters: ColumnOperators
+) -> int:
+    """Count specified entries."""
+    stmt = (
+        select(func.count())
+        .select_from(what)
+        .where(*filters)  # type: ignore[arg-type]
+    )
+    return (await session.execute(stmt)).scalar() or 0
+
+
 async def get_user(
     session: AsyncSession, email: str
 ) -> UserWithPasswordSchema | None:
@@ -133,13 +145,7 @@ async def get_sites(
         url=url,
     )
     filters.append(Site.c.accepted == True)  # noqa
-    count = (
-        await session.execute(
-            select(func.count())
-            .select_from(Site)
-            .where(*filters)  # type: ignore[arg-type]
-        )
-    ).scalar() or 0
+    count = await row_count(session, Site, *filters)
     stmt = (
         select(
             Site.c.id,
@@ -190,11 +196,7 @@ async def get_pending_sites(
     limit: int = MAX_PAGE_SIZE,
 ) -> tuple[int, Iterable[PendingSiteSchema]]:
     filters = [Site.c.accepted == False]  # noqa
-    count = (
-        await session.execute(
-            select(func.count()).select_from(Site).where(*filters)
-        )
-    ).scalar() or 0
+    count = await row_count(session, Site, *filters)
     stmt = (
         select(
             Site.c.id,
@@ -245,9 +247,7 @@ async def get_tokens(
     offset: int = 0,
     limit: int = MAX_PAGE_SIZE,
 ) -> tuple[int, Iterable[TokenSchema]]:
-    count = (
-        await session.execute(select(func.count()).select_from(Token))
-    ).scalar() or 0
+    count = await row_count(session, Token)
     result = await session.execute(
         select(
             Token.c.id,