Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,13 @@ def __hash__(self):
"pyspark.pandas.typedef.typehints",
# unittests
"pyspark.pandas.tests.test_dataframe",
"pyspark.pandas.tests.test_config",

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I sort these tests in my last PR? That would make resolving other PRs' conflicts much easier.

"pyspark.pandas.tests.test_default_index",
"pyspark.pandas.tests.test_extension",
"pyspark.pandas.tests.test_internal",
"pyspark.pandas.tests.test_numpy_compat",
"pyspark.pandas.tests.test_typedef",
"pyspark.pandas.tests.test_utils",
"pyspark.pandas.tests.test_dataframe_conversion",
"pyspark.pandas.tests.test_dataframe_spark_io",
"pyspark.pandas.tests.test_frame_spark",
Expand Down
155 changes: 155 additions & 0 deletions python/pyspark/pandas/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark import pandas as ps
from pyspark.pandas import config
from pyspark.pandas.config import Option, DictWrapper
from pyspark.pandas.testing.utils import ReusedSQLTestCase


class ConfigTest(ReusedSQLTestCase):
def setUp(self):
config._options_dict["test.config"] = Option(key="test.config", doc="", default="default")

config._options_dict["test.config.list"] = Option(
key="test.config.list", doc="", default=[], types=list
)
config._options_dict["test.config.float"] = Option(
key="test.config.float", doc="", default=1.2, types=float
)

config._options_dict["test.config.int"] = Option(
key="test.config.int",
doc="",
default=1,
types=int,
check_func=(lambda v: v > 0, "bigger then 0"),
)
config._options_dict["test.config.int.none"] = Option(
key="test.config.int", doc="", default=None, types=(int, type(None))
)

def tearDown(self):
ps.reset_option("test.config")
del config._options_dict["test.config"]
del config._options_dict["test.config.list"]
del config._options_dict["test.config.float"]
del config._options_dict["test.config.int"]
del config._options_dict["test.config.int.none"]

def test_get_set_reset_option(self):
self.assertEqual(ps.get_option("test.config"), "default")

ps.set_option("test.config", "value")
self.assertEqual(ps.get_option("test.config"), "value")

ps.reset_option("test.config")
self.assertEqual(ps.get_option("test.config"), "default")

def test_get_set_reset_option_different_types(self):
ps.set_option("test.config.list", [1, 2, 3, 4])
self.assertEqual(ps.get_option("test.config.list"), [1, 2, 3, 4])

ps.set_option("test.config.float", 5.0)
self.assertEqual(ps.get_option("test.config.float"), 5.0)

ps.set_option("test.config.int", 123)
self.assertEqual(ps.get_option("test.config.int"), 123)

self.assertEqual(ps.get_option("test.config.int.none"), None) # default None
ps.set_option("test.config.int.none", 123)
self.assertEqual(ps.get_option("test.config.int.none"), 123)
ps.set_option("test.config.int.none", None)
self.assertEqual(ps.get_option("test.config.int.none"), None)

def test_different_types(self):
with self.assertRaisesRegex(ValueError, "was <class 'int'>"):
ps.set_option("test.config.list", 1)

with self.assertRaisesRegex(ValueError, "however, expected types are"):
ps.set_option("test.config.float", "abc")

with self.assertRaisesRegex(ValueError, "[<class 'int'>]"):
ps.set_option("test.config.int", "abc")

with self.assertRaisesRegex(ValueError, "(<class 'int'>, <class 'NoneType'>)"):
ps.set_option("test.config.int.none", "abc")

def test_check_func(self):
with self.assertRaisesRegex(ValueError, "bigger then 0"):
ps.set_option("test.config.int", -1)

def test_unknown_option(self):
with self.assertRaisesRegex(config.OptionError, "No such option"):
ps.get_option("unknown")

with self.assertRaisesRegex(config.OptionError, "Available options"):
ps.set_option("unknown", "value")

with self.assertRaisesRegex(config.OptionError, "test.config"):
ps.reset_option("unknown")

def test_namespace_access(self):
try:
self.assertEqual(ps.options.compute.max_rows, ps.get_option("compute.max_rows"))
ps.options.compute.max_rows = 0
self.assertEqual(ps.options.compute.max_rows, 0)
self.assertTrue(isinstance(ps.options.compute, DictWrapper))

wrapper = ps.options.compute
self.assertEqual(wrapper.max_rows, ps.get_option("compute.max_rows"))
wrapper.max_rows = 1000
self.assertEqual(ps.options.compute.max_rows, 1000)

self.assertRaisesRegex(config.OptionError, "No such option", lambda: ps.options.compu)
self.assertRaisesRegex(
config.OptionError, "No such option", lambda: ps.options.compute.max
)
self.assertRaisesRegex(
config.OptionError, "No such option", lambda: ps.options.max_rows1
)

with self.assertRaisesRegex(config.OptionError, "No such option"):
ps.options.compute.max = 0
with self.assertRaisesRegex(config.OptionError, "No such option"):
ps.options.compute = 0
with self.assertRaisesRegex(config.OptionError, "No such option"):
ps.options.com = 0
finally:
ps.reset_option("compute.max_rows")

