← Back to team overview

sts-sponsors team mailing list archive

[Merge] ~lloydwaltersj/maas-site-manager:add-login into maas-site-manager:main

 

Jack Lloyd-Walters has proposed merging ~lloydwaltersj/maas-site-manager:add-login into maas-site-manager:main.

Commit message:
Add login/authentication to backend

Requested reviews:
  MAAS Lander (maas-lander): unittests
  MAAS Committers (maas-committers)

For more details, see:
https://code.launchpad.net/~lloydwaltersj/maas-site-manager/+git/site-manager/+merge/440870

Follows and extends the work done in ~thorsten-merten/MAASENG-1485-add-authentication
-- 
Your team MAAS Committers is requested to review the proposed merge of ~lloydwaltersj/maas-site-manager:add-login into maas-site-manager:main.
diff --git a/Makefile b/Makefile
index ad5452d..86cccb6 100644
--- a/Makefile
+++ b/Makefile
@@ -54,7 +54,7 @@ ci-backend-build:  # nothing to do since everything is run in tox envs
 .PHONY: ci-backend-build
 
 ci-backend-lint:
-	env -C backend tox -e lint,check
+	env -C backend tox -e format,lint,check
 .PHONY: ci-backend-lint
 
 ci-backend-test:
diff --git a/backend/msm/db/queries.py b/backend/msm/db/queries.py
index f70c850..be38092 100644
--- a/backend/msm/db/queries.py
+++ b/backend/msm/db/queries.py
@@ -5,6 +5,7 @@ from datetime import (
 )
 from uuid import UUID
 
