aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-09-24 16:13:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 16:19:42 -0700
commit6c40bc717442d56f0b6a60658b05f0549afd69ee (patch)
tree2d9a179e074e6d0ed7beec2ff3f14f0796bc0107 /tensorflow/python/kernel_tests
parentd25b23d5ec6a0a7828e86fa8868f7a6574f9f827 (diff)
BEGIN_PUBLIC
Temporary rollback to fix forward compatibility. END_PUBLIC Automated rollback of commit 0c48c703c3c1455cf3b2c0e47e2108e053ff83e2. Revert #21798. PiperOrigin-RevId: 214349479
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r--tensorflow/python/kernel_tests/losses_test.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py
index fb0b5f1137..3ce0b74263 100644
--- a/tensorflow/python/kernel_tests/losses_test.py
+++ b/tensorflow/python/kernel_tests/losses_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@@ -33,11 +34,25 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
+from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.ops.losses import util
from tensorflow.python.platform import test
from tensorflow.python.training import momentum as momentum_lib
+safe_div = losses_impl._safe_div # pylint: disable=protected-access
+
+
+class SafeDivTest(test.TestCase):
+
+ def testEager(self):
+ with context.eager_mode():
+ self.assertAllEqual(safe_div(constant_op.constant(1.0),
+ constant_op.constant(0.0)), 0.0)
+ self.assertAllEqual(safe_div(constant_op.constant(1.0),
+ 0.0), 0.0)
+
+
class AbsoluteDifferenceLossTest(test.TestCase):
def setUp(self):