diff options
Diffstat (limited to 'tensorflow/compiler/tests/gather_test.py')
-rw-r--r-- | tensorflow/compiler/tests/gather_test.py | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py index 089d95daab..a38e1edafe 100644 --- a/tensorflow/compiler/tests/gather_test.py +++ b/tensorflow/compiler/tests/gather_test.py @@ -51,7 +51,7 @@ class GatherTest(xla_test.XLATestCase): indices_tf = constant_op.constant(indices) gather_t = array_ops.gather(params, indices_tf) gather_val = session.run(gather_t, feed_dict={params: params_np}) - np_val = params_np[indices] + np_val = constant_op.constant(params_np[indices]) self.assertAllEqual(np_val, gather_val) def testScalar2D(self): @@ -65,7 +65,8 @@ class GatherTest(xla_test.XLATestCase): indices = constant_op.constant(2) gather_t = array_ops.gather(params, indices, axis=axis) gather_val = session.run(gather_t, feed_dict={params: params_np}) - expected = np.take(params_np, 2, axis=axis) + expected = constant_op.constant( + np.take(params_np, 2, axis=axis), dtype) self.assertAllEqual(expected, gather_val) def testSimpleTwoD32(self): @@ -80,7 +81,8 @@ class GatherTest(xla_test.XLATestCase): indices = constant_op.constant([0, 1, 0, 2]) gather_t = array_ops.gather(params, indices, axis=axis) gather_val = session.run(gather_t, feed_dict={params: params_np}) - expected = np.take(params_np, [0, 1, 0, 2], axis=axis) + expected = constant_op.constant( + np.take(params_np, [0, 1, 0, 2], axis=axis), dtype) self.assertAllEqual(expected, gather_val) def testSimpleTwoD32_Int64Indices(self): @@ -103,7 +105,8 @@ class GatherTest(xla_test.XLATestCase): params: params_np, indices: indices_np }) - expected = np.take(params_np, [0, 1, 0, 2], axis=axis) + expected = constant_op.constant( + np.take(params_np, [0, 1, 0, 2], axis=axis), dtype) self.assertAllEqual(expected, gather_val) def testHigherRank(self): @@ -119,7 +122,8 @@ class GatherTest(xla_test.XLATestCase): tf_indices = constant_op.constant(indices, dtype=dtypes.int32) gather = array_ops.gather(tf_params, tf_indices, axis=axis) gather_value = sess.run(gather, feed_dict={tf_params: params}) - gather_np = np.take(params, indices, axis=axis) + gather_np = constant_op.constant( + np.take(params, indices, axis=axis), dtype) self.assertAllEqual(gather_np, gather_value) def testIndicesWithDifferentDimensions(self): |