Source code for RankingSplit

# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root for information.

import numpy as np
import pyspark

# Methods for train-test split.
[docs]class RankingSplit:
[docs] @staticmethod def min_rating_filter(self, min_rating, by_customer=True): ''' Filter rating DataFrame for each user with minimum rating. :param min_rating: minimum number of rating for filtering. :param by: by which variable (customer or item) to filter the rating. ''' from pyspark.sql.functions import col, broadcast by = "customer" if by_customer else "item" with_ = "item" if by_customer else "customer" split_by_column, split_with_column = by + "ID", with_ + "ID" rating_temp = self.groupBy(split_by_column) \ .agg({split_with_column: "count"}) \ .withColumnRenamed('count(' + split_with_column + ')', "n" + split_with_column) \ .where(col("n" + split_with_column) >= min_rating) rating_filtered = self.join(broadcast(rating_temp), split_by_column) \ .drop("n" + split_with_column) return rating_filtered
[docs] @staticmethod def stratified_split(self, min_rating, by_customer=True, ratio=0.3, fixed_test_sample=False, sample=3): ''' Perform stratified sampling on rating DataFrame to split into train and test. Fixed ratio and fixed number also apply to splitting. The fixed number of samples for testing should be less than min_rating. This method is usually used for evaluating ranking metrics for warm user or item. :param min_rating: minimum number of rating for filtering. :param ratio: splitting ratio for train and test. :param by: by which variable (customer or item) to filter the rating. ''' from pyspark.sql import Window from pyspark.sql.functions import row_number, col, rand, bround, broadcast if fixed_test_sample == True & sample > min_rating: print("sample should be less than min_rating.") return -1 by = "customer" if by_customer else "item" with_ = "item" if by_customer else "customer" split_by_column, split_with_column = by + "ID", with_ + "ID" window_spec = Window.partitionBy(split_by_column).orderBy(rand()) rating_joined = self.min_rating_filter(min_rating, by_customer) rating_grouped = rating_joined \ .groupBy(split_by_column) \ .agg({split_with_column: "count"}) \ .withColumnRenamed("count(" + split_with_column + ")", "count") if not fixed_test_sample: rating_all = rating_joined \ .join(broadcast(rating_grouped), on=split_by_column, how="outer") \ .withColumn('splitPoint', bround(col('count') * ratio)) rating_train = rating_all \ .select('*', row_number().over(window_spec).alias('rank')) \ .filter(col('rank') > col('splitPoint')) \ .drop("splitPoint", "rank") rating_test = rating_all \ .select('*', row_number().over(window_spec).alias('rank')) \ .filter(col('rank') <= col('splitPoint')) \ .drop("splitPoint", "rank") else: rating_train = rating_joined \ .select('*', row_number().over(window_spec).alias('rank')) \ .filter(col('rank') > sample) \ .drop("rank") rating_test = rating_joined \ .select('*', row_number().over(window_spec).alias('rank')) \ .filter(col('rank') <= sample) \ .drop("rank") return rating_train, rating_test
[docs] @staticmethod def chronological_split(self, min_rating, by_customer=True, ratio=0.3, fixed_test_sample=False, sample=3): ''' Chronological splitting split data (items are ordered by timestamps for each customer) by timestamps. Fixed ratio and fixed number also apply to splitting. The fixed number of samples for testing should be less than min_rating. This method assumes implicit rating so there must be timestamps presented in the DataFrame. This method is usually used for evaluating ranking metrics for warm user or item. :param min_rating: minimum number of rating for filtering. :param ratio: sampling ratio for testing set . :param fixed_test_sample: whether or not fixing the number in sampling testing data. :param sample: number of samples for testing data. ''' from pyspark.sql import Window from pyspark.sql.functions import col, bround, row_number, broadcast if fixed_test_sample == True & sample > min_rating: print("sample should be less than min_rating.") return -1 by = "customer" if by_customer else "item" with_ = "customer" if by_customer else "item" split_by_column, split_with_column = by + "ID", with_ + "ID" pyspark.sql.DataFrame.min_rating_filter = TrainTestSplit.min_rating_filter rating_joined = self.min_rating_filter(min_rating, by_customer) rating_grouped = rating_joined \ .groupBy(split_by_column) \ .agg({'timeStamp': 'count'}) \ .withColumnRenamed('count(timeStamp)', 'count') window_spec = Window.partitionBy(split_by_column).orderBy(col('timeStamp').desc()) if fixed_test_sample == False: rating_all = rating_joined \ .join(broadcast(rating_grouped), on=split_by_column, how="outer") \ .withColumn('splitPoint', bround(col('count') * ratio)) rating_train = rating_all \ .select('*', row_number().over(window_spec).alias('rank')) \ .filter(col('rank') > col('splitPoint')) \ .drop("splitPoint", "rank") rating_test = rating_all \ .select('*', row_number().over(window_spec).alias('rank')) \ .filter(col('rank') <= col('splitPoint')) \ .drop("splitPoint", "rank") else: rating_train = rating_joined \ .select('*', row_number().over(window_spec).alias('rank')) \ .filter(col('rank') > sample) \ .drop("rank") rating_test = rating_joined \ .select('*', row_number().over(window_spec).alias('rank')) \ .filter(col('rank') <= sample) \ .drop("rank") rating_train = rating_train.select(split_by_column, split_with_column, "timeStamp") rating_test = rating_test.select(split_by_column, split_with_column, "timeStamp") return rating_train, rating_test
[docs] @staticmethod def non_overlapping_split(self, min_rating, by_customer=True, ratio=0.7): ''' Split by customer or item. Customer (or item) in sets of training and testing data are mutually exclusive. This method is usually used for evaluating ranking metrics for cold user or item. :param min_rating: minimum number of rating for filtering. :param ratio: sampling ratio for testing set . ''' from pyspark.sql.window import Window from pyspark.sql.functions import row_number, col, rand, broadcast pyspark.sql.DataFrame.min_rating_filter = TrainTestSplit.min_rating_filter rating_joined = self.min_rating_filter(min_rating, by_customer) by = "customer" if by_customer else "item" with_ = "customer" if by_customer else "item" split_by_column, split_with_column = by + "ID", with_ + "ID" rating_exclusive = rating_joined \ .groupBy(split_by_column) \ .agg({split_with_column: "count"}) \ .withColumnRenamed("count(" + split_with_column + ")", "n" + with_) \ .drop("n" + with_) count = rating_exclusive.count() window_spec = Window.orderBy(rand()) rating_tmp = rating_exclusive \ .select(col("*"), row_number().over(window_spec).alias("rowNumber")) rating_split = \ rating_tmp.filter(rating_tmp['rowNumber'] <= round(count * ratio)).drop("rowNumber"), \ rating_tmp.filter(rating_tmp['rowNumber'] > round(count * ratio)).drop('rowNumber') rating_train = rating_joined.join(broadcast(rating_split[0]), split_by_column) rating_test = rating_joined.join(broadcast(rating_split[1]), split_by_column) return rating_train, rating_test
[docs] @staticmethod def random_split(self, min_rating, by_customer=True, ratio=0.7): ''' Purely random splitting. This method is generally used for evaluating rating metrics for both warm and cold user/item. :param min_rating: minimum number of rating for filtering. :param ratio: sampling ratio for testing set . ''' from pyspark.sql.window import Window from pyspark.sql.functions import row_number, col, rand pyspark.sql.DataFrame.min_rating_filter = TrainTestSplit.min_rating_filter rating_split = self \ .min_rating_filter(min_rating, by_customer) \ .randomSplit([1 - ratio, ratio]) rating_train = rating_split[0] rating_test = rating_split[1] return rating_train, rating_test
if __name__ == "__main__": print("Splitter")