Source code for TrainClassifier

# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.


import sys
if sys.version >= '3':
    basestring = str

from pyspark.ml.param.shared import *
from pyspark import keyword_only
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
from pyspark.ml.wrapper import JavaTransformer, JavaEstimator, JavaModel
from pyspark.ml.common import inherit_doc
from mmlspark.Utils import *
from mmlspark.TypeConversionUtils import generateTypeConverter, complexTypeConverter

[docs]@inherit_doc class TrainClassifier(ComplexParamsMixin, JavaMLReadable, JavaMLWritable, JavaEstimator): """ Trains a classification model The currently supported classifiers are: Logistic Regression Classifier Decision Tree Classifier Random Forest Classifier Gradient Boosted Trees Classifier Naive Bayes Classifier Multilayer Perceptron Classifier In addition to any generic learner that inherits from Predictor. This module featurizes the given data into a vector of doubles and passes it to the given learner. Note the behavior of the reindex and labels parameters, the parameters interact as: reindex - false labels - false (Empty) Assume all double values, don't use metadata, assume natural ordering reindex - true labels - false (Empty) Index, use natural ordering of string indexer reindex - false labels - true (Specified) Assume user knows indexing, apply label values. Currently only string type supported. reindex - true labels - true (Specified) Validate labels matches column type, try to recast to label type, reindex label column Args: featuresCol (str): The name of the features column (default: [self.uid]_features) labelCol (str): The name of the label column labels (list): Sorted label values on the labels column model (object): Classifier to run numFeatures (int): Number of features to hash to (default: 0) reindexLabel (bool): Re-index the label column (default: true) """ @keyword_only def __init__(self, featuresCol=None, labelCol=None, labels=None, model=None, numFeatures=0, reindexLabel=True): super(TrainClassifier, self).__init__() self._java_obj = self._new_java_obj("com.microsoft.ml.spark.TrainClassifier") self._cache = {} self.featuresCol = Param(self, "featuresCol", "featuresCol: The name of the features column (default: [self.uid]_features)") self._setDefault(featuresCol=self.uid + "_features") self.labelCol = Param(self, "labelCol", "labelCol: The name of the label column") self.labels = Param(self, "labels", "labels: Sorted label values on the labels column") self.model = Param(self, "model", "model: Classifier to run", generateTypeConverter("model", self._cache, complexTypeConverter)) self.numFeatures = Param(self, "numFeatures", "numFeatures: Number of features to hash to (default: 0)") self._setDefault(numFeatures=0) self.reindexLabel = Param(self, "reindexLabel", "reindexLabel: Re-index the label column (default: true)") self._setDefault(reindexLabel=True) if hasattr(self, "_input_kwargs"): kwargs = self._input_kwargs else: kwargs = self.__init__._input_kwargs self.setParams(**kwargs)
[docs] @keyword_only def setParams(self, featuresCol=None, labelCol=None, labels=None, model=None, numFeatures=0, reindexLabel=True): """ Set the (keyword only) parameters Args: featuresCol (str): The name of the features column (default: [self.uid]_features) labelCol (str): The name of the label column labels (list): Sorted label values on the labels column model (object): Classifier to run numFeatures (int): Number of features to hash to (default: 0) reindexLabel (bool): Re-index the label column (default: true) """ if hasattr(self, "_input_kwargs"): kwargs = self._input_kwargs else: kwargs = self.__init__._input_kwargs return self._set(**kwargs)
[docs] def setFeaturesCol(self, value): """ Args: featuresCol (str): The name of the features column (default: [self.uid]_features) """ self._set(featuresCol=value) return self
[docs] def getFeaturesCol(self): """ Returns: str: The name of the features column (default: [self.uid]_features) """ return self.getOrDefault(self.featuresCol)
[docs] def setLabelCol(self, value): """ Args: labelCol (str): The name of the label column """ self._set(labelCol=value) return self
[docs] def getLabelCol(self): """ Returns: str: The name of the label column """ return self.getOrDefault(self.labelCol)
[docs] def setLabels(self, value): """ Args: labels (list): Sorted label values on the labels column """ self._set(labels=value) return self
[docs] def getLabels(self): """ Returns: list: Sorted label values on the labels column """ return self.getOrDefault(self.labels)
[docs] def setModel(self, value): """ Args: model (object): Classifier to run """ self._set(model=value) return self
[docs] def getModel(self): """ Returns: object: Classifier to run """ return self._cache.get("model", None)
[docs] def setNumFeatures(self, value): """ Args: numFeatures (int): Number of features to hash to (default: 0) """ self._set(numFeatures=value) return self
[docs] def getNumFeatures(self): """ Returns: int: Number of features to hash to (default: 0) """ return self.getOrDefault(self.numFeatures)
[docs] def setReindexLabel(self, value): """ Args: reindexLabel (bool): Re-index the label column (default: true) """ self._set(reindexLabel=value) return self
[docs] def getReindexLabel(self): """ Returns: bool: Re-index the label column (default: true) """ return self.getOrDefault(self.reindexLabel)
[docs] @classmethod def read(cls): """ Returns an MLReader instance for this class. """ return JavaMMLReader(cls)
[docs] @staticmethod def getJavaPackage(): """ Returns package name String. """ return "com.microsoft.ml.spark.TrainClassifier"
@staticmethod def _from_java(java_stage): module_name=TrainClassifier.__module__ module_name=module_name.rsplit(".", 1)[0] + ".TrainClassifier" return from_java(java_stage, module_name) def _create_model(self, java_model): return TrainedClassifierModel(java_model)
[docs]class TrainedClassifierModel(ComplexParamsMixin, JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by :class:`TrainClassifier`. This class is left empty on purpose. All necessary methods are exposed through inheritance. """
[docs] @classmethod def read(cls): """ Returns an MLReader instance for this class. """ return JavaMMLReader(cls)
[docs] @staticmethod def getJavaPackage(): """ Returns package name String. """ return "com.microsoft.ml.spark.TrainedClassifierModel"
@staticmethod def _from_java(java_stage): module_name=TrainedClassifierModel.__module__ module_name=module_name.rsplit(".", 1)[0] + ".TrainedClassifierModel" return from_java(java_stage, module_name)