package com.microsoft.ml.spark.train;

import com.microsoft.ml.spark.core.contracts.HasAdditionalPythonMethods;
import com.microsoft.ml.spark.core.contracts.HasFeaturesCol;
import com.microsoft.ml.spark.core.contracts.HasLabelCol;
import com.microsoft.ml.spark.featurize.Featurize;
import com.microsoft.ml.spark.featurize.FeaturizeUtilities$;
import com.microsoft.ml.spark.train.AutoTrainer;
import java.io.IOException;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.ml.ComplexParamsWritable;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.Predictor;
import org.apache.spark.ml.param.EstimatorParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.regression.DecisionTreeRegressor;
import org.apache.spark.ml.regression.GBTRegressor;
import org.apache.spark.ml.regression.RandomForestRegressor;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.types.BooleanType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.LongType;
import org.apache.spark.sql.types.ShortType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructType;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Tuple2;
import scala.collection.Seq$;
import scala.collection.immutable.Map;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ScalaSignature;

/* compiled from: TrainRegressor.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\u001db\u0001B\u0001\u0003\u00015\u0011a\u0002\u0016:bS:\u0014Vm\u001a:fgN|'O\u0003\u0002\u0004\t\u0005)AO]1j]*\u0011QAB\u0001\u0006gB\f'o\u001b\u0006\u0003\u000f!\t!!\u001c7\u000b\u0005%Q\u0011!C7jGJ|7o\u001c4u\u0015\u0005Y\u0011aA2p[\u000e\u00011c\u0001\u0001\u000f9A\u0019qB\u0006\r\u000e\u0003AQ!aB\t\u000b\u0005\u0015\u0011\"BA\n\u0015\u0003\u0019\t\u0007/Y2iK*\tQ#A\u0002pe\u001eL!a\u0006\t\u0003\u0013\u0015\u001bH/[7bi>\u0014\bCA\r\u001b\u001b\u0005\u0011\u0011BA\u000e\u0003\u0005U!&/Y5oK\u0012\u0014Vm\u001a:fgN|'/T8eK2\u00042!G\u000f\u0019\u0013\tq\"AA\u0006BkR|GK]1j]\u0016\u0014\b\u0002\u0003\u0011\u0001\u0005\u000b\u0007I\u0011I\u0011\u0002\u0007ULG-F\u0001#!\t\u0019\u0013F\u0004\u0002%O5\tQEC\u0001'\u0003\u0015\u00198-\u00197b\u0013\tAS%\u0001\u0004Qe\u0016$WMZ\u0005\u0003U-\u0012aa\u0015;sS:<'B\u0001\u0015&\u0011!i\u0003A!A!\u0002\u0013\u0011\u0013\u0001B;jI\u0002BQa\f\u0001\u0005\u0002A\na\u0001P5oSRtDCA\u00193!\tI\u0002\u0001C\u0003!]\u0001\u0007!\u0005C\u00030\u0001\u0011\u0005A\u0007F\u00012\u0011\u00151\u0004\u0001\"\u0011\"\u0003!iw\u000eZ3m\t>\u001c\u0007\"\u0002\u001d\u0001\t\u0003J\u0014a\u00014jiR\u0011\u0001D\u000f\u0005\u0006w]\u0002\r\u0001P\u0001\bI\u0006$\u0018m]3ua\tiT\tE\u0002?\u0003\u000ek\u0011a\u0010\u0006\u0003\u0001F\t1a]9m\u0013\t\u0011uHA\u0004ECR\f7/\u001a;\u0011\u0005\u0011+E\u0002\u0001\u0003\n\rj\n\t\u0011!A\u0003\u0002\u001d\u00131a\u0018\u00132#\tA5\n\u0005\u0002%\u0013&\u0011!*\n\u0002\b\u001d>$\b.\u001b8h!\t!C*\u0003\u0002NK\t\u0019\u0011I\\=\t\u000b=\u0003A\u0011\t)\u0002\t\r|\u0007/\u001f\u000b\u0003\u001dECQA\u0015(A\u0002M\u000bQ!\u001a=ue\u0006\u0004\"\u0001V,\u000e\u0003US!A\u0016\t\u0002\u000bA\f'/Y7\n\u0005a+&\u0001\u0003)be\u0006lW*\u00199\t\u000bi\u0003A\u0011I.\u0002\u001fQ\u0014\u0018M\\:g_Jl7k\u00195f[\u0006$\"\u0001\u00182\u0011\u0005u\u0003W\"\u00010\u000b\u0005}{\u0014!\u0002;za\u0016\u001c\u0018BA1_\u0005)\u0019FO];diRK\b/\u001a\u0005\u0006Gf\u0003\r\u0001X\u0001\u0007g\u000eDW-\\1)\u0005e+\u0007C\u00014j\u001b\u00059'B\u00015\u0012\u0003)\tgN\\8uCRLwN\\\u0005\u0003U\u001e\u0014A\u0002R3wK2|\u0007/\u001a:Ba&D#\u0001\u00017\u0011\u00055\u0014X\"\u00018\u000b\u0005=\u0004\u0018aA3om*\u0011\u0011\u000fB\u0001\u0005G>\u0014X-\u0003\u0002t]\ny\u0011J\u001c;fe:\fGn\u0016:baB,'oB\u0003v\u0005!\u0005a/\u0001\bUe\u0006LgNU3he\u0016\u001c8o\u001c:\u0011\u0005e9h!B\u0001\u0003\u0011\u0003A8\u0003B<zy~\u0004\"\u0001\n>\n\u0005m,#AB!osJ+g\rE\u0002\u0010{FJ!A \t\u0003+\r{W\u000e\u001d7fqB\u000b'/Y7t%\u0016\fG-\u00192mKB\u0019A%!\u0001\n\u0007\u0005\rQE\u0001\u0007TKJL\u0017\r\\5{C\ndW\r\u0003\u00040o\u0012\u0005\u0011q\u0001\u000b\u0002m\"9\u00111B<\u0005\u0002\u00055\u0011a\u0006<bY&$\u0017\r^3Ue\u0006t7OZ8s[N\u001b\u0007.Z7b)\ra\u0016q\u0002\u0005\u0007G\u0006%\u0001\u0019\u0001/\t\u0013\u0005Mq/!A\u0005\n\u0005U\u0011a\u0003:fC\u0012\u0014Vm]8mm\u0016$\"!a\u0006\u0011\t\u0005e\u00111E\u0007\u0003\u00037QA!!\b\u0002 \u0005!A.\u00198h\u0015\t\t\t#\u0001\u0003kCZ\f\u0017\u0002BA\u0013\u00037\u0011aa\u00142kK\u000e$\b")
/* loaded from: input_file:com/microsoft/ml/spark/train/TrainRegressor.class */
public class TrainRegressor extends Estimator<TrainedRegressorModel> implements AutoTrainer<TrainedRegressorModel> {
    private final String uid;
    private final IntParam numFeatures;
    private final EstimatorParam model;
    private final Param<String> featuresCol;
    private final Param<String> labelCol;

