← Back to team overview

launchpad-reviewers team mailing list archive

[Merge] ~cjwatson/launchpad:sqlbase-connect-set-role-after-connecting into launchpad:master

 

Colin Watson has proposed merging ~cjwatson/launchpad:sqlbase-connect-set-role-after-connecting into launchpad:master.

Commit message:
Rework sqlbase.connect to honour set_role_after_connecting

Requested reviews:
  Launchpad code reviewers (launchpad-reviewers)

For more details, see:
https://code.launchpad.net/~cjwatson/launchpad/+git/launchpad/+merge/439909

Commit 31b283765f2cb0c3c6551c46c02dd84cd51b505e reworked `LaunchpadDatabase.raw_connect` to add the ability to switch database role after connecting; this covers everything that uses the Storm store interface.  However, a few scripts (including those that run without the Zope component architecture and those that do some particularly arcane things with database connections) construct database connections a little more directly, typically via `lp.services.database.sqlbase.connect`.  The one I've noticed is `librarian-gc`, which failed to work in a charmed deployment configured with `set_role_after_connecting: True`.

To fix this, rework `sqlbase.connect` so that it honours `set_role_after_connecting` as well; it now always uses `dbconfig.dbuser` rather than accepting an optional `user` argument, which makes things slightly less confusing to follow.  Ideally we'd now also rework `LaunchpadDatabase.raw_connect` on top of `sqlbase.connect` so that we only have one implementation of this logic, but that's difficult since in that case the actual connection is made by code in a base class, so I've just left an XXX comment for now.
-- 
Your team Launchpad code reviewers is requested to review the proposed merge of ~cjwatson/launchpad:sqlbase-connect-set-role-after-connecting into launchpad:master.
diff --git a/cronscripts/librarian-gc.py b/cronscripts/librarian-gc.py
index 7f8d39e..9f7ae7d 100755
--- a/cronscripts/librarian-gc.py
+++ b/cronscripts/librarian-gc.py
@@ -14,7 +14,7 @@ import _pythonpath  # noqa: F401
 
 import logging
 
-from lp.services.config import config, dbconfig
+from lp.services.config import config
 from lp.services.database.sqlbase import ISOLATION_LEVEL_AUTOCOMMIT, connect
 from lp.services.librarianserver import librariangc
 from lp.services.scripts.base import LaunchpadCronScript
@@ -78,9 +78,7 @@ class LibrarianGC(LaunchpadCronScript):
         if self.options.loglevel <= logging.DEBUG:
             librariangc.debug = True
 
-        conn = connect(
-            user=dbconfig.dbuser, isolation=ISOLATION_LEVEL_AUTOCOMMIT
-        )
+        conn = connect(isolation=ISOLATION_LEVEL_AUTOCOMMIT)
 
         # Refuse to run if we have significant clock skew between the
         # librarian and the database.
diff --git a/lib/lp/services/database/doc/sqlbaseconnect.rst b/lib/lp/services/database/doc/sqlbaseconnect.rst
deleted file mode 100644
index 6e9385e..0000000
--- a/lib/lp/services/database/doc/sqlbaseconnect.rst
+++ /dev/null
@@ -1,41 +0,0 @@
-Ensure that lp.services.database.sqlbase connects as we expect.
-
-    >>> from lp.services.config import config
-    >>> from lp.services.database.sqlbase import (
-    ...     connect,
-    ...     ISOLATION_LEVEL_DEFAULT,
-    ...     ISOLATION_LEVEL_SERIALIZABLE,
-    ... )
-
-    >>> def do_connect(user, dbname=None, isolation=ISOLATION_LEVEL_DEFAULT):
-    ...     con = connect(user=user, dbname=dbname, isolation=isolation)
-    ...     cur = con.cursor()
-    ...     cur.execute("SHOW session_authorization")
-    ...     who = cur.fetchone()[0]
-    ...     cur.execute("SELECT current_database()")
-    ...     where = cur.fetchone()[0]
-    ...     cur.execute("SHOW transaction_isolation")
-    ...     how = cur.fetchone()[0]
-    ...     print(
-    ...         "Connected as %s to %s in %s isolation." % (who, where, how)
-    ...     )
-    ...
-
-Specifying the user connects as that user.
-
-    >>> do_connect(user=config.launchpad_session.dbuser)
-    Connected as session to ... in read committed isolation.
-
-Specifying the database name connects to that database.
-
-    >>> do_connect(user=config.launchpad.dbuser, dbname="launchpad_empty")
-    Connected as launchpad_main to launchpad_empty in read committed
-    isolation.
-
-Specifying the isolation level works too.
-
-    >>> do_connect(
-    ...     user=config.launchpad.dbuser,
-    ...     isolation=ISOLATION_LEVEL_SERIALIZABLE,
-    ... )
-    Connected as launchpad_main to ... in serializable isolation.
diff --git a/lib/lp/services/database/sqlbase.py b/lib/lp/services/database/sqlbase.py
index c02d042..7ebdc9b 100644
--- a/lib/lp/services/database/sqlbase.py
+++ b/lib/lp/services/database/sqlbase.py
@@ -37,6 +37,8 @@ from psycopg2.extensions import (
     ISOLATION_LEVEL_READ_COMMITTED,
     ISOLATION_LEVEL_REPEATABLE_READ,
     ISOLATION_LEVEL_SERIALIZABLE,
+    make_dsn,
+    parse_dsn,
 )
 from storm.databases.postgres import compile as postgres_compile
 from storm.expr import State
