Use a custom metric function in PySpark Datafame

I have defined a custom function in python to calculate class-wise auc scores in a one-vs-rest fashion. It takes true classes and the probabilities for different classes as input and returns class-wise auc scores.

from sklearn.metrics import roc_curve, auc import pandas as pd  def mclass_auc(y_true, y_pred, n_class):     tp = {}     fp = {}     aucs = {}     for i in range(n_class):     classes = [0]*n_class     classes[i] = 1     fp[i] tp[i], th = roc_curve(y_true.replace(list(range(n_class)), classes), y_pred[:, i])     aucs[i] = auc(fp[i], tp[i])     return aucs 

For the sake of simplicity, I am generating some probability values which don’t sum up to one.

cola = [np.random.randint(40, 81)/100 for i in range(10000)] colb = [np.random.randint(30, 801)/1000 for i in range(10000)] colc = [np.random.randint(40, 81)/200 for i in range(10000)]  coly = [np.random.randint(0, 4) for i in range(10000)]  sample_df = pd.DataFrame({'0':cola, '1':colb, '2':colc, 'y':coly})  y_true = sample_df['y'] y_pred = sample_df[['1','2','3']].values  auc_multiclass(y_true, y_pred, 3) sql.createDataFrame(sample_df) 

In python, I can use the above function. Can someone help me calculate this in PySpark data frame setting? Changing it to pandas data frame and calculating was failing in this case.

from pyspark.mllib.evaluation import BinaryClassificationMetrics import pyspark.sql.functions as F  def mclass_auc_spark(y_true, y_pred, n_class):     aucs = {}     for i in range(n_class):         pred ='row_id', F.monotonically_increasing_id())         true = y_true.withColumn('y', F.when(F.col('y') == i, 1.0).otherwise(0.0)).withColumn('row_id', F.monotonically_increasing_id())         pred_labels = pred.join(true,on='row_id')         metric = BinaryClassificationMetrics('y',str(i)).rdd)         aucs[i] = metric.areaUnderROC     return aucs  spark_df = sql.createDataFrame(sample_df)  y_true ='y') y_pred ='0','1','2')  auc_scores = mclass_auc(y_true, y_pred, 3) 
Answered on July 16, 2020.
