Skip to content

Commit 09eac5d

Browse files
jorisvandenbosschepitrou
authored andcommitted
ARROW-7168: [Python] Respect the specified dictionary type for pd.Categorical conversion
https://issues.apache.org/jira/browse/ARROW-7168 This change ensures that if you specify a `type` in `pa.array`, we ensure the output actually has this type when converting to dictionary array (as we also do for other types). The PR now implements this change, but we might want to do this with a deprecation first, as this can break people's code. Closes #5866 from jorisvandenbossche/ARROW-7168-categorical-specified-type and squashes the following commits: 39ff8e8 <Joris Van den Bossche> more python 2 e4dbb2c <Joris Van den Bossche> try fix python 2 003e653 <Joris Van den Bossche> for now use deprecation warnings instead of error bfb8237 <Joris Van den Bossche> additional tests 3535a56 <Joris Van den Bossche> ARROW-7168: Respect the specified dictionary type when converting pd.Categorical Authored-by: Joris Van den Bossche <jorisvandenbossche@gmail.com> Signed-off-by: Antoine Pitrou <antoine@python.org>
1 parent ea75dfd commit 09eac5d

2 files changed

Lines changed: 135 additions & 9 deletions

File tree

python/pyarrow/array.pxi

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import warnings
19+
1820

