← Back to team overview

sts-sponsors team mailing list archive

[Merge] ~ack/maas-site-manager:token-csv-export into maas-site-manager:main

 

Alberto Donato has proposed merging ~ack/maas-site-manager:token-csv-export into maas-site-manager:main.

Commit message:
add support for CSV export of active tokens



Requested reviews:
  MAAS Lander (maas-lander): unittests
  MAAS Committers (maas-committers)

For more details, see:
https://code.launchpad.net/~ack/maas-site-manager/+git/site-manager/+merge/443428
-- 
Your team MAAS Committers is requested to review the proposed merge of ~ack/maas-site-manager:token-csv-export into maas-site-manager:main.
diff --git a/backend/msm/db/queries.py b/backend/msm/db/queries.py
index bba64d6..adfdccf 100644
--- a/backend/msm/db/queries.py
+++ b/backend/msm/db/queries.py
@@ -1,11 +1,13 @@
-from collections.abc import Iterable
 from datetime import (
     datetime,
     timedelta,
 )
 from functools import reduce
 from operator import or_
-from typing import Any
+from typing import (
+    Any,
+    Iterable,
+)
 from uuid import UUID
 
 from sqlalchemy import (
@@ -145,7 +147,7 @@ async def get_sites(
     street: list[str] | None = None,
     timezone: list[str] | None = None,
     url: list[str] | None = None,
-) -> tuple[int, Iterable[SiteSchema]]:
+) -> tuple[int, list[SiteSchema]]:
     filters = filters_from_arguments(
         Site,
         city=city,
@@ -220,7 +222,7 @@ async def get_pending_sites(
     session: AsyncSession,
     offset: int = 0,
     limit: int | None = None,
-) -> tuple[int, Iterable[PendingSiteSchema]]:
+) -> tuple[int, list[PendingSiteSchema]]:
     filters = [Site.c.accepted == False]  # noqa
     count = await row_count(session, Site, *filters)
     stmt = (
@@ -273,7 +275,7 @@ async def get_tokens(
     session: AsyncSession,
     offset: int = 0,
     limit: int | None = None,
-) -> tuple[int, Iterable[TokenSchema]]:
+) -> tuple[int, list[TokenSchema]]:
     count = await row_count(session, Token)
     stmt = (
         select(
@@ -293,6 +295,22 @@ async def get_tokens(
     return count, [TokenSchema(**row._asdict()) for row in result.all()]
 
 
+async def get_active_tokens(session: AsyncSession) -> list[TokenSchema]:
+    result = await session.execute(
+        select(
+            Token.c.id,
+            Token.c.site_id,
+            Token.c.value,
+            Token.c.expired,
+            Token.c.created,
+        )
+        .select_from(Token)
+        .where(Token.c.expired > func.now())
+        .order_by(Token.c.id)
+    )
+    return [TokenSchema(**row._asdict()) for row in result.all()]
+
+
 async def create_tokens(
     session: AsyncSession, duration: timedelta, count: int = 1
 ) -> tuple[datetime, list[UUID]]:
diff --git a/backend/msm/user_api/_csv.py b/backend/msm/user_api/_csv.py
new file mode 100644
index 0000000..f0e3c4e
--- /dev/null
+++ b/backend/msm/user_api/_csv.py
@@ -0,0 +1,24 @@
+import csv
+from io import StringIO
+
+from fastapi import Response
+from pydantic import BaseModel
+
+
+class CSVResponse(Response):
+    """Return a CSV response serializing a list of pydantic models."""
+
+    media_type = "text/csv"
+
+    def render(self, content: list[BaseModel]) -> bytes:
+        if not content:
+            return b""
+
+        model_fields = list(content[0].__fields__)
+        stream = StringIO()
+
+        writer = csv.writer(stream)
+        writer.writerow(model_fields)
+        for entry in content:
+            writer.writerow((value for key, value in entry))
+        return stream.getvalue().encode()
diff --git a/backend/msm/user_api/_handlers.py b/backend/msm/user_api/_handlers.py
index 544f656..0ea2f20 100644
--- a/backend/msm/user_api/_handlers.py
+++ b/backend/msm/user_api/_handlers.py
@@ -24,6 +24,7 @@ from ..schema._sorting import (
     SortParamParser,
 )
 from ..settings import SETTINGS
+from ._csv import CSVResponse
 from ._forms import (
     site_filter_parameters,
     SiteFilterParams,
@@ -143,6 +144,14 @@ async def tokens_get(
     )
 
 
+async def tokens_export_get(
+    session: Annotated[AsyncSession, Depends(db_session)],
+    authenticated_user: Annotated[User, Depends(get_authenticated_user)],
+) -> CSVResponse:
+    tokens = await queries.get_active_tokens(session)
+    return CSVResponse(content=tokens)
+
+
 async def tokens_post(
     session: Annotated[AsyncSession, Depends(db_session)],
     authenticated_user: Annotated[User, Depends(get_authenticated_user)],
diff --git a/backend/msm/user_api/_setup.py b/backend/msm/user_api/_setup.py
index 02aee09..326bacf 100644
--- a/backend/msm/user_api/_setup.py
+++ b/backend/msm/user_api/_setup.py
@@ -50,6 +50,9 @@ def create_app(db_dsn: str | None = None) -> FastAPI:
     app.router.add_api_route("/sites", _handlers.sites_get, methods=["GET"])
     app.router.add_api_route("/tokens", _handlers.tokens_get, methods=["GET"])
     app.router.add_api_route(
+        "/tokens/export", _handlers.tokens_export_get, methods=["GET"]
+    )
+    app.router.add_api_route(
         "/tokens", _handlers.tokens_post, methods=["POST"]
     )
     app.router.add_api_route(
diff --git a/backend/tests/user_api/test_csv.py b/backend/tests/user_api/test_csv.py
new file mode 100644
index 0000000..3fbf56f
--- /dev/null
+++ b/backend/tests/user_api/test_csv.py
@@ -0,0 +1,37 @@
+from datetime import (
+    datetime,
+    timedelta,
+)
+import uuid
+
+from msm.db.models import Token
+from msm.user_api._csv import CSVResponse
+
+
+def isoformat(t: datetime) -> str:
+    return t.isoformat(sep=" ")
+
+
+def test_csv_response() -> None:
+    uuid1 = str(uuid.uuid4())
+    uuid2 = str(uuid.uuid4())
+    now = datetime.utcnow()
+    created1 = now - timedelta(hours=1)
+    expired1 = now + timedelta(hours=1)
+    created2 = now - timedelta(hours=2)
+    expired2 = now + timedelta(hours=2)
+    tokens = [
+        Token(
+            id=1, value=uuid1, site_id=10, expired=expired1, created=created1
+        ),
+        Token(
+            id=2, value=uuid2, site_id=20, expired=expired2, created=created2
+        ),
+    ]
+    response = CSVResponse(content=tokens)
+
+    assert response.body.decode() == (
+        "id,value,site_id,expired,created\r\n"
+        f"1,{uuid1},10,{isoformat(expired1)},{isoformat(created1)}\r\n"
+        f"2,{uuid2},20,{isoformat(expired2)},{isoformat(created2)}\r\n"
+    )