aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests
diff options
context:
space:
mode:
authorGravatar Tony Wang <tonywy@google.com>2018-06-28 14:04:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-28 14:08:54 -0700
commit15b258dc93fddd39e1f86a6768a7fc9fe70e0f62 (patch)
treed570e7bdca76f564d798d243d68f6a5dd930ee78 /tensorflow/compiler/tests
parentf5606a442ae2097beea77db9d0f517125696cb84 (diff)
Automated g4 rollback of changelist 201419522
PiperOrigin-RevId: 202539762
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r--tensorflow/compiler/tests/gather_test.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py
index 1a8c451911..e9c8ef7c91 100644
--- a/tensorflow/compiler/tests/gather_test.py
+++ b/tensorflow/compiler/tests/gather_test.py
@@ -136,6 +136,20 @@ class GatherTest(xla_test.XLATestCase):
self.assertAllEqual(
[[7]], gather.eval(feed_dict={params: [4, 7, 2], indices: [[1]]}))
+ def testGatherPrecision(self):
+ with self.test_session() as session, self.test_scope():
+ data = np.array([[0, 0, 0, 0], [0, 2 * (1 + np.exp2(-8)), 0, 0],
+ [0, 0, 0, 0], [0.015789, 0.0985, 0.55789, 0.3842]])
+ indices = np.array([1, 2, 3, 1])
+ dtype = dtypes.float32
+ params_np = self._buildParams(data, dtype)
+ params = array_ops.placeholder(dtype=dtype)
+ indices_tf = constant_op.constant(indices)
+ gather_t = array_ops.gather(params, indices_tf)
+ gather_val = session.run(gather_t, feed_dict={params: params_np})
+ np_val = params_np[indices]
+ self.assertAllEqual(np_val, gather_val)
+
class GatherBenchmark(test.Benchmark):
"""Microbenchmarks for the gather op."""