Skip to content

Commit d9816b7

Browse files
committed
[SPARK-51873][ML] For OneVsRest algorithm, allow using save / load to replace cache
### What changes were proposed in this pull request? For OneVsRest algorithm, allow using save / load to replace cache ### Why are the changes needed? Dataframe persisting is not well supported in certain cases, so we need a replacement. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50672 from WeichenXu123/one-vs-rest-cache. Authored-by: Weichen Xu <[email protected]> Signed-off-by: Weichen Xu <[email protected]>
1 parent 632e681 commit d9816b7

File tree

4 files changed

+113
-75
lines changed

4 files changed

+113
-75
lines changed

python/pyspark/ml/classification.py

Lines changed: 64 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
try_remote_read,
9090
try_remote_write,
9191
try_remote_attribute_relation,
92+
_cache_spark_dataset,
9293
)
9394
from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel, JavaWrapper
9495
from pyspark.ml.common import inherit_doc
@@ -3603,46 +3604,47 @@ def _fit(self, dataset: DataFrame) -> "OneVsRestModel":
36033604

36043605
# persist if underlying dataset is not persistent.
36053606
handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False)
3606-
if handlePersistence:
3607-
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
36083607

3609-
def _oneClassFitTasks(numClasses: int) -> List[Callable[[], Tuple[int, CM]]]:
3610-
indices = iter(range(numClasses))
3611-
3612-
def trainSingleClass() -> Tuple[int, CM]:
3613-
index = next(indices)
3614-
3615-
binaryLabelCol = "mc2b$" + str(index)
3616-
trainingDataset = multiclassLabeled.withColumn(
3617-
binaryLabelCol,
3618-
F.when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0),
3619-
)
3620-
paramMap = dict(
3621-
[
3622-
(classifier.labelCol, binaryLabelCol),
3623-
(classifier.featuresCol, featuresCol),
3624-
(classifier.predictionCol, predictionCol),
3625-
]
3626-
)
3627-
if weightCol:
3628-
paramMap[cast(HasWeightCol, classifier).weightCol] = weightCol
3629-
return index, classifier.fit(trainingDataset, paramMap)
3630-
3631-
return [trainSingleClass] * numClasses
3632-
3633-
tasks = map(
3634-
inheritable_thread_target(dataset.sparkSession),
3635-
_oneClassFitTasks(numClasses),
3636-
)
3637-
pool = ThreadPool(processes=min(self.getParallelism(), numClasses))
3638-
3639-
subModels = [None] * numClasses
3640-
for j, subModel in pool.imap_unordered(lambda f: f(), tasks):
3641-
assert subModels is not None
3642-
subModels[j] = subModel
3608+
with _cache_spark_dataset(
3609+
multiclassLabeled,
3610+
storageLevel=StorageLevel.MEMORY_AND_DISK,
3611+
enable=handlePersistence,
3612+
) as multiclassLabeled:
3613+
3614+
def _oneClassFitTasks(numClasses: int) -> List[Callable[[], Tuple[int, CM]]]:
3615+
indices = iter(range(numClasses))
3616+
3617+
def trainSingleClass() -> Tuple[int, CM]:
3618+
index = next(indices)
3619+
3620+
binaryLabelCol = "mc2b$" + str(index)
3621+
trainingDataset = multiclassLabeled.withColumn(
3622+
binaryLabelCol,
3623+
F.when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0),
3624+
)
3625+
paramMap = dict(
3626+
[
3627+
(classifier.labelCol, binaryLabelCol),
3628+
(classifier.featuresCol, featuresCol),
3629+
(classifier.predictionCol, predictionCol),
3630+
]
3631+
)
3632+
if weightCol:
3633+
paramMap[cast(HasWeightCol, classifier).weightCol] = weightCol
3634+
return index, classifier.fit(trainingDataset, paramMap)
3635+
3636+
return [trainSingleClass] * numClasses
3637+
3638+
tasks = map(
3639+
inheritable_thread_target(dataset.sparkSession),
3640+
_oneClassFitTasks(numClasses),
3641+
)
3642+
pool = ThreadPool(processes=min(self.getParallelism(), numClasses))
36433643

