← Back to team overview

sts-sponsors team mailing list archive

[Merge] ~adam-collard/maas-site-manager:MAASENG-1290-add-filter-to-sites into maas-site-manager:main

 

Adam Collard has proposed merging ~adam-collard/maas-site-manager:MAASENG-1290-add-filter-to-sites into maas-site-manager:main.

Commit message:
feat: add filters to sites list view

Use Pydantic models in queries

see MAASENG-1390 and MAASENG-1290


Requested reviews:
  MAAS Committers (maas-committers)

For more details, see:
https://code.launchpad.net/~adam-collard/maas-site-manager/+git/site-manager/+merge/439240
-- 
Your team MAAS Committers is requested to review the proposed merge of ~adam-collard/maas-site-manager:MAASENG-1290-add-filter-to-sites into maas-site-manager:main.
diff --git a/backend/msm/db/queries.py b/backend/msm/db/queries.py
index ce2bd93..c50c798 100644
--- a/backend/msm/db/queries.py
+++ b/backend/msm/db/queries.py
@@ -1,21 +1,115 @@
+from __future__ import annotations
+
 from collections.abc import Iterable
 from datetime import (
     datetime,
     timedelta,
 )
-from typing import Any
+from typing import TYPE_CHECKING
 from uuid import UUID
 
-from sqlalchemy import select
+from sqlalchemy import (
+    and_,
+    or_,
+    select,
+    Table,
+    true,
+)
 from sqlalchemy.ext.asyncio import AsyncSession
 
+from ..schema import (
+    Site as SiteSchema,
+    Token as TokenSchema,
+)
 from ._tables import (
     Site,
     Token,
 )
 
+if TYPE_CHECKING:
+    from sqlalchemy import (
+        ColumnElement,
+        Operators,
+    )
+    from sqlalchemy.sql._typing import _ColumnExpressionArgument
+
+
+def filters_from_arguments(
+    table: Table,
+    **kwargs: list[str] | None,
+) -> Iterable[Operators]:
+    """
+    Yields clauses to join with AND and all entries for a single arg by OR.
+    This enables to convert query params such as
+
+      ?name=name1&name=name2&city=city
+
+    to a where clause such as
+
+      (name ilike %name1% OR name ilike %name2%) AND city ilike %city%
+
+    :param table: the table to create the WHERE clause for
+    :param kwargs: the parameters matching the table's column name
+                   as keys and lists of strings that will be matched
+                   via ilike
+    :returns: a generator yielding where clause that joins all queries
+              per column with OR and all columns with AND
+    """
+    for dimension, needles in kwargs.items():
+        column = table.c[dimension]
+
+        match needles:
+            case [needle]:
+                # If there's only one we don't need any ORs
+                yield column.icontains(needle, autoescape=True)
+            case [needle, *other_needles]:
+                # More than one thing to match against, join them with OR
+                clause = column.icontains(needle, autoescape=True) | False
+                for needle in other_needles:
+                    clause |= column.icontains(needle, autoescape=True)
+                yield clause
+
+
+async def get_filtered_sites(
+    session: AsyncSession,
+    city: list[str] | None = [],
+    name: list[str] | None = [],
+    note: list[str] | None = [],
+    region: list[str] | None = [],
+    street: list[str] | None = [],
+    timezone: list[str] | None = [],
+    url: list[str] | None = [],
+) -> Iterable[SiteSchema]:
+    filters = filters_from_arguments(
+        Site,
+        city=city,
+        name=name,
+        note=note,
+        region=region,
+        street=street,
+        timezone=timezone,
+        url=url,
+    )
+    stmt = select(
+        Site.c.id,
+        Site.c.name,
+        Site.c.identifier,
+        Site.c.city,
+        Site.c.latitude,
+        Site.c.longitude,
+        Site.c.note,
+        Site.c.region,
+        Site.c.street,
+        Site.c.timezone,
+        Site.c.url,
+    )
+    for clause in filters:
+        stmt = stmt.where(clause)  # type: ignore
+    result = await session.execute(stmt)
+    return (SiteSchema(**row._asdict()) for row in result.all())
+
 
