From a71869472c974b6270d7360726f05e8f83e3cfc3 Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Wed, 2 Aug 2017 17:31:21 -0700 Subject: [PATCH 1/2] LogisticRegressionModel.toString should summarize model --- .../spark/ml/classification/LogisticRegression.scala | 4 ++++ .../spark/ml/classification/LogisticRegressionSuite.scala | 6 ++++++ python/pyspark/ml/classification.py | 7 +++++++ python/pyspark/mllib/classification.py | 3 +++ 4 files changed, 20 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 21957d94e2dc3..058bec6826a68 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1169,6 +1169,10 @@ class LogisticRegressionModel private[spark] ( */ @Since("1.6.0") override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) + + override def toString: String = { + s"${super.toString}, numClasses = $numClasses, numFeatures = $numFeatures" + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 0570499e74516..ad888ee3c2286 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -2594,6 +2594,12 @@ class LogisticRegressionSuite assert(model.getFamily === family) } } + + test("toString") { + val model = new LogisticRegressionModel("logReg", Vectors.dense(0.1, 0.2, 0.3), 0.0) + val expected = "logReg, numClasses = 2, numFeatures = 3" + assert(model.toString === expected) + } } object LogisticRegressionSuite { diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index bccf8e7f636f1..16976cd1eacce 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -237,6 +237,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti True >>> blorModel.intercept == model2.intercept True + >>> model2._resetUid("logReg") + uid = logReg, numClasses = 2, numFeatures = 2 .. versionadded:: 1.3.0 """ @@ -558,6 +560,11 @@ def evaluate(self, dataset): java_blr_summary = self._call_java("evaluate", dataset) return BinaryLogisticRegressionSummary(java_blr_summary) + def __repr__(self): + numClasses = str(self._call_java("numClasses")) + numFeatures = str(self._call_java("numFeatures")) + return "uid = %s, numClasses = %s, numFeatures = %s" % (self.uid, numClasses, numFeatures) + class LogisticRegressionSummary(JavaWrapper): """ diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index e04eeb2b60d71..b2e43afea6ef8 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -257,6 +257,9 @@ def load(cls, sc, path): model.setThreshold(threshold) return model + def __repr__(self): + return self._call_java("toString") + class LogisticRegressionWithSGD(object): """ From 189acfb7a32d545e74ef660a316084739f32f88d Mon Sep 17 00:00:00 2001 From: bravo-zhang Date: Fri, 22 Jun 2018 21:00:01 -0700 Subject: [PATCH 2/2] Pyspark repr by calling toString --- .../spark/ml/classification/LogisticRegression.scala | 3 ++- .../spark/ml/classification/LogisticRegressionSuite.scala | 2 +- python/pyspark/ml/classification.py | 8 +++----- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index e5e1a1e8e198b..92e342ed4a464 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1204,7 +1204,8 @@ class LogisticRegressionModel private[spark] ( override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) override def toString: String = { - s"${super.toString}, numClasses = $numClasses, numFeatures = $numFeatures" + s"LogisticRegressionModel: " + + s"uid = ${super.toString}, numClasses = $numClasses, numFeatures = $numFeatures" } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 8eca49b4c0b5b..75c2aeb146786 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -2754,7 +2754,7 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { test("toString") { val model = new LogisticRegressionModel("logReg", Vectors.dense(0.1, 0.2, 0.3), 0.0) - val expected = "logReg, numClasses = 2, numFeatures = 3" + val expected = "LogisticRegressionModel: uid = logReg, numClasses = 2, numFeatures = 3" assert(model.toString === expected) } } diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index e1a0d3617cf6b..d5963f4f7042c 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -239,8 +239,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti True >>> blorModel.intercept == model2.intercept True - >>> model2._resetUid("logReg") - uid = logReg, numClasses = 2, numFeatures = 2 + >>> model2 + LogisticRegressionModel: uid = ..., numClasses = 2, numFeatures = 2 .. versionadded:: 1.3.0 """ @@ -565,9 +565,7 @@ def evaluate(self, dataset): return BinaryLogisticRegressionSummary(java_blr_summary) def __repr__(self): - numClasses = str(self._call_java("numClasses")) - numFeatures = str(self._call_java("numFeatures")) - return "uid = %s, numClasses = %s, numFeatures = %s" % (self.uid, numClasses, numFeatures) + return self._call_java("toString") class LogisticRegressionSummary(JavaWrapper):