package com.microsoft.ml.spark.train;

import com.microsoft.ml.spark.core.contracts.HasAdditionalPythonMethods;
import com.microsoft.ml.spark.core.contracts.HasFeaturesCol;
import com.microsoft.ml.spark.core.contracts.HasLabelCol;
import com.microsoft.ml.spark.core.contracts.HasOutputCol;
import com.microsoft.ml.spark.core.schema.CategoricalUtilities$;
import com.microsoft.ml.spark.core.schema.SparkSchema$;
import com.microsoft.ml.spark.featurize.Featurize;
import com.microsoft.ml.spark.featurize.FeaturizeUtilities$;
import com.microsoft.ml.spark.featurize.ValueIndexer;
import com.microsoft.ml.spark.featurize.ValueIndexerModel;
import com.microsoft.ml.spark.train.AutoTrainer;
import java.io.IOException;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.ml.ComplexParamsWritable;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.Predictor;
import org.apache.spark.ml.classification.DecisionTreeClassifier;
import org.apache.spark.ml.classification.GBTClassifier;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier;
import org.apache.spark.ml.classification.OneVsRest;
import org.apache.spark.ml.classification.RandomForestClassifier;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.BooleanParam;
import org.apache.spark.ml.param.EstimatorParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.param.StringArrayParam;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Some;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.Seq$;
import scala.collection.immutable.Map;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