3644-
if handlePersistence:
3645-
multiclassLabeled.unpersist()
3644+
subModels = [None] * numClasses
3645+
for j, subModel in pool.imap_unordered(lambda f: f(), tasks):
3646+
assert subModels is not None
3647+
subModels[j] = subModel
36463648

36473649
return self._copyValues(OneVsRestModel(models=cast(List[ClassificationModel], subModels)))
36483650

@@ -3868,32 +3870,31 @@ def _transform(self, dataset: DataFrame) -> DataFrame:
38683870

38693871
# persist if underlying dataset is not persistent.
38703872
handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False)
3871-
if handlePersistence:
3872-
newDataset.persist(StorageLevel.MEMORY_AND_DISK)
3873-
3874-
# update the accumulator column with the result of prediction of models
3875-
aggregatedDataset = newDataset
3876-
for index, model in enumerate(self.models):
3877-
rawPredictionCol = self.getRawPredictionCol()
3878-
3879-
columns = origCols + [rawPredictionCol, accColName]
3880-
3881-
# add temporary column to store intermediate scores and update
3882-
tmpColName = "mbc$tmp" + str(uuid.uuid4())
3883-
transformedDataset = model.transform(aggregatedDataset).select(*columns)
3884-
updatedDataset = transformedDataset.withColumn(
3885-
tmpColName,
3886-
F.array_append(accColName, SF.vector_get(F.col(rawPredictionCol), F.lit(1))),
3887-
)
3888-
newColumns = origCols + [tmpColName]
3889-
3890-
# switch out the intermediate column with the accumulator column
3891-
aggregatedDataset = updatedDataset.select(*newColumns).withColumnRenamed(
3892-
tmpColName, accColName
3893-
)
3873+
with _cache_spark_dataset(
3874+
newDataset,
3875+
storageLevel=StorageLevel.MEMORY_AND_DISK,
3876+
enable=handlePersistence,
3877+
) as newDataset:
3878+
# update the accumulator column with the result of prediction of models
3879+
aggregatedDataset = newDataset
3880+
for index, model in enumerate(self.models):
3881+
rawPredictionCol = self.getRawPredictionCol()
3882+
3883+
columns = origCols + [rawPredictionCol, accColName]
3884+
3885+
# add temporary column to store intermediate scores and update
3886+
tmpColName = "mbc$tmp" + str(uuid.uuid4())
3887+
transformedDataset = model.transform(aggregatedDataset).select(*columns)
3888+
updatedDataset = transformedDataset.withColumn(
3889+
tmpColName,
3890+
F.array_append(accColName, SF.vector_get(F.col(rawPredictionCol), F.lit(1))),
3891+
)
3892+
newColumns = origCols + [tmpColName]
38943893

3895-
if handlePersistence:
3896-
newDataset.unpersist()
3894+
# switch out the intermediate column with the accumulator column
3895+
aggregatedDataset = updatedDataset.select(*newColumns).withColumnRenamed(
3896+
tmpColName, accColName
3897+
)
38973898

