# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root for information.
import sys
from mmlspark.RankingTrainValidationSplitModel import RankingTrainValidationSplitModel as tvmodel
from mmlspark._RankingTrainValidationSplit import _RankingTrainValidationSplit
from pyspark import keyword_only
from pyspark.ml.param import Params
from pyspark.ml.tuning import ValidatorParams
from pyspark.ml.util import *
if sys.version >= '3':
basestring = str
[docs]@inherit_doc
class RankingTrainValidationSplit(_RankingTrainValidationSplit, ValidatorParams):
@keyword_only
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, seed=None):
"""
__init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
seed=None)
"""
super(RankingTrainValidationSplit, self).__init__()
kwargs = self._input_kwargs
self._set(**kwargs)
[docs] @keyword_only
def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, seed=None):
"""
setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
seed=None):
Sets params for cross validator.
"""
kwargs = self._input_kwargs
return self._set(**kwargs)
[docs] def copy(self, extra=None):
"""
Creates a copy of this instance with a randomly generated uid
and some extra params. This copies creates a deep copy of
the embedded paramMap, and copies the embedded and extra parameters over.
:param extra: Extra parameters to copy to the new instance
:return: Copy of this instance
"""
if extra is None:
extra = dict()
newCV = Params.copy(self, extra)
if self.isSet(self.estimator):
newCV.setEstimator(self.getEstimator().copy(extra))
# estimatorParamMaps remain the same
if self.isSet(self.evaluator):
newCV.setEvaluator(self.getEvaluator().copy(extra))
return newCV
def _create_model(self, java_model):
model = tvmodel()
model._java_obj = java_model
model._transfer_params_from_java()
return model
def _to_java(self):
"""
Transfer this instance to a Java TrainValidationSplit. Used for ML persistence.
:return: Java object equivalent to this instance.
"""
estimator, epms, evaluator = super(RankingTrainValidationSplit, self)._to_java_impl()
_java_obj = JavaParams._new_java_obj("mmlspark.RankingTrainValidationSplitModel",
self.uid)
_java_obj.setEstimatorParamMaps(epms)
_java_obj.setEvaluator(evaluator)
_java_obj.setEstimator(estimator)
_java_obj.setTrainRatio(self.getTrainRatio())
_java_obj.setSeed(self.getSeed())
_java_obj.setParallelism(self.getParallelism())
_java_obj.setCollectSubModels(self.getCollectSubModels())
return _java_obj
def _fit(self, dataset):
return self._to_java()._call_java("fit", dataset)