← Back to team overview

sts-sponsors team mailing list archive

[Merge] ~thorsten-merten/maas-site-manager:MAASENG-1497-pagination into maas-site-manager:main

 

Thorsten Merten has proposed merging ~thorsten-merten/maas-site-manager:MAASENG-1497-pagination into maas-site-manager:main.

Commit message:
feat: WIP add pagination and update tests

* add window function to get total count
* change tests to use async httpx client
  so we can make multiple calls

also

chore: update coding style to use black's default line length of 88


Requested reviews:
  MAAS Committers (maas-committers)

For more details, see:
https://code.launchpad.net/~thorsten-merten/maas-site-manager/+git/maas-site-manager/+merge/439629
-- 
Your team MAAS Committers is requested to review the proposed merge of ~thorsten-merten/maas-site-manager:MAASENG-1497-pagination into maas-site-manager:main.
diff --git a/backend/msm/__init__.py b/backend/msm/__init__.py
index 8506985..3651249 100644
--- a/backend/msm/__init__.py
+++ b/backend/msm/__init__.py
@@ -1,8 +1,16 @@
 from pkg_resources import get_distribution
 
-__all__ = ["PACKAGE", "__version__"]
+__all__ = [
+    "PACKAGE",
+    "__version__",
+    "__default_page_size__",
+    "__max_page_size__",
+]
 
 
 PACKAGE = get_distribution("msm")
 
 __version__ = PACKAGE.version
+
+__default_page_size__ = 20
+__max_page_size__ = 100
diff --git a/backend/msm/db/queries.py b/backend/msm/db/queries.py
index 30df706..7c34f0a 100644
--- a/backend/msm/db/queries.py
+++ b/backend/msm/db/queries.py
@@ -5,15 +5,24 @@ from datetime import (
     datetime,
     timedelta,
 )
-from typing import TYPE_CHECKING
+from typing import (
+    Any,
+    Sequence,
+    Type,
+    TYPE_CHECKING,
+    TypeVar,
+)
 from uuid import UUID
 
 from sqlalchemy import (
+    func,
+    Row,
     select,
     Table,
 )
 from sqlalchemy.ext.asyncio import AsyncSession
 