@@ -568,7 +570,7 @@ def reset_store(func):
     return mergeFunctionMetadata(func, reset_store_decorator)
 
 
-def connect(user=None, dbname=None, isolation=ISOLATION_LEVEL_DEFAULT):
+def connect(dbname=None, isolation=ISOLATION_LEVEL_DEFAULT):
     """Return a fresh DB-API connection to the MAIN PRIMARY database.
 
     Can be used without first setting up the Component Architecture,
@@ -576,27 +578,33 @@ def connect(user=None, dbname=None, isolation=ISOLATION_LEVEL_DEFAULT):
 
     Default database name is the one specified in the main configuration file.
     """
-    con = psycopg2.connect(connect_string(user=user, dbname=dbname))
-    con.set_isolation_level(isolation)
-    return con
-
-
-def connect_string(user=None, dbname=None):
-    """Return a PostgreSQL connection string.
-
-    Allows you to pass the generated connection details to external
-    programs like pg_dump or embed in slonik scripts.
-    """
     # We must connect to the read-write DB here, so we use rw_main_primary
     # directly.
-    from lp.services.database.postgresql import ConnectionString
-
-    con_str = ConnectionString(dbconfig.rw_main_primary)
-    if user is not None:
-        con_str.user = user
+    parsed_dsn = parse_dsn(dbconfig.rw_main_primary)
+    dsn_kwargs = {}
     if dbname is not None:
-        con_str.dbname = dbname
-    return str(con_str)
+        dsn_kwargs["dbname"] = dbname
+    if dbconfig.set_role_after_connecting:
+        assert "user" in parsed_dsn, (
+            "With set_role_after_connecting, database username must be "
+            "specified in connection string (%s)." % dbconfig.rw_main_primary
+        )
+    else:
+        assert "user" not in parsed_dsn, (
+            "Database username must not be specified in connection string "
+            "(%s)." % dbconfig.rw_main_primary
+        )
+        dsn_kwargs["user"] = dbconfig.dbuser
+    dsn = make_dsn(dbconfig.rw_main_primary, **dsn_kwargs)
+
+    con = psycopg2.connect(dsn)
+    con.set_isolation_level(isolation)
+    if (
+        dbconfig.set_role_after_connecting
+        and dbconfig.dbuser != parsed_dsn["user"]
+    ):
+        con.cursor().execute("SET ROLE %s", (dbconfig.dbuser,))
+    return con
 
 
 class cursor:
diff --git a/lib/lp/services/database/tests/test_sqlbase.py b/lib/lp/services/database/tests/test_sqlbase.py
index 35f12a7..1b37592 100644
--- a/lib/lp/services/database/tests/test_sqlbase.py
+++ b/lib/lp/services/database/tests/test_sqlbase.py
@@ -4,16 +4,169 @@
 import doctest
 import unittest
 from doctest import ELLIPSIS, NORMALIZE_WHITESPACE, REPORT_NDIFF
+from typing import Tuple
 
+from psycopg2.errors import InsufficientPrivilege
+from psycopg2.extensions import connection, parse_dsn
+
+from lp.services.config import config, dbconfig
 from lp.services.database import sqlbase
