我已经能够创建一个允许我一次索引多个字符串列的管道,但是我对它们进行了编码,因为与索引不同,编码器不是估算器所以我根本不会根据OneHotEncoder示例调用文档.
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler, OneHotEncoder} import org.apache.spark.ml.Pipeline val data = sqlContext.read.parquet("s3n://map2-test/forecaster/intermediate_data") val df = data.select("win","bid_price","domain","size", "form_factor").na.drop() //indexing columns val stringColumns = Array("domain","size", "form_factor") val index_transformers: Array[org.apache.spark.ml.PipelineStage] = stringColumns.map( cname => new StringIndexer() .setInputCol(cname) .setOutputCol(s"${cname}_index") ) // Add the rest of your pipeline like VectorAssembler and algorithm val index_pipeline = new Pipeline().setStages(index_transformers) val index_model = index_pipeline.fit(df) val df_indexed = index_model.transform(df) //encoding columns val indexColumns = df_indexed.columns.filter(x => x contains "index") val one_hot_encoders: Array[org.apache.spark.ml.PipelineStage] = indexColumns.map( cname => new OneHotEncoder() .setInputCol(cname) .setOutputCol(s"${cname}_vec") ) val one_hot_pipeline = new Pipeline().setStages(one_hot_encoders) val df_encoded = one_hot_pipeline.transform(df_indexed)
OneHotEncoder对象没有fit方法,因此将它放在与索引器不同的管道中也不起作用 - 当我在管道上调用fit时会抛出错误.我也不能调用我用管道阶段数组生成的管道上的变换one_hot_encoders
.
我没有找到一个很好的解决方案,使用OneHotEncoder而不单独创建和调用转换为我想要编码的所有列转换自身
Spark> = 2.3
Spark 2.3引入了新类OneHotEncoderEstimator
,OneHotEncoder
即使在外部使用也需要进行拟合OneHotEncoderEstimator
,并且同时在多个列上运行.
import org.apache.spark.ml.feature.{OneHotEncoder, OneHotEncoderModel} val encoder = new OneHotEncoder() .setInputCols(indexColumns) .setOutputCols(indexColumns map (name => s"${name}_vec"))
Spark <2.3
即使您使用的变压器不需要拟合,您也必须使用OneHotEncoderModel
方法来创建Pipeline
可用于转换数据的方法.
import org.apache.spark.ml.feature.{OneHotEncoderEstimator, OneHotEncoderModel} val encoder = new OneHotEncoderEstimator() .setInputCols(indexColumns) .setOutputCols(indexColumns map (name => s"${name}_vec")) encoder.fit(df_indexed).transform(df_indexed)
另外,您可以将索引和编码组合成一个fit
:
one_hot_pipeline.fit(df_indexed).transform(df_indexed)
编辑:
您看到错误表示您的某列包含空PipelineModel
.它被索引器接受但不能用于编码.根据您的要求,您可以删除它们或使用虚拟标签.不幸的是,Pipeline
直到SPARK-11569)才能解决.