+from .. import __max_page_size__
 from ..schema import (
     Site as SiteSchema,
     Token as TokenSchema,
@@ -63,8 +72,27 @@ def filters_from_arguments(
                 yield clause
 
 
+DerivedSchema = TypeVar("DerivedSchema", SiteSchema, TokenSchema)
+
+
+def extract_count_and_results(
+    schema: Type[DerivedSchema],
+    db_results: Sequence[Row[tuple[Any, Any, Any, Any, Any]]],
+) -> tuple[int, Iterable[DerivedSchema]]:
+    """Extract the count and result from a paginated query using a window function"""
+    schema_objects = (schema(**row._asdict()) for row in db_results)
+    count: int = 0
+    try:
+        count = db_results[0][0]
+    except IndexError:
+        count = 0
+    return count, list(schema_objects)
+
+
 async def get_filtered_sites(
     session: AsyncSession,
+    offset: int = 0,
+    limit: int = __max_page_size__,
     city: list[str] | None = [],
     country: list[str] | None = [],
     name: list[str] | None = [],
@@ -73,7 +101,7 @@ async def get_filtered_sites(
     street: list[str] | None = [],
     timezone: list[str] | None = [],
     url: list[str] | None = [],
-) -> Iterable[SiteSchema]:
+) -> tuple[int, Iterable[SiteSchema]]:
     filters = filters_from_arguments(
         Site,
         city=city,
@@ -85,47 +113,52 @@ async def get_filtered_sites(
         timezone=timezone,
         url=url,
     )
-    stmt = select(
-        Site.c.id,
-        Site.c.name,
-        Site.c.city,
-        Site.c.country,
-        Site.c.latitude,
-        Site.c.longitude,
-        Site.c.note,
-        Site.c.region,
-        Site.c.street,
-        Site.c.timezone,
-        Site.c.url,
+    stmt = (
+        select(
+            # use a window function to get the count before limit
+            # will be added as the first item of every results
+            func.count().over(),  # type: ignore[no-untyped-call]
+            Site.c.id,
+            Site.c.name,
+            Site.c.city,
+            Site.c.country,
+            Site.c.latitude,
+            Site.c.longitude,
+            Site.c.note,
+            Site.c.region,
+            Site.c.street,
+            Site.c.timezone,
+            Site.c.url,
+        )
+        .select_from(Site)
+        .limit(limit)
+        .offset(offset)
     )
     for clause in filters:
-        stmt = stmt.where(clause)  # type: ignore
-    result = await session.execute(stmt)
-    return (SiteSchema(**row._asdict()) for row in result.all())
-
-
-async def get_sites(session: AsyncSession) -> Iterable[SiteSchema]:
-    stmt = select(
-        Site.c.id,
-        Site.c.name,
-        Site.c.city,
-        Site.c.latitude,
-        Site.c.longitude,
-        Site.c.note,
-        Site.c.region,
-        Site.c.street,
-        Site.c.timezone,
-        Site.c.url,
-    )
+        stmt = stmt.where(clause)  # type: ignore[arg-type]
+
     result = await session.execute(stmt)
-    return (SiteSchema(**row._asdict()) for row in result.all())
+    return extract_count_and_results(SiteSchema, result.all())
 
 
-async def get_tokens(session: AsyncSession) -> Iterable[TokenSchema]:
+async def get_tokens(
+    session: AsyncSession,
+    offset: int = 0,
+    limit: int = __max_page_size__,
+) -> tuple[int, Iterable[TokenSchema]]:
     result = await session.execute(
-        select(Token.c.id, Token.c.site_id, Token.c.value, Token.c.expiration)
+        select(
+            func.count().over(),  # type: ignore[no-untyped-call]
+            Token.c.id,
+            Token.c.site_id,
+            Token.c.value,
+            Token.c.expiration,
+        )
+        .select_from(Token)
+        .offset(offset)
+        .limit(limit)
     )
-    return (TokenSchema(**row._asdict()) for row in result.all())
+    return extract_count_and_results(TokenSchema, result.all())
 
 
 async def create_tokens(
diff --git a/backend/msm/schema.py b/backend/msm/schema.py
index b0e2acc..688f76e 100644
--- a/backend/msm/schema.py
+++ b/backend/msm/schema.py
@@ -1,3 +1,4 @@
+from collections.abc import Sequence
 from datetime import (
     datetime,
     timedelta,
@@ -11,6 +12,20 @@ from pydantic import (
 )
 from pydantic.fields import Field
 
+from msm import __max_page_size__
+
+
+class PaginatedResults(BaseModel):
+    """
+    Base class to wrap objects in a pagination.
+    Derived classes should overwrite the items property
+    """
+
+    total: int = Field(min=0)
+    page: int = Field(min=0)
+    size: int = Field(min=0, max=__max_page_size__)
+    items: Sequence[BaseModel]
+
 
 class CreateUser(BaseModel):
     """
@@ -75,6 +90,10 @@ class SiteData(CreateSiteData):
     id: int
 
 
+class PaginatedSites(PaginatedResults):
+    items: list[Site]
+
+
 class SiteWithData(Site):
 
     """
@@ -104,6 +123,10 @@ class Token(CreateToken):
     id: int
 
 
+class PaginatedTokens(PaginatedResults):
+    items: list[Token]
+
+
 class CreateTokensRequest(BaseModel):
     """
     Request to create one or more tokens, with a certain validity,
diff --git a/backend/msm/testing/app.py b/backend/msm/testing/app.py
index 152b7d5..f3eb369 100644
--- a/backend/msm/testing/app.py
+++ b/backend/msm/testing/app.py
@@ -1,7 +1,12 @@
-from typing import Iterable
+from typing import (
+    AsyncIterable,
+    Iterable,
+)
 
 from fastapi import FastAPI
-from fastapi.testclient import TestClient
+
+# from fastapi.testclient import TestClient
+from httpx import AsyncClient
 import pytest
 
 from ..db import Database
@@ -15,6 +20,7 @@ def user_app(db: Database) -> Iterable[FastAPI]:
 
 
 @pytest.fixture
-def user_app_client(user_app: FastAPI) -> Iterable[TestClient]:
+async def user_app_client(user_app: FastAPI) -> AsyncIterable[AsyncClient]:
     """Client for the user API."""
-    yield TestClient(user_app)
+    async with AsyncClient(app=user_app, base_url="http://test";) as client:
+        yield client
diff --git a/backend/msm/testing/db.py b/backend/msm/testing/db.py
index f1d0a44..f95eb2e 100644
--- a/backend/msm/testing/db.py
+++ b/backend/msm/testing/db.py
@@ -84,6 +84,16 @@ class Fixture:
             await session.commit()
             return [row._asdict() for row in result]
 
+    async def select_all(
+        self,
+        table: str,
+    ) -> list[dict[str, Any]]:
+        """Take a peak what is in there"""
+        async with self.db.session() as session:
+            result = await session.execute(METADATA.tables[table].select())
+            await session.commit()
+            return [row._asdict() for row in result]
+
 
 @pytest.fixture
 def fixture(db: Database) -> Iterator[Fixture]:
diff --git a/backend/msm/user_api/_base.py b/backend/msm/user_api/_base.py
index 1d49a7d..554fcb6 100644
--- a/backend/msm/user_api/_base.py
+++ b/backend/msm/user_api/_base.py
@@ -5,6 +5,8 @@ from fastapi import (
 from sqlalchemy.ext.asyncio import AsyncSession
 
 from .. import (
+    __default_page_size__,
+    __max_page_size__,
     __version__,
     schema,
 )
@@ -14,6 +16,16 @@ from ..db import (
 )
 
 
+async def pagination_parameters(
+    page: int = Query(default=1, gte=1),
+    size: int = Query(
+        default=__default_page_size__, lte=__max_page_size__, gte=1
+    ),
+) -> dict[str, int]:
+    """Make parameters for pagination accessible as a dict"""
+    return {"page": page, "size": size, "offset": (page - 1) * size}
+
+
 async def root() -> dict[str, str]:
     return {"version": __version__}
 
@@ -28,27 +40,43 @@ async def sites(
     | None = Query(default=None, title="Filter for timezones"),
     url: list[str] | None = Query(default=None, title="Filter for urls"),
     session: AsyncSession = Depends(db_session),
-) -> list[schema.Site]:
+    pagination_params: dict[str, int] = Depends(pagination_parameters),
+) -> schema.PaginatedSites:
     """Return all sites"""
-    return list(
-        await queries.get_filtered_sites(
-            session,
-            city,
-            name,
-            note,
-            region,
-            street,
-            timezone,
-            url,
-        )
+    total, results = await queries.get_filtered_sites(
+        session,
+        pagination_params["offset"],
+        pagination_params["size"],
+        city,
+        name,
+        note,
+        region,
+        street,
+        timezone,
+        url,
+    )
+    return schema.PaginatedSites(
+        total=total,
+        page=pagination_params["page"],
+        size=pagination_params["size"],
+        items=list(results),
     )
 
 
 async def tokens(
     session: AsyncSession = Depends(db_session),
-) -> list[schema.Token]:
+    pagination_params: dict[str, int] = Depends(pagination_parameters),
+) -> schema.PaginatedTokens:
     """Return all tokens"""
-    return list(await queries.get_tokens(session))
+    total, results = await queries.get_tokens(
+        session, pagination_params["offset"], pagination_params["size"]
+    )
+    return schema.PaginatedTokens(
+        total=total,
+        page=pagination_params["page"],
+        size=pagination_params["size"],
+        items=list(results),
+    )
 
 
 async def tokens_post(
diff --git a/backend/msm/user_api/tests/test_handlers.py b/backend/msm/user_api/tests/test_handlers.py
index 93a27cc..0784840 100644
--- a/backend/msm/user_api/tests/test_handlers.py
+++ b/backend/msm/user_api/tests/test_handlers.py
@@ -3,21 +3,23 @@ from datetime import (
     timedelta,
 )
 
-from fastapi.testclient import TestClient
+# from fastapi.testclient import TestClient
+from httpx import AsyncClient
 import pytest
 
 from ...testing.db import Fixture
 
 
-def test_root(user_app_client: TestClient) -> None:
-    response = user_app_client.get("/")
+@pytest.mark.asyncio
+async def test_root(user_app_client: AsyncClient) -> None:
+    response = await user_app_client.get("/")
     assert response.status_code == 200
     assert response.json() == {"version": "0.0.1"}
 
 
 @pytest.mark.asyncio
 async def test_list_sites(
-    user_app_client: TestClient, fixture: Fixture
+    user_app_client: AsyncClient, fixture: Fixture
 ) -> None:
     site1 = {
         "id": 1,
@@ -35,19 +37,40 @@ async def test_list_sites(
     site2 = site1.copy()
     site2["id"] = 2
     site2["name"] = "BerlinHQ"
-    site2["timezone"] = "+1.00"
+    site2["timezone"] = "1.00"
     site2["city"] = "Berlin"
     site2["country"] = "de"
     await fixture.create("sites", [site1, site2])
-    response = user_app_client.get("/sites?city=onDo")  # vs London
-    assert response.status_code == 200
-    assert response.json() == [site1]
+    page1 = await user_app_client.get("/sites")
+    assert page1.status_code == 200
+    assert page1.json() == {
+        "page": 1,
+        "size": 20,
+        "total": 2,
+        "items": [site1, site2],
+    }
+    filtered = await user_app_client.get("/sites?city=onDo")  # vs London
+    assert filtered.status_code == 200
+    assert filtered.json() == {
+        "page": 1,
+        "size": 20,
+        "total": 1,
+        "items": [site1],
+    }
+    paginated = await user_app_client.get("/sites?page=2&size=1")
+    assert paginated.status_code == 200
+    assert paginated.json() == {
+        "page": 2,
+        "size": 1,
+        "total": 2,
+        "items": [site2],
+    }
 
 
 @pytest.mark.asyncio
-async def test_create_token(user_app_client: TestClient) -> None:
+async def test_create_token(user_app_client: AsyncClient) -> None:
     seconds = 100
-    response = user_app_client.post(
+    response = await user_app_client.post(
         "/tokens", json={"count": 5, "duration": seconds}
     )
     assert response.status_code == 200
@@ -55,13 +78,12 @@ async def test_create_token(user_app_client: TestClient) -> None:
     assert datetime.fromisoformat(result["expiration"]) < (
         datetime.utcnow() + timedelta(seconds=seconds)
     )
-
     assert len(result["tokens"]) == 5
 
 
 @pytest.mark.asyncio
 async def test_list_tokens(
-    user_app_client: TestClient, fixture: Fixture
+    user_app_client: AsyncClient, fixture: Fixture
 ) -> None:
     tokens = [
         {
@@ -78,6 +100,7 @@ async def test_list_tokens(
         },
     ]
     await fixture.create("tokens", tokens)
-    response = user_app_client.get("/tokens")
+    response = await user_app_client.get("/tokens")
     assert response.status_code == 200
-    assert len(response.json()) == 2
+    assert response.json()["total"] == 2
+    assert len(response.json()["items"]) == 2
diff --git a/backend/pyproject.toml b/backend/pyproject.toml
index c4b8d97..237b6af 100644
--- a/backend/pyproject.toml
+++ b/backend/pyproject.toml
@@ -13,6 +13,7 @@ profile = 'black'
 use_parentheses = true
 
 [tool.pytest.ini_options]
+asyncio_mode = "auto"
 testpaths = [
   "msm/"
 ]
diff --git a/backend/setup.cfg b/backend/setup.cfg
index 2209af8..6436e19 100644
--- a/backend/setup.cfg
+++ b/backend/setup.cfg
@@ -31,7 +31,7 @@ testing =
 lint_files = setup.py msm/
 
 [flake8]
-max-line-length = 79
+max-line-length = 88
 
 [tox:tox]
 minversion = 4.0.8

Follow ups