Source code for LightGBMClassifier

# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

import sys
from pyspark import SQLContext
from pyspark import SparkContext

if sys.version >= '3':
    basestring = str

from mmlspark._LightGBMClassifier import _LightGBMClassifier
from mmlspark._LightGBMClassifier import _LightGBMClassificationModel
from pyspark.ml.common import inherit_doc

[docs]@inherit_doc class LightGBMClassifier(_LightGBMClassifier): def _create_model(self, java_model): model = LightGBMClassificationModel() model._java_obj = java_model model._transfer_params_from_java() return model
[docs]@inherit_doc class LightGBMClassificationModel(_LightGBMClassificationModel):
[docs] def saveNativeModel(self, sparkSession, filename): """ Save the booster as string format to a local or WASB remote location. """ jsession = sparkSession._jsparkSession self._java_obj.saveNativeModel(jsession, filename)
[docs] def loadNativeModelFromFile(self, sparkSession, filename, labelColName="label", featuresColName="features", predictionColName="prediction", probColName="probability", rawPredictionColName="rawPrediction"): """ Load the model from a native LightGBM text file. """ jsession = sparkSession._jsparkSession return self._java_obj.loadNativeModelFromFile(jsession, filename, labelColName, featuresColName, predictionColName, probColName, rawPredictionColName)
[docs] def loadNativeModelFromString(self, model, labelColName="label", featuresColName="features", predictionColName="prediction", probColName="probability", rawPredictionColName="rawPrediction"): """ Load the model from a native LightGBM text file. """ return self._java_obj.loadNativeModelFromString(model, labelColName, featuresColName, predictionColName, probColName, rawPredictionColName)
[docs] def getFeatureImportances(self, importance_type="split"): """ Get the feature importances. The importance_type can be "split" or "gain". """ return self._java_obj.getFeatureImportances(importance_type)