# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root for information.
import sys
if sys.version >= '3':
basestring = str
from pyspark.ml.common import inherit_doc
from pyspark.ml.tuning import ValidatorParams
from pyspark.ml.util import *
from mmlspark._RankingTrainValidationSplitModel import _RankingTrainValidationSplitModel
[docs]@inherit_doc
class RankingTrainValidationSplitModel(_RankingTrainValidationSplitModel, ValidatorParams):
def __init__(self, bestModel, validationMetrics=[]):
super(RankingTrainValidationSplitModel, self).__init__()
#: best model from cross validation
self.bestModel = bestModel
#: evaluated validation metrics
self.validationMetrics = validationMetrics
def _transform(self, dataset):
return self.bestModel.transform(dataset)
[docs] def copy(self, extra=None):
"""
Creates a copy of this instance with a randomly generated uid
and some extra params. This copies the underlying bestModel,
creates a deep copy of the embedded paramMap, and
copies the embedded and extra parameters over.
And, this creates a shallow copy of the validationMetrics.
:param extra: Extra parameters to copy to the new instance
:return: Copy of this instance
"""
if extra is None:
extra = dict()
bestModel = self.bestModel.copy(extra)
validationMetrics = list(self.validationMetrics)
return RankingTrainValidationSplitModel(bestModel, validationMetrics)
[docs] def recommendForAllUsers(self, numItems):
return self.bestModel._call_java("recommendForAllUsers", numItems)
[docs] def recommendForAllItems(self, numItems):
return self.bestModel._call_java("recommendForAllItems", numItems)