1921
cdef _sequence_to_array(object sequence, object mask, object size,
2022
DataType type, CMemoryPool* pool, c_bool from_pandas):
@@ -84,6 +86,19 @@ cdef _ndarray_to_array(object values, object mask, DataType type,
8486
return pyarrow_wrap_array(chunked_out.get().chunk(0))
8587

8688

89+
cdef _codes_to_indices(object codes, object mask, DataType type,
90+
MemoryPool memory_pool):
91+
"""
92+
Convert the codes of a pandas Categorical to indices for a pyarrow
93+
DictionaryArray, taking into account missing values + mask
94+
"""
95+
if mask is None:
96+
mask = codes == -1
97+
else:
98+
mask = mask | (codes == -1)
99+
return array(codes, mask=mask, type=type, memory_pool=memory_pool)
100+
101+
87102
def _handle_arrow_array_protocol(obj, type, mask, size):
88103
if mask is not None or size is not None:
89104
raise ValueError(
@@ -199,11 +214,50 @@ def array(object obj, type=None, mask=None, size=None, from_pandas=None,
199214
if hasattr(values, '__arrow_array__'):
200215
return _handle_arrow_array_protocol(values, type, mask, size)
201216
elif pandas_api.is_categorical(values):
217+
if type is not None:
218+
if type.id != Type_DICTIONARY:
219+
return _ndarray_to_array(
220+
np.asarray(values), mask, type, c_from_pandas, safe,
221+
pool)
222+
index_type = type.index_type
223+
value_type = type.value_type
224+
if values.ordered != type.ordered:
225+
warnings.warn(
226+
"The 'ordered' flag of the passed categorical values "
227+
"does not match the 'ordered' of the specified type. "
228+
"Using the flag of the values, but in the future this "
229+
"mismatch will raise a ValueError.",
230+
FutureWarning, stacklevel=2)
231+
else:
232+
index_type = None
233+
value_type = None
234+
235+
indices = _codes_to_indices(
236+
values.codes, mask, index_type, memory_pool)
237+
try:
238+
dictionary = array(
239+
values.categories.values, type=value_type,
240+
memory_pool=memory_pool)
241+
except TypeError:
242+
# TODO when removing the deprecation warning, this whole
243+
# try/except can be removed (to bubble the TypeError of
244+
# the first array(..) call)
245+
if value_type is not None:
246+
warnings.warn(
247+
"The dtype of the 'categories' of the passed "
248+
"categorical values ({0}) does not match the "
249+
"specified type ({1}). For now ignoring the specified "
250+
"type, but in the future this mismatch will raise a "
251+
"TypeError".format(
252+
values.categories.dtype, value_type),
253+
FutureWarning, stacklevel=2)
254+
dictionary = array(
255+
values.categories.values, memory_pool=memory_pool)
256+
else:
257+
raise
258+
202259
return DictionaryArray.from_arrays(
203-
values.codes, values.categories.values,
204-
mask=mask, ordered=values.ordered,
205-
from_pandas=True, safe=safe,
206-
memory_pool=memory_pool)
260+
indices, dictionary, ordered=values.ordered, safe=safe)
207261
else:
208262
if pandas_api.have_pandas:
209263
values, type = pandas_api.compat.get_datetimetz_type(
@@ -1553,11 +1607,9 @@ cdef class DictionaryArray(Array):
15531607
_indices = indices
15541608
else:
15551609
if from_pandas:
1556-
if mask is None:
1557-
mask = indices == -1
1558-
else:
1559-
mask = mask | (indices == -1)
1560-
_indices = array(indices, mask=mask, memory_pool=memory_pool)
1610+
_indices = _codes_to_indices(indices, mask, None, memory_pool)
1611+
else:
1612+
_indices = array(indices, mask=mask, memory_pool=memory_pool)
15611613

15621614
if isinstance(dictionary, Array):
15631615
_dictionary = dictionary

python/pyarrow/tests/test_pandas.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3212,6 +3212,80 @@ def test_variable_dictionary_to_pandas():
32123212
tm.assert_series_equal(result_dense, expected_dense)
32133213

32143214

3215+
def test_dictionary_from_pandas():
3216+
cat = pd.Categorical([u'a', u'b', u'a'])
3217+
expected_type = pa.dictionary(pa.int8(), pa.string())
3218+
3219+
result = pa.array(cat)
3220+
assert result.to_pylist() == ['a', 'b', 'a']
3221+
assert result.type.equals(expected_type)
3222+
3223+
# with missing values in categorical
3224+
cat = pd.Categorical([u'a', u'b', None, u'a'])
3225+
3226+
result = pa.array(cat)
3227+
assert result.to_pylist() == ['a', 'b', None, 'a']
3228+
assert result.type.equals(expected_type)
3229+
3230+
# with additional mask
3231+
result = pa.array(cat, mask=np.array([False, False, False, True]))
3232+
assert result.to_pylist() == ['a', 'b', None, None]
3233+
assert result.type.equals(expected_type)
3234+
3235+
3236+
def test_dictionary_from_pandas_specified_type():
3237+
# ARROW-7168 - ensure specified type is always respected
3238+
3239+
# the same as cat = pd.Categorical(['a', 'b']) but explicit about dtypes
3240+
cat = pd.Categorical.from_codes(
3241+
np.array([0, 1], dtype='int8'), np.array(['a', 'b'], dtype=object))
3242+
3243+
# different index type -> allow this
3244+
# (the type of the 'codes' in pandas is not part of the data type)
3245+
typ = pa.dictionary(index_type=pa.int16(), value_type=pa.string())
3246+
result = pa.array(cat, type=typ)
3247+
assert result.type.equals(typ)
3248+
assert result.to_pylist() == ['a', 'b']
3249+
3250+
# mismatching values type -> raise error (for now a deprecation warning)
3251+
typ = pa.dictionary(index_type=pa.int8(), value_type=pa.int64())
3252+
with pytest.warns(FutureWarning):
3253+
result = pa.array(cat, type=typ)
3254+
assert result.to_pylist() == ['a', 'b']
3255+
3256+
# mismatching order -> raise error (for now a deprecation warning)
3257+
typ = pa.dictionary(
3258+
index_type=pa.int8(), value_type=pa.string(), ordered=True)
3259+
with pytest.warns(FutureWarning, match="The 'ordered' flag of the passed"):
3260+
result = pa.array(cat, type=typ)
3261+
assert result.to_pylist() == ['a', 'b']
3262+
3263+
# with mask
3264+
typ = pa.dictionary(index_type=pa.int16(), value_type=pa.string())
3265+
result = pa.array(cat, type=typ, mask=np.array([False, True]))
3266+
assert result.type.equals(typ)
3267+
assert result.to_pylist() == ['a', None]
3268+
3269+
# empty categorical -> be flexible in values type to allow
3270+
cat = pd.Categorical([])
3271+
3272+
typ = pa.dictionary(index_type=pa.int8(), value_type=pa.string())
3273+
result = pa.array(cat, type=typ)
3274+
assert result.type.equals(typ)
3275+
assert result.to_pylist() == []
3276+
typ = pa.dictionary(index_type=pa.int8(), value_type=pa.int64())
3277+
result = pa.array(cat, type=typ)
3278+
assert result.type.equals(typ)
3279+
assert result.to_pylist() == []
3280+
3281+
# passing non-dictionary type
3282+
cat = pd.Categorical(['a', 'b'])
3283+
result = pa.array(cat, type=pa.string())
3284+
expected = pa.array(['a', 'b'], type=pa.string())
3285+
assert result.equals(expected)
3286+
assert result.to_pylist() == ['a', 'b']
3287+
3288+
32153289
# ----------------------------------------------------------------------
32163290
# Array protocol in pandas conversions tests
32173291

0 commit comments

Comments
 (0)