Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to interpret probability column in spark logistic regression prediction?

I'm getting predictions through spark.ml.classification.LogisticRegressionModel.predict. A number of the rows have the prediction column as 1.0 and probability column as .04. The model.getThreshold is 0.5 so I'd assume the model is classifying everything over a 0.5 probability threshold as 1.0.

How am I supposed to interpret a result with a 1.0 prediction and a probability of 0.04?

like image 343
Walrus the Cat Avatar asked May 18 '17 23:05

Walrus the Cat


1 Answers

The probability column from performing a LogisticRegression should contain a list with the same length as the number of classes, where each index gives the corresponding probability for that class. I made a small example with two classes for illustration:

case class Person(label: Double, age: Double, height: Double, weight: Double)
val df = List(Person(0.0, 15, 175, 67), 
      Person(0.0, 30, 190, 100), 
      Person(1.0, 40, 155, 57), 
      Person(1.0, 50, 160, 56), 
      Person(0.0, 15, 170, 56), 
      Person(1.0, 80, 180, 88)).toDF()

val assembler = new VectorAssembler().setInputCols(Array("age", "height", "weight"))
  .setOutputCol("features")
  .select("label", "features")
val df2 = assembler.transform(df)
df2.show

+-----+------------------+
|label|          features|
+-----+------------------+
|  0.0| [15.0,175.0,67.0]|
|  0.0|[30.0,190.0,100.0]|
|  1.0| [40.0,155.0,57.0]|
|  1.0| [50.0,160.0,56.0]|
|  0.0| [15.0,170.0,56.0]|
|  1.0| [80.0,180.0,88.0]|
+-----+------------------+

val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)
val Array(testing, training) = df2.randomSplit(Array(0.7, 0.3))

val model = lr.fit(training)
val predictions = model.transform(testing)
predictions.select("probability", "prediction").show(false)


+----------------------------------------+----------+
|probability                             |prediction|
+----------------------------------------+----------+
|[0.7487950501224138,0.2512049498775863] |0.0       |
|[0.6458452667523259,0.35415473324767416]|0.0       |
|[0.3888393314864866,0.6111606685135134] |1.0       |
+----------------------------------------+----------+

Here are the probabilities as well as the final prediction made by the algorithm. The class that have the highest probability in the end is the one predicted.

like image 129
Shaido Avatar answered Sep 22 '22 16:09

Shaido