aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-17 07:41:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-17 07:43:45 -0700
commitfb7675e06d6b5ee1d45dcd4eda64a4caa689e393 (patch)
tree7bc99314abe91bc7703a3739d004f61373bde470
parentf73d793e7a9234efb14fd8f11322429d122949b1 (diff)
Add uint32/uint64 support to Gather op.
PiperOrigin-RevId: 193195939
-rw-r--r--tensorflow/core/kernels/gather_op.cc2
-rw-r--r--tensorflow/python/kernel_tests/gather_op_test.py9
2 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc
index 08adf4badb..ef332ebee3 100644
--- a/tensorflow/core/kernels/gather_op.cc
+++ b/tensorflow/core/kernels/gather_op.cc
@@ -143,6 +143,8 @@ TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU);
TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
TF_CALL_quint16(REGISTER_GATHER_CPU);
TF_CALL_qint16(REGISTER_GATHER_CPU);
+TF_CALL_uint32(REGISTER_GATHER_CPU);
+TF_CALL_uint64(REGISTER_GATHER_CPU);
#undef REGISTER_GATHER_CPU
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py
index 9a94692569..a2fcd751df 100644
--- a/tensorflow/python/kernel_tests/gather_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_op_test.py
@@ -149,6 +149,15 @@ class GatherTest(test.TestCase):
self.assertAllEqual([b"asdf", b"qwer"],
array_ops.gather(params, 0, axis=1).eval())
+ def testUInt32AndUInt64(self):
+ for unsigned_type in (dtypes.uint32, dtypes.uint64):
+ params = self._buildParams(
+ np.array([[1, 2, 3], [7, 8, 9]]), unsigned_type)
+ with self.test_session():
+ self.assertAllEqual([7, 8, 9],
+ array_ops.gather(params, 1, axis=0).eval())
+ self.assertAllEqual([1, 7], array_ops.gather(params, 0, axis=1).eval())
+
def testUnknownIndices(self):
params = constant_op.constant([[0, 1, 2]])
indices = array_ops.placeholder(dtypes.int32)