← Back to team overview

sts-sponsors team mailing list archive

Re: [Merge] ~lloydwaltersj/maas-site-manager:add-login into maas-site-manager:main

 

Nice work!

A few comments inline, but looks mostly good

Diff comments:

> diff --git a/backend/msm/db/queries.py b/backend/msm/db/queries.py
> index 60ecee9..99331ea 100644
> --- a/backend/msm/db/queries.py
> +++ b/backend/msm/db/queries.py
> @@ -23,14 +24,37 @@ from ..schema import (
>      MAX_PAGE_SIZE,
>      Site as SiteSchema,
>      Token as TokenSchema,
> +    UserWithPassword as UserPWSchema,
>  )
>  from ._tables import (
>      Site,
>      SiteData,
>      Token,
> +    User,
>  )
>  
>  
> +async def get_user(session: AsyncSession, email: str) -> UserPWSchema | None:
> +    """
> +    Gets a user by its unique identifier: their email
> +    """
> +    stmt = select(
> +        User.c.id,
> +        User.c.disabled,
> +        User.c.email,
> +        User.c.full_name,
> +        User.c.password,
> +    ).where(User.c.email == email)

I'd add an explicit .select_from(User) as well.

Also, you can add a .one_or_none() at the end, which avoids the nested if/else and the need for first()

