← Back to team overview

sts-sponsors team mailing list archive

[Merge] ~ack/maas-site-manager:pending-sites-action into maas-site-manager:main

 

Alberto Donato has proposed merging ~ack/maas-site-manager:pending-sites-action into maas-site-manager:main.

Commit message:
add POST handler for /requests



Requested reviews:
  MAAS Committers (maas-committers)

For more details, see:
https://code.launchpad.net/~ack/maas-site-manager/+git/site-manager/+merge/442472
-- 
Your team MAAS Committers is requested to review the proposed merge of ~ack/maas-site-manager:pending-sites-action into maas-site-manager:main.
diff --git a/backend/msm/db/queries.py b/backend/msm/db/queries.py
index 3e51fbb..2b83dbf 100644
--- a/backend/msm/db/queries.py
+++ b/backend/msm/db/queries.py
@@ -12,11 +12,13 @@ from uuid import UUID
 from sqlalchemy import (
     case,
     ColumnOperators,
+    delete,
     func,
     select,
     String,
     Table,
     Text,
+    update,
 )
 from sqlalchemy.ext.asyncio import AsyncSession
 
@@ -35,6 +37,14 @@ from .models import (
 )
 
 
+class InvalidPendingSites(Exception):
+    """Raised when unknown pending site IDs are provided."""
+
+    def __init__(self, ids: Iterable[int]):
+        self.ids = sorted(ids)
+        super().__init__("Unknown pending sites")
+
+
 async def get_user(
     session: AsyncSession, email: str
 ) -> UserWithPasswordSchema | None:
@@ -202,6 +212,34 @@ async def get_pending_sites(
     return count, [PendingSiteSchema(**row._asdict()) for row in result.all()]
 
 
+async def accept_reject_pending_sites(
+    session: AsyncSession,
+    ids: list[int],
+    accept: bool,
+) -> None:
+    site_ids: set[int] = set(ids)
+    stmt = (
+        select(Site.c.id)
+        .select_from(Site)
+        .where(
+            Site.c.id.in_(site_ids),
+            Site.c.accepted == False,  # noqa
+        )
+    )
+    result = await session.execute(stmt)
+    pending_ids = set(row[0] for row in result.all())
+    if unknown_ids := site_ids - pending_ids:
+        raise InvalidPendingSites(unknown_ids)
+
+    if accept:
+        await session.execute(
+            update(Site).where(Site.c.id.in_(site_ids)).values(accepted=True)
+        )
+    else:
+        await session.execute(delete(Site).where(Site.c.id.in_(site_ids)))
+    return None
+
+
 async def get_tokens(
     session: AsyncSession,
     offset: int = 0,
diff --git a/backend/msm/user_api/_handlers.py b/backend/msm/user_api/_handlers.py
index 7d0bfac..04d1e4d 100644
--- a/backend/msm/user_api/_handlers.py
+++ b/backend/msm/user_api/_handlers.py
@@ -14,6 +14,7 @@ from ..db import (
     queries,
 )
 from ..db.models import User
+from ..db.queries import InvalidPendingSites
 from ..schema import (
     pagination_params,
     PaginationParams,
@@ -35,6 +36,7 @@ from ._schema import (
     PaginatedPendingSites,
     PaginatedSites,
     PaginatedTokens,
+    PendingSitesActionRequest,
     UserLoginRequest,
 )
 
@@ -45,8 +47,8 @@ async def root() -> dict[str, str]:
 
 
 async def sites(
+    session: Annotated[AsyncSession, Depends(db_session)],
     authenticated_user: Annotated[User, Depends(get_authenticated_user)],
-    session: AsyncSession = Depends(db_session),
     pagination_params: PaginationParams = Depends(pagination_params),
     filter_params: SiteFilterParams = Depends(site_filter_parameters),
 ) -> PaginatedSites:
@@ -66,8 +68,8 @@ async def sites(
 
 
 async def pending_sites(
+    session: Annotated[AsyncSession, Depends(db_session)],
     authenticated_user: Annotated[User, Depends(get_authenticated_user)],
-    session: AsyncSession = Depends(db_session),
     pagination_params: PaginationParams = Depends(pagination_params),
 ) -> PaginatedPendingSites:
     """Return pending sites."""
@@ -84,9 +86,30 @@ async def pending_sites(
     )
 
 
+async def pending_sites_post(
+    session: Annotated[AsyncSession, Depends(db_session)],
+    authenticated_user: Annotated[User, Depends(get_authenticated_user)],
+    action: PendingSitesActionRequest,
+) -> None:
+    """Accept or reject pending sites."""
+    try:
+        await queries.accept_reject_pending_sites(
+            session,
+            action.ids,
+            action.accept,
+        )
+    except InvalidPendingSites as error:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail={"message": str(error), "ids": error.ids},
+        )
+
+    return None
+
+
 async def tokens(
+    session: Annotated[AsyncSession, Depends(db_session)],
     authenticated_user: Annotated[User, Depends(get_authenticated_user)],
-    session: AsyncSession = Depends(db_session),
     pagination_params: PaginationParams = Depends(pagination_params),
 ) -> PaginatedTokens:
     """Return all tokens"""
@@ -102,9 +125,9 @@ async def tokens(
 
 
 async def tokens_post(
+    session: Annotated[AsyncSession, Depends(db_session)],
     authenticated_user: Annotated[User, Depends(get_authenticated_user)],
     create_request: CreateTokensRequest,
-    session: AsyncSession = Depends(db_session),
 ) -> CreateTokensResponse:
     """
     Create one or more tokens.
@@ -119,8 +142,8 @@ async def tokens_post(
 
 
 async def login_for_access_token(
+    session: Annotated[AsyncSession, Depends(db_session)],
     user_login: UserLoginRequest,
-    session: AsyncSession = Depends(db_session),
 ) -> JSONWebToken:
     user = await authenticate_user(
         session, user_login.username, user_login.password
@@ -141,7 +164,7 @@ async def login_for_access_token(
 
 
 async def read_users_me(
+    session: Annotated[AsyncSession, Depends(db_session)],
     authenticated_user: Annotated[User, Depends(get_authenticated_user)],
-    session: AsyncSession = Depends(db_session),
 ) -> User:
     return authenticated_user
diff --git a/backend/msm/user_api/_schema.py b/backend/msm/user_api/_schema.py
index af1da4f..f481522 100644
--- a/backend/msm/user_api/_schema.py
+++ b/backend/msm/user_api/_schema.py
@@ -31,6 +31,13 @@ class CreateTokensResponse(BaseModel):
     tokens: list[UUID]
 
 
+class PendingSitesActionRequest(BaseModel):
+    """Request to accept/reject sites."""
+
+    ids: list[int]
+    accept: bool
+
+
 class PaginatedSites(PaginatedResults):
     items: list[Site]
 
diff --git a/backend/msm/user_api/_setup.py b/backend/msm/user_api/_setup.py
index d9e976d..8a7f332 100644
--- a/backend/msm/user_api/_setup.py
+++ b/backend/msm/user_api/_setup.py
@@ -43,6 +43,12 @@ def create_app(db_dsn: str | None = None) -> FastAPI:
     app.router.add_api_route(
         "/requests", _handlers.pending_sites, methods=["GET"]
     )
+    app.router.add_api_route(
+        "/requests",
+        _handlers.pending_sites_post,
+        methods=["POST"],
+        status_code=204,
+    )
     app.router.add_api_route("/sites", _handlers.sites, methods=["GET"])
     app.router.add_api_route("/tokens", _handlers.tokens, methods=["GET"])
     app.router.add_api_route(
diff --git a/backend/tests/fixtures/db.py b/backend/tests/fixtures/db.py
index 7e326a3..366e354 100644
--- a/backend/tests/fixtures/db.py
+++ b/backend/tests/fixtures/db.py
@@ -7,7 +7,10 @@ from typing import (
 import pytest
 from pytest_postgresql.executor import PostgreSQLExecutor
 from pytest_postgresql.janitor import DatabaseJanitor
-from sqlalchemy import create_engine
+from sqlalchemy import (
+    ColumnOperators,
+    create_engine,
+)
 
 from msm.db import (
     Database,
@@ -87,14 +90,18 @@ class Fixture:
             await session.commit()
             return [row._asdict() for row in result]
 
-    async def select_all(
+    async def get(
         self,
         table: str,
+        *filters: ColumnOperators,
     ) -> list[dict[str, Any]]:
         """Take a peak what is in there"""
         async with self.db.session() as session:
-            result = await session.execute(METADATA.tables[table].select())
-            await session.commit()
+            result = await session.execute(
+                METADATA.tables[table]
+                .select()
+                .where(*filters)  # type: ignore[arg-type]
+            )
             return [row._asdict() for row in result]
 
 
diff --git a/backend/tests/user_api/test_handlers.py b/backend/tests/user_api/test_handlers.py
index a1301f0..e31c0bf 100644
--- a/backend/tests/user_api/test_handlers.py
+++ b/backend/tests/user_api/test_handlers.py
@@ -290,6 +290,45 @@ async def test_list_pending_sites(
 
 
 @pytest.mark.asyncio
+async def test_accept_pending_sites(
+    authenticated_user_app_client: AuthAsyncClient, fixture: Fixture
+) -> None:
+    site = {
+        "name": "LondonHQ",
+        "url": "https://londoncalling.example.com";,
+        "accepted": False,
+    }
+    [pending_site] = await fixture.create("site", [site])
+
+    response = await authenticated_user_app_client.post(
+        "/requests",
+        json={"ids": [pending_site["id"]], "accept": True},
+    )
+    assert response.status_code == 204
+    [created_site] = await fixture.get("site")
+    assert created_site["accepted"]
+
+
+@pytest.mark.asyncio
+async def test_reject_pending_sites(
+    authenticated_user_app_client: AuthAsyncClient, fixture: Fixture
+) -> None:
+    site = {
+        "name": "LondonHQ",
+        "url": "https://londoncalling.example.com";,
+        "accepted": False,
+    }
+    [pending_site] = await fixture.create("site", [site])
+
+    response = await authenticated_user_app_client.post(
+        "/requests",
+        json={"ids": [pending_site["id"]], "accept": False},
+    )
+    assert response.status_code == 204
+    assert await fixture.get("site") == []
+
+
+@pytest.mark.asyncio
 @pytest.mark.parametrize("time_format", ["ISO 8601", "Float"])
 async def test_token_time_format(
     time_format: str, authenticated_user_app_client: AuthAsyncClient

Follow ups