-async def get_sites(session: AsyncSession) -> Iterable[dict[str, Any]]:
+async def get_sites(session: AsyncSession) -> Iterable[SiteSchema]:
     stmt = select(
         Site.c.id,
         Site.c.name,
@@ -30,14 +124,14 @@ async def get_sites(session: AsyncSession) -> Iterable[dict[str, Any]]:
         Site.c.url,
     )
     result = await session.execute(stmt)
-    return (row._asdict() for row in result.all())
+    return (SiteSchema(**row._asdict()) for row in result.all())
 
 
-async def get_tokens(session: AsyncSession) -> Iterable[dict[str, Any]]:
+async def get_tokens(session: AsyncSession) -> Iterable[TokenSchema]:
     result = await session.execute(
         select(Token.c.id, Token.c.site_id, Token.c.value, Token.c.expiration)
     )
-    return (row._asdict() for row in result.all())
+    return (TokenSchema(**row._asdict()) for row in result.all())
 
 
 async def create_tokens(
diff --git a/backend/msm/user_api/_schema.py b/backend/msm/schema.py
similarity index 100%
rename from backend/msm/user_api/_schema.py
rename to backend/msm/schema.py
diff --git a/backend/msm/user_api/_base.py b/backend/msm/user_api/_base.py
index d1e72f1..1d49a7d 100644
--- a/backend/msm/user_api/_base.py
+++ b/backend/msm/user_api/_base.py
@@ -1,8 +1,13 @@
-from fastapi import Depends
+from fastapi import (
+    Depends,
+    Query,
+)
 from sqlalchemy.ext.asyncio import AsyncSession
 
-from . import _schema as schema
-from .. import __version__
+from .. import (
+    __version__,
+    schema,
+)
 from ..db import (
     db_session,
     queries,
@@ -14,19 +19,36 @@ async def root() -> dict[str, str]:
 
 
 async def sites(
+    city: list[str] | None = Query(default=None, title="Filter for cities"),
+    name: list[str] | None = Query(default=None, title="Filter for names"),
+    note: list[str] | None = Query(default=None, title="Filter for notes"),
+    region: list[str] | None = Query(default=None, title="Filter for regions"),
+    street: list[str] | None = Query(default=None, title="Filter for streets"),
+    timezone: list[str]
+    | None = Query(default=None, title="Filter for timezones"),
+    url: list[str] | None = Query(default=None, title="Filter for urls"),
     session: AsyncSession = Depends(db_session),
 ) -> list[schema.Site]:
     """Return all sites"""
-    return [schema.Site(**entry) for entry in await queries.get_sites(session)]
+    return list(
+        await queries.get_filtered_sites(
+            session,
+            city,
+            name,
+            note,
+            region,
+            street,
+            timezone,
+            url,
+        )
+    )
 
 
 async def tokens(
     session: AsyncSession = Depends(db_session),
 ) -> list[schema.Token]:
     """Return all tokens"""
-    return [
-        schema.Token(**entry) for entry in await queries.get_tokens(session)
-    ]
+    return list(await queries.get_tokens(session))
 
 
 async def tokens_post(
diff --git a/backend/msm/user_api/tests/test_handlers.py b/backend/msm/user_api/tests/test_handlers.py
index 0fd45a2..20f9b5e 100644
--- a/backend/msm/user_api/tests/test_handlers.py
+++ b/backend/msm/user_api/tests/test_handlers.py
@@ -35,10 +35,12 @@ async def test_list_sites(
     site2 = site1.copy()
     site2["id"] = 2
     site2["identifier"] = "site two"
+    site2["name"] = "BerlinHQ"
+    site2["city"] = "Berlin"
     await fixture.create("site", [site1, site2])
-    response = user_app_client.get("/sites")
+    response = user_app_client.get("/sites?city=onDo")  # vs London
     assert response.status_code == 200
-    assert response.json() == [site1, site2]
+    assert response.json() == [site1]
 
 
 @pytest.mark.asyncio
diff --git a/backend/pyproject.toml b/backend/pyproject.toml
index f5c7d8a..c4b8d97 100644
--- a/backend/pyproject.toml
+++ b/backend/pyproject.toml
@@ -25,3 +25,6 @@ non_interactive = true
 strict = true
 warn_return_any = true
 warn_unused_configs = true
+plugins = [
+  "pydantic.mypy"
+]

Follow ups