package com.microsoft.ml.spark.lime;

import com.microsoft.ml.spark.FluentAPI$;
import com.microsoft.ml.spark.core.contracts.HasAdditionalPythonMethods;
import com.microsoft.ml.spark.core.contracts.HasInputCol;
import com.microsoft.ml.spark.core.contracts.HasOutputCol;
import com.microsoft.ml.spark.core.contracts.Wrappable;
import com.microsoft.ml.spark.core.schema.DatasetExtensions$;
import com.microsoft.ml.spark.lime.LIMEBase;
import com.microsoft.ml.spark.lime.LIMEParams;
import java.io.IOException;
import org.apache.spark.ml.ComplexParamsWritable;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.linalg.DenseMatrix;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.SQLDataTypes$;
import org.apache.spark.ml.param.DoubleArrayParam;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.TransformerParam;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.StructType;
import scala.Predef$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: LIME.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\u0015v!B\u0001\u0003\u0011\u0003i\u0011\u0001\u0005+bEVd\u0017M\u001d'J\u001b\u0016ku\u000eZ3m\u0015\t\u0019A!\u0001\u0003mS6,'BA\u0003\u0007\u0003\u0015\u0019\b/\u0019:l\u0015\t9\u0001\"\u0001\u0002nY*\u0011\u0011BC\u0001\n[&\u001c'o\\:pMRT\u0011aC\u0001\u0004G>l7\u0001\u0001\t\u0003\u001d=i\u0011A\u0001\u0004\u0006!\tA\t!\u0005\u0002\u0011)\u0006\u0014W\u000f\\1s\u0019&kU)T8eK2\u001cRa\u0004\n\u0019\u0003\u000b\u0003\"a\u0005\f\u000e\u0003QQ\u0011!F\u0001\u0006g\u000e\fG.Y\u0005\u0003/Q\u0011a!\u00118z%\u00164\u0007cA\r!E5\t!D\u0003\u0002\b7)\u0011Q\u0001\b\u0006\u0003;y\ta!\u00199bG\",'\"A\u0010\u0002\u0007=\u0014x-\u0003\u0002\"5\t)2i\\7qY\u0016D\b+\u0019:b[N\u0014V-\u00193bE2,\u0007C\u0001\b$\r\u0011\u0001\"\u0001\u0001\u0013\u0014\t\r*\u0003f\u000b\t\u00043\u0019\u0012\u0013BA\u0014\u001b\u0005\u0015iu\u000eZ3m!\tq\u0011&\u0003\u0002+\u0005\tAA*S'F\u0005\u0006\u001cX\r\u0005\u0002-c5\tQF\u0003\u0002/_\u0005I1m\u001c8ue\u0006\u001cGo\u001d\u0006\u0003a\u0011\tAaY8sK&\u0011!'\f\u0002\n/J\f\u0007\u000f]1cY\u0016D\u0001\u0002N\u0012\u0003\u0006\u0004%\t!N\u0001\u0004k&$W#\u0001\u001c\u0011\u0005]RdBA\n9\u0013\tID#\u0001\u0004Qe\u0016$WMZ\u0005\u0003wq\u0012aa\u0015;sS:<'BA\u001d\u0015\u0011!q4E!A!\u0002\u00131\u0014\u0001B;jI\u0002BQ\u0001Q\u0012\u0005\u0002\u0005\u000ba\u0001P5oSRtDC\u0001\u0012C\u0011\u0015!t\b1\u00017\u0011\u0015\u00015\u0005\"\u0001E)\u0005\u0011\u0003b\u0002$$\u0005\u0004%\taR\u0001\fG>dW/\u001c8NK\u0006t7/F\u0001I!\tIE*D\u0001K\u0015\tY%$A\u0003qCJ\fW.\u0003\u0002N\u0015\n\u0001Bi\\;cY\u0016\f%O]1z!\u0006\u0014\u0018-\u001c\u0005\u0007\u001f\u000e\u0002\u000b\u0011\u0002%\u0002\u0019\r|G.^7o\u001b\u0016\fgn\u001d\u0011\t\u000bE\u001bC\u0011\u0001*\u0002\u001d\u001d,GoQ8mk6tW*Z1ogV\t1\u000bE\u0002\u0014)ZK!!\u0016\u000b\u0003\u000b\u0005\u0013(/Y=\u0011\u0005M9\u0016B\u0001-\u0015\u0005\u0019!u.\u001e2mK\")!l\tC\u00017\u0006q1/\u001a;D_2,XN\\'fC:\u001cHC\u0001/^\u001b\u0005\u0019\u0003\"\u00020Z\u0001\u0004\u0019\u0016!\u0001<\t\u000f\u0001\u001c#\u0019!C\u0001\u000f\u0006Q1m\u001c7v[:\u001cF\u000bR:\t\r\t\u001c\u0003\u0015!\u0003I\u0003-\u0019w\u000e\\;n]N#Fi\u001d\u0011\t\u000b\u0011\u001cC\u0011\u0001*\u0002\u001b\u001d,GoQ8mk6t7\u000b\u0016#t\u0011\u001517\u0005\"\u0001h\u00035\u0019X\r^\"pYVlgn\u0015+EgR\u0011A\f\u001b\u0005\u0006=\u0016\u0004\ra\u0015\u0005\u0006U\u000e\"Ia[\u0001\u0016a\u0016\u0014H/\u001e:cK\u0012$UM\\:f-\u0016\u001cGo\u001c:t)\tag\u0010E\u0002nkbt!A\\:\u000f\u0005=\u0014X\"\u00019\u000b\u0005Ed\u0011A\u0002\u001fs_>$h(C\u0001\u0016\u0013\t!H#A\u0004qC\u000e\\\u0017mZ3\n\u0005Y<(aA*fc*\u0011A\u000f\u0006\t\u0003srl\u0011A\u001f\u0006\u0003wj\ta\u0001\\5oC2<\u0017BA?{\u0005-!UM\\:f-\u0016\u001cGo\u001c:\t\u000byK\u0007\u0019\u0001=\t\u0013\u0005\u00051E1A\u0005\n\u0005\r\u0011\u0001\u00079feR,(OY3e\t\u0016t7/\u001a,fGR|'o]+E\rV\u0011\u0011Q\u0001\t\u0005\u0003\u000f\t\t\"\u0004\u0002\u0002\n)!\u00111BA\u0007\u0003-)\u0007\u0010\u001d:fgNLwN\\:\u000b\u0007\u0005=1$A\u0002tc2LA!a\u0005\u0002\n\t\u0019Rk]3s\t\u00164\u0017N\\3e\rVt7\r^5p]\"A\u0011qC\u0012!\u0002\u0013\t)!A\rqKJ$XO\u001d2fI\u0012+gn]3WK\u000e$xN]:V\t\u001a\u0003\u0003bBA\u000eG\u0011\u0005\u0013QD\u0001\niJ\fgn\u001d4pe6$B!a\b\u0002<A!\u0011\u0011EA\u001b\u001d\u0011\t\u0019#a\r\u000f\t\u0005\u0015\u0012\u0011\u0007\b\u0005\u0003O\tyC\u0004\u0003\u0002*\u00055bbA8\u0002,%\tq$\u0003\u0002\u001e=%\u0011Q\u0001H\u0005\u0004\u0003\u001fY\u0012b\u0001;\u0002\u000e%!\u0011qGA\u001d\u0005%!\u0015\r^1Ge\u0006lWMC\u0002u\u0003\u001bA\u0001\"!\u0010\u0002\u001a\u0001\u0007\u0011qH\u0001\bI\u0006$\u0018m]3ua\u0011\t\t%!\u0014\u0011\r\u0005\r\u0013QIA%\u001b\t\ti!\u0003\u0003\u0002H\u00055!a\u0002#bi\u0006\u001cX\r\u001e\t\u0005\u0003\u0017\ni\u0005\u0004\u0001\u0005\u0019\u0005=\u00131HA\u0001\u0002\u0003\u0015\t!!\u0015\u0003\u0007}##'\u0005\u0003\u0002T\u0005e\u0003cA\n\u0002V%\u0019\u0011q\u000b\u000b\u0003\u000f9{G\u000f[5oOB\u00191#a\u0017\n\u0007\u0005uCCA\u0002B]fDq!!\u0019$\t\u0003\n\u0019'\u0001\u0003d_BLHc\u0001\u0012\u0002f!A\u0011qMA0\u0001\u0004\tI'A\u0003fqR\u0014\u0018\rE\u0002J\u0003WJ1!!\u001cK\u0005!\u0001\u0016M]1n\u001b\u0006\u0004\bbBA9G\u0011\u0005\u00131O\u0001\u0010iJ\fgn\u001d4pe6\u001c6\r[3nCR!\u0011QOAA!\u0011\t9(! \u000e\u0005\u0005e$\u0002BA>\u0003\u001b\tQ\u0001^=qKNLA!a \u0002z\tQ1\u000b\u001e:vGR$\u0016\u0010]3\t\u0011\u0005\r\u0015q\u000ea\u0001\u0003k\naa]2iK6\f\u0007cA\n\u0002\b&\u0019\u0011\u0011\u0012\u000b\u0003\u0019M+'/[1mSj\f'\r\\3\t\r\u0001{A\u0011AAG)\u0005i\u0001\"CAI\u001f\u0005\u0005I\u0011BAJ\u0003-\u0011X-\u00193SKN|GN^3\u0015\u0005\u0005U\u0005\u0003BAL\u0003Ck!!!'\u000b\t\u0005m\u0015QT\u0001\u0005Y\u0006twM\u0003\u0002\u0002 \u0006!!.\u0019<b\u0013\u0011\t\u0019+!'\u0003\r=\u0013'.Z2u\u0001")
/* loaded from: input_file:com/microsoft/ml/spark/lime/TabularLIMEModel.class */
public class TabularLIMEModel extends Model<TabularLIMEModel> implements LIMEBase, Wrappable {
    private final String uid;
    private final DoubleArrayParam columnMeans;
    private final DoubleArrayParam columnSTDs;
    private final UserDefinedFunction perturbedDenseVectorsUDF;
    private final UserDefinedFunction arrToMatUDF;
    private final UserDefinedFunction arrToVectUDF;
    private final UserDefinedFunction fitLassoUDF;
    private final UserDefinedFunction getSampleUDF;
    private final TransformerParam model;
    private final IntParam nSamples;
    private final DoubleParam samplingFraction;
    private final DoubleParam regularization;
    private final Param<String> predictionCol;
    private final Param<String> outputCol;
    private final Param<String> inputCol;

