aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-01-03 13:18:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-03 13:24:58 -0800
commit1055b6a81e2c58773c1db415c5bc5a8b3b9b74c7 (patch)
tree246c388fba6246e0f5d4923d00f5378b7a192033
parent7f62ba6e7fa475ab4aa2f99ef899a07dac835a2c (diff)
Handle non-tensor args for predictions and labels.
Add test for 3d predictions and losses. Change: 143478073
-rw-r--r--tensorflow/python/kernel_tests/losses_test.py198
-rw-r--r--tensorflow/python/ops/losses/losses_impl.py38
2 files changed, 125 insertions, 111 deletions
diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py
index 69b6e2c39f..f70a59cc01 100644
--- a/tensorflow/python/kernel_tests/losses_test.py
+++ b/tensorflow/python/kernel_tests/losses_test.py
@@ -780,7 +780,7 @@ class MeanSquaredErrorTest(test.TestCase):
self.assertAlmostEqual(0.0, loss.eval(), 3)
-class MeanPairwiseSquaresErrorTest(test.TestCase):
+class MeanPairwiseSquaredErrorTest(test.TestCase):
def setUp(self):
self._predictions = np.array([[4, 8, 12], [8, 1, 3]])
@@ -789,14 +789,14 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
batch_size, dims = self._labels.shape
# Compute the expected loss 'manually'.
- total = np.zeros((batch_size, 1))
+ total = np.zeros((batch_size,))
for b in range(batch_size):
for i in range(dims):
for j in range(dims):
x = self._predictions[b, i].item() - self._predictions[b, j].item()
y = self._labels[b, i].item() - self._labels[b, j].item()
- tmp = (x - y) * (x - y)
- total[b] += tmp
+ diff = (x - y)
+ total[b] += (diff * diff)
self._expected_losses = np.divide(total, 9.0)
@@ -808,19 +808,39 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
labels=constant_op.constant(self._labels),
weights=None)
+ def _test_mean_pairwise_squared_error(
+ self, labels, predictions, expected_loss, weights=1.0):
+ with self.test_session():
+ static_inputs_op = losses.mean_pairwise_squared_error(
+ predictions=predictions, labels=labels, weights=weights)
+ self.assertAlmostEqual(expected_loss, static_inputs_op.eval(), places=3)
+
+ predictions_placeholder = array_ops.placeholder(
+ dtypes.float32, shape=np.asarray(predictions.shape))
+ labels_placeholder = array_ops.placeholder(
+ dtypes.int32, shape=np.asarray(labels.shape))
+ weights_placeholder = array_ops.placeholder(
+ dtypes.float32, shape=np.asarray(weights).shape)
+ dynamic_inputs_op = losses.mean_pairwise_squared_error(
+ predictions=predictions_placeholder,
+ labels=labels_placeholder,
+ weights=weights_placeholder)
+ feed_dict = {
+ predictions_placeholder: predictions,
+ labels_placeholder: labels,
+ weights_placeholder: weights,
+ }
+ self.assertAlmostEqual(
+ expected_loss, dynamic_inputs_op.eval(feed_dict=feed_dict), places=3)
+
def testAllCorrectNoLossWeight(self):
- loss = losses.mean_pairwise_squared_error(
- predictions=constant_op.constant(self._labels),
- labels=constant_op.constant(self._labels))
- with self.test_session():
- self.assertAlmostEqual(0.0, loss.eval(), 3)
+ self._test_mean_pairwise_squared_error(
+ self._labels, self._labels, expected_loss=0.0)
def testNonZeroLoss(self):
- loss = losses.mean_pairwise_squared_error(
- predictions=constant_op.constant(self._predictions),
- labels=constant_op.constant(self._labels))
- with self.test_session():
- self.assertAlmostEqual(np.sum(self._expected_losses), loss.eval(), 3)
+ self._test_mean_pairwise_squared_error(
+ self._labels, self._predictions,
+ expected_loss=np.sum(self._expected_losses))
def testGradientWithZeroWeight(self):
with ops.Graph().as_default():
@@ -848,14 +868,11 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
self.assertFalse(np.isnan(np_grad).any())
def testNonZeroLossWithPythonScalarWeight(self):
- weights = 2.3
- loss = losses.mean_pairwise_squared_error(
- predictions=constant_op.constant(self._predictions),
- labels=constant_op.constant(self._labels),
- weights=weights)
- with self.test_session():
- self.assertAlmostEqual(weights * np.sum(self._expected_losses),
- loss.eval(), 3)
+ weight = 2.3
+ self._test_mean_pairwise_squared_error(
+ self._labels, self._predictions,
+ expected_loss=weight * np.sum(self._expected_losses),
+ weights=weight)
def testNonZeroLossWithScalarTensorWeight(self):
weights = 2.3
@@ -868,83 +885,80 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
loss.eval(), 3)
def testNonZeroLossWithScalarZeroWeight(self):
- weights = 0
- loss = losses.mean_pairwise_squared_error(
- predictions=constant_op.constant(self._predictions),
- labels=constant_op.constant(self._labels),
- weights=constant_op.constant(weights))
- with self.test_session():
- self.assertAlmostEqual(0, loss.eval(), 3)
-
- def testNonZeroLossWithScalarTensorWeightWithPlaceholder(self):
- weights = 2.3
- tf_predictions = array_ops.placeholder(
- dtypes.float32, shape=self._predictions.shape)
- tf_labels = array_ops.placeholder(dtypes.float32, shape=self._labels.shape)
- loss = losses.mean_pairwise_squared_error(
- predictions=tf_predictions,
- labels=tf_labels,
- weights=constant_op.constant(weights))
- with self.test_session() as sess:
- loss = sess.run(loss,
- feed_dict={
- tf_predictions: self._predictions,
- tf_labels: self._labels,
- })
- self.assertAlmostEqual(weights * np.sum(self._expected_losses), loss, 3)
+ self._test_mean_pairwise_squared_error(
+ self._labels, self._predictions, expected_loss=0.0, weights=0.0)
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
- weights = np.asarray([2.0, 1.0]).reshape((2, 1))
- expected_losses = np.multiply(weights, self._expected_losses)
+ weights = np.asarray((1.2, 3.4))
+ self._test_mean_pairwise_squared_error(
+ self._labels, self._predictions,
+ expected_loss=np.sum(np.multiply(weights, self._expected_losses)),
+ weights=weights)
- loss = losses.mean_pairwise_squared_error(
- predictions=constant_op.constant(self._predictions),
- labels=constant_op.constant(self._labels),
- weights=constant_op.constant(
- weights, shape=[2]))
- with self.test_session():
- self.assertAlmostEqual(np.sum(expected_losses), loss.eval(), 3)
+ def test3d(self):
+ labels = np.array([
+ [[1, 9, 2], [12, 11, 10], [9, 8, 7]],
+ [[-5, -5, 7], [6, 5, 4], [3, 2, 1]],
+ ])
+ predictions = np.array([
+ [[4, 8, 12], [1, 2, 3], [4, 5, 6]],
+ [[8, 1, 3], [7, 8, 9], [10, 11, 12]],
+ ])
+ self._test_mean_pairwise_squared_error(
+ labels, predictions, expected_loss=122.22222)
+
+ def test3dWeightedScalar(self):
+ labels = np.array([
+ [[1, 9, 2], [12, 11, 10], [9, 8, 7]],
+ [[-5, -5, 7], [6, 5, 4], [3, 2, 1]],
+ ])
+ predictions = np.array([
+ [[4, 8, 12], [1, 2, 3], [4, 5, 6]],
+ [[8, 1, 3], [7, 8, 9], [10, 11, 12]],
+ ])
+ weight = 3.0
+ self._test_mean_pairwise_squared_error(
+ labels, predictions, expected_loss=weight * 122.22222,
+ weights=weight)
+
+ def test3dWeighted2x0(self):
+ labels = np.array([
+ [[1, 9, 2], [12, 11, 10], [9, 8, 7]],
+ [[-5, -5, 7], [6, 5, 4], [3, 2, 1]],
+ ])
+ predictions = np.array([
+ [[4, 8, 12], [1, 2, 3], [4, 5, 6]],
+ [[8, 1, 3], [7, 8, 9], [10, 11, 12]],
+ ])
+ self._test_mean_pairwise_squared_error(
+ labels, predictions, expected_loss=253.24445,
+ weights=np.asarray((1.2, 3.4)))
+
+ # TODO(ptucker): According to the pydoc, this should work.
+ def test3dWeighted2x3x3(self):
+ labels = np.array([
+ [[1, 9, 2], [12, 11, 10], [9, 8, 7]],
+ [[-5, -5, 7], [6, 5, 4], [3, 2, 1]],
+ ])
+ predictions = np.array([
+ [[4, 8, 12], [1, 2, 3], [4, 5, 6]],
+ [[8, 1, 3], [7, 8, 9], [10, 11, 12]],
+ ])
+ with self.assertRaisesRegexp(
+ ValueError, 'Dimensions must be equal, but are 2 and 3'):
+ losses.mean_pairwise_squared_error(
+ predictions=predictions, labels=labels,
+ weights=np.ones((2, 3, 3)))
def testZeroLossWithOneDimBatchZeroWeights(self):
- weights = np.asarray([0.0, 0.0]).reshape((2, 1))
- loss = losses.mean_pairwise_squared_error(
- predictions=constant_op.constant(self._predictions),
- labels=constant_op.constant(self._labels),
- weights=constant_op.constant(
- weights, shape=[2]))
- with self.test_session():
- self.assertAlmostEqual(0, loss.eval(), 3)
-
- def testNonZeroLossWithOneDimBatchSpecificWeightsAndPlaceholders(self):
- weights = np.asarray([1.2, 3.4]).reshape((2, 1))
- expected_losses = np.multiply(weights, self._expected_losses)
-
- tf_predictions = array_ops.placeholder(
- dtypes.float32, shape=self._predictions.shape)
- tf_labels = array_ops.placeholder(dtypes.int32, shape=self._labels.shape)
- loss = losses.mean_pairwise_squared_error(
- predictions=tf_predictions,
- labels=tf_labels,
- weights=constant_op.constant(
- weights, shape=[2]))
-
- with self.test_session() as sess:
- loss = sess.run(loss,
- feed_dict={
- tf_predictions: self._predictions,
- tf_labels: self._labels,
- })
- self.assertAlmostEqual(np.sum(expected_losses), loss, 3)
+ self._test_mean_pairwise_squared_error(
+ self._labels, self._predictions, expected_loss=0.0,
+ weights=np.zeros((2,)))
def testLossWithAllZeroBatchSpecificWeights(self):
- weights = np.zeros((2, 1))
- loss = losses.mean_pairwise_squared_error(
- predictions=constant_op.constant(self._predictions),
- labels=constant_op.constant(self._labels),
- weights=constant_op.constant(
- weights, shape=[2]))
- with self.test_session():
- self.assertAlmostEqual(0.0, loss.eval(), 3)
+ self._test_mean_pairwise_squared_error(
+ self._labels, self._predictions, expected_loss=0.0,
+ weights=np.zeros((2, 1)))
class CosineDistanceLossTest(test.TestCase):
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index 2a06753f3c..bd6b87a3a5 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -247,10 +247,10 @@ def absolute_difference(
if the shape of `weights` is invalid.
"""
with ops.name_scope(scope, "absolute_difference",
- [predictions, labels, weights]) as scope:
- predictions.get_shape().assert_is_compatible_with(labels.get_shape())
+ (predictions, labels, weights)) as scope:
predictions = math_ops.to_float(predictions)
labels = math_ops.to_float(labels)
+ predictions.get_shape().assert_is_compatible_with(labels.get_shape())
losses = math_ops.abs(math_ops.subtract(predictions, labels))
return compute_weighted_loss(losses, weights, scope, loss_collection)
@@ -288,11 +288,10 @@ def cosine_distance(
if dim is None:
raise ValueError("`dim` cannot be None.")
with ops.name_scope(scope, "cosine_distance_loss",
- [predictions, labels, weights]) as scope:
- predictions.get_shape().assert_is_compatible_with(labels.get_shape())
-
+ (predictions, labels, weights)) as scope:
predictions = math_ops.to_float(predictions)
labels = math_ops.to_float(labels)
+ predictions.get_shape().assert_is_compatible_with(labels.get_shape())
radial_diffs = math_ops.multiply(predictions, labels)
losses = 1 - math_ops.reduce_sum(radial_diffs, reduction_indices=[dim,])
@@ -324,10 +323,11 @@ def hinge_loss(labels, logits, weights=1.0, scope=None,
Raises:
ValueError: If the shapes of `logits` and `labels` don't match.
"""
- with ops.name_scope(scope, "hinge_loss", [logits, labels]) as scope:
+ with ops.name_scope(scope, "hinge_loss", (logits, labels)) as scope:
+ logits = math_ops.to_float(logits)
+ labels = math_ops.to_float(labels)
logits.get_shape().assert_is_compatible_with(labels.get_shape())
# We first need to convert binary labels to -1/1 labels (as floats).
- labels = math_ops.to_float(labels)
all_ones = array_ops.ones_like(labels)
labels = math_ops.subtract(2 * labels, all_ones)
losses = nn_ops.relu(
@@ -370,10 +370,10 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None,
if the shape of `weights` is invalid.
"""
with ops.name_scope(scope, "log_loss",
- [predictions, labels, weights]) as scope:
- predictions.get_shape().assert_is_compatible_with(labels.get_shape())
+ (predictions, labels, weights)) as scope:
predictions = math_ops.to_float(predictions)
labels = math_ops.to_float(labels)
+ predictions.get_shape().assert_is_compatible_with(labels.get_shape())
losses = -math_ops.multiply(
labels,
math_ops.log(predictions + epsilon)) - math_ops.multiply(
@@ -424,10 +424,10 @@ def mean_pairwise_squared_error(labels, predictions, weights=1.0, scope=None,
if the shape of `weights` is invalid.
"""
with ops.name_scope(scope, "mean_pairwise_squared_error",
- [predictions, labels, weights]) as scope:
- predictions.get_shape().assert_is_compatible_with(labels.get_shape())
+ (predictions, labels, weights)) as scope:
predictions = math_ops.to_float(predictions)
labels = math_ops.to_float(labels)
+ predictions.get_shape().assert_is_compatible_with(labels.get_shape())
weights = math_ops.to_float(ops.convert_to_tensor(weights))
diffs = math_ops.subtract(predictions, labels)
@@ -496,10 +496,10 @@ def mean_squared_error(labels, predictions, weights=1.0, scope=None,
if the shape of `weights` is invalid.
"""
with ops.name_scope(scope, "mean_squared_error",
- [predictions, labels, weights]) as scope:
- predictions.get_shape().assert_is_compatible_with(labels.get_shape())
+ (predictions, labels, weights)) as scope:
predictions = math_ops.to_float(predictions)
labels = math_ops.to_float(labels)
+ predictions.get_shape().assert_is_compatible_with(labels.get_shape())
losses = math_ops.square(math_ops.subtract(predictions, labels))
return compute_weighted_loss(losses, weights, scope, loss_collection)
@@ -544,10 +544,10 @@ def sigmoid_cross_entropy(
`weights` is None.
"""
with ops.name_scope(scope, "sigmoid_cross_entropy_loss",
- [logits, multi_class_labels, weights]) as scope:
- logits.get_shape().assert_is_compatible_with(multi_class_labels.get_shape())
-
+ (logits, multi_class_labels, weights)) as scope:
+ logits = ops.convert_to_tensor(logits)
multi_class_labels = math_ops.cast(multi_class_labels, logits.dtype)
+ logits.get_shape().assert_is_compatible_with(multi_class_labels.get_shape())
if label_smoothing > 0:
multi_class_labels = (multi_class_labels * (1 - label_smoothing) +
@@ -595,10 +595,10 @@ def softmax_cross_entropy(
or if the shape of `weights` is invalid or if `weights` is None.
"""
with ops.name_scope(scope, "softmax_cross_entropy_loss",
- [logits, onehot_labels, weights]) as scope:
- logits.get_shape().assert_is_compatible_with(onehot_labels.get_shape())
-
+ (logits, onehot_labels, weights)) as scope:
+ logits = ops.convert_to_tensor(logits)
onehot_labels = math_ops.cast(onehot_labels, logits.dtype)
+ logits.get_shape().assert_is_compatible_with(onehot_labels.get_shape())
if label_smoothing > 0:
num_classes = math_ops.cast(