Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Spark ML indexer cannot resolve DataFrame column name with dots?

I have a DataFrame with a column named a.b. When I specify a.b as the input column name to a StringIndexer, AnalysisException with the message "cannot resolve 'a.b' given input columns a.b". I'm using Spark 1.6.0.

I'm aware that older versions of Spark may have had issues with dots in column names, but that in more recent version, backquotes can be used around column names in the Spark shell and with SQL queries. For instance, that's the resolution to another question, How to escape column names with hyphen in Spark SQL. Some of these issues were reported SPARK-6898, Special chars in column names is broken, but that was resolved back in 1.4.0.

Here's a minimal example and stacktrace:

public class SparkMLDotColumn {
    public static void main(String[] args) {
        // Get the contexts
        SparkConf conf = new SparkConf()
                .setMaster("local[*]")
                .setAppName("test")
                .set("spark.ui.enabled", "false"); // http://permalink.gmane.org/gmane.comp.lang.scala.spark.user/21385
        JavaSparkContext sparkContext = new JavaSparkContext(conf);
        SQLContext sqlContext = new SQLContext(sparkContext);

        // Create a schema with a single string column named "a.b"
        StructType schema = new StructType(new StructField[] {
                DataTypes.createStructField("a.b", DataTypes.StringType, false)
        });

        // Create an empty RDD and DataFrame
        JavaRDD<Row> rdd = sparkContext.parallelize(Collections.emptyList());
        DataFrame df = sqlContext.createDataFrame(rdd, schema);

        StringIndexer indexer = new StringIndexer()
            .setInputCol("a.b")
            .setOutputCol("a.b_index");
        df = indexer.fit(df).transform(df);
    }
}

Now, it's worth trying the same kind of example using backquoted column names, because we get some weird results. Here's an example with the same schema, but we've got data in the frame this time. Before attempting any indexing, we'll copy the column named a.b to a column named a_b. That requires the use of backticks, and it works without a problem. Then, we'll try indexing the a_b column, which works without a problem. Then something really weird happens when we try to indexing the a.b column, using backticks. We get no error, but get no result, either:

public class SparkMLDotColumn {
    public static void main(String[] args) {
        // Get the contexts
        SparkConf conf = new SparkConf()
                .setMaster("local[*]")
                .setAppName("test")
                .set("spark.ui.enabled", "false");
        JavaSparkContext sparkContext = new JavaSparkContext(conf);
        SQLContext sqlContext = new SQLContext(sparkContext);

        // Create a schema with a single string column named "a.b"
        StructType schema = new StructType(new StructField[] {
                DataTypes.createStructField("a.b", DataTypes.StringType, false)
        });

        // Create an empty RDD and DataFrame
        List<Row> rows = Arrays.asList(RowFactory.create("foo"), RowFactory.create("bar")); 
        JavaRDD<Row> rdd = sparkContext.parallelize(rows);
        DataFrame df = sqlContext.createDataFrame(rdd, schema);

        df = df.withColumn("a_b", df.col("`a.b`"));

        StringIndexer indexer0 = new StringIndexer();
        indexer0.setInputCol("a_b");
        indexer0.setOutputCol("a_bIndex");
        df = indexer0.fit(df).transform(df);

        StringIndexer indexer1 = new StringIndexer();
        indexer1.setInputCol("`a.b`");
        indexer1.setOutputCol("abIndex");
        df = indexer1.fit(df).transform(df);

        df.show();
    }
}
+---+---+--------+
|a.b|a_b|a_bIndex|  // where's the abIndex column?
+---+---+--------+
|foo|foo|     0.0|
|bar|bar|     1.0|
+---+---+--------+

Stacktrace from first example

Exception in thread "main" org.apache.spark.sql.AnalysisException: cannot resolve 'a.b' given input columns a.b;
    at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:60)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:57)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:319)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:319)
    at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:53)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:318)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:316)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:316)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:265)
    at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
    at scala.collection.Iterator$class.foreach(Iterator.scala:727)
    at scala.collection.AbstractIterator.foreach(Iterator.scala:1157)
    at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48)
    at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103)
    at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47)
    at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273)
    at scala.collection.AbstractIterator.to(Iterator.scala:1157)
    at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265)
    at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157)
    at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252)
    at scala.collection.AbstractIterator.toArray(Iterator.scala:1157)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:305)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:316)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:316)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:316)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:265)
    at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
    at scala.collection.Iterator$class.foreach(Iterator.scala:727)
    at scala.collection.AbstractIterator.foreach(Iterator.scala:1157)
    at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48)
    at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103)
    at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47)
    at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273)
    at scala.collection.AbstractIterator.to(Iterator.scala:1157)
    at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265)
    at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157)
    at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252)
    at scala.collection.AbstractIterator.toArray(Iterator.scala:1157)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:305)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:316)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressionUp$1(QueryPlan.scala:107)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$recursiveTransform$2(QueryPlan.scala:117)
    at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$org$apache$spark$sql$catalyst$plans$QueryPlan$$recursiveTransform$2$1.apply(QueryPlan.scala:121)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
    at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
    at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47)
    at scala.collection.TraversableLike$class.map(TraversableLike.scala:244)
    at scala.collection.AbstractTraversable.map(Traversable.scala:105)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$recursiveTransform$2(QueryPlan.scala:121)
    at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2.apply(QueryPlan.scala:125)
    at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
    at scala.collection.Iterator$class.foreach(Iterator.scala:727)
    at scala.collection.AbstractIterator.foreach(Iterator.scala:1157)
    at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48)
    at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103)
    at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47)
    at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273)
    at scala.collection.AbstractIterator.to(Iterator.scala:1157)
    at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265)
    at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157)
    at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252)
    at scala.collection.AbstractIterator.toArray(Iterator.scala:1157)
    at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressionsUp(QueryPlan.scala:125)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:57)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:50)
    at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:105)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.checkAnalysis(CheckAnalysis.scala:50)
    at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis(Analyzer.scala:44)
    at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:34)
    at org.apache.spark.sql.DataFrame.<init>(DataFrame.scala:133)
    at org.apache.spark.sql.DataFrame.org$apache$spark$sql$DataFrame$$withPlan(DataFrame.scala:2165)
    at org.apache.spark.sql.DataFrame.select(DataFrame.scala:751)
    at org.apache.spark.ml.feature.StringIndexer.fit(StringIndexer.scala:84)
    at SparkMLDotColumn.main(SparkMLDotColumn.java:38)
like image 697
Joshua Taylor Avatar asked Jan 22 '16 18:01

Joshua Taylor


1 Answers

I experienced the same issue on Spark 2.1. I ended up creating a function that "validifies" (TM) all columnnames by replacing all dots. Scala implementation:

def validifyColumnnames[T](df : Dataset[T], spark : SparkSession) : DataFrame = {
   val newColumnNames = ArrayBuffer[String]()
   for(oldCol <- df.columns) {
      newColumnNames +=  oldCol.replaceAll("\\.","") // append
   }
   val newColumnNamesB = spark.sparkContext.broadcast(newColumnNames.toArray)
   df.toDF(newColumnNamesB.value : _*)
}

Sorry, that this is probably not the answer you were hoping for, but this was too long for a comment.

like image 155
Boern Avatar answered Sep 19 '22 12:09

Boern