package com.microsoft.ml.spark.nn;

import com.microsoft.ml.spark.core.contracts.HasAdditionalPythonMethods;
import com.microsoft.ml.spark.core.contracts.HasFeaturesCol;
import com.microsoft.ml.spark.core.contracts.HasLabelCol;
import com.microsoft.ml.spark.core.contracts.HasOutputCol;
import com.microsoft.ml.spark.nn.ConditionalKNNParams;
import com.microsoft.ml.spark.nn.KNNParams;
import java.io.IOException;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.ComplexParamsWritable;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.param.ConditionalBallTreeParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.ArrayType$;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StructType;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;

/* compiled from: ConditionalKNN.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\u0005g\u0001B\u0001\u0003\u00015\u00111cQ8oI&$\u0018n\u001c8bY.se*T8eK2T!a\u0001\u0003\u0002\u00059t'BA\u0003\u0007\u0003\u0015\u0019\b/\u0019:l\u0015\t9\u0001\"\u0001\u0002nY*\u0011\u0011BC\u0001\n[&\u001c'o\\:pMRT\u0011aC\u0001\u0004G>l7\u0001A\n\u0005\u00019QR\u0004E\u0002\u0010-ai\u0011\u0001\u0005\u0006\u0003\u000fEQ!!\u0002\n\u000b\u0005M!\u0012AB1qC\u000eDWMC\u0001\u0016\u0003\ry'oZ\u0005\u0003/A\u0011Q!T8eK2\u0004\"!\u0007\u0001\u000e\u0003\t\u0001\"aD\u000e\n\u0005q\u0001\"!F\"p[BdW\r\u001f)be\u0006l7o\u0016:ji\u0006\u0014G.\u001a\t\u00033yI!a\b\u0002\u0003)\r{g\u000eZ5uS>t\u0017\r\\&O\u001dB\u000b'/Y7t\u0011!\t\u0003A!b\u0001\n\u0003\u0011\u0013aA;jIV\t1\u0005\u0005\u0002%U9\u0011Q\u0005K\u0007\u0002M)\tq%A\u0003tG\u0006d\u0017-\u0003\u0002*M\u00051\u0001K]3eK\u001aL!a\u000b\u0017\u0003\rM#(/\u001b8h\u0015\tIc\u0005\u0003\u0005/\u0001\t\u0005\t\u0015!\u0003$\u0003\u0011)\u0018\u000e\u001a\u0011\t\u000bA\u0002A\u0011A\u0019\u0002\rqJg.\u001b;?)\tA\"\u0007C\u0003\"_\u0001\u00071\u0005C\u00031\u0001\u0011\u0005A\u0007F\u0001\u0019\u0011\u001d1\u0004\u00011A\u0005\n]\naC\u0019:pC\u0012\u001c\u0017m\u001d;fI6{G-\u001a7PaRLwN\\\u000b\u0002qA\u0019Q%O\u001e\n\u0005i2#AB(qi&|g\u000eE\u0002=\u007f\u0005k\u0011!\u0010\u0006\u0003}E\t\u0011B\u0019:pC\u0012\u001c\u0017m\u001d;\n\u0005\u0001k$!\u0003\"s_\u0006$7-Y:ua\r\u0011u\t\u0018\t\u00053\r+5,\u0003\u0002E\u0005\t\u00192i\u001c8eSRLwN\\1m\u0005\u0006dG\u000e\u0016:fKB\u0011ai\u0012\u0007\u0001\t%A\u0015*!A\u0001\u0002\u000b\u0005\u0011KA\u0002`IIBaA\u0013\u0001!B\u0013Y\u0015a\u00062s_\u0006$7-Y:uK\u0012lu\u000eZ3m\u001fB$\u0018n\u001c8!!\r)\u0013\b\u0014\t\u0004y}j\u0005g\u0001(Q3B!\u0011dQ(Y!\t1\u0005\u000bB\u0005I\u0013\u0006\u0005\t\u0011!B\u0001#F\u0011!+\u0016\t\u0003KMK!\u0001\u0016\u0014\u0003\u000f9{G\u000f[5oOB\u0011QEV\u0005\u0003/\u001a\u00121!\u00118z!\t1\u0015\fB\u0005[\u0013\u0006\u0005\t\u0011!B\u0001#\n\u0019q\fJ\u001a\u0011\u0005\u0019cF!\u0003.J\u0003\u0003\u0005\tQ!\u0001R\u0011\u001dq\u0006\u00011A\u0005\n}\u000b!D\u0019:pC\u0012\u001c\u0017m\u001d;fI6{G-\u001a7PaRLwN\\0%KF$\"\u0001Y2\u0011\u0005\u0015\n\u0017B\u00012'\u0005\u0011)f.\u001b;\t\u000f\u0011l\u0016\u0011!a\u0001K\u0006\u0019\u0001\u0010J\u0019\u0011\u0007\u0015Jd\rE\u0002=\u007f\u001d\u00044\u0001\u001b6m!\u0011I2)[6\u0011\u0005\u0019SG!\u0003%J\u0003\u0003\u0005\tQ!\u0001R!\t1E\u000eB\u0005[\u0013\u0006\u0005\t\u0011!B\u0001#\"9a\u000e\u0001b\u0001\n\u0003y\u0017\u0001\u00032bY2$&/Z3\u0016\u0003A\u0004\"!\u001d;\u000e\u0003IT!a\u001d\t\u0002\u000bA\f'/Y7\n\u0005U\u0014(\u0001G\"p]\u0012LG/[8oC2\u0014\u0015\r\u001c7Ue\u0016,\u0007+\u0019:b[\"1q\u000f\u0001Q\u0001\nA\f\u0011BY1mYR\u0013X-\u001a\u0011\t\u000be\u0004A\u0011\u0001>\u0002\u0017\u001d,GOQ1mYR\u0013X-Z\u000b\u0002wB\"AP`A\u0002!\u0015I2)`A\u0001!\t1e\u0010B\u0005��q\u0006\u0005\t\u0011!B\u0001#\n\u0019q\f\n\u001b\u0011\u0007\u0019\u000b\u0019\u0001\u0002\u0006\u0002\u0006a\f\t\u0011!A\u0003\u0002E\u00131a\u0018\u00136\u0011\u001d\tI\u0001\u0001C\u0001\u0003\u0017\t1b]3u\u0005\u0006dG\u000e\u0016:fKR!\u0011QBA\b\u001b\u0005\u0001\u0001\u0002CA\t\u0003\u000f\u0001\r!a\u0005\u0002\u0003Y\u0004d!!\u0006\u0002\u001a\u0005}\u0001CB\rD\u0003/\ti\u0002E\u0002G\u00033!1\"a\u0007\u0002\u0010\u0005\u0005\t\u0011!B\u0001#\n\u0019q\f\n\u001c\u0011\u0007\u0019\u000by\u0002B\u0006\u0002\"\u0005=\u0011\u0011!A\u0001\u0006\u0003\t&aA0%o!9\u0011Q\u0005\u0001\u0005B\u0005\u001d\u0012\u0001B2paf$2\u0001GA\u0015\u0011!\tY#a\tA\u0002\u00055\u0012!B3yiJ\f\u0007cA9\u00020%\u0019\u0011\u0011\u0007:\u0003\u0011A\u000b'/Y7NCBDq!!\u000e\u0001\t\u0003\n9$A\u0005ue\u0006t7OZ8s[R!\u0011\u0011HA1!\u0011\tY$a\u0017\u000f\t\u0005u\u0012Q\u000b\b\u0005\u0003\u007f\t\tF\u0004\u0003\u0002B\u0005=c\u0002BA\"\u0003\u001brA!!\u0012\u0002L5\u0011\u0011q\t\u0006\u0004\u0003\u0013b\u0011A\u0002\u001fs_>$h(C\u0001\u0016\u0013\t\u0019B#\u0003\u0002\u0006%%\u0019\u00111K\t\u0002\u0007M\fH.\u0003\u0003\u0002X\u0005e\u0013a\u00029bG.\fw-\u001a\u0006\u0004\u0003'\n\u0012\u0002BA/\u0003?\u0012\u0011\u0002R1uC\u001a\u0013\u0018-\\3\u000b\t\u0005]\u0013\u0011\f\u0005\t\u0003G\n\u0019\u00041\u0001\u0002f\u00059A-\u0019;bg\u0016$\b\u0007BA4\u0003c\u0002b!!\u001b\u0002l\u0005=TBAA-\u0013\u0011\ti'!\u0017\u0003\u000f\u0011\u000bG/Y:fiB\u0019a)!\u001d\u0005\u0017\u0005M\u0014\u0011MA\u0001\u0002\u0003\u0015\t!\u0015\u0002\u0004?\u0012B\u0004bBA<\u0001\u0011\u0005\u0013\u0011P\u0001\u0010iJ\fgn\u001d4pe6\u001c6\r[3nCR!\u00111PAD!\u0011\ti(a!\u000e\u0005\u0005}$\u0002BAA\u00033\nQ\u0001^=qKNLA!!\"\u0002��\tQ1\u000b\u001e:vGR$\u0016\u0010]3\t\u0011\u0005%\u0015Q\u000fa\u0001\u0003w\naa]2iK6\fwaBAG\u0005!\u0005\u0011qR\u0001\u0014\u0007>tG-\u001b;j_:\fGn\u0013(O\u001b>$W\r\u001c\t\u00043\u0005EeAB\u0001\u0003\u0011\u0003\t\u0019j\u0005\u0005\u0002\u0012\u0006U\u00151TAQ!\r)\u0013qS\u0005\u0004\u000333#AB!osJ+g\r\u0005\u0003\u0010\u0003;C\u0012bAAP!\t)2i\\7qY\u0016D\b+\u0019:b[N\u0014V-\u00193bE2,\u0007cA\u0013\u0002$&\u0019\u0011Q\u0015\u0014\u0003\u0019M+'/[1mSj\f'\r\\3\t\u000fA\n\t\n\"\u0001\u0002*R\u0011\u0011q\u0012\u0005\u000b\u0003[\u000b\t*!A\u0005\n\u0005=\u0016a\u0003:fC\u0012\u0014Vm]8mm\u0016$\"!!-\u0011\t\u0005M\u0016QX\u0007\u0003\u0003kSA!a.\u0002:\u0006!A.\u00198h\u0015\t\tY,\u0001\u0003kCZ\f\u0017\u0002BA`\u0003k\u0013aa\u00142kK\u000e$\b")
/* loaded from: input_file:com/microsoft/ml/spark/nn/ConditionalKNNModel.class */
public class ConditionalKNNModel extends Model<ConditionalKNNModel> implements ComplexParamsWritable, ConditionalKNNParams {
    private final String uid;
    private Option<Broadcast<ConditionalBallTree<?, ?>>> broadcastedModelOption;
    private final ConditionalBallTreeParam ballTree;
    private final Param<String> conditionerCol;
    private final Param<String> labelCol;
    private final Param<String> valuesCol;
    private final IntParam leafSize;
    private final IntParam k;
    private final Param<String> outputCol;
    private final Param<String> featuresCol;

