From af8da61ad4b688a7bedb4ba1e0365735c9f25b14 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 1 Oct 2017 09:19:35 -0700 Subject: Make DynamicStitch's shape function handle the case where all inputs are constant. PiperOrigin-RevId: 170637740 --- tensorflow/core/ops/data_flow_ops.cc | 26 +++++++++-- tensorflow/core/ops/data_flow_ops_test.cc | 28 +++++++++++- .../python/kernel_tests/dynamic_stitch_op_test.py | 51 +++++++++++----------- 3 files changed, 74 insertions(+), 31 deletions(-) diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index 2209ecf1de..8e24ea70cb 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -133,17 +133,23 @@ num_partitions: The number of partitions to output. namespace { Status DynamicStitchShapeFunction(InferenceContext* c) { - int64 num_partitions; + int32 num_partitions; TF_RETURN_IF_ERROR(c->GetAttr("N", &num_partitions)); + bool all_indices_constant = true; + int32 max_index = 0; ShapeHandle extra_shape = c->UnknownShape(); - for (int64 i = 0; i < num_partitions; ++i) { + for (int i = 0; i < num_partitions; ++i) { + const Tensor* indices_t = c->input_tensor(i); + if (indices_t == nullptr) { + all_indices_constant = false; + } + ShapeHandle indices_shape = c->input(i); ShapeHandle data_shape = c->input(i + num_partitions); if (!c->RankKnown(indices_shape)) { continue; } - const int64 indices_rank = c->Rank(indices_shape); // Assert that data_shape starts with indices_shape. @@ -155,9 +161,21 @@ Status DynamicStitchShapeFunction(InferenceContext* c) { ShapeHandle rest; TF_RETURN_IF_ERROR(c->Subshape(data_shape, indices_rank, &rest)); TF_RETURN_IF_ERROR(c->Merge(extra_shape, rest, &extra_shape)); + + if (indices_t != nullptr) { + // The length is based on the highest index from flattened indices. + const int32* indices = indices_t->flat().data(); + int64 count = indices_t->NumElements(); + for (int64 i = 0; i < count; ++i) { + if (indices[i] > max_index) { + max_index = indices[i]; + } + } + } } - ShapeHandle output_shape = c->Vector(c->UnknownDim()); + ShapeHandle output_shape = c->Vector( + all_indices_constant ? c->MakeDim(max_index + 1) : c->UnknownDim()); TF_RETURN_IF_ERROR(c->Concatenate(output_shape, extra_shape, &output_shape)); c->set_output(0, output_shape); return Status::OK(); diff --git a/tensorflow/core/ops/data_flow_ops_test.cc b/tensorflow/core/ops/data_flow_ops_test.cc index 9c94d9aac9..a071eac453 100644 --- a/tensorflow/core/ops/data_flow_ops_test.cc +++ b/tensorflow/core/ops/data_flow_ops_test.cc @@ -126,8 +126,6 @@ TEST(DataFlowOpsTest, DynamicStitch) { .Attr("N", 2) .Finalize(&op.node_def)); - INFER_OK(op, "[2,3];[5,6];[2,3,4,5];[5,6,4,5]", "[?,d2_2,d2_3]"); - // Bad prefix for the second data input. INFER_ERROR("Dimensions must be equal, but are 10 and 5", op, "[2,3];[5,6];[2,3,4,5];[10,11,4,5]"); @@ -135,6 +133,32 @@ TEST(DataFlowOpsTest, DynamicStitch) { // Inconsistent suffix dimensions INFER_ERROR("Dimension 0 in both shapes must be equal, but are 4 and 13", op, "[2,3];[5,6];[2,3,4,5];[5,6,13,14]"); + + // Good case, but no known input tensors. + INFER_OK(op, "[2,3];[5,6];[2,3,4,5];[5,6,4,5]", "[?,d2_2,d2_3]"); + + // 1 known input tensors, not enough to change answer. + Tensor tensor_2 = test::AsTensor( + std::vector{2, 4, 6, 0, 10, 11}, TensorShape({2, 3})); + Tensor tensor_5 = test::AsTensor( + std::vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 1000, 21, 22, 23, 24, 25, 26, 27, 28, 29}, + TensorShape({5, 6})); + op.input_tensors.push_back(nullptr); + op.input_tensors.push_back(&tensor_5); + INFER_OK(op, "[2,3];[5,6];[2,3,4,5];[5,6,4,5]", "[?,d2_2,d2_3]"); + + op.input_tensors[0] = &tensor_2; + op.input_tensors[1] = nullptr; + INFER_OK(op, "[2,3];[5,6];[2,3,4,5];[5,6,4,5]", "[?,d2_2,d2_3]"); + INFER_OK(op, "[2,3];?;[2,3,4,5];[5,6,4,5]", "[?,d2_2,d2_3]"); + + op.input_tensors[1] = &tensor_5; + INFER_OK(op, "[2,3];[5,6];[2,3,4,5];[5,6,4,5]", "[1001,d2_2,d2_3]"); + + tensor_2.flat()(3) = 10000; + INFER_OK(op, "[2,3];[5,6];[2,3,4,5];[5,6,4,5]", "[10001,d2_2,d2_3]"); } TEST(DataFlowOpsTest, ParallelDynamicStitch) { diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py index 9b9aa98b37..cf723f5eec 100644 --- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py +++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gradients_impl import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import @@ -42,8 +43,18 @@ class DynamicStitchTestBase(object): stitched_t = self.stitch_op(indices[::step], data) stitched_val = stitched_t.eval() self.assertAllEqual([40, 60][::step], stitched_val) - # Dimension 0 is determined by the max index in indices, so we - # can only infer that the output is a vector of some unknown + # Dimension 0 is max(flatten(indices))+1. + self.assertEqual([2], stitched_t.get_shape().as_list()) + + def testShapeInferenceForScalarWithNonConstantIndices(self): + with self.test_session(use_gpu=True): + indices = [array_ops.placeholder(dtype=dtypes.int32), + constant_op.constant(1)] + data = [constant_op.constant(40), constant_op.constant(60)] + for step in -1, 1: + stitched_t = self.stitch_op(indices[::step], data) + # Dimension 0 is max(flatten(indices))+1, but the first indices input is + # not a constant tensor, so we can only infer it as a vector of unknown # length. self.assertEqual([None], stitched_t.get_shape().as_list()) @@ -59,10 +70,8 @@ class DynamicStitchTestBase(object): stitched_t = self.stitch_op(indices, data) stitched_val = stitched_t.eval() self.assertAllEqual([0, 10, 20, 30, 40, 50, 60, 70], stitched_val) - # Dimension 0 is determined by the max index in indices, so we - # can only infer that the output is a vector of some unknown - # length. - self.assertEqual([None], stitched_t.get_shape().as_list()) + # Dimension 0 is max(flatten(indices))+1. + self.assertEqual([8], stitched_t.get_shape().as_list()) def testOneListOneDimensional(self): with self.test_session(use_gpu=True): @@ -71,10 +80,8 @@ class DynamicStitchTestBase(object): stitched_t = self.stitch_op(indices, data) stitched_val = stitched_t.eval() self.assertAllEqual([0, 10, 20, 30, 40, 50, 60, 70], stitched_val) - # Dimension 0 is determined by the max index in indices, so we - # can only infer that the output is a vector of some unknown - # length. - self.assertEqual([None], stitched_t.get_shape().as_list()) + # Dimension 0 is max(flatten(indices))+1. + self.assertEqual([8], stitched_t.get_shape().as_list()) def testSimpleTwoDimensional(self): with self.test_session(use_gpu=True): @@ -91,10 +98,8 @@ class DynamicStitchTestBase(object): stitched_val = stitched_t.eval() self.assertAllEqual([[0, 1], [10, 11], [20, 21], [30, 31], [40, 41], [50, 51], [60, 61], [70, 71]], stitched_val) - # Dimension 0 is determined by the max index in indices, so we - # can only infer that the output is a matrix with 2 columns and - # some unknown number of rows. - self.assertEqual([None, 2], stitched_t.get_shape().as_list()) + # Dimension 0 is max(flatten(indices))+1. + self.assertEqual([8, 2], stitched_t.get_shape().as_list()) def testHigherRank(self): with self.test_session(use_gpu=True) as sess: @@ -111,7 +116,7 @@ class DynamicStitchTestBase(object): stitched_val = stitched_t.eval() correct = 10 * np.arange(7)[:, None] + [1, 2] self.assertAllEqual(correct, stitched_val) - self.assertEqual([None, 2], stitched_t.get_shape().as_list()) + self.assertEqual([7, 2], stitched_t.get_shape().as_list()) # Test gradients stitched_grad = 7 * stitched_val grads = gradients_impl.gradients(stitched_t, indices + data, @@ -186,10 +191,8 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase): stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data) stitched_val = stitched_t.eval() self.assertAllEqual([40.0, 60.0][::step], stitched_val) - # Dimension 0 is determined by the max index in indices, so we - # can only infer that the output is a vector of some unknown - # length. - self.assertEqual([None], stitched_t.get_shape().as_list()) + # Dimension 0 is max(flatten(indices))+1. + self.assertEqual([2], stitched_t.get_shape().as_list()) def testHigherRank(self): with self.test_session(use_gpu=True) as sess: @@ -208,7 +211,7 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase): stitched_val = stitched_t.eval() correct = 10 * np.arange(7)[:, None] + [1.0, 2.0] self.assertAllEqual(correct, stitched_val) - self.assertEqual([None, 2], stitched_t.get_shape().as_list()) + self.assertEqual([7, 2], stitched_t.get_shape().as_list()) # Test gradients stitched_grad = 7 * stitched_val grads = gradients_impl.gradients(stitched_t, indices + data, @@ -226,10 +229,8 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase): stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data) stitched_val = stitched_t.eval() self.assertAllEqual([40.0, 60.0][::step], stitched_val) - # Dimension 0 is determined by the max index in indices, so we - # can only infer that the output is a vector of some unknown - # length. - self.assertEqual([None], stitched_t.get_shape().as_list()) + # Dimension 0 is max(flatten(indices))+1. + self.assertEqual([2], stitched_t.get_shape().as_list()) def testHigherRankGPU(self): with self.test_session() as sess: @@ -246,7 +247,7 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase): stitched_val = stitched_t.eval() correct = 10 * np.arange(7)[:, None] + [1.0, 2.0] self.assertAllEqual(correct, stitched_val) - self.assertEqual([None, 2], stitched_t.get_shape().as_list()) + self.assertEqual([7, 2], stitched_t.get_shape().as_list()) # Test gradients stitched_grad = 7 * stitched_val grads = gradients_impl.gradients(stitched_t, indices + data, -- cgit v1.2.3