package com.microsoft.ml.spark.image;

import com.microsoft.CNTK.CNTKExtensions$;
import com.microsoft.CNTK.SerializableFunction;
import com.microsoft.CNTK.Variable;
import com.microsoft.ml.spark.cntk.CNTKModel;
import com.microsoft.ml.spark.core.contracts.HasAdditionalPythonMethods;
import com.microsoft.ml.spark.core.contracts.HasInputCol;
import com.microsoft.ml.spark.core.contracts.HasOutputCol;
import com.microsoft.ml.spark.core.contracts.Wrappable;
import com.microsoft.ml.spark.core.schema.DatasetExtensions$;
import com.microsoft.ml.spark.core.schema.ImageSchemaUtils$;
import com.microsoft.ml.spark.downloader.ModelSchema;
import java.io.IOException;
import org.apache.spark.ml.ComplexParamsWritable;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.linalg.SQLDataTypes$;
import org.apache.spark.ml.param.BooleanParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.param.ParamValidators$;
import org.apache.spark.ml.param.StringArrayParam;
import org.apache.spark.ml.param.TransformerParam;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.BinaryType$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructType;
import scala.Predef$;
import scala.StringContext;
import scala.collection.Set;
import scala.collection.immutable.List$;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: ImageFeaturizer.scala */
@ScalaSignature(bytes = "\u0006\u0001\t]r!B\u0001\u0003\u0011\u0003i\u0011aD%nC\u001e,g)Z1ukJL'0\u001a:\u000b\u0005\r!\u0011!B5nC\u001e,'BA\u0003\u0007\u0003\u0015\u0019\b/\u0019:l\u0015\t9\u0001\"\u0001\u0002nY*\u0011\u0011BC\u0001\n[&\u001c'o\\:pMRT\u0011aC\u0001\u0004G>l7\u0001\u0001\t\u0003\u001d=i\u0011A\u0001\u0004\u0006!\tA\t!\u0005\u0002\u0010\u00136\fw-\u001a$fCR,(/\u001b>feN)qB\u0005\r\u0003\u0018A\u00111CF\u0007\u0002))\tQ#A\u0003tG\u0006d\u0017-\u0003\u0002\u0018)\t1\u0011I\\=SK\u001a\u00042!\u0007\u0011#\u001b\u0005Q\"BA\u0004\u001c\u0015\t)AD\u0003\u0002\u001e=\u00051\u0011\r]1dQ\u0016T\u0011aH\u0001\u0004_J<\u0017BA\u0011\u001b\u0005U\u0019u.\u001c9mKb\u0004\u0016M]1ngJ+\u0017\rZ1cY\u0016\u0004\"AD\u0012\u0007\tA\u0011\u0001\u0001J\n\u0007G\u0015B\u0003g\r\u001c\u0011\u0005e1\u0013BA\u0014\u001b\u0005-!&/\u00198tM>\u0014X.\u001a:\u0011\u0005%rS\"\u0001\u0016\u000b\u0005-b\u0013!C2p]R\u0014\u0018m\u0019;t\u0015\tiC!\u0001\u0003d_J,\u0017BA\u0018+\u0005-A\u0015m]%oaV$8i\u001c7\u0011\u0005%\n\u0014B\u0001\u001a+\u00051A\u0015m](viB,HoQ8m!\tIC'\u0003\u00026U\tIqK]1qa\u0006\u0014G.\u001a\t\u00033]J!\u0001\u000f\u000e\u0003+\r{W\u000e\u001d7fqB\u000b'/Y7t/JLG/\u00192mK\"A!h\tBC\u0002\u0013\u00051(A\u0002vS\u0012,\u0012\u0001\u0010\t\u0003{\u0001s!a\u0005 \n\u0005}\"\u0012A\u0002)sK\u0012,g-\u0003\u0002B\u0005\n11\u000b\u001e:j]\u001eT!a\u0010\u000b\t\u0011\u0011\u001b#\u0011!Q\u0001\nq\nA!^5eA!)ai\tC\u0001\u000f\u00061A(\u001b8jiz\"\"A\t%\t\u000bi*\u0005\u0019\u0001\u001f\t\u000b\u0019\u001bC\u0011\u0001&\u0015\u0003\tBq\u0001T\u0012C\u0002\u0013\u0005Q*A\u0005d]R\\Wj\u001c3fYV\ta\n\u0005\u0002P%6\t\u0001K\u0003\u0002R5\u0005)\u0001/\u0019:b[&\u00111\u000b\u0015\u0002\u0011)J\fgn\u001d4pe6,'\u000fU1sC6Da!V\u0012!\u0002\u0013q\u0015AC2oi.lu\u000eZ3mA!)qk\tC\u00011\u0006a1/\u001a;D]R\\Wj\u001c3fYR\u0011\u0011LW\u0007\u0002G!)1L\u0016a\u00019\u0006)a/\u00197vKB\u0011Q\fY\u0007\u0002=*\u0011q\fB\u0001\u0005G:$8.\u0003\u0002b=\nI1I\u0014+L\u001b>$W\r\u001c\u0005\bG\u000e\u0012\r\u0011\"\u0001e\u00039)W\u000e\u001d;z\u0007:$8.T8eK2,\u0012\u0001\u0018\u0005\u0007M\u000e\u0002\u000b\u0011\u0002/\u0002\u001f\u0015l\u0007\u000f^=D]R\\Wj\u001c3fY\u0002BQ\u0001[\u0012\u0005\u0002\u0011\fAbZ3u\u0007:$8.T8eK2DQA[\u0012\u0005\u0002-\f\u0001c]3u\u001b&t\u0017NQ1uG\"\u001c\u0016N_3\u0015\u0005ec\u0007\"B.j\u0001\u0004i\u0007CA\no\u0013\tyGCA\u0002J]RDQ!]\u0012\u0005\u0002I\f\u0001cZ3u\u001b&t\u0017NQ1uG\"\u001c\u0016N_3\u0016\u00035DQ\u0001^\u0012\u0005\u0002U\fAb]3u\u0013:\u0004X\u000f\u001e(pI\u0016$\"!\u0017<\t\u000bm\u001b\b\u0019A7\t\u000ba\u001cC\u0011\u0001:\u0002\u0019\u001d,G/\u00138qkRtu\u000eZ3\t\u000bi\u001cC\u0011A>\u0002!M,G/T8eK2dunY1uS>tGCA-}\u0011\u0015i\u0018\u00101\u0001=\u0003\u0011\u0001\u0018\r\u001e5\t\r}\u001cC\u0011AA\u0001\u0003!\u0019X\r^'pI\u0016dGcA-\u0002\u0004!9\u0011Q\u0001@A\u0002\u0005\u001d\u0011aC7pI\u0016d7k\u00195f[\u0006\u0004B!!\u0003\u0002\u00105\u0011\u00111\u0002\u0006\u0004\u0003\u001b!\u0011A\u00033po:dw.\u00193fe&!\u0011\u0011CA\u0006\u0005-iu\u000eZ3m'\u000eDW-\\1\t\r}\u001cC\u0011AA\u000b)\rI\u0016q\u0003\u0005\t\u00033\t\u0019\u00021\u0001\u0002\u001c\u0005)Qn\u001c3fYB!\u0011QDA\u0012\u001b\t\tyBC\u0002\u0002\"!\tAa\u0011(U\u0017&!\u0011QEA\u0010\u0005Q\u0019VM]5bY&T\u0018M\u00197f\rVt7\r^5p]\"9\u0011\u0011F\u0012\u0005\u0002\u0005-\u0012\u0001C4fi6{G-\u001a7\u0016\u0005\u0005m\u0001\"CA\u0018G\t\u0007I\u0011AA\u0019\u0003=\u0019W\u000f^(viB,H\u000fT1zKJ\u001cXCAA\u001a!\ry\u0015QG\u0005\u0004\u0003o\u0001&\u0001C%oiB\u000b'/Y7\t\u0011\u0005m2\u0005)A\u0005\u0003g\t\u0001cY;u\u001fV$\b/\u001e;MCf,'o\u001d\u0011\t\u000f\u0005}2\u0005\"\u0001\u0002B\u0005\u00112/\u001a;DkR|U\u000f\u001e9vi2\u000b\u00170\u001a:t)\rI\u00161\t\u0005\u00077\u0006u\u0002\u0019A7\t\r\u0005\u001d3\u0005\"\u0001s\u0003I9W\r^\"vi>+H\u000f];u\u0019\u0006LXM]:\t\u0013\u0005-3E1A\u0005\u0002\u00055\u0013A\u00023s_Bt\u0015-\u0006\u0002\u0002PA\u0019q*!\u0015\n\u0007\u0005M\u0003K\u0001\u0007C_>dW-\u00198QCJ\fW\u000e\u0003\u0005\u0002X\r\u0002\u000b\u0011BA(\u0003\u001d!'o\u001c9OC\u0002Bq!a\u0017$\t\u0003\ti&A\u0005tKR$%o\u001c9OCR\u0019\u0011,a\u0018\t\u000fm\u000bI\u00061\u0001\u0002bA\u00191#a\u0019\n\u0007\u0005\u0015DCA\u0004C_>dW-\u00198\t\u000f\u0005%4\u0005\"\u0001\u0002l\u0005Iq-\u001a;Ee>\u0004h*Y\u000b\u0003\u0003CB\u0011\"a\u001c$\u0005\u0004%\t!!\u001d\u0002\u00151\f\u00170\u001a:OC6,7/\u0006\u0002\u0002tA\u0019q*!\u001e\n\u0007\u0005]\u0004K\u0001\tTiJLgnZ!se\u0006L\b+\u0019:b[\"A\u00111P\u0012!\u0002\u0013\t\u0019(A\u0006mCf,'OT1nKN\u0004\u0003bBA@G\u0011\u0005\u0011\u0011Q\u0001\u000eg\u0016$H*Y=fe:\u000bW.Z:\u0015\u0007e\u000b\u0019\tC\u0004\\\u0003{\u0002\r!!\"\u0011\tM\t9\tP\u0005\u0004\u0003\u0013#\"!B!se\u0006L\bbBAGG\u0011\u0005\u0011qR\u0001\u000eO\u0016$H*Y=fe:\u000bW.Z:\u0016\u0005\u0005\u0015\u0005bBAJG\u0011\u0005\u0013QS\u0001\niJ\fgn\u001d4pe6$B!a&\u0002@B!\u0011\u0011TA]\u001d\u0011\tY*a-\u000f\t\u0005u\u0015q\u0016\b\u0005\u0003?\u000biK\u0004\u0003\u0002\"\u0006-f\u0002BAR\u0003Sk!!!*\u000b\u0007\u0005\u001dF\"\u0001\u0004=e>|GOP\u0005\u0002?%\u0011QDH\u0005\u0003\u000bqI1!!-\u001c\u0003\r\u0019\u0018\u000f\\\u0005\u0005\u0003k\u000b9,A\u0004qC\u000e\\\u0017mZ3\u000b\u0007\u0005E6$\u0003\u0003\u0002<\u0006u&!\u0003#bi\u00064%/Y7f\u0015\u0011\t),a.\t\u0011\u0005\u0005\u0017\u0011\u0013a\u0001\u0003\u0007\fq\u0001Z1uCN,G\u000f\r\u0003\u0002F\u0006E\u0007CBAd\u0003\u0013\fi-\u0004\u0002\u00028&!\u00111ZA\\\u0005\u001d!\u0015\r^1tKR\u0004B!a4\u0002R2\u0001A\u0001DAj\u0003\u007f\u000b\t\u0011!A\u0003\u0002\u0005U'aA0%cE!\u0011q[Ao!\r\u0019\u0012\u0011\\\u0005\u0004\u00037$\"a\u0002(pi\"Lgn\u001a\t\u0004'\u0005}\u0017bAAq)\t\u0019\u0011I\\=\t\u000f\u0005\u00158\u0005\"\u0011\u0002h\u0006!1m\u001c9z)\r)\u0013\u0011\u001e\u0005\t\u0003W\f\u0019\u000f1\u0001\u0002n\u0006)Q\r\u001f;sCB\u0019q*a<\n\u0007\u0005E\bK\u0001\u0005QCJ\fW.T1q\u0011\u001d\t)p\tC!\u0003o\fq\u0002\u001e:b]N4wN]7TG\",W.\u0019\u000b\u0005\u0003s\u0014)\u0001\u0005\u0003\u0002|\n\u0005QBAA\u007f\u0015\u0011\ty0a.\u0002\u000bQL\b/Z:\n\t\t\r\u0011Q \u0002\u000b'R\u0014Xo\u0019;UsB,\u0007\u0002\u0003B\u0004\u0003g\u0004\r!!?\u0002\rM\u001c\u0007.Z7bQ\r\u0019#1\u0002\t\u0005\u0005\u001b\u0011\u0019\"\u0004\u0002\u0003\u0010)\u0019!\u0011\u0003\u0017\u0002\u0007\u0015tg/\u0003\u0003\u0003\u0016\t=!aD%oi\u0016\u0014h.\u00197Xe\u0006\u0004\b/\u001a:\u0011\u0007M\u0011I\"C\u0002\u0003\u001cQ\u0011AbU3sS\u0006d\u0017N_1cY\u0016DaAR\b\u0005\u0002\t}A#A\u0007\t\u0013\t\rr\"!A\u0005\n\t\u0015\u0012a\u0003:fC\u0012\u0014Vm]8mm\u0016$\"Aa\n\u0011\t\t%\"1G\u0007\u0003\u0005WQAA!\f\u00030\u0005!A.\u00198h\u0015\t\u0011\t$\u0001\u0003kCZ\f\u0017\u0002\u0002B\u001b\u0005W\u0011aa\u00142kK\u000e$\b")
/* loaded from: input_file:com/microsoft/ml/spark/image/ImageFeaturizer.class */
public class ImageFeaturizer extends Transformer implements HasInputCol, HasOutputCol, Wrappable, ComplexParamsWritable {
    private final String uid;
    private final TransformerParam cntkModel;
    private final CNTKModel emptyCntkModel;
    private final IntParam cutOutputLayers;
    private final BooleanParam dropNa;
    private final StringArrayParam layerNames;
    private final Param<String> outputCol;
    private final Param<String> inputCol;

