# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

#' Spark ML -- LightGBMClassifier
#'
#'     Trains a LightGBM Binary Classification model, a fast, distributed, high performance gradient boosting
#'         framework based on decision tree algorithms.
#'         For more information please see here: https://github.com/Microsoft/LightGBM.
#' @param baggingFraction Bagging fraction
#' @param baggingFreq Bagging frequency
#' @param baggingSeed Bagging seed
#' @param boostFromAverage Adjusts initial score to the mean of labels for faster convergence
#' @param boostingType Default gbdt = traditional Gradient Boosting Decision Tree. Options are: gbdt, gbrt, rf (Random Forest), random_forest, dart (Dropouts meet Multiple Additive Regression Trees), goss (Gradient-based One-Side Sampling). 
#' @param categoricalSlotIndexes List of categorical column indexes, the slot index in the features column
#' @param categoricalSlotNames List of categorical column slot names, the slot name in the features column
#' @param defaultListenPort The default listen port on executors, used for testing
#' @param earlyStoppingRound Early stopping round
#' @param featureFraction Feature fraction
#' @param featuresCol features column name
#' @param isUnbalance Set to true if training data is unbalanced in binary classification scenario
#' @param labelCol label column name
#' @param learningRate Learning rate or shrinkage rate
#' @param maxBin Max bin
#' @param maxDepth Max depth
#' @param minSumHessianInLeaf Minimal sum hessian in one leaf
#' @param modelString LightGBM model to retrain
#' @param numIterations Number of iterations, LightGBM constructs num_class * num_iterations trees
#' @param numLeaves Number of leaves
#' @param objective The Objective. For regression applications, this can be: regression_l2, regression_l1, huber, fair, poisson, quantile, mape, gamma or tweedie. For classification applications, this can be: binary, multiclass, or multiclassova. 
#' @param parallelism Tree learner parallelism, can be set to data_parallel or voting_parallel
#' @param predictionCol prediction column name
#' @param probabilityCol Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities
#' @param rawPredictionCol raw prediction (a.k.a. confidence) column name
#' @param thresholds Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold
#' @param timeout Timeout in seconds
#' @param validationIndicatorCol Indicates whether the row is for training or validation
#' @param verbosity Verbosity where lt 0 is Fatal, eq 0 is Error, eq 1 is Info, gt 1 is Debug
#' @param weightCol The name of the weight column
#' @export
ml_light_gbmclassifier <- function(x, baggingFraction=1.0, baggingFreq=0, baggingSeed=3, boostFromAverage=TRUE, boostingType="gbdt", categoricalSlotIndexes=NULL, categoricalSlotNames=NULL, defaultListenPort=12400, earlyStoppingRound=0, featureFraction=1.0, featuresCol="features", isUnbalance=FALSE, labelCol="label", learningRate=0.1, maxBin=255, maxDepth=-1, minSumHessianInLeaf=0.001, modelString="", numIterations=100, numLeaves=31, objective="binary", parallelism="data_parallel", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", thresholds=NULL, timeout=1200.0, validationIndicatorCol=NULL, verbosity=1, weightCol=NULL, unfit.model=FALSE, only.model=FALSE)
{
  if (unfit.model) {
    sc <- x
  } else {
    df <- spark_dataframe(x)
    sc <- spark_connection(df)
  }
  env <- new.env(parent = emptyenv())

  env$model <- "com.microsoft.ml.spark.LightGBMClassifier"
  mod <- invoke_new(sc, env$model)

  mod_parameterized <- mod %>%
    invoke("setBaggingFraction", as.double(baggingFraction)) %>%
    invoke("setBaggingFreq", as.integer(baggingFreq)) %>%
    invoke("setBaggingSeed", as.integer(baggingSeed)) %>%
    invoke("setBoostFromAverage", as.logical(boostFromAverage)) %>%
    invoke("setBoostingType", boostingType) %>%
    invoke("setCategoricalSlotIndexes", categoricalSlotIndexes) %>%
    invoke("setCategoricalSlotNames", as.array(categoricalSlotNames)) %>%
    invoke("setDefaultListenPort", as.integer(defaultListenPort)) %>%
    invoke("setEarlyStoppingRound", as.integer(earlyStoppingRound)) %>%
    invoke("setFeatureFraction", as.double(featureFraction)) %>%
    invoke("setFeaturesCol", featuresCol) %>%
    invoke("setIsUnbalance", as.logical(isUnbalance)) %>%
    invoke("setLabelCol", labelCol) %>%
    invoke("setLearningRate", as.double(learningRate)) %>%
    invoke("setMaxBin", as.integer(maxBin)) %>%
    invoke("setMaxDepth", as.integer(maxDepth)) %>%
    invoke("setMinSumHessianInLeaf", as.double(minSumHessianInLeaf)) %>%
    invoke("setModelString", modelString) %>%
    invoke("setNumIterations", as.integer(numIterations)) %>%
    invoke("setNumLeaves", as.integer(numLeaves)) %>%
    invoke("setObjective", objective) %>%
    invoke("setParallelism", parallelism) %>%
    invoke("setPredictionCol", predictionCol) %>%
    invoke("setProbabilityCol", probabilityCol) %>%
    invoke("setRawPredictionCol", rawPredictionCol) %>%
    invoke("setThresholds", thresholds) %>%
    invoke("setTimeout", as.double(timeout)) %>%
    invoke("setValidationIndicatorCol", validationIndicatorCol) %>%
    invoke("setVerbosity", as.integer(verbosity)) %>%
    invoke("setWeightCol", weightCol)
  if (unfit.model)
    return(mod_parameterized)
  mod_model_raw <- mod_parameterized %>%
    invoke("fit", df)

  mod_model <- sparklyr:::new_ml_model(mod_parameterized, mod_model_raw, mod_model_raw)

  if (only.model)
    return(mod_model)

  transformed <- invoke(mod_model$model, "transform", df)

  sdf_register(transformed)
}
