← Back to team overview

sts-sponsors team mailing list archive

[Merge] ~ack/maas-site-manager:timezone-filter-fix into maas-site-manager:main

 

Alberto Donato has proposed merging ~ack/maas-site-manager:timezone-filter-fix into maas-site-manager:main with ~ack/maas-site-manager:filters-group as a prerequisite.

Commit message:
fix comparison in filter for non-text fields (timezone)

Requested reviews:
  MAAS Committers (maas-committers)

For more details, see:
https://code.launchpad.net/~ack/maas-site-manager/+git/site-manager/+merge/441168
-- 
Your team MAAS Committers is requested to review the proposed merge of ~ack/maas-site-manager:timezone-filter-fix into maas-site-manager:main.
diff --git a/backend/msm/db/queries.py b/backend/msm/db/queries.py
index f75dcc7..047ef72 100644
--- a/backend/msm/db/queries.py
+++ b/backend/msm/db/queries.py
@@ -5,6 +5,7 @@ from datetime import (
 )
 from functools import reduce
 from operator import or_
+from typing import Any
 from uuid import UUID
 
 from sqlalchemy import (
@@ -12,7 +13,9 @@ from sqlalchemy import (
     ColumnOperators,
     func,
     select,
+    String,
     Table,
+    Text,
 )
 from sqlalchemy.ext.asyncio import AsyncSession
 
@@ -30,7 +33,7 @@ from ._tables import (
 
 def filters_from_arguments(
     table: Table,
-    **filter_args: list[str] | None,
+    **filter_args: list[Any] | None,
 ) -> list[ColumnOperators]:
     """Yields clauses to join with AND and all entries for a single arg by OR.
     This enables to convert query params such as
@@ -44,17 +47,25 @@ def filters_from_arguments(
     :param table: the table to create the WHERE clause for
     :param filter_args: the parameters matching the table's column name
                         as keys and lists of strings that will be matched
-                        via ilike
     :returns: a list with clauses to filter table values, which are meant to be
               used in AND. Clauses for each column are joined with OR.
+
+    Matching is performed using `ilike` for text-based fields, exact match
+    otherwise.
+
     """
+
+    def compare_expr(name: str, value: Any) -> ColumnOperators:
+        column = table.c[name]
+        if isinstance(column.type, (Text, String)):
+            return column.icontains(value, autoescape=True)
+        else:
+            return column.__eq__(value)
+
     return [
         reduce(
             or_,
-            (
-                table.c[name].icontains(value, autoescape=True)
-                for value in values
-            ),
+            (compare_expr(name, value) for value in values),
         )
         for name, values in filter_args.items()
         if values
@@ -71,7 +82,7 @@ async def get_filtered_sites(
     note: list[str] | None = None,
     region: list[str] | None = None,
     street: list[str] | None = None,
-    timezone: list[str] | None = None,
+    timezone: list[float] | None = None,
     url: list[str] | None = None,
 ) -> tuple[int, Iterable[SiteSchema]]:
     filters = filters_from_arguments(
diff --git a/backend/msm/user_api/_forms.py b/backend/msm/user_api/_forms.py
index d6d2cf7..0f2d9b6 100644
--- a/backend/msm/user_api/_forms.py
+++ b/backend/msm/user_api/_forms.py
@@ -12,7 +12,7 @@ class SiteFilterParams(NamedTuple):
     note: list[str] | None
     region: list[str] | None
     street: list[str] | None
-    timezone: list[str] | None
+    timezone: list[float] | None
     url: list[str] | None
 
 
@@ -24,7 +24,7 @@ async def site_filter_parameters(
     note: list[str] | None = Query(default=None, title="Filter for notes"),
     region: list[str] | None = Query(default=None, title="Filter for regions"),
     street: list[str] | None = Query(default=None, title="Filter for streets"),
-    timezone: list[str]
+    timezone: list[float]
     | None = Query(default=None, title="Filter for timezones"),
     url: list[str] | None = Query(default=None, title="Filter for urls"),
 ) -> SiteFilterParams:
diff --git a/backend/tests/user_api/test_handlers.py b/backend/tests/user_api/test_handlers.py
index a36c5aa..9a2511e 100644
--- a/backend/tests/user_api/test_handlers.py
+++ b/backend/tests/user_api/test_handlers.py
@@ -89,6 +89,40 @@ async def test_list_sites(
 
 
 @pytest.mark.asyncio
+async def test_list_sites_filter_timezone(
+    user_app_client: AsyncClient, fixture: Fixture
+) -> None:
+    site1 = {
+        "name": "LondonHQ",
+        "city": "London",
+        "country": "gb",
+        "latitude": "51.509865",
+        "longitude": "-0.118092",
+        "note": "the first site",
+        "region": "Blue Fin Bldg",
+        "street": "110 Southwark St",
+        "timezone": "3.00",
+        "url": "https://londoncalling.example.com";,
+    }
+    site2 = site1.copy()
+    site2["name"] = "BerlinHQ"
+    site2["timezone"] = "1.00"
+    site2["city"] = "Berlin"
+    site2["country"] = "de"
+    [created_site, _] = await fixture.create("site", [site1, site2])
+    created_site["timezone"] = str(created_site["timezone"])
+    created_site["stats"] = None
+    page1 = await user_app_client.get("/sites?timezone=3.0")
+    assert page1.status_code == 200
+    assert page1.json() == {
+        "page": 1,
+        "size": 20,
+        "total": 1,
+        "items": [created_site],
+    }
+
+
+@pytest.mark.asyncio
 async def test_list_sites_with_stats(
     user_app_client: AsyncClient, fixture: Fixture
 ) -> None:

Follow ups