Skip to content

Commit 43f2c80

Browse files
committed
Add SerializationContext.clone method. Add pandas_serialization_context member that uses pickle for NumPy arrays with unsupported tensor types
Change-Id: Ia70c26954ff9ab3af435281bcafbf298c8c0cf28
1 parent c944023 commit 43f2c80

4 files changed

Lines changed: 53 additions & 17 deletions

File tree

python/pyarrow/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@
125125
localfs = LocalFileSystem.get_instance()
126126

127127
from pyarrow.serialization import (_default_serialization_context,
128+
pandas_serialization_context,
128129
register_default_serialization_handlers)
129130

130131
import pyarrow.types as types

python/pyarrow/serialization.pxi

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,23 @@ cdef class SerializationContext:
5757
self.custom_serializers = dict()
5858
self.custom_deserializers = dict()
5959

60+
def clone(self):
61+
"""
62+
Return copy of this SerializationContext
63+
64+
Returns
65+
-------
66+
clone : SerializationContext
67+
"""
68+
result = SerializationContext()
69+
result.type_to_type_id = self.type_to_type_id.copy()
70+
result.whitelisted_types = self.whitelisted_types.copy()
71+
result.types_to_pickle = self.types_to_pickle.copy()
72+
result.custom_serializers = self.custom_serializers.copy()
73+
result.custom_deserializers = self.custom_deserializers.copy()
74+
75+
return result
76+
6077
def register_type(self, type_, type_id,
6178
custom_serializer=None, custom_deserializer=None):
6279
"""EXPERIMENTAL: Add type to the list of types we can serialize.

python/pyarrow/serialization.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,28 @@
3131
cloudpickle = pickle
3232

3333

34+
# ----------------------------------------------------------------------
35+
# Set up serialization for numpy with dtype object (primitive types are
36+
# handled efficiently with Arrow's Tensor facilities, see
37+
# python_to_arrow.cc)
38+
39+
def _serialize_numpy_array_list(obj):
40+
return obj.tolist(), obj.dtype.str
41+
42+
43+
def _deserialize_numpy_array_list(data):
44+
return np.array(data[0], dtype=np.dtype(data[1]))
45+
46+
47+
def _serialize_numpy_array_pickle(obj):
48+
pickled = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
49+
return frombuffer(pickled)
50+
51+
52+
def _deserialize_numpy_array_pickle(data):
53+
return pickle.loads(memoryview(data))
54+
55+
3456
def register_default_serialization_handlers(serialization_context):
3557

3658
# ----------------------------------------------------------------------
@@ -81,22 +103,10 @@ def _deserialize_default_dict(data):
81103
custom_serializer=cloudpickle.dumps,
82104
custom_deserializer=cloudpickle.loads)
83105

84-
# ----------------------------------------------------------------------
85-
# Set up serialization for numpy with dtype object (primitive types are
86-
# handled efficiently with Arrow's Tensor facilities, see
87-
# python_to_arrow.cc)
88-
89-
def _serialize_numpy_array(obj):
90-
pickled = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
91-
return frombuffer(pickled)
92-
93-
def _deserialize_numpy_array(data):
94-
return pickle.loads(memoryview(data))
95-
96106
serialization_context.register_type(
97107
np.ndarray, 'np.array',
98-
custom_serializer=_serialize_numpy_array,
99-
custom_deserializer=_deserialize_numpy_array)
108+
custom_serializer=_serialize_numpy_array_list,
109+
custom_deserializer=_deserialize_numpy_array_list)
100110

101111
# ----------------------------------------------------------------------
102112
# Set up serialization for pandas Series and DataFrame
@@ -155,3 +165,10 @@ def _deserialize_torch_tensor(data):
155165

156166

157167
register_default_serialization_handlers(_default_serialization_context)
168+
169+
pandas_serialization_context = _default_serialization_context.clone()
170+
171+
pandas_serialization_context.register_type(
172+
np.ndarray, 'np.array',
173+
custom_serializer=_serialize_numpy_array_pickle,
174+
custom_deserializer=_deserialize_numpy_array_pickle)

python/pyarrow/tests/test_serialization.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,11 @@ def make_serialization_context():
212212
serialization_context = make_serialization_context()
213213

214214

215-
def serialization_roundtrip(value, f):
215+
def serialization_roundtrip(value, f, ctx=serialization_context):
216216
f.seek(0)
217-
pa.serialize_to(value, f, serialization_context)
217+
pa.serialize_to(value, f, ctx)
218218
f.seek(0)
219-
result = pa.deserialize_from(f, None, serialization_context)
219+
result = pa.deserialize_from(f, None, ctx)
220220
assert_equal(value, result)
221221

222222
_check_component_roundtrip(value)
@@ -249,6 +249,7 @@ def test_primitive_serialization(large_memory_map):
249249
with pa.memory_map(large_memory_map, mode="r+") as mmap:
250250
for obj in PRIMITIVE_OBJECTS:
251251
serialization_roundtrip(obj, mmap)
252+
serialization_roundtrip(obj, mmap, pa.pandas_serialization_context)
252253

253254

254255
def test_serialize_to_buffer():

0 commit comments

Comments
 (0)