Blob Blame History Raw
From 29cf9b8d63ef3437ba11aa29502af8773faa17a7 Mon Sep 17 00:00:00 2001
From: Paul Kehrer <paul.l.kehrer@gmail.com>
Date: Wed, 14 Apr 2021 13:15:57 -0500
Subject: [PATCH 3/4] switch to using EVP_PKEY_derive instead of DH_compute_key
 in DH (#5972)

* switch to using EVP_PKEY_derive instead of DH_compute_key in DH

Where checks are occurring is changing in OpenSSL 3.0 and this makes it
easier to be consistent (and is the API we should be using anyway). The
tests change because EVP_PKEY_derive now verifies that we have shared
parameters, which the test previously only verified by asserting that
the derived keys didn't match

* review feedback

* type ignores required for typeerror tests. some day i will remember this
---
 src/_cffi_src/openssl/dh.py                   |  1 -
 .../hazmat/backends/openssl/dh.py             | 57 ++++++++++++-------
 tests/hazmat/primitives/test_dh.py            | 19 ++++---
 3 files changed, 45 insertions(+), 32 deletions(-)

diff --git a/src/_cffi_src/openssl/dh.py b/src/_cffi_src/openssl/dh.py
index 979dafa9..50989e45 100644
--- a/src/_cffi_src/openssl/dh.py
+++ b/src/_cffi_src/openssl/dh.py
@@ -18,7 +18,6 @@ DH *DH_new(void);
 void DH_free(DH *);
 int DH_size(const DH *);
 int DH_generate_key(DH *);
-int DH_compute_key(unsigned char *, const BIGNUM *, DH *);
 DH *DHparams_dup(DH *);
 
 /* added in 1.1.0 when the DH struct was opaqued */
diff --git a/src/cryptography/hazmat/backends/openssl/dh.py b/src/cryptography/hazmat/backends/openssl/dh.py
index 65ddaeec..b928f024 100644
--- a/src/cryptography/hazmat/backends/openssl/dh.py
+++ b/src/cryptography/hazmat/backends/openssl/dh.py
@@ -127,35 +127,48 @@ class _DHPrivateKey(dh.DHPrivateKey):
         )
 
     def exchange(self, peer_public_key: dh.DHPublicKey) -> bytes:
-        buf = self._backend._ffi.new("unsigned char[]", self._key_size_bytes)
-        pub_key = self._backend._ffi.new("BIGNUM **")
-        self._backend._lib.DH_get0_key(
-            peer_public_key._dh_cdata,  # type: ignore[attr-defined]
-            pub_key,
-            self._backend._ffi.NULL,
+        if not isinstance(peer_public_key, _DHPublicKey):
+            raise TypeError("peer_public_key must be a DHPublicKey")
+
+        ctx = self._backend._lib.EVP_PKEY_CTX_new(
+            self._evp_pkey, self._backend._ffi.NULL
         )
-        self._backend.openssl_assert(pub_key[0] != self._backend._ffi.NULL)
-        res = self._backend._lib.DH_compute_key(
-            buf, pub_key[0], self._dh_cdata
+        self._backend.openssl_assert(ctx != self._backend._ffi.NULL)
+        ctx = self._backend._ffi.gc(ctx, self._backend._lib.EVP_PKEY_CTX_free)
+        res = self._backend._lib.EVP_PKEY_derive_init(ctx)
+        self._backend.openssl_assert(res == 1)
+        res = self._backend._lib.EVP_PKEY_derive_set_peer(
+            ctx, peer_public_key._evp_pkey
+        )
+        # Invalid kex errors here in OpenSSL 3.0 because checks were moved
+        # to EVP_PKEY_derive_set_peer
+        self._exchange_assert(res == 1)
+        keylen = self._backend._ffi.new("size_t *")
+        res = self._backend._lib.EVP_PKEY_derive(
+            ctx, self._backend._ffi.NULL, keylen
         )
+        # Invalid kex errors here in OpenSSL < 3
+        self._exchange_assert(res == 1)
+        self._backend.openssl_assert(keylen[0] > 0)
+        buf = self._backend._ffi.new("unsigned char[]", keylen[0])
+        res = self._backend._lib.EVP_PKEY_derive(ctx, buf, keylen)
+        self._backend.openssl_assert(res == 1)
 
-        if res == -1:
+        key = self._backend._ffi.buffer(buf, keylen[0])[:]
+        pad = self._key_size_bytes - len(key)
+
+        if pad > 0:
+            key = (b"\x00" * pad) + key
+
+        return key
+
+    def _exchange_assert(self, ok):
+        if not ok:
             errors_with_text = self._backend._consume_errors_with_text()
             raise ValueError(
-                "Error computing shared key. Public key is likely invalid "
-                "for this exchange.",
+                "Error computing shared key.",
                 errors_with_text,
             )
-        else:
-            self._backend.openssl_assert(res >= 1)
-
-            key = self._backend._ffi.buffer(buf)[:res]
-            pad = self._key_size_bytes - len(key)
-
-            if pad > 0:
-                key = (b"\x00" * pad) + key
-
-            return key
 
     def public_key(self) -> dh.DHPublicKey:
         dh_cdata = _dh_params_dup(self._dh_cdata, self._backend)
diff --git a/tests/hazmat/primitives/test_dh.py b/tests/hazmat/primitives/test_dh.py
index bb29919f..2914f7e7 100644
--- a/tests/hazmat/primitives/test_dh.py
+++ b/tests/hazmat/primitives/test_dh.py
@@ -296,6 +296,12 @@ class TestDH(object):
         assert isinstance(key.private_numbers(), dh.DHPrivateNumbers)
         assert isinstance(key.parameters(), dh.DHParameters)
 
+    def test_exchange_wrong_type(self, backend):
+        parameters = FFDH3072_P.parameters(backend)
+        key1 = parameters.generate_private_key()
+        with pytest.raises(TypeError):
+            key1.exchange(b"invalidtype")  # type: ignore[arg-type]
+
     def test_exchange(self, backend):
         parameters = FFDH3072_P.parameters(backend)
         assert isinstance(parameters, dh.DHParameters)
@@ -386,16 +392,11 @@ class TestDH(object):
         key2 = private2.private_key(backend)
         pub_key2 = key2.public_key()
 
-        if pub_key2.public_numbers().y >= parameters1.p:
-            with pytest.raises(ValueError):
-                key1.exchange(pub_key2)
-        else:
-            symkey1 = key1.exchange(pub_key2)
-            assert symkey1
-
-            symkey2 = key2.exchange(pub_key1)
+        with pytest.raises(ValueError):
+            key1.exchange(pub_key2)
 
-            assert symkey1 != symkey2
+        with pytest.raises(ValueError):
+            key2.exchange(pub_key1)
 
     @pytest.mark.skip_fips(reason="key_size too small for FIPS")
     @pytest.mark.supported(
-- 
2.31.1