I am trying to find approximate nearest neighbors for a categorical dataset.
For this, I am using MinHashLSH
model present in Spark.
My dataset has categorical data. So I am using StringIndexer
followed by OneHotEncoderEstimator
followed by VectorAssembler
to convert the categorical values into continuous values.
Now I want to find nearest neighbors for a given key from my dataset and this key should be in Vector form. I am unable to find a way to convert a categorical key into a continuous vector.
List<Row> dataA = Arrays.asList(RowFactory.create(0, "apple"),
RowFactory.create(1, "banana"),
RowFactory.create(2, "coconut"));
StructType schema = new StructType(
new StructField[] { new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
new StructField("fruits", DataTypes.StringType, false, Metadata.empty()) });
Dataset<Row> dfA = spark.createDataFrame(dataA, schema);
StringIndexer stringIndexer = new StringIndexer().setInputCol("fruits").setOutputCol("fruitIndex").setHandleInvalid("keep");
OneHotEncoderEstimator encoder = new OneHotEncoderEstimator().setInputCols(new String[]{"fruitIndex"}).setOutputCols(new String[]{"fruitVec"});
String[] featuredCols = new String[] {"fruitIndex","fruitVec"};
VectorAssembler assembler = new VectorAssembler().setInputCols(featuredCols).setOutputCol("features");
Pipeline sovPipeline = new Pipeline().setStages(new PipelineStage[]{stringIndexer, encoder, assembler});
// Feature Transformation
PipelineModel plModel = sovPipeline.fit(dfA);
Dataset<Row> dfT = plModel.transform(dfA);
MinHashLSH mh = new MinHashLSH().setNumHashTables(5).setInputCol("features").setOutputCol("hashes");
MinHashLSHModel model = mh.fit(dfT);
// model.approxNearestNeighbors(dfT, key, 2).show();
How can I create the key
(numerical continuous vector) for approxNearestNeighbors
method from a categorical key?
The Vector
you use should be transformed using the same methods as the training data. Since Pipeline
model cannot work on single item, the quickest solution is to use a single item Dataset
:
import org.apache.spark.ml.linalg.Vector;
Vector key = plModel.transform(spark.createDataFrame(Arrays.asList(
RowFactory.create(0, "some key")), schema
)).first().getAs("features");