# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

#' Spark ML -- CNTKLearner
#'
#' 
#' @param brainScript String of BrainScript config
#' @param dataFormat Transfer format
#' @param dataTransfer Transfer strategy
#' @param featureCount Number of features for reduction
#' @param featuresColumnName Features column name
#' @param gpuMachines GPU machines to train on
#' @param labelsColumnName Label column name
#' @param localHdfsMount Local mount point for hdfs:///
#' @param parallelTrain Train using an MPI ring
#' @param username Username for the GPU VM
#' @param weightPrecision Weights
#' @param workingDir Working directory for CNTK
#' @export
ml_cntklearner <- function(x, brainScript=NULL, dataFormat="text", dataTransfer="local", featureCount=1, featuresColumnName="features", gpuMachines="[Ljava.lang.String;@677529a9", labelsColumnName="labels", localHdfsMount=NULL, parallelTrain=TRUE, username="sshuser", weightPrecision="float", workingDir="tmp", unfit.model=FALSE, only.model=FALSE)
{
  if (unfit.model) {
    sc <- x
  } else {
    df <- spark_dataframe(x)
    sc <- spark_connection(df)
  }
  env <- new.env(parent = emptyenv())

  env$model <- "com.microsoft.ml.spark.CNTKLearner"
  mod <- invoke_new(sc, env$model)

  mod_parameterized <- mod %>%
    invoke("setBrainScript", brainScript) %>%
    invoke("setDataFormat", dataFormat) %>%
    invoke("setDataTransfer", dataTransfer) %>%
    invoke("setFeatureCount", as.integer(featureCount)) %>%
    invoke("setFeaturesColumnName", featuresColumnName) %>%
    invoke("setGpuMachines", as.array(gpuMachines)) %>%
    invoke("setLabelsColumnName", labelsColumnName) %>%
    invoke("setLocalHdfsMount", localHdfsMount) %>%
    invoke("setParallelTrain", as.logical(parallelTrain)) %>%
    invoke("setUsername", username) %>%
    invoke("setWeightPrecision", weightPrecision) %>%
    invoke("setWorkingDir", workingDir)
  if (unfit.model)
    return(mod_parameterized)
  mod_model_raw <- mod_parameterized %>%
    invoke("fit", df)

  mod_model <- sparklyr:::new_ml_model(mod_parameterized, mod_model_raw, mod_model_raw)

  if (only.model)
    return(mod_model)

  transformed <- invoke(mod_model$model, "transform", df)

  sdf_register(transformed)
}