    public static Object load(String str) {
        return ConditionalKNNModel$.MODULE$.load(str);
    }

    public static MLReader<ConditionalKNNModel> read() {
        return ConditionalKNNModel$.MODULE$.read();
    }

    @Override // com.microsoft.ml.spark.nn.ConditionalKNNParams
    public Param<String> conditionerCol() {
        return this.conditionerCol;
    }

    @Override // com.microsoft.ml.spark.nn.ConditionalKNNParams
    public void com$microsoft$ml$spark$nn$ConditionalKNNParams$_setter_$conditionerCol_$eq(Param param) {
        this.conditionerCol = param;
    }

    @Override // com.microsoft.ml.spark.nn.ConditionalKNNParams
    public String getConditionerCol() {
        return ConditionalKNNParams.Cclass.getConditionerCol(this);
    }

    @Override // com.microsoft.ml.spark.nn.ConditionalKNNParams
    public ConditionalKNNParams setConditionerCol(String str) {
        return ConditionalKNNParams.Cclass.setConditionerCol(this, str);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasLabelCol
    public Param<String> labelCol() {
        return this.labelCol;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasLabelCol
    public void com$microsoft$ml$spark$core$contracts$HasLabelCol$_setter_$labelCol_$eq(Param param) {
        this.labelCol = param;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasLabelCol
    public HasLabelCol setLabelCol(String str) {
        return HasLabelCol.Cclass.setLabelCol(this, str);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasLabelCol
    public String getLabelCol() {
        return HasLabelCol.Cclass.getLabelCol(this);
    }

    @Override // com.microsoft.ml.spark.nn.KNNParams
    public Param<String> valuesCol() {
        return this.valuesCol;
    }

    @Override // com.microsoft.ml.spark.nn.KNNParams
    public IntParam leafSize() {
        return this.leafSize;
    }

    @Override // com.microsoft.ml.spark.nn.KNNParams
    public IntParam k() {
        return this.k;
    }

    @Override // com.microsoft.ml.spark.nn.KNNParams
    public void com$microsoft$ml$spark$nn$KNNParams$_setter_$valuesCol_$eq(Param param) {
        this.valuesCol = param;
    }

    @Override // com.microsoft.ml.spark.nn.KNNParams
    public void com$microsoft$ml$spark$nn$KNNParams$_setter_$leafSize_$eq(IntParam intParam) {
        this.leafSize = intParam;
    }

    @Override // com.microsoft.ml.spark.nn.KNNParams
    public void com$microsoft$ml$spark$nn$KNNParams$_setter_$k_$eq(IntParam intParam) {
        this.k = intParam;
    }

    @Override // com.microsoft.ml.spark.nn.KNNParams
    public String getValuesCol() {
        return KNNParams.Cclass.getValuesCol(this);
    }

    @Override // com.microsoft.ml.spark.nn.KNNParams
    public KNNParams setValuesCol(String str) {
        return KNNParams.Cclass.setValuesCol(this, str);
    }

    @Override // com.microsoft.ml.spark.nn.KNNParams
    public int getLeafSize() {
        return KNNParams.Cclass.getLeafSize(this);
    }

    @Override // com.microsoft.ml.spark.nn.KNNParams
    public KNNParams setLeafSize(int i) {
        return KNNParams.Cclass.setLeafSize(this, i);
    }

    @Override // com.microsoft.ml.spark.nn.KNNParams
    public int getK() {
        return KNNParams.Cclass.getK(this);
    }

    @Override // com.microsoft.ml.spark.nn.KNNParams
    public KNNParams setK(int i) {
        return KNNParams.Cclass.setK(this, i);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasOutputCol
    public Param<String> outputCol() {
        return this.outputCol;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasOutputCol
    public void com$microsoft$ml$spark$core$contracts$HasOutputCol$_setter_$outputCol_$eq(Param param) {
        this.outputCol = param;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasOutputCol
    public HasOutputCol setOutputCol(String str) {
        return HasOutputCol.Cclass.setOutputCol(this, str);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasOutputCol
    public String getOutputCol() {
        return HasOutputCol.Cclass.getOutputCol(this);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasAdditionalPythonMethods
    public String additionalPythonMethods() {
        return HasAdditionalPythonMethods.Cclass.additionalPythonMethods(this);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasFeaturesCol
    public Param<String> featuresCol() {
        return this.featuresCol;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasFeaturesCol
    public void com$microsoft$ml$spark$core$contracts$HasFeaturesCol$_setter_$featuresCol_$eq(Param param) {
        this.featuresCol = param;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasFeaturesCol
    public HasFeaturesCol setFeaturesCol(String str) {
        return HasFeaturesCol.Cclass.setFeaturesCol(this, str);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasFeaturesCol
    public String getFeaturesCol() {
        return HasFeaturesCol.Cclass.getFeaturesCol(this);
    }

    @Override // org.apache.spark.ml.ComplexParamsWritable
    public MLWriter write() {
        return ComplexParamsWritable.Cclass.write(this);
    }

    public void save(String str) throws IOException {
        MLWritable.class.save(this, str);
    }

    public String uid() {
        return this.uid;
    }

    private Option<Broadcast<ConditionalBallTree<?, ?>>> broadcastedModelOption() {
        return this.broadcastedModelOption;
    }

    private void broadcastedModelOption_$eq(Option<Broadcast<ConditionalBallTree<?, ?>>> option) {
        this.broadcastedModelOption = option;
    }

    public ConditionalBallTreeParam ballTree() {
        return this.ballTree;
    }

    public ConditionalBallTree<?, ?> getBallTree() {
        return (ConditionalBallTree) $(ballTree());
    }

    public ConditionalKNNModel setBallTree(ConditionalBallTree<?, ?> conditionalBallTree) {
        broadcastedModelOption().foreach(new ConditionalKNNModel$$anonfun$setBallTree$1(this));
        broadcastedModelOption_$eq(None$.MODULE$);
        return (ConditionalKNNModel) set(ballTree(), conditionalBallTree);
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public ConditionalKNNModel m670copy(ParamMap paramMap) {
        return (ConditionalKNNModel) defaultCopy(paramMap);
    }

    public Dataset<Row> transform(Dataset<?> dataset) {
        if (broadcastedModelOption().isEmpty()) {
            broadcastedModelOption_$eq(new Some(dataset.sparkSession().sparkContext().broadcast(getBallTree(), ClassTag$.MODULE$.apply(ConditionalBallTree.class))));
        }
        return dataset.toDF().withColumn(getOutputCol(), functions$.MODULE$.udf(new ConditionalKNNModel$$anonfun$2(this, (Broadcast) broadcastedModelOption().get(), getK()), ArrayType$.MODULE$.apply(new StructType().add("value", dataset.schema().apply(getValuesCol()).dataType()).add("distance", DoubleType$.MODULE$).add("label", dataset.schema().apply(getLabelCol()).dataType()))).apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(getFeaturesCol()), functions$.MODULE$.col(getConditionerCol())})));
    }

    public StructType transformSchema(StructType structType) {
        return structType.add(getOutputCol(), ArrayType$.MODULE$.apply(new StructType().add("value", structType.apply(getValuesCol()).dataType()).add("distance", DoubleType$.MODULE$).add("label", structType.apply(getLabelCol()).dataType())));
    }

    public ConditionalKNNModel(String str) {
        this.uid = str;
        MLWritable.class.$init$(this);
        ComplexParamsWritable.Cclass.$init$(this);
        com$microsoft$ml$spark$core$contracts$HasFeaturesCol$_setter_$featuresCol_$eq(new Param(this, "featuresCol", "The name of the features column"));
        HasAdditionalPythonMethods.Cclass.$init$(this);
        com$microsoft$ml$spark$core$contracts$HasOutputCol$_setter_$outputCol_$eq(new Param(this, "outputCol", "The name of the output column"));
        KNNParams.Cclass.$init$(this);
        com$microsoft$ml$spark$core$contracts$HasLabelCol$_setter_$labelCol_$eq(new Param(this, "labelCol", "The name of the label column"));
        com$microsoft$ml$spark$nn$ConditionalKNNParams$_setter_$conditionerCol_$eq(new Param(this, "conditionerCol", "column holding identifiers for features that will be returned when queried"));
        this.broadcastedModelOption = None$.MODULE$;
        this.ballTree = new ConditionalBallTreeParam(this, "ballTree", "the ballTree model used for perfoming queries", new ConditionalKNNModel$$anonfun$1(this));
    }

    public ConditionalKNNModel() {
        this(Identifiable$.MODULE$.randomUID("ConditionalKNNModel"));
    }
}