+from lp.testing import TestCase
+from lp.testing.layers import DatabaseLayer, ZopelessDatabaseLayer
 
 
-def test_suite():
-    optionflags = ELLIPSIS | NORMALIZE_WHITESPACE | REPORT_NDIFF
-    dt_suite = doctest.DocTestSuite(sqlbase, optionflags=optionflags)
-    return unittest.TestSuite((dt_suite,))
+class TestConnect(TestCase):
+
+    layer = ZopelessDatabaseLayer
+
+    @staticmethod
+    def examineConnection(con: connection) -> Tuple[str, str, str]:
+        with con.cursor() as cur:
+            cur.execute("SHOW session_authorization")
+            who = cur.fetchone()[0]
+            cur.execute("SELECT current_database()")
+            where = cur.fetchone()[0]
+            cur.execute("SHOW transaction_isolation")
+            how = cur.fetchone()[0]
+        return (who, where, how)
+
+    def test_honours_dbconfig_dbuser(self):
+        dbconfig.override(dbuser=config.launchpad_session.dbuser)
+        con = sqlbase.connect()
+        who, _, how = self.examineConnection(con)
+        self.assertEqual(("session", "read committed"), (who, how))
+
+    def test_honours_dbname(self):
+        dbconfig.override(dbuser=config.launchpad.dbuser)
+        con = sqlbase.connect(dbname="launchpad_empty")
+        self.assertEqual(
+            ("launchpad_main", "launchpad_empty", "read committed"),
+            self.examineConnection(con),
+        )
+
+    def test_honours_isolation(self):
+        dbconfig.override(dbuser=config.launchpad.dbuser)
+        con = sqlbase.connect(isolation=sqlbase.ISOLATION_LEVEL_SERIALIZABLE)
+        who, _, how = self.examineConnection(con)
+        self.assertEqual(("launchpad_main", "serializable"), (who, how))
+
+    def getCurrentUser(self, con: connection) -> None:
+        with con.cursor() as cur:
+            cur.execute("SELECT current_user")
+            return cur.fetchone()[0]
+
+    def test_refuses_connstring_with_user(self):
+        connstr = "dbname=%s user=foo" % DatabaseLayer._db_fixture.dbname
+        dbconfig.override(rw_main_primary=connstr)
+        self.assertRaisesWithContent(
+            AssertionError,
+            "Database username must not be specified in connection string "
+            "(%s)." % connstr,
+            sqlbase.connect,
+        )
 
+    def test_refuses_connstring_uri_with_user(self):
+        connstr = "postgresql://foo@/%s" % DatabaseLayer._db_fixture.dbname
+        dbconfig.override(rw_main_primary=connstr)
+        self.assertRaisesWithContent(
+            AssertionError,
+            "Database username must not be specified in connection string "
+            "(%s)." % connstr,
+            sqlbase.connect,
+        )
 