    public static Object load(String str) {
        return TrainRegressor$.MODULE$.load(str);
    }

    public static MLReader<TrainRegressor> read() {
        return TrainRegressor$.MODULE$.read();
    }

    public static StructType validateTransformSchema(StructType structType) {
        return TrainRegressor$.MODULE$.validateTransformSchema(structType);
    }

    @Override // com.microsoft.ml.spark.train.AutoTrainer
    public IntParam numFeatures() {
        return this.numFeatures;
    }

    @Override // com.microsoft.ml.spark.train.AutoTrainer
    public EstimatorParam model() {
        return this.model;
    }

    @Override // com.microsoft.ml.spark.train.AutoTrainer
    public void com$microsoft$ml$spark$train$AutoTrainer$_setter_$numFeatures_$eq(IntParam intParam) {
        this.numFeatures = intParam;
    }

    @Override // com.microsoft.ml.spark.train.AutoTrainer
    public void com$microsoft$ml$spark$train$AutoTrainer$_setter_$model_$eq(EstimatorParam estimatorParam) {
        this.model = estimatorParam;
    }

    @Override // com.microsoft.ml.spark.train.AutoTrainer
    public int getNumFeatures() {
        return AutoTrainer.Cclass.getNumFeatures(this);
    }

    @Override // com.microsoft.ml.spark.train.AutoTrainer
    public AutoTrainer<TrainedRegressorModel> setNumFeatures(int i) {
        return AutoTrainer.Cclass.setNumFeatures(this, i);
    }

    @Override // com.microsoft.ml.spark.train.AutoTrainer
    public Estimator<? extends Model<?>> getModel() {
        return AutoTrainer.Cclass.getModel(this);
    }

