aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/losses
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-01-30 11:18:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-30 12:36:56 -0800
commit29d24237b0e29d83478182ad219da478ee2135c1 (patch)
treebf983b721d4d3007edfeffb2bcc9c7d32eb0a67a /tensorflow/contrib/losses
parentb970652bf714cbf676fdd84256cb128afc2b1306 (diff)
Make loss_ops_test.py work with C API enabled.
PiperOrigin-RevId: 183861779
Diffstat (limited to 'tensorflow/contrib/losses')
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops_test.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
index 9d0f95e6f3..1417772e04 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@@ -274,6 +275,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
+@test_util.with_c_api
class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testNoneWeightRaisesValueError(self):
@@ -471,7 +473,11 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
labels = constant_op.constant([[0, 1], [2, 3]])
weights = constant_op.constant([1.2, 3.4, 5.6, 7.8])
- with self.assertRaises(errors_impl.InvalidArgumentError):
+ if ops._USE_C_API:
+ error_type = ValueError
+ else:
+ error_type = errors_impl.InvalidArgumentError
+ with self.assertRaises(error_type):
loss_ops.sparse_softmax_cross_entropy(
logits, labels, weights=weights).eval()