← Back to team overview

sts-sponsors team mailing list archive

[Merge] ~ack/maas-site-manager:db-session-no-commit-default into maas-site-manager:main

 

Alberto Donato has proposed merging ~ack/maas-site-manager:db-session-no-commit-default into maas-site-manager:main.

Commit message:
don't commit by default when creating test fixtures



Requested reviews:
  MAAS Lander (maas-lander): unittests
  MAAS Committers (maas-committers)

For more details, see:
https://code.launchpad.net/~ack/maas-site-manager/+git/site-manager/+merge/443470
-- 
Your team MAAS Committers is requested to review the proposed merge of ~ack/maas-site-manager:db-session-no-commit-default into maas-site-manager:main.
diff --git a/backend/tests/fixtures/app.py b/backend/tests/fixtures/app.py
index 11f3443..3058e90 100644
--- a/backend/tests/fixtures/app.py
+++ b/backend/tests/fixtures/app.py
@@ -91,6 +91,7 @@ async def authenticated_user_app_client(
             "full_name": "Admin",
             "password": phash,
         },
+        commit=True,
     )
     async with AuthAsyncClient(app=user_app, base_url="http://test";) as client:
         await client.login("admin@xxxxxxxxxxx", "admin")
diff --git a/backend/tests/fixtures/db.py b/backend/tests/fixtures/db.py
index cb4cf14..9f2a5e1 100644
--- a/backend/tests/fixtures/db.py
+++ b/backend/tests/fixtures/db.py
@@ -79,25 +79,30 @@ async def session(db: Database) -> AsyncGenerator[AsyncSession, None]:
     """A database session."""
     async with db.session() as session:
         yield session
+        await session.rollback()
 
 
 class Fixture:
     """Helper for creating test fixtures."""
 
-    def __init__(self, db: Database):
-        self.db = db
+    def __init__(self, session: AsyncSession):
+        self.session = session
+
+    async def commit(self) -> None:
+        await self.session.commit()
 
     async def create(
         self,
         table: str,
         data: dict[str, Any] | list[dict[str, Any]] | None = None,
+        commit: bool = False,
     ) -> list[dict[str, Any]]:
-        async with self.db.session() as session:
-            result = await session.execute(
-                METADATA.tables[table].insert().returning("*"), data
-            )
-            await session.commit()
-            return [row._asdict() for row in result]
+        result = await self.session.execute(
+            METADATA.tables[table].insert().returning("*"), data
+        )
+        if commit:
+            await self.session.commit()
+        return [row._asdict() for row in result]
 
     async def get(
         self,
@@ -105,15 +110,14 @@ class Fixture:
         *filters: ColumnOperators,
     ) -> list[dict[str, Any]]:
         """Take a peak what is in there"""
-        async with self.db.session() as session:
-            result = await session.execute(
-                METADATA.tables[table]
-                .select()
-                .where(*filters)  # type: ignore[arg-type]
-            )
-            return [row._asdict() for row in result]
+        result = await self.session.execute(
+            METADATA.tables[table]
+            .select()
+            .where(*filters)  # type: ignore[arg-type]
+        )
+        return [row._asdict() for row in result]
 
 
 @pytest.fixture
-def fixture(db: Database) -> Iterator[Fixture]:
-    yield Fixture(db)
+def fixture(session: AsyncSession) -> Iterator[Fixture]:
+    yield Fixture(session)
diff --git a/backend/tests/user_api/test_handlers.py b/backend/tests/user_api/test_handlers.py
index e2b1910..ebb9dc8 100644
--- a/backend/tests/user_api/test_handlers.py
+++ b/backend/tests/user_api/test_handlers.py
@@ -66,6 +66,7 @@ class TestSitesHandler:
                 site_details(city="London"),
                 site_details(name="BerlinHQ", city="Berlin"),
             ],
+            commit=True,
         )
         for site in sites:
             site["stats"] = None
@@ -110,6 +111,7 @@ class TestSitesHandler:
                 site_details(),
                 site_details(name="BerlinHQ", accepted=False),
             ],
+            commit=True,
         )
         created_site["stats"] = None
         del created_site["created"]