def test_dir_options(self):
self.assertTrue("compute.default_index_type" in dir(ps.options))
self.assertTrue("plotting.sample_ratio" in dir(ps.options))

self.assertTrue("default_index_type" in dir(ps.options.compute))
self.assertTrue("sample_ratio" not in dir(ps.options.compute))

self.assertTrue("default_index_type" not in dir(ps.options.plotting))
self.assertTrue("sample_ratio" in dir(ps.options.plotting))


if __name__ == "__main__":
import unittest
from pyspark.pandas.tests.test_config import * # noqa: F401

try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
51 changes: 51 additions & 0 deletions python/pyspark/pandas/tests/test_default_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import pandas as pd

from pyspark import pandas as ps
from pyspark.pandas.testing.utils import ReusedSQLTestCase


class DefaultIndexTest(ReusedSQLTestCase):
def test_default_index_sequence(self):
with ps.option_context("compute.default_index_type", "sequence"):
sdf = self.spark.range(1000)
self.assert_eq(ps.DataFrame(sdf), pd.DataFrame({"id": list(range(1000))}))

def test_default_index_distributed_sequence(self):
with ps.option_context("compute.default_index_type", "distributed-sequence"):
sdf = self.spark.range(1000)
self.assert_eq(ps.DataFrame(sdf), pd.DataFrame({"id": list(range(1000))}))

def test_default_index_distributed(self):
with ps.option_context("compute.default_index_type", "distributed"):
sdf = self.spark.range(1000)
pdf = ps.DataFrame(sdf).to_pandas()
self.assertEqual(len(set(pdf.index)), len(pdf))


if __name__ == "__main__":
import unittest
from pyspark.pandas.tests.test_default_index import * # noqa: F401

try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
151 changes: 151 additions & 0 deletions python/pyspark/pandas/tests/test_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import contextlib

import numpy as np
import pandas as pd

from pyspark import pandas as ps
from pyspark.pandas.testing.utils import assert_produces_warning, ReusedSQLTestCase
from pyspark.pandas.extensions import (
register_dataframe_accessor,
register_series_accessor,
register_index_accessor,
)


@contextlib.contextmanager
def ensure_removed(obj, attr):
"""
Ensure attribute attached to 'obj' during testing is removed in the end
"""
try:
yield

finally:
try:
delattr(obj, attr)
except AttributeError:
pass


class CustomAccessor:
def __init__(self, obj):
self.obj = obj
self.item = "item"

@property
def prop(self):
return self.item

def method(self):
return self.item

def check_length(self, col=None):
if type(self.obj) == ps.DataFrame or col is not None:
return len(self.obj[col])
else:
try:
return len(self.obj)
except Exception as e:
raise ValueError(str(e))


class ExtensionTest(ReusedSQLTestCase):
@property
def pdf(self):
return pd.DataFrame(
{"a": [1, 2, 3, 4, 5, 6, 7, 8, 9], "b": [4, 5, 6, 3, 2, 1, 0, 0, 0]},
index=np.random.rand(9),
)

@property
def kdf(self):
return ps.from_pandas(self.pdf)

@property
def accessor(self):
return CustomAccessor(self.kdf)

def test_setup(self):
self.assertEqual("item", self.accessor.item)

def test_dataframe_register(self):
with ensure_removed(ps.DataFrame, "test"):
register_dataframe_accessor("test")(CustomAccessor)
assert self.kdf.test.prop == "item"
assert self.kdf.test.method() == "item"
assert len(self.kdf["a"]) == self.kdf.test.check_length("a")

def test_series_register(self):
with ensure_removed(ps.Series, "test"):
register_series_accessor("test")(CustomAccessor)
assert self.kdf.a.test.prop == "item"
assert self.kdf.a.test.method() == "item"
assert self.kdf.a.test.check_length() == len(self.kdf["a"])

def test_index_register(self):
with ensure_removed(ps.Index, "test"):
register_index_accessor("test")(CustomAccessor)
assert self.kdf.index.test.prop == "item"
assert self.kdf.index.test.method() == "item"
assert self.kdf.index.test.check_length() == self.kdf.index.size

def test_accessor_works(self):
register_series_accessor("test")(CustomAccessor)

s = ps.Series([1, 2])
assert s.test.obj is s
assert s.test.prop == "item"
assert s.test.method() == "item"

def test_overwrite_warns(self):
mean = ps.Series.mean
try:
with assert_produces_warning(UserWarning, raise_on_extra_warnings=False) as w:
register_series_accessor("mean")(CustomAccessor)
s = ps.Series([1, 2])
assert s.mean.prop == "item"
msg = str(w[0].message)
assert "mean" in msg
assert "CustomAccessor" in msg
assert "Series" in msg
finally:
ps.Series.mean = mean

def test_raises_attr_error(self):
with ensure_removed(ps.Series, "bad"):

class Bad:
def __init__(self, data):
raise AttributeError("whoops")

with self.assertRaises(AttributeError):
ps.Series([1, 2], dtype=object).bad


if __name__ == "__main__":
import unittest
from pyspark.pandas.tests.test_extension import * # noqa: F401

try:
import xmlrunner # type: ignore[import]
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Loading