Source code for RankingTrainValidationSplit

# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root for information.

import sys
from mmlspark.RankingTrainValidationSplitModel import RankingTrainValidationSplitModel as tvmodel
from mmlspark._RankingTrainValidationSplit import _RankingTrainValidationSplit
from pyspark import keyword_only
from pyspark.ml.param import Params
from pyspark.ml.tuning import ValidatorParams
from pyspark.ml.util import *

if sys.version >= '3':
    basestring = str


[docs]@inherit_doc class RankingTrainValidationSplit(_RankingTrainValidationSplit, ValidatorParams): @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, seed=None): """ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ seed=None) """ super(RankingTrainValidationSplit, self).__init__() kwargs = self._input_kwargs self._set(**kwargs)
[docs] @keyword_only def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, seed=None): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ seed=None): Sets params for cross validator. """ kwargs = self._input_kwargs return self._set(**kwargs)
[docs] def copy(self, extra=None): """ Creates a copy of this instance with a randomly generated uid and some extra params. This copies creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ if extra is None: extra = dict() newCV = Params.copy(self, extra) if self.isSet(self.estimator): newCV.setEstimator(self.getEstimator().copy(extra)) # estimatorParamMaps remain the same if self.isSet(self.evaluator): newCV.setEvaluator(self.getEvaluator().copy(extra)) return newCV
def _create_model(self, java_model): model = tvmodel() model._java_obj = java_model model._transfer_params_from_java() return model def _to_java(self): """ Transfer this instance to a Java TrainValidationSplit. Used for ML persistence. :return: Java object equivalent to this instance. """ estimator, epms, evaluator = super(RankingTrainValidationSplit, self)._to_java_impl() _java_obj = JavaParams._new_java_obj("mmlspark.RankingTrainValidationSplitModel", self.uid) _java_obj.setEstimatorParamMaps(epms) _java_obj.setEvaluator(evaluator) _java_obj.setEstimator(estimator) _java_obj.setTrainRatio(self.getTrainRatio()) _java_obj.setSeed(self.getSeed()) _java_obj.setParallelism(self.getParallelism()) _java_obj.setCollectSubModels(self.getCollectSubModels()) return _java_obj def _fit(self, dataset): return self._to_java()._call_java("fit", dataset)