/* compiled from: TrainClassifier.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005%h\u0001B\u0001\u0003\u00015\u0011q\u0002\u0016:bS:\u001cE.Y:tS\u001aLWM\u001d\u0006\u0003\u0007\u0011\tQ\u0001\u001e:bS:T!!\u0002\u0004\u0002\u000bM\u0004\u0018M]6\u000b\u0005\u001dA\u0011AA7m\u0015\tI!\"A\u0005nS\u000e\u0014xn]8gi*\t1\"A\u0002d_6\u001c\u0001aE\u0002\u0001\u001dq\u00012a\u0004\f\u0019\u001b\u0005\u0001\"BA\u0004\u0012\u0015\t)!C\u0003\u0002\u0014)\u00051\u0011\r]1dQ\u0016T\u0011!F\u0001\u0004_J<\u0017BA\f\u0011\u0005%)5\u000f^5nCR|'\u000f\u0005\u0002\u001a55\t!!\u0003\u0002\u001c\u0005\t1BK]1j]\u0016$7\t\\1tg&4\u0017.\u001a:N_\u0012,G\u000eE\u0002\u001a;aI!A\b\u0002\u0003\u0017\u0005+Ho\u001c+sC&tWM\u001d\u0005\tA\u0001\u0011)\u0019!C!C\u0005\u0019Q/\u001b3\u0016\u0003\t\u0002\"aI\u0015\u000f\u0005\u0011:S\"A\u0013\u000b\u0003\u0019\nQa]2bY\u0006L!\u0001K\u0013\u0002\rA\u0013X\rZ3g\u0013\tQ3F\u0001\u0004TiJLgn\u001a\u0006\u0003Q\u0015B\u0001\"\f\u0001\u0003\u0002\u0003\u0006IAI\u0001\u0005k&$\u0007\u0005C\u00030\u0001\u0011\u0005\u0001'\u0001\u0004=S:LGO\u0010\u000b\u0003cI\u0002\"!\u0007\u0001\t\u000b\u0001r\u0003\u0019\u0001\u0012\t\u000b=\u0002A\u0011\u0001\u001b\u0015\u0003EBQA\u000e\u0001\u0005B\u0005\n\u0001\"\\8eK2$un\u0019\u0005\bq\u0001\u0011\r\u0011\"\u0001:\u00031\u0011X-\u001b8eKbd\u0015MY3m+\u0005Q\u0004CA\u001e?\u001b\u0005a$BA\u001f\u0011\u0003\u0015\u0001\u0018M]1n\u0013\tyDH\u0001\u0007C_>dW-\u00198QCJ\fW\u000e\u0003\u0004B\u0001\u0001\u0006IAO\u0001\u000ee\u0016Lg\u000eZ3y\u0019\u0006\u0014W\r\u001c\u0011\t\u000b\r\u0003A\u0011\u0001#\u0002\u001f\u001d,GOU3j]\u0012,\u0007\u0010T1cK2,\u0012!\u0012\t\u0003I\u0019K!aR\u0013\u0003\u000f\t{w\u000e\\3b]\")\u0011\n\u0001C\u0001\u0015\u0006y1/\u001a;SK&tG-\u001a=MC\n,G\u000e\u0006\u0002L\u00196\t\u0001\u0001C\u0003N\u0011\u0002\u0007Q)A\u0003wC2,X\rC\u0004P\u0001\t\u0007I\u0011\u0001)\u0002\r1\f'-\u001a7t+\u0005\t\u0006CA\u001eS\u0013\t\u0019FH\u0001\tTiJLgnZ!se\u0006L\b+\u0019:b[\"1Q\u000b\u0001Q\u0001\nE\u000bq\u0001\\1cK2\u001c\b\u0005C\u0003X\u0001\u0011\u0005\u0001,A\u0005hKRd\u0015MY3mgV\t\u0011\fE\u0002%5\nJ!aW\u0013\u0003\u000b\u0005\u0013(/Y=\t\u000bu\u0003A\u0011\u00010\u0002\u0013M,G\u000fT1cK2\u001cHCA&`\u0011\u0015iE\f1\u0001Z\u0011\u0015\t\u0007\u0001\"\u0011c\u0003\r1\u0017\u000e\u001e\u000b\u00031\rDQ\u0001\u001a1A\u0002\u0015\fq\u0001Z1uCN,G\u000f\r\u0002g]B\u0019qM\u001b7\u000e\u0003!T!![\t\u0002\u0007M\fH.\u0003\u0002lQ\n9A)\u0019;bg\u0016$\bCA7o\u0019\u0001!\u0011b\\2\u0002\u0002\u0003\u0005)\u0011\u00019\u0003\u0007}#\u0013'\u0005\u0002riB\u0011AE]\u0005\u0003g\u0016\u0012qAT8uQ&tw\r\u0005\u0002%k&\u0011a/\n\u0002\u0004\u0003:L\b\"\u0002=\u0001\t\u0003I\u0018AE4fi\u001a+\u0017\r^;sSj,\u0007+\u0019:b[N,\u0012A\u001f\t\u0006Im,U)`\u0005\u0003y\u0016\u0012a\u0001V;qY\u0016\u001c\u0004C\u0001\u0013\u007f\u0013\tyXEA\u0002J]RDq!a\u0001\u0001\t\u0003\t)!\u0001\u0007d_:4XM\u001d;MC\n,G\u000e\u0006\u0005\u0002\b\u0005\u0005\u0013QJA)!\u001d!\u0013\u0011BA\u0007\u0003cI1!a\u0003&\u0005\u0019!V\u000f\u001d7feA!\u0011qBA\u0016\u001d\u0011\t\t\"a\n\u000f\t\u0005M\u0011Q\u0005\b\u0005\u0003+\t\u0019C\u0004\u0003\u0002\u0018\u0005\u0005b\u0002BA\r\u0003?i!!a\u0007\u000b\u0007\u0005uA\"\u0001\u0004=e>|GOP\u0005\u0002+%\u00111\u0003F\u0005\u0003\u000bII!![\t\n\u0007\u0005%\u0002.A\u0004qC\u000e\\\u0017mZ3\n\t\u00055\u0012q\u0006\u0002\n\t\u0006$\u0018M\u0012:b[\u0016T1!!\u000bi!\u0015!\u00131GA\u001c\u0013\r\t)$\n\u0002\u0007\u001fB$\u0018n\u001c81\t\u0005e\u0012Q\b\t\u0005Ii\u000bY\u0004E\u0002n\u0003{!1\"a\u0010\u0002\u0002\u0005\u0005\t\u0011!B\u0001a\n\u0019q\f\n\u001d\t\u000f\u0011\f\t\u00011\u0001\u0002DA\"\u0011QIA%!\u00119'.a\u0012\u0011\u00075\fI\u0005B\u0006\u0002L\u0005\u0005\u0013\u0011!A\u0001\u0006\u0003\u0001(aA0%m!9\u0011qJA\u0001\u0001\u0004\u0011\u0013a\u00037bE\u0016d7i\u001c7v[:D\u0001\"a\u0015\u0002\u0002\u0001\u0007\u0011QK\u0001\fY\u0006\u0014W\r\u001c,bYV,7\u000fE\u0003%\u0003g\t9\u0006\r\u0003\u0002Z\u0005u\u0003\u0003\u0002\u0013[\u00037\u00022!\\A/\t-\ty&!\u0015\u0002\u0002\u0003\u0005)\u0011\u00019\u0003\u0007}#s\u0007C\u0004\u0002d\u0001!\t%!\u001a\u0002\t\r|\u0007/\u001f\u000b\u0004\u001d\u0005\u001d\u0004\u0002CA5\u0003C\u0002\r!a\u001b\u0002\u000b\u0015DHO]1\u0011\u0007m\ni'C\u0002\u0002pq\u0012\u0001\u0002U1sC6l\u0015\r\u001d\u0005\b\u0003g\u0002A\u0011IA;\u0003=!(/\u00198tM>\u0014XnU2iK6\fG\u0003BA<\u0003\u0007\u0003B!!\u001f\u0002��5\u0011\u00111\u0010\u0006\u0004\u0003{B\u0017!\u0002;za\u0016\u001c\u0018\u0002BAA\u0003w\u0012!b\u0015;sk\u000e$H+\u001f9f\u0011!\t))!\u001dA\u0002\u0005]\u0014AB:dQ\u0016l\u0017\r\u000b\u0003\u0002r\u0005%\u0005\u0003BAF\u0003#k!!!$\u000b\u0007\u0005=\u0015#\u0001\u0006b]:|G/\u0019;j_:LA!a%\u0002\u000e\naA)\u001a<fY>\u0004XM]!qS\"\u001a\u0001!a&\u0011\t\u0005e\u00151U\u0007\u0003\u00037SA!!(\u0002 \u0006\u0019QM\u001c<\u000b\u0007\u0005\u0005F!\u0001\u0003d_J,\u0017\u0002BAS\u00037\u0013q\"\u00138uKJt\u0017\r\\,sCB\u0004XM]\u0004\b\u0003S\u0013\u0001\u0012AAV\u0003=!&/Y5o\u00072\f7o]5gS\u0016\u0014\bcA\r\u0002.\u001a1\u0011A\u0001E\u0001\u0003_\u001b\u0002\"!,\u00022\u0006]\u0016Q\u0018\t\u0004I\u0005M\u0016bAA[K\t1\u0011I\\=SK\u001a\u0004BaDA]c%\u0019\u00111\u0018\t\u0003+\r{W\u000e\u001d7fqB\u000b'/Y7t%\u0016\fG-\u00192mKB\u0019A%a0\n\u0007\u0005\u0005WE\u0001\u0007TKJL\u0017\r\\5{C\ndW\rC\u00040\u0003[#\t!!2\u0015\u0005\u0005-\u0006\u0002CAe\u0003[#\t!a3\u0002/Y\fG.\u001b3bi\u0016$&/\u00198tM>\u0014XnU2iK6\fGCBA<\u0003\u001b\f\t\u000eC\u0004\u0002P\u0006\u001d\u0007\u0019A#\u0002\u0019!\f7oU2pe\u0016\u001cu\u000e\\:\t\u0011\u0005\u0015\u0015q\u0019a\u0001\u0003oB!\"!6\u0002.\u0006\u0005I\u0011BAl\u0003-\u0011X-\u00193SKN|GN^3\u0015\u0005\u0005e\u0007\u0003BAn\u0003Kl!!!8\u000b\t\u0005}\u0017\u0011]\u0001\u0005Y\u0006twM\u0003\u0002\u0002d\u0006!!.\u0019<b\u0013\u0011\t9/!8\u0003\r=\u0013'.Z2u\u0001")
/* loaded from: input_file:com/microsoft/ml/spark/train/TrainClassifier.class */
public class TrainClassifier extends Estimator<TrainedClassifierModel> implements AutoTrainer<TrainedClassifierModel> {
    private final String uid;
    private final BooleanParam reindexLabel;
    private final StringArrayParam labels;
    private final IntParam numFeatures;
    private final EstimatorParam model;
    private final Param<String> featuresCol;
    private final Param<String> labelCol;

