aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-08-21 08:52:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 08:59:26 -0700
commit9f59beb67643953d87e7673fa0000cc775562693 (patch)
tree3f30145de2fd5b7969bfc157d76e9b9ca287c9b2
parent5b456c9ab567c7d9262c57f32693ff33a87946e6 (diff)
Allow DT_INT64 input shapes for ReshapeOp.
The fix itself is simple but I'm not sure if this how we should be testing auto-jit. PiperOrigin-RevId: 209602427
-rw-r--r--tensorflow/compiler/tests/BUILD13
-rw-r--r--tensorflow/compiler/tests/reshape_op_test.py48
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reshape_op.cc6
3 files changed, 64 insertions, 3 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index ae98b3f0f9..47311d2630 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -388,6 +388,19 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "reshape_op_test",
+ size = "small",
+ srcs = ["reshape_op_test.py"],
+ deps = [
+ "//tensorflow/compiler/tests:xla_test",
+ "//tensorflow/compiler/tf2xla/python:xla",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dtypes",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+tf_xla_py_test(
name = "dynamic_stitch_test",
size = "small",
srcs = ["dynamic_stitch_test.py"],
diff --git a/tensorflow/compiler/tests/reshape_op_test.py b/tensorflow/compiler/tests/reshape_op_test.py
new file mode 100644
index 0000000000..8aa312cbc1
--- /dev/null
+++ b/tensorflow/compiler/tests/reshape_op_test.py
@@ -0,0 +1,48 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slicing."""
+
+from __future__ import absolute_import
+
+from absl.testing import parameterized
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import googletest
+
+
+class ReshapeTest(xla_test.XLATestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(('32_bit_index', dtypes.int32),
+ ('64_bit_index', dtypes.int64))
+ def testBasic(self, index_dtype):
+ for dtype in self.numeric_types:
+ with self.test_session():
+ i = array_ops.placeholder(dtype, shape=[2, 3])
+ with self.test_scope():
+ shape = constant_op.constant([3, 2], dtype=index_dtype)
+ o = array_ops.reshape(i, shape)
+ params = {
+ i: [[1, 2, 3], [4, 5, 6]],
+ }
+ result = o.eval(feed_dict=params)
+
+ self.assertAllEqual([[1, 2], [3, 4], [5, 6]], result)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
index 121750a82a..366ce42866 100644
--- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
@@ -41,8 +41,8 @@ class ReshapeOp : public XlaOpKernel {
sizes_shape.DebugString()));
const int64 num_dims = sizes_shape.num_elements();
- xla::Literal literal;
- OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal));
+ std::vector<int64> shape_input;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input));
// Compute the output shape. Determine product of specified
// dimensions, and find the index of the unspecified one if there
@@ -51,7 +51,7 @@ class ReshapeOp : public XlaOpKernel {
int64 product = 1;
int unknown_index = -1;
for (int d = 0; d < num_dims; ++d) {
- const int32 size = literal.Get<int>({d});
+ const int32 size = shape_input[d];
if (size == -1) {
OP_REQUIRES(
ctx, unknown_index == -1,