diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 933a5f1d52ed9..ae336982092d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -745,12 +745,12 @@ class DistributedLDAModel private[clustering] ( val N_wk = vertex._2 val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0) val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k - (eta - 1.0) * sum(phi_wk.map(math.log)) + sumPrior + (eta - 1.0) * sum(phi_wk.map(math.log)) } else { val N_kj = vertex._2 val smoothed_N_kj: TopicCounts = N_kj + (alpha - 1.0) val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) - (alpha - 1.0) * sum(theta_kj.map(math.log)) + sumPrior + (alpha - 1.0) * sum(theta_kj.map(math.log)) } } graph.vertices.aggregate(0.0)(seqOp, _ + _) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 3f39deddf20b4..9aa11fbdbe868 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -260,6 +260,14 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) assert(Vectors.dense(model.getDocConcentration) ~== Vectors.dense(model2.getDocConcentration) absTol 1e-6) + val logPrior = model.asInstanceOf[DistributedLDAModel].logPrior + val logPrior2 = model2.asInstanceOf[DistributedLDAModel].logPrior + val trainingLogLikelihood = + model.asInstanceOf[DistributedLDAModel].trainingLogLikelihood + val trainingLogLikelihood2 = + model2.asInstanceOf[DistributedLDAModel].trainingLogLikelihood + assert(logPrior ~== logPrior2 absTol 1e-6) + assert(trainingLogLikelihood ~== trainingLogLikelihood2 absTol 1e-6) } val lda = new LDA() testEstimatorAndModelReadWrite(lda, dataset,