@@ -133,6 +135,7 @@ class TestSitesHandler:
                 site_details(timezone="Europe/London"),
                 site_details(name="BerlinHQ", timezone="Europe/Berlin"),
             ],
+            commit=True,
         )
         created_site["stats"] = None
         del created_site["created"]
@@ -151,7 +154,7 @@ class TestSitesHandler:
     async def test_get_with_stats(
         self, authenticated_user_app_client: AuthAsyncClient, fixture: Fixture
     ) -> None:
-        [site] = await fixture.create("site", [site_details()])
+        [site] = await fixture.create("site", [site_details()], commit=True)
         [site_data] = await fixture.create(
             "site_data",
             [
@@ -165,6 +168,7 @@ class TestSitesHandler:
                     "last_seen": datetime.utcnow(),
                 }
             ],
+            commit=True,
         )
         del site_data["id"]
         del site_data["site_id"]
@@ -212,14 +216,20 @@ class TestSitesHandler:
             return [site["city"] for site in resp.json()["items"]]
 
         await fixture.create(
-            "site", [site_details(city="Milan", country="IT")]
+            "site",
+            [site_details(city="Milan", country="IT")],
+            commit=True,
+        )
+        await fixture.create(
+            "site",
+            [site_details(city="Paris", country="FR")],
+            commit=True,
         )
         await fixture.create(
-            "site", [site_details(city="Paris", country="FR")]
+            "site", [site_details(city="Rome", country="IT")], commit=True
         )
-        await fixture.create("site", [site_details(city="Rome", country="IT")])
         await fixture.create(
-            "site", [site_details(city="London", country="GB")]
+            "site", [site_details(city="London", country="GB")], commit=True
         )
 
         response = await authenticated_user_app_client.get(
@@ -238,7 +248,7 @@ class TestSitesHandler:
         query_params: str,
     ) -> None:
         await fixture.create(
-            "site", [site_details(city="Milan", country="IT")]
+            "site", [site_details(city="Milan", country="IT")], commit=True
         )
 
         # not sortable
@@ -259,6 +269,7 @@ class TestPendingSitesHandler:
                 site_details(),
                 site_details(name="BerlinHQ", accepted=False),
             ],
+            commit=True,
         )
 
         response = await authenticated_user_app_client.get("/requests")
@@ -281,7 +292,7 @@ class TestPendingSitesHandler:
         self, authenticated_user_app_client: AuthAsyncClient, fixture: Fixture
     ) -> None:
         [pending_site] = await fixture.create(
-            "site", [site_details(accepted=False)]
+            "site", [site_details(accepted=False)], commit=True
         )
 
         response = await authenticated_user_app_client.post(
@@ -296,7 +307,7 @@ class TestPendingSitesHandler:
         self, authenticated_user_app_client: AuthAsyncClient, fixture: Fixture
     ) -> None:
         [pending_site] = await fixture.create(
-            "site", [site_details(accepted=False)]
+            "site", [site_details(accepted=False)], commit=True
         )
 
         response = await authenticated_user_app_client.post(
@@ -309,7 +320,7 @@ class TestPendingSitesHandler:
     async def test_post_invalid_ids(
         self, authenticated_user_app_client: AuthAsyncClient, fixture: Fixture
     ) -> None:
-        [site] = await fixture.create("site", [site_details()])
+        [site] = await fixture.create("site", [site_details()], commit=True)
         # unknown IDs and IDs for non-pending sites are invalid
         ids = [site["id"], 10000]
         response = await authenticated_user_app_client.post(
@@ -368,6 +379,7 @@ class TestTokensHandler:
                     ),
                 },
             ],
+            commit=True,
         )
         for token in tokens:
             token["expired"] = token["expired"].isoformat()
@@ -391,7 +403,7 @@ class TestLoginHandler:
             "full_name": "Admin",
             "password": phash,
         }
-        await fixture.create("user", userdata)
+        await fixture.create("user", userdata, commit=True)
         response = await user_app_client.post(
             "/login",
             json={"username": userdata["email"], "password": "admin"},
@@ -409,7 +421,7 @@ class TestLoginHandler:
             "full_name": "Admin",
             "password": phash,
         }
-        await fixture.create("user", userdata)
+        await fixture.create("user", userdata, commit=True)
 
         fail_response = await user_app_client.post(
             "/login",