    public static Object load(String str) {
        return ImageFeaturizer$.MODULE$.load(str);
    }

    public static MLReader<ImageFeaturizer> read() {
        return ImageFeaturizer$.MODULE$.read();
    }

    @Override // org.apache.spark.ml.ComplexParamsWritable
    public MLWriter write() {
        return ComplexParamsWritable.Cclass.write(this);
    }

    public void save(String str) throws IOException {
        MLWritable.class.save(this, str);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasAdditionalPythonMethods
    public String additionalPythonMethods() {
        return HasAdditionalPythonMethods.Cclass.additionalPythonMethods(this);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasOutputCol
    public Param<String> outputCol() {
        return this.outputCol;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasOutputCol
    public void com$microsoft$ml$spark$core$contracts$HasOutputCol$_setter_$outputCol_$eq(Param param) {
        this.outputCol = param;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasOutputCol
    public HasOutputCol setOutputCol(String str) {
        return HasOutputCol.Cclass.setOutputCol(this, str);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasOutputCol
    public String getOutputCol() {
        return HasOutputCol.Cclass.getOutputCol(this);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasInputCol
    public Param<String> inputCol() {
        return this.inputCol;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasInputCol
    public void com$microsoft$ml$spark$core$contracts$HasInputCol$_setter_$inputCol_$eq(Param param) {
        this.inputCol = param;
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasInputCol
    public HasInputCol setInputCol(String str) {
        return HasInputCol.Cclass.setInputCol(this, str);
    }

    @Override // com.microsoft.ml.spark.core.contracts.HasInputCol
    public String getInputCol() {
        return HasInputCol.Cclass.getInputCol(this);
    }

    public String uid() {
        return this.uid;
    }

    public TransformerParam cntkModel() {
        return this.cntkModel;
    }

    public ImageFeaturizer setCntkModel(CNTKModel cNTKModel) {
        return (ImageFeaturizer) set(cntkModel(), cNTKModel);
    }

    public CNTKModel emptyCntkModel() {
        return this.emptyCntkModel;
    }

    public CNTKModel getCntkModel() {
        return isDefined(cntkModel()) ? (CNTKModel) $(cntkModel()) : emptyCntkModel();
    }

    public ImageFeaturizer setMiniBatchSize(int i) {
        return (ImageFeaturizer) set(cntkModel(), getCntkModel().setMiniBatchSize(i));
    }

    public int getMiniBatchSize() {
        return getCntkModel().getMiniBatchSize();
    }

    public ImageFeaturizer setInputNode(int i) {
        return (ImageFeaturizer) set(cntkModel(), getCntkModel().setInputNodeIndex(i));
    }

    public int getInputNode() {
        return getCntkModel().getInputNodeIndex();
    }

    public ImageFeaturizer setModelLocation(String str) {
        return (ImageFeaturizer) set(cntkModel(), getCntkModel().setModelLocation(str));
    }

    public ImageFeaturizer setModel(ModelSchema modelSchema) {
        return setLayerNames(modelSchema.layerNames()).setInputNode(modelSchema.inputNode()).setModelLocation(modelSchema.uri().toString());
    }

    public ImageFeaturizer setModel(SerializableFunction serializableFunction) {
        return (ImageFeaturizer) set(cntkModel(), getCntkModel().setModel(serializableFunction));
    }

    public SerializableFunction getModel() {
        return getCntkModel().getModel();
    }

    public IntParam cutOutputLayers() {
        return this.cutOutputLayers;
    }

    public ImageFeaturizer setCutOutputLayers(int i) {
        return (ImageFeaturizer) set(cutOutputLayers(), BoxesRunTime.boxToInteger(i));
    }

    public int getCutOutputLayers() {
        return BoxesRunTime.unboxToInt($(cutOutputLayers()));
    }

    public BooleanParam dropNa() {
        return this.dropNa;
    }

    public ImageFeaturizer setDropNa(boolean z) {
        return (ImageFeaturizer) set(dropNa(), BoxesRunTime.boxToBoolean(z));
    }

    public boolean getDropNa() {
        return BoxesRunTime.unboxToBoolean($(dropNa()));
    }

    public StringArrayParam layerNames() {
        return this.layerNames;
    }

    public ImageFeaturizer setLayerNames(String[] strArr) {
        return (ImageFeaturizer) set(layerNames(), strArr);
    }

    public String[] getLayerNames() {
        return (String[]) $(layerNames());
    }

    public Dataset<Row> transform(Dataset<?> dataset) {
        Dataset<Row> transform;
        String findUnusedColumnName = DatasetExtensions$.MODULE$.findUnusedColumnName("resized", (Set<String>) Predef$.MODULE$.refArrayOps(dataset.columns()).toSet());
        CNTKModel outputCol = getCntkModel().setOutputNode(getLayerNames()[getCutOutputLayers()]).setInputCol(findUnusedColumnName).setOutputCol(getOutputCol());
        long[] dimensions = ((Variable) CNTKExtensions$.MODULE$.fromSerializable(outputCol.getModel()).getArguments().get(0)).getShape().getDimensions();
        DataType dataType = dataset.schema().apply(getInputCol()).dataType();
        if (ImageSchemaUtils$.MODULE$.isImage(dataType)) {
            ResizeImageTransformer nChannels = ((ResizeImageTransformer) new ResizeImageTransformer().setInputCol(getInputCol())).setWidth((int) dimensions[0]).setHeight((int) dimensions[1]).setNChannels(3);
            transform = ((UnrollImage) ((HasOutputCol) new UnrollImage().setInputCol(nChannels.getOutputCol())).setOutputCol(findUnusedColumnName)).transform(nChannels.transform(dataset)).drop(nChannels.getOutputCol());
        } else {
            BinaryType$ binaryType$ = BinaryType$.MODULE$;
            if (dataType != null ? !dataType.equals(binaryType$) : binaryType$ != null) {
                throw new IllegalArgumentException(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Input schema : ", " needs to have image or binary type"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{dataType})));
            }
            transform = ((UnrollBinaryImage) ((UnrollBinaryImage) new UnrollBinaryImage().setInputCol(getInputCol())).setWidth((int) dimensions[0]).setHeight((int) dimensions[1]).setNChannels(3).setOutputCol(findUnusedColumnName)).transform(dataset);
        }
        Dataset<Row> dataset2 = transform;
        return outputCol.transform(getDropNa() ? dataset2.na().drop(List$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new String[]{findUnusedColumnName}))) : dataset2).drop(findUnusedColumnName);
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public Transformer m417copy(ParamMap paramMap) {
        return defaultCopy(paramMap);
    }

    public StructType transformSchema(StructType structType) {
        return structType.add(getOutputCol(), SQLDataTypes$.MODULE$.VectorType());
    }

    public ImageFeaturizer(String str) {
        this.uid = str;
        com$microsoft$ml$spark$core$contracts$HasInputCol$_setter_$inputCol_$eq(new Param(this, "inputCol", "The name of the input column"));
        com$microsoft$ml$spark$core$contracts$HasOutputCol$_setter_$outputCol_$eq(new Param(this, "outputCol", "The name of the output column"));
        HasAdditionalPythonMethods.Cclass.$init$(this);
        MLWritable.class.$init$(this);
        ComplexParamsWritable.Cclass.$init$(this);
        this.cntkModel = new TransformerParam(this, "cntkModel", "The internal CNTK model used in the featurizer", new ImageFeaturizer$$anonfun$1(this));
        this.emptyCntkModel = new CNTKModel();
        this.cutOutputLayers = new IntParam(this, "cutOutputLayers", "The number of layers to cut off the end of the network, 0 leaves the network intact, 1 removes the output layer, etc", ParamValidators$.MODULE$.gtEq(0.0d));
        this.dropNa = new BooleanParam(this, "dropNa", "Whether to drop na values before mapping");
        this.layerNames = new StringArrayParam(this, "layerNames", "Array with valid CNTK nodes to choose from, the first entries of this array should be closer to the output node");
        setDefault(Predef$.MODULE$.wrapRefArray(new ParamPair[]{cutOutputLayers().$minus$greater(BoxesRunTime.boxToInteger(1)), outputCol().$minus$greater(new StringBuilder().append(str).append("_output").toString()), dropNa().$minus$greater(BoxesRunTime.boxToBoolean(true))}));
    }

    public ImageFeaturizer() {
        this(Identifiable$.MODULE$.randomUID("ImageFeaturizer"));
    }
}
