diff --git a/tensorflow_ranking/python/losses_test.py b/tensorflow_ranking/python/losses_test.py index 5864b16..0f0ea53 100644 --- a/tensorflow_ranking/python/losses_test.py +++ b/tensorflow_ranking/python/losses_test.py @@ -19,6 +19,8 @@ from __future__ import print_function import math + +import numpy as np import tensorflow as tf from tensorflow_ranking.python import losses as ranking_losses @@ -151,14 +153,11 @@ def _loss(si, sj, label_diff, delta): def _batch_aggregation(batch_loss_list, reduction=None): """Returns the aggregated loss.""" - loss_sum = 0. - weight_sum = 0. - for loss, weight, count in batch_loss_list: - loss_sum += loss - if reduction == 'mean': - weight_sum += weight - else: - weight_sum += count + loss_sum = np.sum([loss for loss, weight, count in batch_loss_list]) + if reduction == 'mean': + weight_sum = np.sum([weight for loss, weight, count in batch_loss_list]) + else: + weight_sum = np.sum([count for loss, weight, count in batch_loss_list]) return loss_sum / weight_sum