← Back to team overview

sts-sponsors team mailing list archive

[Merge] ~ack/maas-site-manager:filters-group into maas-site-manager:main

 

Alberto Donato has proposed merging ~ack/maas-site-manager:filters-group into maas-site-manager:main.

Commit message:
Group filter and pagination API query parameters

This reduces the number of parameters of the handlers, using Depends.
It also moves the schema file to a package, moving pagination logic to a separate file.
Finally, it adds missing filter for country.

Requested reviews:
  MAAS Committers (maas-committers)

For more details, see:
https://code.launchpad.net/~ack/maas-site-manager/+git/site-manager/+merge/441103
-- 
Your team MAAS Committers is requested to review the proposed merge of ~ack/maas-site-manager:filters-group into maas-site-manager:main.
diff --git a/backend/msm/__init__.py b/backend/msm/__init__.py
index b4e4731..1667eae 100644
--- a/backend/msm/__init__.py
+++ b/backend/msm/__init__.py
@@ -3,14 +3,9 @@ from pkg_resources import get_distribution
 __all__ = [
     "PACKAGE",
     "__version__",
-    "DEFAULT_PAGE_SIZE",
-    "MAX_PAGE_SIZE",
 ]
 
 
 PACKAGE = get_distribution("msm")
 
 __version__ = PACKAGE.version
-
-DEFAULT_PAGE_SIZE = 20
-MAX_PAGE_SIZE = 100
diff --git a/backend/msm/db/queries.py b/backend/msm/db/queries.py
index f70c850..f75dcc7 100644
--- a/backend/msm/db/queries.py
+++ b/backend/msm/db/queries.py
@@ -3,19 +3,21 @@ from datetime import (
     datetime,
     timedelta,
 )
+from functools import reduce
+from operator import or_
 from uuid import UUID
 
 from sqlalchemy import (
     case,
+    ColumnOperators,
     func,
-    Operators,
     select,
     Table,
 )
 from sqlalchemy.ext.asyncio import AsyncSession
 
-from .. import MAX_PAGE_SIZE
 from ..schema import (
+    MAX_PAGE_SIZE,
     Site as SiteSchema,
     Token as TokenSchema,
 )
