Blob Blame History Raw
From f37269ae603bc0e09210379e03a476ff2193c74c Mon Sep 17 00:00:00 2001
From: Mike Bayer <mike_mp@zzzcomputing.com>
Date: Mon, 24 Aug 2020 12:00:04 -0400
Subject: [PATCH] Accommodate immutable URL api

SQLAlchemy 1.4 has modified the URL object to be immutable.
This patch makes the adjustments needed, using duck-typing
to check if the URL object is of the new style.  When
requirements reach SQLAlchemy 1.4 as the minimum required
version, these conditionals should be removed in favor of
the noted option in each.

Change-Id: Id2f0663b13ed0f81e91a8d44f73d8541015bf844
---
 oslo_db/sqlalchemy/engines.py                | 12 +++++++++++-
 oslo_db/sqlalchemy/provision.py              | 11 ++++++++++-
 oslo_db/tests/sqlalchemy/test_exc_filters.py | 12 ++++++++++--
 oslo_db/tests/sqlalchemy/test_sqlalchemy.py  | 20 +++++++++++++++-----
 4 files changed, 46 insertions(+), 9 deletions(-)

diff --git a/oslo_db/sqlalchemy/engines.py b/oslo_db/sqlalchemy/engines.py
index 25215b96..0e19c183 100644
--- a/oslo_db/sqlalchemy/engines.py
+++ b/oslo_db/sqlalchemy/engines.py
@@ -105,6 +105,14 @@ def _setup_logging(connection_debug=0):
 
 
 def _extend_url_parameters(url, connection_parameters):
+    # TODO(zzzeek): remove hasattr() conditional when SQLAlchemy 1.4 is the
+    # minimum version in requirements; call update_query_string()
+    # unconditionally
+    if hasattr(url, "update_query_string"):
+        return url.update_query_string(connection_parameters, append=True)
+
+    # TODO(zzzeek): remove the remainder of this method when SQLAlchemy 1.4
+    # is the minimum version in requirements
     for key, value in parse.parse_qs(
             connection_parameters).items():
         if key in url.query:
@@ -118,6 +126,8 @@ def _extend_url_parameters(url, connection_parameters):
         if len(value) == 1:
             url.query[key] = value[0]
 
+    return url
+
 
 def _vet_url(url):
     if "+" not in url.drivername and not url.drivername.startswith("sqlite"):
@@ -153,7 +163,7 @@ def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None,
     url = sqlalchemy.engine.url.make_url(sql_connection)
 
     if connection_parameters:
-        _extend_url_parameters(url, connection_parameters)
+        url = _extend_url_parameters(url, connection_parameters)
 
     _vet_url(url)
 
diff --git a/oslo_db/sqlalchemy/provision.py b/oslo_db/sqlalchemy/provision.py
index 2fc3a543..5addab45 100644
--- a/oslo_db/sqlalchemy/provision.py
+++ b/oslo_db/sqlalchemy/provision.py
@@ -495,7 +495,16 @@ def provisioned_database_url(self, base_url, ident):
         """
 
         url = sa_url.make_url(str(base_url))
-        url.database = ident
+
+        # TODO(zzzeek): remove hasattr() conditional in favor of "url.set()"
+        # when SQLAlchemy 1.4 is the minimum version in requirements
+        if hasattr(url, "set"):
+            url = url.set(database=ident)
+        else:
+            # TODO(zzzeek): remove when SQLAlchemy 1.4
+            # is the minimum version in requirements
+            url.database = ident
+
         return url
 
 
diff --git a/oslo_db/tests/sqlalchemy/test_exc_filters.py b/oslo_db/tests/sqlalchemy/test_exc_filters.py
index 0451c895..7d55bb74 100644
--- a/oslo_db/tests/sqlalchemy/test_exc_filters.py
+++ b/oslo_db/tests/sqlalchemy/test_exc_filters.py
@@ -386,8 +386,16 @@ def setUp(self):
         super(TestNonExistentDatabase, self).setUp()
 
         url = sqla_url.make_url(str(self.engine.url))
-        url.database = 'non_existent_database'
-        self.url = url
+
+        # TODO(zzzeek): remove hasattr() conditional in favor of "url.set()"
+        # when SQLAlchemy 1.4 is the minimum version in requirements
+        if hasattr(url, "set"):
+            self.url = url.set(database="non_existent_database")
+        else:
+            # TODO(zzzeek): remove when SQLAlchemy 1.4
+            # is the minimum version in requirements
+            url.database = 'non_existent_database'
+            self.url = url
 
     def test_raise(self):
         matched = self.assertRaises(
diff --git a/oslo_db/tests/sqlalchemy/test_sqlalchemy.py b/oslo_db/tests/sqlalchemy/test_sqlalchemy.py
index 080f418c..44b7453a 100644
--- a/oslo_db/tests/sqlalchemy/test_sqlalchemy.py
+++ b/oslo_db/tests/sqlalchemy/test_sqlalchemy.py
@@ -229,6 +229,16 @@ def _mock_create_engine(*arg, **kw):
             "oslo_db.sqlalchemy.engines.sqlalchemy.create_engine",
             side_effect=_mock_create_engine)
 
+    def _normalize_query_dict(self, qdict):
+        # SQLAlchemy 1.4 returns url.query as:
+        # immutabledict({k1: v1, k2: (v2a, v2b, ...), ...})
+        # that is with tuples not lists for multiparams
+
+        return {
+            k: list(v) if isinstance(v, tuple) else v
+            for k, v in qdict.items()
+        }
+
     def test_add_assorted_params(self):
         with self._fixture() as ce:
             engines.create_engine(
@@ -236,7 +246,7 @@ def test_add_assorted_params(self):
                 connection_parameters="foo=bar&bat=hoho&bat=param2")
 
         self.assertEqual(
-            ce.mock_calls[0][1][0].query,
+            self._normalize_query_dict(ce.mock_calls[0][1][0].query),
             {'bat': ['hoho', 'param2'], 'foo': 'bar'}
         )
 
@@ -247,7 +257,7 @@ def test_add_no_params(self):
 
         self.assertEqual(
             ce.mock_calls[0][1][0].query,
-            {}
+            self._normalize_query_dict({})
         )
 
     def test_combine_params(self):
@@ -260,7 +270,7 @@ def test_combine_params(self):
                                       "bind_host=192.168.1.5")
 
         self.assertEqual(
-            ce.mock_calls[0][1][0].query,
+            self._normalize_query_dict(ce.mock_calls[0][1][0].query),
             {
                 'bind_host': '192.168.1.5',
                 'charset': 'utf8',
@@ -280,7 +290,7 @@ def test_combine_multi_params(self):
                                       "bind_host=192.168.1.5")
 
         self.assertEqual(
-            ce.mock_calls[0][1][0].query,
+            self._normalize_query_dict(ce.mock_calls[0][1][0].query),
             {
                 'bind_host': '192.168.1.5',
                 'charset': 'utf8',
@@ -751,7 +761,7 @@ def test_warn_on_missing_driver(self):
         def warn_interpolate(msg, args):
             # test the interpolation itself to ensure the password
             # is concealed
-            warnings.warning(msg % args)
+            warnings.warning(msg % (args, ))
 
         with mock.patch(
                 "oslo_db.sqlalchemy.engines.LOG.warning",