+# from passlib.context import CryptContext
 from sqlalchemy import (
     case,
     func,
@@ -18,14 +19,37 @@ from .. import MAX_PAGE_SIZE
 from ..schema import (
     Site as SiteSchema,
     Token as TokenSchema,
+    UserWithPassword as UserPWSchema,
 )
 from ._tables import (
     Site,
     SiteData,
     Token,
+    User,
 )
 
 
+async def get_user(session: AsyncSession, email: str) -> UserPWSchema | None:
+    """
+    Gets a user by its unique identifier: their email
+    """
+    stmt = select(
+        User.c.id,
+        User.c.disabled,
+        User.c.email,
+        User.c.full_name,
+        User.c.password,
+    ).where(User.c.email == email)
+    result = await session.execute(stmt)
+    if result is None:
+        return None
+    else:
+        user = result.first()
+        if user is None:
+            return None
+        return UserPWSchema(**user._asdict())
+
+
 def filters_from_arguments(
     table: Table,
     **kwargs: list[str] | None,
@@ -62,6 +86,29 @@ def filters_from_arguments(
                 yield clause
 
 
+<<<<<<< backend/msm/db/queries.py
+=======
+DerivedSchema = TypeVar("DerivedSchema", SiteSchema, TokenSchema)
+
+
+def extract_count_and_results(
+    schema: Type[DerivedSchema],
+    db_results: Sequence[Row[tuple[Any, Any, Any, Any, Any]]]
+    | Sequence[Row[tuple[Any, Any, Any, Any, Any, Any]]],
+) -> tuple[int, Iterable[DerivedSchema]]:
+    """
+    Extract the count and result from a paginated query using a window function
+    """
+    schema_objects = (schema(**row._asdict()) for row in db_results)
+    count: int = 0
+    try:
+        count = db_results[0][0]
+    except IndexError:
+        count = 0
+    return count, list(schema_objects)
+
+
+>>>>>>> backend/msm/db/queries.py
 async def get_filtered_sites(
     session: AsyncSession,
     offset: int = 0,
diff --git a/backend/msm/schema.py b/backend/msm/schema.py
index d1772d6..7864bdc 100644
--- a/backend/msm/schema.py
+++ b/backend/msm/schema.py
@@ -27,19 +27,31 @@ class PaginatedResults(BaseModel):
     items: Sequence[BaseModel]
 
 
-class CreateUser(BaseModel):
+class ReadUser(BaseModel):
     """
     A MAAS Site Manager User
+    We never want to sent the password (hash) around
     """
 
     email: EmailStr = Field(title="email@xxxxxxxxxxx")
     full_name: str
-    # use password.get_secret_value() to retrieve the value
-    password: SecretStr = Field(min_length=8, max_length=50)
     disabled: bool
 
 
-class User(CreateUser):
+class UserWithPassword(ReadUser):
+    """
+    To create a user we need a password as well.
+    """
+
+    # use password.get_secret_value() to retrieve the value
+    password: SecretStr = Field(min_length=8, max_length=100)
+
+
+class User(ReadUser):
+    """
+    To read a user from the DB it comes with an ID
+    """
+
     id: int
 
 
@@ -120,6 +132,15 @@ class Token(CreateToken):
     id: int
 
 
+class JWTToken(BaseModel):
+    access_token: str
+    token_type: str
+
+
+class JWTTokenData(BaseModel):
+    email: str
+
+
 class PaginatedTokens(PaginatedResults):
     items: list[Token]
 
diff --git a/backend/msm/user_api/_base.py b/backend/msm/user_api/_base.py
index a3797a2..42dfe5a 100644
--- a/backend/msm/user_api/_base.py
+++ b/backend/msm/user_api/_base.py
@@ -1,7 +1,13 @@
+from datetime import timedelta
+from typing import Annotated
+
 from fastapi import (
     Depends,
+    HTTPException,
     Query,
+    status,
 )
+from fastapi.security import OAuth2PasswordRequestForm
 from sqlalchemy.ext.asyncio import AsyncSession
 
 from .. import (
@@ -14,6 +20,12 @@ from ..db import (
     db_session,
     queries,
 )
+from ._jwt import (
+    ACCESS_TOKEN_EXPIRE_MINUTES,
+    authenticate_user,
+    create_access_token,
+    get_current_active_user,
+)
 
 
 async def pagination_parameters(
@@ -90,4 +102,34 @@ async def tokens_post(
         create_request.duration,
         count=create_request.count,
     )
+<<<<<<< backend/msm/user_api/_base.py
     return schema.CreateTokensResponse(expired=expired, tokens=tokens)
+=======
+    return schema.CreateTokensResponse(expiration=expiration, tokens=tokens)
+
+
+async def login_for_access_token(
+    form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
+    session: AsyncSession = Depends(db_session),
+) -> schema.JWTToken:
+    user = await authenticate_user(
+        session, form_data.username, form_data.password
+    )
+    if not user:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail="Incorrect username or password",
+            headers={"WWW-Authenticate": "Bearer"},
+        )
+    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
+    access_token = create_access_token(
+        data={"sub": user.email}, expires_delta=access_token_expires
+    )
+    return schema.JWTToken(access_token=access_token, token_type="bearer")
+
+
+async def read_users_me(
+    current_user: Annotated[schema.User, Depends(get_current_active_user)]
+) -> schema.User:
+    return current_user
+>>>>>>> backend/msm/user_api/_base.py
diff --git a/backend/msm/user_api/_jwt.py b/backend/msm/user_api/_jwt.py
new file mode 100644
index 0000000..313b177
--- /dev/null
+++ b/backend/msm/user_api/_jwt.py
@@ -0,0 +1,110 @@
+from __future__ import annotations
+
+from datetime import (
+    datetime,
+    timedelta,
+)
+from typing import (
+    Annotated,
+    Any,
+)
+
+from fastapi import (
+    Depends,
+    HTTPException,
+    status,
+)
+from fastapi.security import OAuth2PasswordBearer
+from jose import (
+    jwt,
+    JWTError,
+)
+from passlib.context import CryptContext
+from sqlalchemy.ext.asyncio import AsyncSession
+
+# from ..db import db_session
+from ..db.queries import get_user
+from ..schema import (
+    JWTTokenData,
+    User,
+    UserWithPassword,
+)
+
+# to get a string like this run:
+# openssl rand -hex 32
+SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
+ALGORITHM = "HS256"
+ACCESS_TOKEN_EXPIRE_MINUTES = 30
+
+pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+
+def verify_password(
+    plain_password: str | bytes, hashed_password: str | bytes | None
+) -> bool:
+    """
+    Verify a plain password against a password hash created by passlib
+    """
+    return pwd_context.verify(plain_password, hashed_password)
+
+
+def get_password_hash(password: Any) -> str:
+    """
+    Get a hash for a password
+    """
+    return pwd_context.hash(password)
+
+
+def create_access_token(
+    data: dict[str, Any], expires_delta: timedelta | None = None
+) -> str:
+    to_encode = data.copy()
+    if expires_delta:
+        expire = datetime.utcnow() + expires_delta
+    else:
+        expire = datetime.utcnow() + timedelta(minutes=15)
+    to_encode.update({"exp": expire})
+    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
+    return encoded_jwt
+
+
+async def authenticate_user(
+    session: AsyncSession, email: str, password: str
+) -> UserWithPassword | None:
+    user = await get_user(session, email)
+    if not user or user.disabled:
+        return None
+    if not verify_password(password, user.password.get_secret_value()):
+        return None
+    return user
+
+
+async def get_current_user(
+    token: Annotated[str, Depends(oauth2_scheme)], session: Any
+) -> UserWithPassword | None:
+    credentials_exception = HTTPException(
+        status_code=status.HTTP_401_UNAUTHORIZED,
+        detail="Could not validate credentials",
+        headers={"WWW-Authenticate": "Bearer"},
+    )
+    try:
+        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
+        email: str | None = payload.get("sub")
+        if email is None:
+            raise credentials_exception
+        token_data = JWTTokenData(email=email)
+    except JWTError:
+        raise credentials_exception
+    user = await get_user(session, email=token_data.email)
+    if user is None:
+        raise credentials_exception
+    return user
+
+
+async def get_current_active_user(
+    current_user: Annotated[User, Depends(get_current_user)]
+) -> User | None:
+    if current_user.disabled:
+        raise HTTPException(status_code=400, detail="Inactive user")
+    return current_user
diff --git a/backend/msm/user_api/_setup.py b/backend/msm/user_api/_setup.py
index 03b9d18..87f392a 100644
--- a/backend/msm/user_api/_setup.py
+++ b/backend/msm/user_api/_setup.py
@@ -52,4 +52,8 @@ def create_app(db_dsn: str = DEFAULT_DB_DSN) -> FastAPI:
     app.router.add_api_route("/sites", _base.sites, methods=["GET"])
     app.router.add_api_route("/tokens", _base.tokens, methods=["GET"])
     app.router.add_api_route("/tokens", _base.tokens_post, methods=["POST"])
+    app.router.add_api_route(
+        "/login", _base.login_for_access_token, methods=["POST"]
+    )
+    app.router.add_api_route("/users/me", _base.read_users_me, methods=["GET"])
     return app
diff --git a/backend/msm/user_api/tests/test_handlers.py b/backend/msm/user_api/tests/test_handlers.py
new file mode 100644
index 0000000..4af6625
--- /dev/null
+++ b/backend/msm/user_api/tests/test_handlers.py
@@ -0,0 +1,142 @@
+<<<<<<< backend/msm/user_api/tests/test_handlers.py
+=======
+from datetime import (
+    datetime,
+    timedelta,
+)
+
+# from fastapi.testclient import TestClient
+from httpx import AsyncClient
+import pytest
+
+from ...testing.db import Fixture
+
+
+@pytest.mark.asyncio
+async def test_root(user_app_client: AsyncClient) -> None:
+    response = await user_app_client.get("/")
+    assert response.status_code == 200
+    assert response.json() == {"version": "0.0.1"}
+
+
+@pytest.mark.asyncio
+async def test_list_sites(
+    user_app_client: AsyncClient, fixture: Fixture
+) -> None:
+    site1 = {
+        "id": 1,
+        "name": "LondonHQ",
+        "city": "London",
+        "country": "gb",
+        "latitude": "51.509865",
+        "longitude": "-0.118092",
+        "note": "the first site",
+        "region": "Blue Fin Bldg",
+        "street": "110 Southwark St",
+        "timezone": "0.00",
+        "url": "https://londoncalling.example.com";,
+    }
+    site2 = site1.copy()
+    site2["id"] = 2
+    site2["name"] = "BerlinHQ"
+    site2["timezone"] = "1.00"
+    site2["city"] = "Berlin"
+    site2["country"] = "de"
+    await fixture.create("sites", [site1, site2])
+    page1 = await user_app_client.get("/sites")
+    assert page1.status_code == 200
+    assert page1.json() == {
+        "page": 1,
+        "size": 20,
+        "total": 2,
+        "items": [site1, site2],
+    }
+    filtered = await user_app_client.get("/sites?city=onDo")  # vs London
+    assert filtered.status_code == 200
+    assert filtered.json() == {
+        "page": 1,
+        "size": 20,
+        "total": 1,
+        "items": [site1],
+    }
+    paginated = await user_app_client.get("/sites?page=2&size=1")
+    assert paginated.status_code == 200
+    assert paginated.json() == {
+        "page": 2,
+        "size": 1,
+        "total": 2,
+        "items": [site2],
+    }
+
+
+@pytest.mark.asyncio
+async def test_create_token(user_app_client: AsyncClient) -> None:
+    seconds = 100
+    response = await user_app_client.post(
+        "/tokens", json={"count": 5, "duration": seconds}
+    )
+    assert response.status_code == 200
+    result = response.json()
+    assert datetime.fromisoformat(result["expiration"]) < (
+        datetime.utcnow() + timedelta(seconds=seconds)
+    )
+    assert len(result["tokens"]) == 5
+
+
+@pytest.mark.asyncio
+async def test_list_tokens(
+    user_app_client: AsyncClient, fixture: Fixture
+) -> None:
+    tokens = [
+        {
+            "id": 1,
+            "site_id": None,
+            "value": "c54e5ba6-d214-40dd-b601-01ebb1019c07",
+            "expiration": datetime.fromisoformat("2023-02-23T09:09:51.103703"),
+        },
+        {
+            "id": 2,
+            "site_id": None,
+            "value": "b67c449e-fcf6-4014-887d-909859f9fb70",
+            "expiration": datetime.fromisoformat("2023-02-23T11:28:54.382456"),
+        },
+    ]
+    await fixture.create("tokens", tokens)
+    response = await user_app_client.get("/tokens")
+    assert response.status_code == 200
+    assert response.json()["total"] == 2
+    assert len(response.json()["items"]) == 2
+
+
+@pytest.mark.asyncio
+async def test_user_login(
+    user_app_client: AsyncClient, fixture: Fixture
+) -> None:
+    backend_data = {
+        "id": 1,
+        "email": "admin@xxxxxxxxxxx",
+        "full_name": "Admin",
+        "disabled": False,
+        "password": "$2b$12$F5sgrhRNtWAOehcoVO.XK.oSvupmcg8.0T2jCHOTg15M8N8LrpRwS",
+    }
+
+    email = "admin@xxxxxxxxxxx"
+    password = "admin"
+
+    await fixture.create("users", backend_data)
+    login_response = await user_app_client.post(
+        "/login", data={"username": email, "password": password}
+    )
+    assert login_response.status_code == 200
+    token = login_response.json()["access_token"]
+    token_type = login_response.json()["token_type"].capitalize()
+    # TODO: This is apparently missing a 'session' field, but introducing one throws an error when fetching the user. Fix?
+    # Seems it's passing a <httpx.AsyncClient object> rather than <sqlalchemy.ext.asyncio.session.AsyncSession object>
+    user_response = await user_app_client.get(
+        "/users/me",
+        headers={"Authorization": f"{token_type} {token}"},
+        params={"session": user_app_client},
+    )
+    assert user_response.status_code == 200
+    assert user_response.json()["username"] == email
+>>>>>>> backend/msm/user_api/tests/test_handlers.py
diff --git a/backend/pyproject.toml b/backend/pyproject.toml
index 74f6260..eaa17fa 100644
--- a/backend/pyproject.toml
+++ b/backend/pyproject.toml
@@ -18,7 +18,10 @@ authors = [
 requires-python = ">=3.10"
 dependencies = [
   "fastapi",
+  "passlib[bcrypt]",
   "pydantic[email]",
+  "python-jose[cryptography]",
+  "python-multipart",
   "SQLAlchemy[postgresql_asyncpg]",
 ]
 [project.optional-dependencies]