@@ -28,10 +30,9 @@ from ._tables import (
 
 def filters_from_arguments(
     table: Table,
-    **kwargs: list[str] | None,
-) -> Iterable[Operators]:
-    """
-    Yields clauses to join with AND and all entries for a single arg by OR.
+    **filter_args: list[str] | None,
+) -> list[ColumnOperators]:
+    """Yields clauses to join with AND and all entries for a single arg by OR.
     This enables to convert query params such as
 
       ?name=name1&name=name2&city=city
@@ -41,52 +42,48 @@ def filters_from_arguments(
       (name ilike %name1% OR name ilike %name2%) AND city ilike %city%
 
     :param table: the table to create the WHERE clause for
-    :param kwargs: the parameters matching the table's column name
-                   as keys and lists of strings that will be matched
-                   via ilike
-    :returns: a generator yielding where clause that joins all queries
-              per column with OR and all columns with AND
+    :param filter_args: the parameters matching the table's column name
+                        as keys and lists of strings that will be matched
+                        via ilike
+    :returns: a list with clauses to filter table values, which are meant to be
+              used in AND. Clauses for each column are joined with OR.
     """
-    for dimension, needles in kwargs.items():
-        column = table.c[dimension]
-
-        match needles:
-            case [needle]:
-                # If there's only one we don't need any ORs
-                yield column.icontains(needle, autoescape=True)
-            case [needle, *other_needles]:
-                # More than one thing to match against, join them with OR
-                clause = column.icontains(needle, autoescape=True) | False
-                for needle in other_needles:
-                    clause |= column.icontains(needle, autoescape=True)
-                yield clause
+    return [
+        reduce(
+            or_,
+            (
+                table.c[name].icontains(value, autoescape=True)
+                for value in values
+            ),
+        )
+        for name, values in filter_args.items()
+        if values
+    ]
 
 
 async def get_filtered_sites(
     session: AsyncSession,
     offset: int = 0,
     limit: int = MAX_PAGE_SIZE,
-    city: list[str] | None = [],
-    country: list[str] | None = [],
-    name: list[str] | None = [],
-    note: list[str] | None = [],
-    region: list[str] | None = [],
-    street: list[str] | None = [],
-    timezone: list[str] | None = [],
-    url: list[str] | None = [],
+    city: list[str] | None = None,
+    country: list[str] | None = None,
+    name: list[str] | None = None,
+    note: list[str] | None = None,
+    region: list[str] | None = None,
+    street: list[str] | None = None,
+    timezone: list[str] | None = None,
+    url: list[str] | None = None,
 ) -> tuple[int, Iterable[SiteSchema]]:
-    filters = list(
-        filters_from_arguments(
-            Site,
-            city=city,
-            country=country,
-            name=name,
-            note=note,
-            region=region,
-            street=street,
-            timezone=timezone,
-            url=url,
-        )
+    filters = filters_from_arguments(
+        Site,
+        city=city,
+        country=country,
+        name=name,
+        note=note,
+        region=region,
+        street=street,
+        timezone=timezone,
+        url=url,
     )
     count = (
         await session.execute(
diff --git a/backend/msm/schema/__init__.py b/backend/msm/schema/__init__.py
new file mode 100644
index 0000000..1865bd1
--- /dev/null
+++ b/backend/msm/schema/__init__.py
@@ -0,0 +1,29 @@
+"""API schema definitions."""
+
+from ._models import (
+    CreateTokensRequest,
+    CreateTokensResponse,
+    PaginatedSites,
+    PaginatedTokens,
+    Site,
+    Token,
+)
+from ._pagination import (
+    MAX_PAGE_SIZE,
+    PaginatedResults,
+    pagination_params,
+    PaginationParams,
+)
+
+__all__ = [
+    "CreateTokensRequest",
+    "CreateTokensResponse",
+    "Site",
+    "Token",
+    "pagination_params",
+    "PaginationParams",
+    "PaginatedResults",
+    "PaginatedSites",
+    "PaginatedTokens",
+    "MAX_PAGE_SIZE",
+]
diff --git a/backend/msm/schema.py b/backend/msm/schema/_models.py
similarity index 93%
rename from backend/msm/schema.py
rename to backend/msm/schema/_models.py
index 69c4a66..5f71293 100644
--- a/backend/msm/schema.py
+++ b/backend/msm/schema/_models.py
@@ -1,4 +1,3 @@
-from collections.abc import Sequence
 from datetime import (
     datetime,
     timedelta,
@@ -12,19 +11,7 @@ from pydantic import (
 )
 from pydantic.fields import Field
 
-from msm import MAX_PAGE_SIZE
-
-
-class PaginatedResults(BaseModel):
-    """
-    Base class to wrap objects in a pagination.
-    Derived classes should overwrite the items property
-    """
-
-    total: int = Field(min=0)
-    page: int = Field(min=0)
-    size: int = Field(min=0, max=MAX_PAGE_SIZE)
-    items: Sequence[BaseModel]
+from ._pagination import PaginatedResults
 
 
 class CreateUser(BaseModel):
diff --git a/backend/msm/schema/_pagination.py b/backend/msm/schema/_pagination.py
new file mode 100644
index 0000000..c6b2bb4
--- /dev/null
+++ b/backend/msm/schema/_pagination.py
@@ -0,0 +1,39 @@
+from collections.abc import Sequence
+from typing import NamedTuple
+
+from fastapi import Query
+from pydantic import (
+    BaseModel,
+    Field,
+)
+
+DEFAULT_PAGE_SIZE = 20
+MAX_PAGE_SIZE = 100
+
+
+class PaginatedResults(BaseModel):
+    """
+    Base class to wrap objects in a pagination.
+    Derived classes should overwrite the items property
+    """
+
+    total: int = Field(min=0)
+    page: int = Field(min=0)
+    size: int = Field(min=0, max=MAX_PAGE_SIZE)
+    items: Sequence[BaseModel]
+
+
+class PaginationParams(NamedTuple):
+    """Pagination parameters."""
+
+    page: int
+    size: int
+    offset: int
+
+
+async def pagination_params(
+    page: int = Query(default=1, gte=1),
+    size: int = Query(default=DEFAULT_PAGE_SIZE, lte=MAX_PAGE_SIZE, gte=1),
+) -> PaginationParams:
+    """Return pagination parameters."""
+    return PaginationParams(page=page, size=size, offset=(page - 1) * size)
diff --git a/backend/msm/user_api/_base.py b/backend/msm/user_api/_base.py
index a3797a2..38d645c 100644
--- a/backend/msm/user_api/_base.py
+++ b/backend/msm/user_api/_base.py
@@ -1,86 +1,70 @@
-from fastapi import (
-    Depends,
-    Query,
-)
+from fastapi import Depends
 from sqlalchemy.ext.asyncio import AsyncSession
 
-from .. import (
-    __version__,
-    DEFAULT_PAGE_SIZE,
-    MAX_PAGE_SIZE,
-    schema,
-)
+from .. import __version__
 from ..db import (
     db_session,
     queries,
 )
-
-
-async def pagination_parameters(
-    page: int = Query(default=1, gte=1),
-    size: int = Query(default=DEFAULT_PAGE_SIZE, lte=MAX_PAGE_SIZE, gte=1),
-) -> dict[str, int]:
-    """Make parameters for pagination accessible as a dict"""
-    return {"page": page, "size": size, "offset": (page - 1) * size}
+from ..schema import (
+    CreateTokensRequest,
+    CreateTokensResponse,
+    PaginatedSites,
+    PaginatedTokens,
+    pagination_params,
+    PaginationParams,
+)
+from ._forms import (
+    site_filter_parameters,
+    SiteFilterParams,
+)
 
 
 async def root() -> dict[str, str]:
+    """Root endpoint."""
     return {"version": __version__}
 
 
 async def sites(
-    city: list[str] | None = Query(default=None, title="Filter for cities"),
-    name: list[str] | None = Query(default=None, title="Filter for names"),
-    note: list[str] | None = Query(default=None, title="Filter for notes"),
-    region: list[str] | None = Query(default=None, title="Filter for regions"),
-    street: list[str] | None = Query(default=None, title="Filter for streets"),
-    timezone: list[str]
-    | None = Query(default=None, title="Filter for timezones"),
-    url: list[str] | None = Query(default=None, title="Filter for urls"),
     session: AsyncSession = Depends(db_session),
-    pagination_params: dict[str, int] = Depends(pagination_parameters),
-) -> schema.PaginatedSites:
-    """Return all sites"""
+    pagination_params: PaginationParams = Depends(pagination_params),
+    filter_params: SiteFilterParams = Depends(site_filter_parameters),
+) -> PaginatedSites:
+    """Return all sites."""
     total, results = await queries.get_filtered_sites(
         session,
-        pagination_params["offset"],
-        pagination_params["size"],
-        city,
-        name,
-        note,
-        region,
-        street,
-        timezone,
-        url,
+        offset=pagination_params.offset,
+        limit=pagination_params.size,
+        **filter_params._asdict(),
     )
-    return schema.PaginatedSites(
+    return PaginatedSites(
         total=total,
-        page=pagination_params["page"],
-        size=pagination_params["size"],
+        page=pagination_params.page,
+        size=pagination_params.size,
         items=list(results),
     )
 
 
 async def tokens(
     session: AsyncSession = Depends(db_session),
-    pagination_params: dict[str, int] = Depends(pagination_parameters),
-) -> schema.PaginatedTokens:
+    pagination_params: PaginationParams = Depends(pagination_params),
+) -> PaginatedTokens:
     """Return all tokens"""
     total, results = await queries.get_tokens(
-        session, pagination_params["offset"], pagination_params["size"]
+        session, pagination_params.offset, pagination_params.size
     )
-    return schema.PaginatedTokens(
+    return PaginatedTokens(
         total=total,
-        page=pagination_params["page"],
-        size=pagination_params["size"],
+        page=pagination_params.page,
+        size=pagination_params.size,
         items=list(results),
     )
 
 
 async def tokens_post(
-    create_request: schema.CreateTokensRequest,
+    create_request: CreateTokensRequest,
     session: AsyncSession = Depends(db_session),
-) -> schema.CreateTokensResponse:
+) -> CreateTokensResponse:
     """
     Create one or more tokens.
     Token duration (TTL) is expressed in seconds.
@@ -90,4 +74,4 @@ async def tokens_post(
         create_request.duration,
         count=create_request.count,
     )
-    return schema.CreateTokensResponse(expired=expired, tokens=tokens)
+    return CreateTokensResponse(expired=expired, tokens=tokens)
diff --git a/backend/msm/user_api/_forms.py b/backend/msm/user_api/_forms.py
new file mode 100644
index 0000000..d6d2cf7
--- /dev/null
+++ b/backend/msm/user_api/_forms.py
@@ -0,0 +1,41 @@
+from typing import NamedTuple
+
+from fastapi import Query
+
+
+class SiteFilterParams(NamedTuple):
+    """Site filtering parameters."""
+
+    city: list[str] | None
+    country: list[str] | None
+    name: list[str] | None
+    note: list[str] | None
+    region: list[str] | None
+    street: list[str] | None
+    timezone: list[str] | None
+    url: list[str] | None
+
+
+async def site_filter_parameters(
+    city: list[str] | None = Query(default=None, title="Filter for cities"),
+    country: list[str]
+    | None = Query(default=None, title="Filter for country"),
+    name: list[str] | None = Query(default=None, title="Filter for names"),
+    note: list[str] | None = Query(default=None, title="Filter for notes"),
+    region: list[str] | None = Query(default=None, title="Filter for regions"),
+    street: list[str] | None = Query(default=None, title="Filter for streets"),
+    timezone: list[str]
+    | None = Query(default=None, title="Filter for timezones"),
+    url: list[str] | None = Query(default=None, title="Filter for urls"),
+) -> SiteFilterParams:
+    """Return parameters for site filtering."""
+    return SiteFilterParams(
+        city=city,
+        country=country,
+        name=name,
+        note=note,
+        region=region,
+        street=street,
+        timezone=timezone,
+        url=url,
+    )

Follow ups