    public static Object load(String str) {
        return TrainClassifier$.MODULE$.load(str);
    }

    public static MLReader<TrainClassifier> read() {
        return TrainClassifier$.MODULE$.read();
    }

    public static StructType validateTransformSchema(boolean z, StructType structType) {
        return TrainClassifier$.MODULE$.validateTransformSchema(z, structType);
    }

    @Override // com.microsoft.ml.spark.train.AutoTrainer
    public IntParam numFeatures() {
        return this.numFeatures;
    }

    @Override // com.microsoft.ml.spark.train.AutoTrainer
    public EstimatorParam model() {
        return this.model;
    }

    @Override // com.microsoft.ml.spark.train.AutoTrainer
    public void com$microsoft$ml$spark$train$AutoTrainer$_setter_$numFeatures_$eq(IntParam intParam) {
        this.numFeatures = intParam;
    }

    @Override // com.microsoft.ml.spark.train.AutoTrainer
    public void com$microsoft$ml$spark$train$AutoTrainer$_setter_$model_$eq(EstimatorParam estimatorParam) {
        this.model = estimatorParam;
    }

    @Override // com.microsoft.ml.spark.train.AutoTrainer
    public int getNumFeatures() {
        return AutoTrainer.Cclass.getNumFeatures(this);
    }