    public static Object load(String str) {
        return TabularLIMEModel$.MODULE$.load(str);
    }

    public static MLReader<TabularLIMEModel> read() {
        return TabularLIMEModel$.MODULE$.read();
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasAdditionalPythonMethods
    public String additionalPythonMethods() {
        return HasAdditionalPythonMethods.Cclass.additionalPythonMethods(this);
    }

    @Override // com.microsoft.ml.spark.lime.LIMEBase
    public UserDefinedFunction arrToMatUDF() {
        return this.arrToMatUDF;
    }

    @Override // com.microsoft.ml.spark.lime.LIMEBase
    public UserDefinedFunction arrToVectUDF() {
        return this.arrToVectUDF;
    }

    @Override // com.microsoft.ml.spark.lime.LIMEBase
    public UserDefinedFunction fitLassoUDF() {
        return this.fitLassoUDF;
    }

    @Override // com.microsoft.ml.spark.lime.LIMEBase
    public UserDefinedFunction getSampleUDF() {
        return this.getSampleUDF;
    }

    @Override // com.microsoft.ml.spark.lime.LIMEBase
    public void com$microsoft$ml$spark$lime$LIMEBase$_setter_$arrToMatUDF_$eq(UserDefinedFunction userDefinedFunction) {
        this.arrToMatUDF = userDefinedFunction;
    }

    @Override // com.microsoft.ml.spark.lime.LIMEBase
    public void com$microsoft$ml$spark$lime$LIMEBase$_setter_$arrToVectUDF_$eq(UserDefinedFunction userDefinedFunction) {
        this.arrToVectUDF = userDefinedFunction;
    }

    @Override // com.microsoft.ml.spark.lime.LIMEBase
    public void com$microsoft$ml$spark$lime$LIMEBase$_setter_$fitLassoUDF_$eq(UserDefinedFunction userDefinedFunction) {
        this.fitLassoUDF = userDefinedFunction;
    }

    @Override // com.microsoft.ml.spark.lime.LIMEBase
    public void com$microsoft$ml$spark$lime$LIMEBase$_setter_$getSampleUDF_$eq(UserDefinedFunction userDefinedFunction) {
        this.getSampleUDF = userDefinedFunction;
    }

    @Override // com.microsoft.ml.spark.lime.LIMEBase
    public Seq<Seq<Object>> getSamples(int i) {
        return LIMEBase.Cclass.getSamples(this, i);
    }

    @Override // com.microsoft.ml.spark.lime.LIMEBase
    public DenseMatrix arrToMat(Seq<DenseVector> seq) {
        return LIMEBase.Cclass.arrToMat(this, seq);
    }

    @Override // com.microsoft.ml.spark.lime.LIMEBase
    public DenseVector arrToVect(Seq<Object> seq) {
        return LIMEBase.Cclass.arrToVect(this, seq);
    }

    @Override // org.apache.spark.ml.ComplexParamsWritable
    public MLWriter write() {
        return ComplexParamsWritable.Cclass.write(this);
    }

    public void save(String str) throws IOException {
        MLWritable.class.save(this, str);
    }

    @Override // com.microsoft.ml.spark.lime.LIMEParams
    public TransformerParam model() {
        return this.model;
    }

    @Override // com.microsoft.ml.spark.lime.LIMEParams
    public IntParam nSamples() {
        return this.nSamples;
    }

    @Override // com.microsoft.ml.spark.lime.LIMEParams
    public DoubleParam samplingFraction() {
        return this.samplingFraction;
    }

    @Override // com.microsoft.ml.spark.lime.LIMEParams
    public DoubleParam regularization() {
        return this.regularization;
    }

    @Override // com.microsoft.ml.spark.lime.LIMEParams
    public void com$microsoft$ml$spark$lime$LIMEParams$_setter_$model_$eq(TransformerParam transformerParam) {
        this.model = transformerParam;
    }

    @Override // com.microsoft.ml.spark.lime.LIMEParams
    public void com$microsoft$ml$spark$lime$LIMEParams$_setter_$nSamples_$eq(IntParam intParam) {
        this.nSamples = intParam;
    }

    @Override // com.microsoft.ml.spark.lime.LIMEParams
    public void com$microsoft$ml$spark$lime$LIMEParams$_setter_$samplingFraction_$eq(DoubleParam doubleParam) {
        this.samplingFraction = doubleParam;
    }

    @Override // com.microsoft.ml.spark.lime.LIMEParams
    public void com$microsoft$ml$spark$lime$LIMEParams$_setter_$regularization_$eq(DoubleParam doubleParam) {
        this.regularization = doubleParam;
    }

    @Override // com.microsoft.ml.spark.lime.LIMEParams
    public LIMEParams setPredictionCol(String str) {
        return LIMEParams.Cclass.setPredictionCol(this, str);
    }

    @Override // com.microsoft.ml.spark.lime.LIMEParams
    public Transformer getModel() {
        return LIMEParams.Cclass.getModel(this);
    }

    @Override // com.microsoft.ml.spark.lime.LIMEParams
    public LIMEParams setModel(Transformer transformer) {
        return LIMEParams.Cclass.setModel(this, transformer);
    }

    @Override // com.microsoft.ml.spark.lime.LIMEParams
    public int getNSamples() {
        return LIMEParams.Cclass.getNSamples(this);
    }

    @Override // com.microsoft.ml.spark.lime.LIMEParams
    public LIMEParams setNSamples(int i) {
        return LIMEParams.Cclass.setNSamples(this, i);
    }

    @Override // com.microsoft.ml.spark.lime.LIMEParams
    public double getSamplingFraction() {
        return LIMEParams.Cclass.getSamplingFraction(this);
    }

    @Override // com.microsoft.ml.spark.lime.LIMEParams
    public LIMEParams setSamplingFraction(double d) {
        return LIMEParams.Cclass.setSamplingFraction(this, d);
    }

    @Override // com.microsoft.ml.spark.lime.LIMEParams
    public double getRegularization() {
        return LIMEParams.Cclass.getRegularization(this);
    }

    @Override // com.microsoft.ml.spark.lime.LIMEParams
    public LIMEParams setRegularization(double d) {
        return LIMEParams.Cclass.setRegularization(this, d);
    }

    public final Param<String> predictionCol() {
        return this.predictionCol;
    }

    public final void org$apache$spark$ml$param$shared$HasPredictionCol$_setter_$predictionCol_$eq(Param param) {
        this.predictionCol = param;
    }

    public final String getPredictionCol() {
        return HasPredictionCol.class.getPredictionCol(this);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasOutputCol
    public Param<String> outputCol() {
        return this.outputCol;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasOutputCol
    public void com$microsoft$ml$spark$core$contracts$HasOutputCol$_setter_$outputCol_$eq(Param param) {
        this.outputCol = param;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasOutputCol
    public HasOutputCol setOutputCol(String str) {
        return HasOutputCol.Cclass.setOutputCol(this, str);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasOutputCol
    public String getOutputCol() {
        return HasOutputCol.Cclass.getOutputCol(this);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasInputCol
    public Param<String> inputCol() {
        return this.inputCol;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasInputCol
    public void com$microsoft$ml$spark$core$contracts$HasInputCol$_setter_$inputCol_$eq(Param param) {
        this.inputCol = param;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasInputCol
    public HasInputCol setInputCol(String str) {
        return HasInputCol.Cclass.setInputCol(this, str);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasInputCol
    public String getInputCol() {
        return HasInputCol.Cclass.getInputCol(this);
    }

    public String uid() {
        return this.uid;
    }

    public DoubleArrayParam columnMeans() {
        return this.columnMeans;
    }

    public double[] getColumnMeans() {
        return (double[]) $(columnMeans());
    }

    public TabularLIMEModel setColumnMeans(double[] dArr) {
        return (TabularLIMEModel) set(columnMeans(), dArr);
    }

    public DoubleArrayParam columnSTDs() {
        return this.columnSTDs;
    }

    public double[] getColumnSTDs() {
        return (double[]) $(columnSTDs());
    }

    public TabularLIMEModel setColumnSTDs(double[] dArr) {
        return (TabularLIMEModel) set(columnSTDs(), dArr);
    }

    public Seq<DenseVector> com$microsoft$ml$spark$lime$TabularLIMEModel$$perturbedDenseVectors(DenseVector denseVector) {
        return Seq$.MODULE$.fill(getNSamples(), new TabularLIMEModel$$anonfun$com$microsoft$ml$spark$lime$TabularLIMEModel$$perturbedDenseVectors$1(this, denseVector));
    }

    private UserDefinedFunction perturbedDenseVectorsUDF() {
        return this.perturbedDenseVectorsUDF;
    }

    public Dataset<Row> transform(Dataset<?> dataset) {
        Dataset<?> df = dataset.toDF();
        String findUnusedColumnName = DatasetExtensions$.MODULE$.findUnusedColumnName("id", df);
        String findUnusedColumnName2 = DatasetExtensions$.MODULE$.findUnusedColumnName("states", df);
        String findUnusedColumnName3 = DatasetExtensions$.MODULE$.findUnusedColumnName("inputCol2", df);
        return LIMEUtils$.MODULE$.localAggregateBy(FluentAPI$.MODULE$.toSugaredDF(df.withColumn(findUnusedColumnName, functions$.MODULE$.monotonically_increasing_id()).withColumnRenamed(getInputCol(), findUnusedColumnName3).withColumn(getInputCol(), functions$.MODULE$.explode_outer(perturbedDenseVectorsUDF().apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(findUnusedColumnName3)}))))).mlTransform(getModel()), findUnusedColumnName, (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new String[]{getInputCol(), getPredictionCol()}))).withColumn(getInputCol(), arrToMatUDF().apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(getInputCol())}))).withColumn(getPredictionCol(), arrToVectUDF().apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(getPredictionCol())}))).withColumn(getOutputCol(), fitLassoUDF().apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(getInputCol()), functions$.MODULE$.col(getPredictionCol()), functions$.MODULE$.lit(BoxesRunTime.boxToDouble(getRegularization()))}))).drop(Predef$.MODULE$.wrapRefArray(new String[]{findUnusedColumnName2, getPredictionCol(), findUnusedColumnName, getInputCol()})).withColumnRenamed(findUnusedColumnName3, getInputCol());
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public TabularLIMEModel m640copy(ParamMap paramMap) {
        return (TabularLIMEModel) defaultCopy(paramMap);
    }

    public StructType transformSchema(StructType structType) {
        return structType.add(getOutputCol(), SQLDataTypes$.MODULE$.VectorType());
    }

    public TabularLIMEModel(String str) {
        this.uid = str;
        com$microsoft$ml$spark$core$contracts$HasInputCol$_setter_$inputCol_$eq(new Param(this, "inputCol", "The name of the input column"));
        com$microsoft$ml$spark$core$contracts$HasOutputCol$_setter_$outputCol_$eq(new Param(this, "outputCol", "The name of the output column"));
        HasPredictionCol.class.$init$(this);
        LIMEParams.Cclass.$init$(this);
        MLWritable.class.$init$(this);
        ComplexParamsWritable.Cclass.$init$(this);
        LIMEBase.Cclass.$init$(this);
        HasAdditionalPythonMethods.Cclass.$init$(this);
        this.columnMeans = new DoubleArrayParam(this, "columnMeans", "the means of each of the columns for perturbation");
        this.columnSTDs = new DoubleArrayParam(this, "columnSTDs", "the standard deviations of each of the columns for perturbation");
        this.perturbedDenseVectorsUDF = functions$.MODULE$.udf(new TabularLIMEModel$$anonfun$12(this), new ArrayType(SQLDataTypes$.MODULE$.VectorType(), true));
    }

    public TabularLIMEModel() {
        this(Identifiable$.MODULE$.randomUID("TabularLIMEModel"));
    }
}
