Blob Blame History Raw
From adc9461c316b5e6f693d362140bf5483aa77ad81 Mon Sep 17 00:00:00 2001
From: Thomas A Caswell <tcaswell@gmail.com>
Date: Mon, 19 Jun 2023 21:28:02 -0400
Subject: [PATCH 7/8] MNT: py312 deprecates pickling objects in itertools

Signed-off-by: Elliott Sales de Andrade <quantum.analyst@gmail.com>
---
 lib/matplotlib/cbook/__init__.py   |  3 +++
 lib/matplotlib/figure.py           | 11 +++++++++++
 lib/matplotlib/tests/test_cbook.py |  7 +++++++
 3 files changed, 21 insertions(+)

diff --git a/lib/matplotlib/cbook/__init__.py b/lib/matplotlib/cbook/__init__.py
index 1e51f6a834..b0d06cddf6 100644
--- a/lib/matplotlib/cbook/__init__.py
+++ b/lib/matplotlib/cbook/__init__.py
@@ -206,9 +206,11 @@ class CallbackRegistry:
                           for s, d in self.callbacks.items()},
             # It is simpler to reconstruct this from callbacks in __setstate__.
             "_func_cid_map": None,
+            "_cid_gen": next(self._cid_gen)
         }
 
     def __setstate__(self, state):
+        cid_count = state.pop('_cid_gen')
         vars(self).update(state)
         self.callbacks = {
             s: {cid: _weak_or_strong_ref(func, self._remove_proxy)
@@ -217,6 +219,7 @@ class CallbackRegistry:
         self._func_cid_map = {
             s: {proxy: cid for cid, proxy in d.items()}
             for s, d in self.callbacks.items()}
+        self._cid_gen = itertools.count(cid_count)
 
     def connect(self, signal, func):
         """Register *func* to be called when signal *signal* is generated."""
diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py
index c6df929e04..b64c677b5d 100644
--- a/lib/matplotlib/figure.py
+++ b/lib/matplotlib/figure.py
@@ -106,6 +106,17 @@ class _AxesStack:
         """Return the active axes, or None if the stack is empty."""
         return max(self._axes, key=self._axes.__getitem__, default=None)
 
+    def __getstate__(self):
+        return {
+            **vars(self),
+            "_counter": max(self._axes.values(), default=0)
+        }
+
+    def __setstate__(self, state):
+        next_counter = state.pop('_counter')
+        vars(self).update(state)
+        self._counter = itertools.count(next_counter)
+
 
 class SubplotParams:
     """
diff --git a/lib/matplotlib/tests/test_cbook.py b/lib/matplotlib/tests/test_cbook.py
index aa5c999b70..da3868b0f8 100644
--- a/lib/matplotlib/tests/test_cbook.py
+++ b/lib/matplotlib/tests/test_cbook.py
@@ -207,6 +207,13 @@ class Test_callback_registry:
         assert self.callbacks._func_cid_map != {}
         assert self.callbacks.callbacks != {}
 
+    def test_cid_restore(self):
+        cb = cbook.CallbackRegistry()
+        cb.connect('a', lambda: None)
+        cb2 = pickle.loads(pickle.dumps(cb))
+        cid = cb2.connect('c', lambda: None)
+        assert cid == 1
+
     @pytest.mark.parametrize('pickle', [True, False])
     def test_callback_complete(self, pickle):
         # ensure we start with an empty registry
-- 
2.41.0