diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-01-03 13:18:38 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-03 13:24:58 -0800 |
commit | 1055b6a81e2c58773c1db415c5bc5a8b3b9b74c7 (patch) | |
tree | 246c388fba6246e0f5d4923d00f5378b7a192033 | |
parent | 7f62ba6e7fa475ab4aa2f99ef899a07dac835a2c (diff) |
Handle non-tensor args for predictions and labels.
Add test for 3d predictions and losses.
Change: 143478073
-rw-r--r-- | tensorflow/python/kernel_tests/losses_test.py | 198 | ||||
-rw-r--r-- | tensorflow/python/ops/losses/losses_impl.py | 38 |
2 files changed, 125 insertions, 111 deletions
diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py index 69b6e2c39f..f70a59cc01 100644 --- a/tensorflow/python/kernel_tests/losses_test.py +++ b/tensorflow/python/kernel_tests/losses_test.py @@ -780,7 +780,7 @@ class MeanSquaredErrorTest(test.TestCase): self.assertAlmostEqual(0.0, loss.eval(), 3) -class MeanPairwiseSquaresErrorTest(test.TestCase): +class MeanPairwiseSquaredErrorTest(test.TestCase): def setUp(self): self._predictions = np.array([[4, 8, 12], [8, 1, 3]]) @@ -789,14 +789,14 @@ class MeanPairwiseSquaresErrorTest(test.TestCase): batch_size, dims = self._labels.shape # Compute the expected loss 'manually'. - total = np.zeros((batch_size, 1)) + total = np.zeros((batch_size,)) for b in range(batch_size): for i in range(dims): for j in range(dims): x = self._predictions[b, i].item() - self._predictions[b, j].item() y = self._labels[b, i].item() - self._labels[b, j].item() - tmp = (x - y) * (x - y) - total[b] += tmp + diff = (x - y) + total[b] += (diff * diff) self._expected_losses = np.divide(total, 9.0) @@ -808,19 +808,39 @@ class MeanPairwiseSquaresErrorTest(test.TestCase): labels=constant_op.constant(self._labels), weights=None) + def _test_mean_pairwise_squared_error( + self, labels, predictions, expected_loss, weights=1.0): + with self.test_session(): + static_inputs_op = losses.mean_pairwise_squared_error( + predictions=predictions, labels=labels, weights=weights) + self.assertAlmostEqual(expected_loss, static_inputs_op.eval(), places=3) + + predictions_placeholder = array_ops.placeholder( + dtypes.float32, shape=np.asarray(predictions.shape)) + labels_placeholder = array_ops.placeholder( + dtypes.int32, shape=np.asarray(labels.shape)) + weights_placeholder = array_ops.placeholder( + dtypes.float32, shape=np.asarray(weights).shape) + dynamic_inputs_op = losses.mean_pairwise_squared_error( + predictions=predictions_placeholder, + labels=labels_placeholder, + weights=weights_placeholder) + feed_dict = { + predictions_placeholder: predictions, + labels_placeholder: labels, + weights_placeholder: weights, + } + self.assertAlmostEqual( + expected_loss, dynamic_inputs_op.eval(feed_dict=feed_dict), places=3) + def testAllCorrectNoLossWeight(self): - loss = losses.mean_pairwise_squared_error( - predictions=constant_op.constant(self._labels), - labels=constant_op.constant(self._labels)) - with self.test_session(): - self.assertAlmostEqual(0.0, loss.eval(), 3) + self._test_mean_pairwise_squared_error( + self._labels, self._labels, expected_loss=0.0) def testNonZeroLoss(self): - loss = losses.mean_pairwise_squared_error( - predictions=constant_op.constant(self._predictions), - labels=constant_op.constant(self._labels)) - with self.test_session(): - self.assertAlmostEqual(np.sum(self._expected_losses), loss.eval(), 3) + self._test_mean_pairwise_squared_error( + self._labels, self._predictions, + expected_loss=np.sum(self._expected_losses)) def testGradientWithZeroWeight(self): with ops.Graph().as_default(): @@ -848,14 +868,11 @@ class MeanPairwiseSquaresErrorTest(test.TestCase): self.assertFalse(np.isnan(np_grad).any()) def testNonZeroLossWithPythonScalarWeight(self): - weights = 2.3 - loss = losses.mean_pairwise_squared_error( - predictions=constant_op.constant(self._predictions), - labels=constant_op.constant(self._labels), - weights=weights) - with self.test_session(): - self.assertAlmostEqual(weights * np.sum(self._expected_losses), - loss.eval(), 3) + weight = 2.3 + self._test_mean_pairwise_squared_error( + self._labels, self._predictions, + expected_loss=weight * np.sum(self._expected_losses), + weights=weight) def testNonZeroLossWithScalarTensorWeight(self): weights = 2.3 @@ -868,83 +885,80 @@ class MeanPairwiseSquaresErrorTest(test.TestCase): loss.eval(), 3) def testNonZeroLossWithScalarZeroWeight(self): - weights = 0 - loss = losses.mean_pairwise_squared_error( - predictions=constant_op.constant(self._predictions), - labels=constant_op.constant(self._labels), - weights=constant_op.constant(weights)) - with self.test_session(): - self.assertAlmostEqual(0, loss.eval(), 3) - - def testNonZeroLossWithScalarTensorWeightWithPlaceholder(self): - weights = 2.3 - tf_predictions = array_ops.placeholder( - dtypes.float32, shape=self._predictions.shape) - tf_labels = array_ops.placeholder(dtypes.float32, shape=self._labels.shape) - loss = losses.mean_pairwise_squared_error( - predictions=tf_predictions, - labels=tf_labels, - weights=constant_op.constant(weights)) - with self.test_session() as sess: - loss = sess.run(loss, - feed_dict={ - tf_predictions: self._predictions, - tf_labels: self._labels, - }) - self.assertAlmostEqual(weights * np.sum(self._expected_losses), loss, 3) + self._test_mean_pairwise_squared_error( + self._labels, self._predictions, expected_loss=0.0, weights=0.0) def testNonZeroLossWithOneDimBatchSpecificWeights(self): - weights = np.asarray([2.0, 1.0]).reshape((2, 1)) - expected_losses = np.multiply(weights, self._expected_losses) + weights = np.asarray((1.2, 3.4)) + self._test_mean_pairwise_squared_error( + self._labels, self._predictions, + expected_loss=np.sum(np.multiply(weights, self._expected_losses)), + weights=weights) - loss = losses.mean_pairwise_squared_error( - predictions=constant_op.constant(self._predictions), - labels=constant_op.constant(self._labels), - weights=constant_op.constant( - weights, shape=[2])) - with self.test_session(): - self.assertAlmostEqual(np.sum(expected_losses), loss.eval(), 3) + def test3d(self): + labels = np.array([ + [[1, 9, 2], [12, 11, 10], [9, 8, 7]], + [[-5, -5, 7], [6, 5, 4], [3, 2, 1]], + ]) + predictions = np.array([ + [[4, 8, 12], [1, 2, 3], [4, 5, 6]], + [[8, 1, 3], [7, 8, 9], [10, 11, 12]], + ]) + self._test_mean_pairwise_squared_error( + labels, predictions, expected_loss=122.22222) + + def test3dWeightedScalar(self): + labels = np.array([ + [[1, 9, 2], [12, 11, 10], [9, 8, 7]], + [[-5, -5, 7], [6, 5, 4], [3, 2, 1]], + ]) + predictions = np.array([ + [[4, 8, 12], [1, 2, 3], [4, 5, 6]], + [[8, 1, 3], [7, 8, 9], [10, 11, 12]], + ]) + weight = 3.0 + self._test_mean_pairwise_squared_error( + labels, predictions, expected_loss=weight * 122.22222, + weights=weight) + + def test3dWeighted2x0(self): + labels = np.array([ + [[1, 9, 2], [12, 11, 10], [9, 8, 7]], + [[-5, -5, 7], [6, 5, 4], [3, 2, 1]], + ]) + predictions = np.array([ + [[4, 8, 12], [1, 2, 3], [4, 5, 6]], + [[8, 1, 3], [7, 8, 9], [10, 11, 12]], + ]) + self._test_mean_pairwise_squared_error( + labels, predictions, expected_loss=253.24445, + weights=np.asarray((1.2, 3.4))) + + # TODO(ptucker): According to the pydoc, this should work. + def test3dWeighted2x3x3(self): + labels = np.array([ + [[1, 9, 2], [12, 11, 10], [9, 8, 7]], + [[-5, -5, 7], [6, 5, 4], [3, 2, 1]], + ]) + predictions = np.array([ + [[4, 8, 12], [1, 2, 3], [4, 5, 6]], + [[8, 1, 3], [7, 8, 9], [10, 11, 12]], + ]) + with self.assertRaisesRegexp( + ValueError, 'Dimensions must be equal, but are 2 and 3'): + losses.mean_pairwise_squared_error( + predictions=predictions, labels=labels, + weights=np.ones((2, 3, 3))) def testZeroLossWithOneDimBatchZeroWeights(self): - weights = np.asarray([0.0, 0.0]).reshape((2, 1)) - loss = losses.mean_pairwise_squared_error( - predictions=constant_op.constant(self._predictions), - labels=constant_op.constant(self._labels), - weights=constant_op.constant( - weights, shape=[2])) - with self.test_session(): - self.assertAlmostEqual(0, loss.eval(), 3) - - def testNonZeroLossWithOneDimBatchSpecificWeightsAndPlaceholders(self): - weights = np.asarray([1.2, 3.4]).reshape((2, 1)) - expected_losses = np.multiply(weights, self._expected_losses) - - tf_predictions = array_ops.placeholder( - dtypes.float32, shape=self._predictions.shape) - tf_labels = array_ops.placeholder(dtypes.int32, shape=self._labels.shape) - loss = losses.mean_pairwise_squared_error( - predictions=tf_predictions, - labels=tf_labels, - weights=constant_op.constant( - weights, shape=[2])) - - with self.test_session() as sess: - loss = sess.run(loss, - feed_dict={ - tf_predictions: self._predictions, - tf_labels: self._labels, - }) - self.assertAlmostEqual(np.sum(expected_losses), loss, 3) + self._test_mean_pairwise_squared_error( + self._labels, self._predictions, expected_loss=0.0, + weights=np.zeros((2,))) def testLossWithAllZeroBatchSpecificWeights(self): - weights = np.zeros((2, 1)) - loss = losses.mean_pairwise_squared_error( - predictions=constant_op.constant(self._predictions), - labels=constant_op.constant(self._labels), - weights=constant_op.constant( - weights, shape=[2])) - with self.test_session(): - self.assertAlmostEqual(0.0, loss.eval(), 3) + self._test_mean_pairwise_squared_error( + self._labels, self._predictions, expected_loss=0.0, + weights=np.zeros((2, 1))) class CosineDistanceLossTest(test.TestCase): diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index 2a06753f3c..bd6b87a3a5 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -247,10 +247,10 @@ def absolute_difference( if the shape of `weights` is invalid. """ with ops.name_scope(scope, "absolute_difference", - [predictions, labels, weights]) as scope: - predictions.get_shape().assert_is_compatible_with(labels.get_shape()) + (predictions, labels, weights)) as scope: predictions = math_ops.to_float(predictions) labels = math_ops.to_float(labels) + predictions.get_shape().assert_is_compatible_with(labels.get_shape()) losses = math_ops.abs(math_ops.subtract(predictions, labels)) return compute_weighted_loss(losses, weights, scope, loss_collection) @@ -288,11 +288,10 @@ def cosine_distance( if dim is None: raise ValueError("`dim` cannot be None.") with ops.name_scope(scope, "cosine_distance_loss", - [predictions, labels, weights]) as scope: - predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - + (predictions, labels, weights)) as scope: predictions = math_ops.to_float(predictions) labels = math_ops.to_float(labels) + predictions.get_shape().assert_is_compatible_with(labels.get_shape()) radial_diffs = math_ops.multiply(predictions, labels) losses = 1 - math_ops.reduce_sum(radial_diffs, reduction_indices=[dim,]) @@ -324,10 +323,11 @@ def hinge_loss(labels, logits, weights=1.0, scope=None, Raises: ValueError: If the shapes of `logits` and `labels` don't match. """ - with ops.name_scope(scope, "hinge_loss", [logits, labels]) as scope: + with ops.name_scope(scope, "hinge_loss", (logits, labels)) as scope: + logits = math_ops.to_float(logits) + labels = math_ops.to_float(labels) logits.get_shape().assert_is_compatible_with(labels.get_shape()) # We first need to convert binary labels to -1/1 labels (as floats). - labels = math_ops.to_float(labels) all_ones = array_ops.ones_like(labels) labels = math_ops.subtract(2 * labels, all_ones) losses = nn_ops.relu( @@ -370,10 +370,10 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None, if the shape of `weights` is invalid. """ with ops.name_scope(scope, "log_loss", - [predictions, labels, weights]) as scope: - predictions.get_shape().assert_is_compatible_with(labels.get_shape()) + (predictions, labels, weights)) as scope: predictions = math_ops.to_float(predictions) labels = math_ops.to_float(labels) + predictions.get_shape().assert_is_compatible_with(labels.get_shape()) losses = -math_ops.multiply( labels, math_ops.log(predictions + epsilon)) - math_ops.multiply( @@ -424,10 +424,10 @@ def mean_pairwise_squared_error(labels, predictions, weights=1.0, scope=None, if the shape of `weights` is invalid. """ with ops.name_scope(scope, "mean_pairwise_squared_error", - [predictions, labels, weights]) as scope: - predictions.get_shape().assert_is_compatible_with(labels.get_shape()) + (predictions, labels, weights)) as scope: predictions = math_ops.to_float(predictions) labels = math_ops.to_float(labels) + predictions.get_shape().assert_is_compatible_with(labels.get_shape()) weights = math_ops.to_float(ops.convert_to_tensor(weights)) diffs = math_ops.subtract(predictions, labels) @@ -496,10 +496,10 @@ def mean_squared_error(labels, predictions, weights=1.0, scope=None, if the shape of `weights` is invalid. """ with ops.name_scope(scope, "mean_squared_error", - [predictions, labels, weights]) as scope: - predictions.get_shape().assert_is_compatible_with(labels.get_shape()) + (predictions, labels, weights)) as scope: predictions = math_ops.to_float(predictions) labels = math_ops.to_float(labels) + predictions.get_shape().assert_is_compatible_with(labels.get_shape()) losses = math_ops.square(math_ops.subtract(predictions, labels)) return compute_weighted_loss(losses, weights, scope, loss_collection) @@ -544,10 +544,10 @@ def sigmoid_cross_entropy( `weights` is None. """ with ops.name_scope(scope, "sigmoid_cross_entropy_loss", - [logits, multi_class_labels, weights]) as scope: - logits.get_shape().assert_is_compatible_with(multi_class_labels.get_shape()) - + (logits, multi_class_labels, weights)) as scope: + logits = ops.convert_to_tensor(logits) multi_class_labels = math_ops.cast(multi_class_labels, logits.dtype) + logits.get_shape().assert_is_compatible_with(multi_class_labels.get_shape()) if label_smoothing > 0: multi_class_labels = (multi_class_labels * (1 - label_smoothing) + @@ -595,10 +595,10 @@ def softmax_cross_entropy( or if the shape of `weights` is invalid or if `weights` is None. """ with ops.name_scope(scope, "softmax_cross_entropy_loss", - [logits, onehot_labels, weights]) as scope: - logits.get_shape().assert_is_compatible_with(onehot_labels.get_shape()) - + (logits, onehot_labels, weights)) as scope: + logits = ops.convert_to_tensor(logits) onehot_labels = math_ops.cast(onehot_labels, logits.dtype) + logits.get_shape().assert_is_compatible_with(onehot_labels.get_shape()) if label_smoothing > 0: num_classes = math_ops.cast( |