    @Override // com.microsoft.ml.spark.train.AutoTrainer
    public AutoTrainer<TrainedClassifierModel> setNumFeatures(int i) {
        return AutoTrainer.Cclass.setNumFeatures(this, i);
    }

    @Override // com.microsoft.ml.spark.train.AutoTrainer
    public Estimator<? extends Model<?>> getModel() {
        return AutoTrainer.Cclass.getModel(this);
    }

    @Override // com.microsoft.ml.spark.train.AutoTrainer
    public AutoTrainer<TrainedClassifierModel> setModel(Estimator<? extends Model<?>> estimator) {
        return AutoTrainer.Cclass.setModel(this, estimator);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasAdditionalPythonMethods
    public String additionalPythonMethods() {
        return HasAdditionalPythonMethods.Cclass.additionalPythonMethods(this);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasFeaturesCol
    public Param<String> featuresCol() {
        return this.featuresCol;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasFeaturesCol
    public void com$microsoft$ml$spark$core$contracts$HasFeaturesCol$_setter_$featuresCol_$eq(Param param) {
        this.featuresCol = param;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasFeaturesCol
    public HasFeaturesCol setFeaturesCol(String str) {
        return HasFeaturesCol.Cclass.setFeaturesCol(this, str);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasFeaturesCol
    public String getFeaturesCol() {
        return HasFeaturesCol.Cclass.getFeaturesCol(this);
    }

    @Override // org.apache.spark.ml.ComplexParamsWritable
    public MLWriter write() {
        return ComplexParamsWritable.Cclass.write(this);
    }

    public void save(String str) throws IOException {
        MLWritable.class.save(this, str);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasLabelCol
    public Param<String> labelCol() {
        return this.labelCol;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasLabelCol
    public void com$microsoft$ml$spark$core$contracts$HasLabelCol$_setter_$labelCol_$eq(Param param) {
        this.labelCol = param;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasLabelCol
    public HasLabelCol setLabelCol(String str) {
        return HasLabelCol.Cclass.setLabelCol(this, str);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasLabelCol
    public String getLabelCol() {
        return HasLabelCol.Cclass.getLabelCol(this);
    }

    public String uid() {
        return this.uid;
    }

    @Override // com.microsoft.ml.spark.train.AutoTrainer
    public String modelDoc() {
        return "Classifier to run";
    }

    public BooleanParam reindexLabel() {
        return this.reindexLabel;
    }

    public boolean getReindexLabel() {
        return BoxesRunTime.unboxToBoolean($(reindexLabel()));
    }

    public TrainClassifier setReindexLabel(boolean z) {
        return (TrainClassifier) set(reindexLabel(), BoxesRunTime.boxToBoolean(z));
    }

    public StringArrayParam labels() {
        return this.labels;
    }

    public String[] getLabels() {
        return (String[]) $(labels());
    }

    public TrainClassifier setLabels(String[] strArr) {
        return (TrainClassifier) set(labels(), strArr);
    }

    public TrainedClassifierModel fit(Dataset<?> dataset) {
        OneVsRest oneVsRest;
        OneVsRest oneVsRest2;
        Tuple2<Dataset<Row>, Option<Object>> convertLabel = convertLabel(dataset, getLabelCol(), isDefined(labels()) ? new Some(getLabels()) : None$.MODULE$);
        if (convertLabel == null) {
            throw new MatchError(convertLabel);
        }
        Tuple2 tuple2 = new Tuple2((Dataset) convertLabel._1(), (Option) convertLabel._2());
        Dataset<?> dataset2 = (Dataset) tuple2._1();
        Option option = (Option) tuple2._2();
        Tuple3<Object, Object, Object> featurizeParams = getFeaturizeParams();
        if (featurizeParams == null) {
            throw new MatchError(featurizeParams);
        }
        Tuple3 tuple3 = new Tuple3(BoxesRunTime.boxToBoolean(BoxesRunTime.unboxToBoolean(featurizeParams._1())), BoxesRunTime.boxToBoolean(BoxesRunTime.unboxToBoolean(featurizeParams._2())), BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(featurizeParams._3())));
        boolean unboxToBoolean = BoxesRunTime.unboxToBoolean(tuple3._1());
        boolean unboxToBoolean2 = BoxesRunTime.unboxToBoolean(tuple3._2());
        int unboxToInt = BoxesRunTime.unboxToInt(tuple3._3());
        OneVsRest model = getModel();
        if (model instanceof LogisticRegression) {
            OneVsRest oneVsRest3 = (LogisticRegression) model;
            oneVsRest = (!option.isDefined() || ScalaRunTime$.MODULE$.array_length(option.get()) <= 2) ? oneVsRest3 : new OneVsRest().setClassifier(oneVsRest3.setLabelCol(getLabelCol()).setFeaturesCol(getFeaturesCol())).setLabelCol(getLabelCol()).setFeaturesCol(getFeaturesCol());
        } else if (model instanceof GBTClassifier) {
            OneVsRest oneVsRest4 = (GBTClassifier) model;
            if (option.isDefined() && ScalaRunTime$.MODULE$.array_length(option.get()) > 2) {
                throw new Exception("Multiclass Gradient Boosted Tree Classifier not supported yet");
            }
            oneVsRest = oneVsRest4;
        } else {
            if (!(model instanceof Estimator)) {
                throw new Exception(new StringBuilder().append("Unsupported learner type ").append(getModel().getClass().toString()).toString());
            }
            oneVsRest = model;
        }
        OneVsRest oneVsRest5 = oneVsRest;
        if (oneVsRest5 instanceof Predictor) {
            oneVsRest2 = ((Predictor) oneVsRest5).setLabelCol(getLabelCol()).setFeaturesCol(getFeaturesCol());
        } else {
            if (!(oneVsRest5 instanceof Estimator)) {
                throw new MatchError(oneVsRest5);
            }
            oneVsRest2 = oneVsRest5;
        }
        OneVsRest oneVsRest6 = oneVsRest2;
        PipelineStage fit = new Featurize().setFeatureColumns((Map) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getFeaturesCol()), Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps(dataset2.columns()).filter(new TrainClassifier$$anonfun$1(this))).toSeq())}))).setOneHotEncodeCategoricals(unboxToBoolean).setNumberOfFeatures(getNumFeatures() != 0 ? getNumFeatures() : unboxToInt).fit(dataset2);
        Dataset transform = fit.transform(dataset2);
        transform.cache();
        if (unboxToBoolean2) {
            MultilayerPerceptronClassifier multilayerPerceptronClassifier = (MultilayerPerceptronClassifier) oneVsRest6;
            Row row = ((Row[]) transform.take(1))[0];
            multilayerPerceptronClassifier.getLayers()[0] = ((Vector) row.get(row.fieldIndex(getFeaturesCol()))).size();
            multilayerPerceptronClassifier.setLayers(multilayerPerceptronClassifier.getLayers());
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        PipelineStage fit2 = oneVsRest6.fit(transform);
        transform.unpersist();
        return new TrainedClassifierModel(uid(), getLabelCol(), new Pipeline().setStages(new PipelineStage[]{fit, fit2}).fit(dataset2), option, getFeaturesCol());
    }

    public Tuple3<Object, Object, Object> getFeaturizeParams() {
        int NumFeaturesDefault;
        boolean z = true;
        boolean z2 = false;
        Estimator<? extends Model<?>> model = getModel();
        if (model instanceof DecisionTreeClassifier ? true : model instanceof GBTClassifier ? true : model instanceof RandomForestClassifier) {
            z = false;
            NumFeaturesDefault = FeaturizeUtilities$.MODULE$.NumFeaturesTreeOrNNBased();
        } else if (model instanceof MultilayerPerceptronClassifier) {
            z2 = true;
            NumFeaturesDefault = FeaturizeUtilities$.MODULE$.NumFeaturesTreeOrNNBased();
        } else {
            NumFeaturesDefault = FeaturizeUtilities$.MODULE$.NumFeaturesDefault();
        }
        return new Tuple3<>(BoxesRunTime.boxToBoolean(z), BoxesRunTime.boxToBoolean(z2), BoxesRunTime.boxToInteger(NumFeaturesDefault));
    }

    public Tuple2<Dataset<Row>, Option<Object>> convertLabel(Dataset<?> dataset, String str, Option<Object> option) {
        Option<Object> option2 = None$.MODULE$;
        if (!getReindexLabel()) {
            if (option.isDefined()) {
                option2 = option;
            }
            return new Tuple2<>(dataset.na().drop(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new String[]{str}))), option2);
        }
        Dataset<?> drop = dataset.toDF().na().drop(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new String[]{str})));
        if (!option.isDefined()) {
            if (SparkSchema$.MODULE$.isCategorical(drop, str)) {
                return new Tuple2<>(drop, CategoricalUtilities$.MODULE$.getLevels(drop.schema(), str));
            }
            Dataset<Row> transform = ((ValueIndexer) ((HasOutputCol) new ValueIndexer().setInputCol(str)).setOutputCol(str)).fit(drop).transform(drop);
            return new Tuple2<>(transform.withColumn(str, transform.apply(str).cast(DoubleType$.MODULE$).as(str, transform.schema().apply(str).metadata())), CategoricalUtilities$.MODULE$.getLevels(transform.schema(), str));
        }
        if (SparkSchema$.MODULE$.isCategorical(drop, str)) {
            throw new Exception("Column is already categorical, cannot set label values");
        }
        DataType dataType = dataset.schema().apply(str).dataType();
        Predef$.MODULE$.genericArrayOps(option.get()).map(new TrainClassifier$$anonfun$convertLabel$1(this, dataType), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Any()));
        return new Tuple2<>(((ValueIndexerModel) ((HasOutputCol) new ValueIndexerModel().setLevels(option.get()).setDataType(dataType).setInputCol(str)).setOutputCol(str)).transform(drop), option);
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public Estimator<TrainedClassifierModel> m906copy(ParamMap paramMap) {
        setModel(getModel().copy(paramMap));
        return defaultCopy(paramMap);
    }

    @DeveloperApi
    public StructType transformSchema(StructType structType) {
        Estimator estimator = (Estimator) $(model());
        return TrainClassifier$.MODULE$.validateTransformSchema(estimator instanceof GBTClassifier ? false : !(estimator instanceof MultilayerPerceptronClassifier), structType);
    }

    /* renamed from: fit, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m907fit(Dataset dataset) {
        return fit((Dataset<?>) dataset);
    }

    public TrainClassifier(String str) {
        this.uid = str;
        com$microsoft$ml$spark$core$contracts$HasLabelCol$_setter_$labelCol_$eq(new Param(this, "labelCol", "The name of the label column"));
        MLWritable.class.$init$(this);
        ComplexParamsWritable.Cclass.$init$(this);
        com$microsoft$ml$spark$core$contracts$HasFeaturesCol$_setter_$featuresCol_$eq(new Param(this, "featuresCol", "The name of the features column"));
        HasAdditionalPythonMethods.Cclass.$init$(this);
        AutoTrainer.Cclass.$init$(this);
        this.reindexLabel = new BooleanParam(this, "reindexLabel", "Re-index the label column");
        setDefault(Predef$.MODULE$.wrapRefArray(new ParamPair[]{reindexLabel().$minus$greater(BoxesRunTime.boxToBoolean(true))}));
        this.labels = new StringArrayParam(this, "labels", "Sorted label values on the labels column");
        setDefault(featuresCol(), new StringBuilder().append(str).append("_features").toString());
    }

    public TrainClassifier() {
        this(Identifiable$.MODULE$.randomUID("TrainClassifier"));
    }
}