    @Override // com.microsoft.ml.spark.train.AutoTrainer
    public AutoTrainer<TrainedRegressorModel> setModel(Estimator<? extends Model<?>> estimator) {
        return AutoTrainer.Cclass.setModel(this, estimator);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasAdditionalPythonMethods
    public String additionalPythonMethods() {
        return HasAdditionalPythonMethods.Cclass.additionalPythonMethods(this);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasFeaturesCol
    public Param<String> featuresCol() {
        return this.featuresCol;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasFeaturesCol
    public void com$microsoft$ml$spark$core$contracts$HasFeaturesCol$_setter_$featuresCol_$eq(Param param) {
        this.featuresCol = param;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasFeaturesCol
    public HasFeaturesCol setFeaturesCol(String str) {
        return HasFeaturesCol.Cclass.setFeaturesCol(this, str);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasFeaturesCol
    public String getFeaturesCol() {
        return HasFeaturesCol.Cclass.getFeaturesCol(this);
    }

    @Override // org.apache.spark.ml.ComplexParamsWritable
    public MLWriter write() {
        return ComplexParamsWritable.Cclass.write(this);
    }

    public void save(String str) throws IOException {
        MLWritable.class.save(this, str);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasLabelCol
    public Param<String> labelCol() {
        return this.labelCol;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasLabelCol
    public void com$microsoft$ml$spark$core$contracts$HasLabelCol$_setter_$labelCol_$eq(Param param) {
        this.labelCol = param;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasLabelCol
    public HasLabelCol setLabelCol(String str) {
        return HasLabelCol.Cclass.setLabelCol(this, str);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasLabelCol
    public String getLabelCol() {
        return HasLabelCol.Cclass.getLabelCol(this);
    }

    public String uid() {
        return this.uid;
    }

    @Override // com.microsoft.ml.spark.train.AutoTrainer
    public String modelDoc() {
        return "Regressor to run";
    }

    public TrainedRegressorModel fit(Dataset<?> dataset) {
        int NumFeaturesDefault;
        Predictor predictor;
        Column apply;
        String labelCol = getLabelCol();
        boolean z = true;
        Estimator<? extends Model<?>> model = getModel();
        if (model instanceof DecisionTreeRegressor ? true : model instanceof GBTRegressor ? true : model instanceof RandomForestRegressor) {
            z = false;
            NumFeaturesDefault = FeaturizeUtilities$.MODULE$.NumFeaturesTreeOrNNBased();
        } else {
            NumFeaturesDefault = FeaturizeUtilities$.MODULE$.NumFeaturesDefault();
        }
        int i = NumFeaturesDefault;
        Predictor model2 = getModel();
        if (model2 instanceof Predictor) {
            predictor = model2.setLabelCol(getLabelCol()).setFeaturesCol(getFeaturesCol());
        } else {
            if (!(model2 instanceof Estimator)) {
                throw new Exception(new StringBuilder().append("Unsupported learner type ").append(getModel().getClass().toString()).toString());
            }
            predictor = model2;
        }
        Predictor predictor2 = predictor;
        int numFeatures = getNumFeatures() != 0 ? getNumFeatures() : i;
        DataType dataType = dataset.schema().apply(labelCol).dataType();
        if (dataType instanceof IntegerType ? true : dataType instanceof BooleanType ? true : dataType instanceof FloatType ? true : dataType instanceof ByteType ? true : dataType instanceof LongType ? true : dataType instanceof ShortType) {
            apply = dataset.apply(labelCol).cast(DoubleType$.MODULE$);
        } else {
            if (dataType instanceof StringType) {
                throw new Exception(new StringBuilder().append("Invalid type: Regressors are not able to train on a string label column: ").append(labelCol).toString());
            }
            if (!(dataType instanceof DoubleType)) {
                throw new Exception(new StringBuilder().append("Unknown type: ").append(dataType.typeName()).append(", for label column: ").append(labelCol).toString());
            }
            apply = dataset.apply(labelCol);
        }
        Dataset<?> drop = dataset.withColumn(labelCol, apply).na().drop(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new String[]{labelCol})));
        PipelineStage fit = new Featurize().setFeatureColumns((Map) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(getFeaturesCol()), Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps(drop.columns()).filter(new TrainRegressor$$anonfun$1(this, labelCol))).toSeq())}))).setOneHotEncodeCategoricals(z).setNumberOfFeatures(numFeatures).fit(drop);
        Dataset transform = fit.transform(drop);
        transform.cache();
        PipelineStage fit2 = predictor2.fit(transform);
        transform.unpersist();
        return new TrainedRegressorModel(uid(), labelCol, new Pipeline().setStages(new PipelineStage[]{fit, fit2}).fit(drop), getFeaturesCol());
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public Estimator<TrainedRegressorModel> m910copy(ParamMap paramMap) {
        setModel(getModel().copy(paramMap));
        return defaultCopy(paramMap);
    }

    @DeveloperApi
    public StructType transformSchema(StructType structType) {
        return TrainRegressor$.MODULE$.validateTransformSchema(structType);
    }

    /* renamed from: fit, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m911fit(Dataset dataset) {
        return fit((Dataset<?>) dataset);
    }

    public TrainRegressor(String str) {
        this.uid = str;
        com$microsoft$ml$spark$core$contracts$HasLabelCol$_setter_$labelCol_$eq(new Param(this, "labelCol", "The name of the label column"));
        MLWritable.class.$init$(this);
        ComplexParamsWritable.Cclass.$init$(this);
        com$microsoft$ml$spark$core$contracts$HasFeaturesCol$_setter_$featuresCol_$eq(new Param(this, "featuresCol", "The name of the features column"));
        HasAdditionalPythonMethods.Cclass.$init$(this);
        AutoTrainer.Cclass.$init$(this);
        setDefault(featuresCol(), new StringBuilder().append(str).append("_features").toString());
    }

    public TrainRegressor() {
        this(Identifiable$.MODULE$.randomUID("TrainRegressor"));
    }
}