-if __name__ == "__main__":
-    runner = unittest.TextTestRunner()
-    runner.run(test_suite())
+    def test_accepts_connstring_without_user(self):
+        connstr = "dbname=%s" % DatabaseLayer._db_fixture.dbname
+        dbconfig.override(rw_main_primary=connstr, dbuser="ro")
+        con = sqlbase.connect()
+        self.assertEqual(
+            {"dbname": DatabaseLayer._db_fixture.dbname, "user": "ro"},
+            parse_dsn(con.dsn),
+        )
+        self.assertEqual("ro", self.getCurrentUser(con))
+
+    def test_accepts_connstring_uri_without_user(self):
+        connstr = "postgresql:///%s" % DatabaseLayer._db_fixture.dbname
+        dbconfig.override(rw_main_primary=connstr, dbuser="ro")
+        con = sqlbase.connect()
+        self.assertEqual(
+            {"dbname": DatabaseLayer._db_fixture.dbname, "user": "ro"},
+            parse_dsn(con.dsn),
+        )
+        self.assertEqual("ro", self.getCurrentUser(con))
+
+    def test_set_role_after_connecting_refuses_connstring_without_user(self):
+        connstr = "dbname=%s" % DatabaseLayer._db_fixture.dbname
+        dbconfig.override(
+            set_role_after_connecting=True,
+            rw_main_primary=connstr,
+            dbuser="read",
+        )
+        self.assertRaisesWithContent(
+            AssertionError,
+            "With set_role_after_connecting, database username must be "
+            "specified in connection string (%s)." % connstr,
+            sqlbase.connect,
+        )
+
+    def test_set_role_after_connecting_refuses_connstring_uri_without_user(
+        self,
+    ):
+        connstr = "postgresql:///%s" % DatabaseLayer._db_fixture.dbname
+        dbconfig.override(
+            set_role_after_connecting=True,
+            rw_main_primary=connstr,
+            dbuser="read",
+        )
+        self.assertRaisesWithContent(
+            AssertionError,
+            "With set_role_after_connecting, database username must be "
+            "specified in connection string (%s)." % connstr,
+            sqlbase.connect,
+        )
+
+    def test_set_role_after_connecting_accepts_connstring_with_user(self):
+        connstr = "dbname=%s user=ro" % DatabaseLayer._db_fixture.dbname
+        dbconfig.override(
+            set_role_after_connecting=True,
+            rw_main_primary=connstr,
+            dbuser="read",
+        )
+        con = sqlbase.connect()
+        self.assertEqual(
+            {"dbname": DatabaseLayer._db_fixture.dbname, "user": "ro"},
+            parse_dsn(con.dsn),
+        )
+        self.assertEqual("read", self.getCurrentUser(con))
+
+    def test_set_role_after_connecting_accepts_connstring_uri_with_user(self):
+        connstr = "postgresql://ro@/%s" % DatabaseLayer._db_fixture.dbname
+        dbconfig.override(
+            set_role_after_connecting=True,
+            rw_main_primary=connstr,
+            dbuser="read",
+        )
+        con = sqlbase.connect()
+        self.assertEqual(
+            {"dbname": DatabaseLayer._db_fixture.dbname, "user": "ro"},
+            parse_dsn(con.dsn),
+        )
+        self.assertEqual("read", self.getCurrentUser(con))
+
+    def test_set_role_after_connecting_not_member(self):
+        connstr = "dbname=%s user=ro" % DatabaseLayer._db_fixture.dbname
+        dbconfig.override(
+            set_role_after_connecting=True,
+            rw_main_primary=connstr,
+            dbuser="launchpad_main",
+        )
+        self.assertRaises(InsufficientPrivilege, sqlbase.connect)
+
+
+def test_suite():
+    suite = unittest.TestSuite()
+    loader = unittest.TestLoader()
+    suite.addTest(loader.loadTestsFromTestCase(TestConnect))
+    optionflags = ELLIPSIS | NORMALIZE_WHITESPACE | REPORT_NDIFF
+    suite.addTest(doctest.DocTestSuite(sqlbase, optionflags=optionflags))
+    return suite
diff --git a/lib/lp/services/librarianserver/tests/test_gc.py b/lib/lp/services/librarianserver/tests/test_gc.py
index 41087cc..a75d9cf 100644
--- a/lib/lp/services/librarianserver/tests/test_gc.py
+++ b/lib/lp/services/librarianserver/tests/test_gc.py
@@ -90,10 +90,7 @@ class TestLibrarianGarbageCollectionBase:
                 content.filesize = len(content_bytes)
         transaction.commit()
 
-        self.con = connect(
-            user=config.librarian_gc.dbuser,
-            isolation=ISOLATION_LEVEL_AUTOCOMMIT,
-        )
+        self.con = connect(isolation=ISOLATION_LEVEL_AUTOCOMMIT)
 
     def tearDown(self):
         self.con.rollback()
@@ -1363,10 +1360,7 @@ class TestBlobCollection(TestCase):
         switch_dbuser(config.librarian_gc.dbuser)
 
         # Open a connection for our test
-        self.con = connect(
-            user=config.librarian_gc.dbuser,
-            isolation=ISOLATION_LEVEL_AUTOCOMMIT,
-        )
+        self.con = connect(isolation=ISOLATION_LEVEL_AUTOCOMMIT)
 
         self.patch(librariangc, "log", BufferLogger())
 
diff --git a/lib/lp/services/webapp/adapter.py b/lib/lp/services/webapp/adapter.py
index 2e16397..a9910a5 100644
--- a/lib/lp/services/webapp/adapter.py
+++ b/lib/lp/services/webapp/adapter.py
@@ -422,6 +422,9 @@ class LaunchpadDatabase(Postgres):
         self.name = uri.database
 
     def raw_connect(self):
+        # XXX cjwatson 2023-03-28: Can we rework this on top of
+        # lp.services.database.sqlbase.connect somehow so that we only need
+        # one implementation of the set_role_after_connecting logic?
         try:
             realm, flavor = self._uri.database.split("-")
         except ValueError:
diff --git a/lib/lp/testing/pgsql.py b/lib/lp/testing/pgsql.py
index 7069d83..68aa01f 100644
--- a/lib/lp/testing/pgsql.py
+++ b/lib/lp/testing/pgsql.py
@@ -126,6 +126,12 @@ class CursorWrapper:
             CursorWrapper.last_executed_sql.append(args[0])
         return self.real_cursor.execute(*args, **kwargs)
 
+    def __enter__(self):
+        return self.real_cursor.__enter__()
+
+    def __exit__(self, exc_type, exc_value, exc_tb):
+        return self.real_cursor.__exit__(exc_type, exc_value, exc_tb)
+
     def __getattr__(self, key):
         return getattr(self.real_cursor, key)
 

Follow ups