diff options
author | Alexandre Passos <apassos@google.com> | 2018-09-24 16:13:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 16:19:42 -0700 |
commit | 6c40bc717442d56f0b6a60658b05f0549afd69ee (patch) | |
tree | 2d9a179e074e6d0ed7beec2ff3f14f0796bc0107 /tensorflow/python/kernel_tests | |
parent | d25b23d5ec6a0a7828e86fa8868f7a6574f9f827 (diff) |
BEGIN_PUBLIC
Temporary rollback to fix forward compatibility.
END_PUBLIC
Automated rollback of commit 0c48c703c3c1455cf3b2c0e47e2108e053ff83e2. Revert #21798.
PiperOrigin-RevId: 214349479
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r-- | tensorflow/python/kernel_tests/losses_test.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py index fb0b5f1137..3ce0b74263 100644 --- a/tensorflow/python/kernel_tests/losses_test.py +++ b/tensorflow/python/kernel_tests/losses_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -33,11 +34,25 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses +from tensorflow.python.ops.losses import losses_impl from tensorflow.python.ops.losses import util from tensorflow.python.platform import test from tensorflow.python.training import momentum as momentum_lib +safe_div = losses_impl._safe_div # pylint: disable=protected-access + + +class SafeDivTest(test.TestCase): + + def testEager(self): + with context.eager_mode(): + self.assertAllEqual(safe_div(constant_op.constant(1.0), + constant_op.constant(0.0)), 0.0) + self.assertAllEqual(safe_div(constant_op.constant(1.0), + 0.0), 0.0) + + class AbsoluteDifferenceLossTest(test.TestCase): def setUp(self): |