> +    result = await session.execute(stmt)
> +    if result is None:
> +        return None
> +    else:
> +        user = result.first()
> +        if user is None:
> +            return None
> +        return UserPWSchema(**user._asdict())
> +
> +
>  def filters_from_arguments(
>      table: Table,
>      **filter_args: list[Any] | None,
> diff --git a/backend/msm/schema/_models.py b/backend/msm/schema/_models.py
> index 9a80683..a4e7ba3 100644
> --- a/backend/msm/schema/_models.py
> +++ b/backend/msm/schema/_models.py
> @@ -19,19 +19,31 @@ from ._pagination import PaginatedResults
>  TimeZone = StrEnum("TimeZone", pytz.all_timezones)
>  
>  
> -class CreateUser(BaseModel):
> +class ReadUser(BaseModel):
>      """
>      A MAAS Site Manager User
> +    We never want to sent the password (hash) around
>      """
>  
>      email: EmailStr = Field(title="email@xxxxxxxxxxx")
>      full_name: str
> -    # use password.get_secret_value() to retrieve the value
> -    password: SecretStr = Field(min_length=8, max_length=50)
>      disabled: bool

as discussed in MM, I think we can drop the disabled flag since we don't need it for now. This should simplify the logic a bit since we don't have to check in different places.

>  
>  
> -class User(CreateUser):
> +class UserWithPassword(ReadUser):
> +    """
> +    To create a user we need a password as well.
> +    """
> +
> +    # use password.get_secret_value() to retrieve the value
> +    password: SecretStr = Field(min_length=8, max_length=100)
> +
> +
> +class User(ReadUser):
> +    """
> +    To read a user from the DB it comes with an ID
> +    """
> +
>      id: int
>  
>  
> diff --git a/backend/msm/user_api/_base.py b/backend/msm/user_api/_base.py
> index 38d645c..45da2ed 100644
> --- a/backend/msm/user_api/_base.py
> +++ b/backend/msm/user_api/_base.py
> @@ -75,3 +94,30 @@ async def tokens_post(
>          count=create_request.count,
>      )
>      return CreateTokensResponse(expired=expired, tokens=tokens)
> +
> +
> +async def login_for_access_token(
> +    form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
> +    session: AsyncSession = Depends(db_session),
> +) -> JWTToken:
> +    user = await authenticate_user(
> +        session, form_data.username, form_data.password
> +    )
> +    if not user:
> +        raise HTTPException(
> +            status_code=status.HTTP_401_UNAUTHORIZED,
> +            detail="Incorrect username or password",
> +            headers={"WWW-Authenticate": "Bearer"},
> +        )
> +    access_token_expires = timedelta(minutes=int(ACCESS_TOKEN_EXPIRE_MINUTES))
> +    access_token = create_access_token(
> +        data={"sub": user.email}, expires_delta=access_token_expires
> +    )
> +    return JWTToken(access_token=access_token, token_type="bearer")
> +
> +
> +async def read_users_me(
> +    current_user: Annotated[User, Depends(get_current_active_user)],

as mentioned before I'd change this to just

user:  Annotated[User, Depends(get_authenticated_user)],

> +    session: AsyncSession = Depends(db_session),
> +) -> User:
> +    return current_user
> diff --git a/backend/msm/user_api/_jwt.py b/backend/msm/user_api/_jwt.py
> new file mode 100644
> index 0000000..641f271
> --- /dev/null
> +++ b/backend/msm/user_api/_jwt.py
> @@ -0,0 +1,125 @@
> +from __future__ import annotations

this shouldn't be needed

> +
> +from datetime import (
> +    datetime,
> +    timedelta,
> +)
> +from logging import getLogger
> +from os import (
> +    environ,
> +    getenv,
> +)
> +from typing import (
> +    Annotated,
> +    Any,
> +)
> +
> +from fastapi import (
> +    Depends,
> +    HTTPException,
> +    status,
> +)
> +from fastapi.security import OAuth2PasswordBearer
> +from jose import (
> +    jwt,
> +    JWTError,
> +)
> +from passlib.context import CryptContext
> +from sqlalchemy.ext.asyncio import AsyncSession
> +
> +from ..db import db_session
> +
> +# from ..db import db_session
> +from ..db.queries import get_user
> +from ..schema import (
> +    JWTTokenData,
> +    User,
> +    UserWithPassword,
> +)
> +
> +logger = getLogger("site-manager.jwt")
> +
> +# to get a string like this run:
> +# openssl rand -hex 32
> +SECRET_KEY = getenv(
> +    "SECRET_KEY",
> +    "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7",
> +)
> +# XXX: Require this in the config
> +ALGORITHM = "HS256"
> +ACCESS_TOKEN_EXPIRE_MINUTES = getenv("TOKEN_EXPIRATION_TIME", 30)

please move them to the msm.settings.Settings object, and access them via the SETTINGS variable here

> +
> +if "SECRET_KEY" not in environ:
> +    logger.critical("Secret key not defined in environment!")
> +
> +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
> +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
> +
> +
> +def verify_password(plain_password: str, hashed_password: str | None) -> bool:
> +    """
> +    Verify a plain password against a password hash created by passlib
> +    """
> +    return bool(pwd_context.verify(plain_password, hashed_password))
> +
> +
> +def get_password_hash(password: str) -> str:
> +    """
> +    Get a hash for a password
> +    """
> +    return str(pwd_context.hash(password))
> +
> +
> +def create_access_token(
> +    data: dict[str, Any], expires_delta: timedelta | None = None
> +) -> str:
> +    to_encode = data.copy()
> +    if expires_delta:
> +        expire = datetime.utcnow() + expires_delta
> +    else:
> +        expire = datetime.utcnow() + timedelta(minutes=15)

nitpick:

a simpler way to write this would be:

if not expires_delta:
    expires_delta = timedelta(minutes=15)  # this should actually use ACCESS_TOKEN_EXPIRE_MINUTES)

expire = datetime.utcnow() + expires_delta

> +    to_encode.update({"exp": expire})
> +    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
> +    return str(encoded_jwt)
> +
> +
> +async def authenticate_user(
> +    session: AsyncSession, email: str, password: str
> +) -> UserWithPassword | None:
> +    user = await get_user(session, email)
> +    if not user or user.disabled:
> +        return None
> +    if not verify_password(password, user.password.get_secret_value()):
> +        return None
> +    return user
> +
> +
> +async def get_current_user(
> +    token: Annotated[str, Depends(oauth2_scheme)],
> +    session: AsyncSession = Depends(db_session),
> +) -> UserWithPassword | None:
> +    credentials_exception = HTTPException(
> +        status_code=status.HTTP_401_UNAUTHORIZED,
> +        detail="Could not validate credentials",
> +        headers={"WWW-Authenticate": "Bearer"},
> +    )
> +    try:
> +        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
> +        email: str | None = payload.get("sub")
> +        if email is None:
> +            raise credentials_exception
> +        token_data = JWTTokenData(email=email)
> +    except JWTError:
> +        raise credentials_exception
> +    user = await get_user(session, email=token_data.email)
> +    if user is None:
> +        raise credentials_exception
> +    return user
> +
> +
> +async def get_current_active_user(
> +    current_user: Annotated[User, Depends(get_current_user)],
> +) -> User | None:
> +    if current_user.disabled:
> +        raise HTTPException(status_code=400, detail="Inactive user")
> +    return current_user

I think this function and the one above could be merged into get_authenticated_user

> diff --git a/backend/tests/fixtures/app.py b/backend/tests/fixtures/app.py
> index d325e81..1c0fc7f 100644
> --- a/backend/tests/fixtures/app.py
> +++ b/backend/tests/fixtures/app.py
> @@ -1,15 +1,64 @@
> +from __future__ import annotations
> +
>  from typing import (
>      AsyncIterable,
>      Iterable,
>  )
>  
>  from fastapi import FastAPI
> -from httpx import AsyncClient
> +from httpx import (
> +    AsyncClient,
> +    Response,
> +)
>  import pytest
>  
>  from msm.db import Database
>  from msm.user_api import create_app
>  
> +from .db import Fixture
> +
> +
> +class AuthAsyncClient(AsyncClient):
> +    """Equivalent to AsyncClient, but has the ability to send
> +    requests from an authorized login"""
> +
> +    def __init__(self, **kwargs) -> None:  # type: ignore
> +        super().__init__(**kwargs)
> +        self.email: str = ""
> +        self._token: str = ""
> +        self._token_type: str = ""
> +
> +    async def login(self, email: str, password: str) -> None:
> +        """login this client with the email and password"""
> +        response = await self.post(
> +            "/login", data={"username": email, "password": password}
> +        )
> +        assert (
> +            response.status_code == 200
> +        ), f"Could not login user: {response.text}"
> +        self.email = email
> +        self._token = response.json()["access_token"]
> +        self._token_type = response.json()["token_type"].capitalize()
> +
> +    @property
> +    def authed(self) -> bool:
> +        """Are we logged in?"""
> +        return self._token is not None

this will always fail, since self._token is "" by default. you can return bool(self._token) instead

> +
> +    async def request(self, *args, **kwargs) -> Response:  # type: ignore
> +        """Generate a request with the authorized payload attached if the user
> +        has been logged in. All methods (get, post, push, ...) use this in
> +        the backend to construct their requests"""
> +        if self.authed:
> +            kwargs.update(
> +                {
> +                    "headers": {
> +                        "Authorization": f"{self._token_type} {self._token}"
> +                    },
> +                }
> +            )
> +        return await super().request(*args, **kwargs)
> +
>  
>  @pytest.fixture
>  def user_app(
> @@ -26,3 +75,24 @@ async def user_app_client(user_app: FastAPI) -> AsyncIterable[AsyncClient]:
>      """Client for the user API."""
>      async with AsyncClient(app=user_app, base_url="http://test";) as client:
>          yield client
> +
> +
> +@pytest.fixture
> +async def authenticated_user_app_client(
> +    user_app: FastAPI, fixture: Fixture
> +) -> AsyncIterable[AuthAsyncClient]:
> +    """Authenticated Client for the user API."""
> +    phash = "$2b$12$F5sgrhRNtWAOehcoVO.XK.oSvupmcg8.0T2jCHOTg15M8N8LrpRwS"
> +    async with AuthAsyncClient(app=user_app, base_url="http://test";) as client:
> +        await fixture.create(
> +            "user",
> +            {
> +                "id": 1,
> +                "email": "admin@xxxxxxxxxxx",
> +                "full_name": "Admin",
> +                "disabled": False,
> +                "password": phash,
> +            },
> +        )

I'd move the user cration outside of the contextmanager, since it's unrelated to the request

> +        await client.login("admin@xxxxxxxxxxx", "admin")
> +        yield client


-- 
https://code.launchpad.net/~lloydwaltersj/maas-site-manager/+git/site-manager/+merge/440870
Your team MAAS Committers is requested to review the proposed merge of ~lloydwaltersj/maas-site-manager:add-login into maas-site-manager:main.