38983899
if self.getRawPredictionCol():
38993900
aggregatedDataset = aggregatedDataset.withColumn(

python/pyspark/ml/tests/test_algorithms.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17-
17+
import os
1818
from shutil import rmtree
1919
import tempfile
2020
import unittest
@@ -154,6 +154,28 @@ def test_support_for_weightCol(self):
154154
ovr2 = OneVsRest(classifier=dt, weightCol="weight")
155155
self.assertIsNotNone(ovr2.fit(df))
156156

157+
def test_tmp_dfs_cache(self):
158+
from pyspark.ml.util import _SPARKML_TEMP_DFS_PATH
159+
160+
with tempfile.TemporaryDirectory(prefix="ml_tmp_dir") as d:
161+
os.environ[_SPARKML_TEMP_DFS_PATH] = d
162+
try:
163+
df = self.spark.createDataFrame(
164+
[
165+
(0.0, Vectors.dense(1.0, 0.8)),
166+
(1.0, Vectors.sparse(2, [], [])),
167+
(2.0, Vectors.dense(0.5, 0.5)),
168+
],
169+
["label", "features"],
170+
)
171+
lr = LogisticRegression(maxIter=5, regParam=0.01)
172+
ovr = OneVsRest(classifier=lr, parallelism=1)
173+
model = ovr.fit(df)
174+
model.transform(df)
175+
assert len(os.listdir(d)) == 0
176+
finally:
177+
os.environ.pop(_SPARKML_TEMP_DFS_PATH, None)
178+
157179

158180
class KMeansTests(SparkSessionTestCase):
159181
def test_kmeans_cosine_distance(self):

python/pyspark/ml/tests/test_tuning.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,15 @@ def test_train_validation_split(self):
7373

7474
with tempfile.TemporaryDirectory(prefix="ml_tmp_dir") as d:
7575
os.environ[_SPARKML_TEMP_DFS_PATH] = d
76-
tvs_model2 = tvs.fit(dataset)
77-
assert len(os.listdir(d)) == 0
78-
self.assertTrue(np.isclose(tvs_model2.validationMetrics[0], 0.5, atol=1e-4))
79-
self.assertTrue(
80-
np.isclose(tvs_model2.validationMetrics[1], 0.8857142857142857, atol=1e-4)
81-
)
76+
try:
77+
tvs_model2 = tvs.fit(dataset)
78+
assert len(os.listdir(d)) == 0
79+
self.assertTrue(np.isclose(tvs_model2.validationMetrics[0], 0.5, atol=1e-4))
80+
self.assertTrue(
81+
np.isclose(tvs_model2.validationMetrics[1], 0.8857142857142857, atol=1e-4)
82+
)
83+
finally:
84+
os.environ.pop(_SPARKML_TEMP_DFS_PATH, None)
8285

8386
# save & load
8487
with tempfile.TemporaryDirectory(prefix="train_validation_split") as d:
@@ -131,9 +134,12 @@ def test_cross_validator(self):
131134

132135
with tempfile.TemporaryDirectory(prefix="ml_tmp_dir") as d:
133136
os.environ[_SPARKML_TEMP_DFS_PATH] = d
134-
model2 = cv.fit(dataset)
135-
assert len(os.listdir(d)) == 0
136-
self.assertTrue(np.isclose(model2.avgMetrics[0], 0.5, atol=1e-4))
137+
try:
138+
model2 = cv.fit(dataset)
139+
assert len(os.listdir(d)) == 0
140+
self.assertTrue(np.isclose(model2.avgMetrics[0], 0.5, atol=1e-4))
141+
finally:
142+
os.environ.pop(_SPARKML_TEMP_DFS_PATH, None)
137143

138144
output = model.transform(dataset)
139145
self.assertEqual(

python/pyspark/ml/util.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from pyspark.ml.common import inherit_doc
4242
from pyspark.sql import SparkSession
4343
from pyspark.sql.utils import is_remote
44+
from pyspark.storagelevel import StorageLevel
4445
from pyspark.util import VersionUtils
4546

4647
if TYPE_CHECKING:
@@ -1138,7 +1139,15 @@ def _remove_dfs_dir(path: str, spark_session: "SparkSession") -> None:
11381139

11391140

11401141
@contextmanager
1141-
def _cache_spark_dataset(dataset: "DataFrame") -> Iterator[Any]:
1142+
def _cache_spark_dataset(
1143+
dataset: "DataFrame",
1144+
storageLevel: "StorageLevel" = StorageLevel.MEMORY_AND_DISK_DESER,
1145+
enable: bool = True,
1146+
) -> Iterator[Any]:
1147+
if not enable:
1148+
yield dataset
1149+
return
1150+
11421151
spark_session = dataset._session
11431152
tmp_dfs_path = os.environ.get(_SPARKML_TEMP_DFS_PATH)
11441153

@@ -1150,7 +1159,7 @@ def _cache_spark_dataset(dataset: "DataFrame") -> Iterator[Any]:
11501159
finally:
11511160
_remove_dfs_dir(tmp_cache_path, spark_session)
11521161
else:
1153-
dataset.cache()
1162+
dataset.persist(storageLevel)
11541163
try:
11551164
yield dataset
11561165
finally:

0 commit comments

Comments
 (0)