aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-01 09:19:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-01 09:47:33 -0700
commitaf8da61ad4b688a7bedb4ba1e0365735c9f25b14 (patch)
tree37fb6942545a14fda79814278a22311496d74caf
parent418fac23f1355fe886fec94f161609c2fa080c7b (diff)
Make DynamicStitch's shape function handle the case where all inputs are
constant. PiperOrigin-RevId: 170637740
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc26
-rw-r--r--tensorflow/core/ops/data_flow_ops_test.cc28
-rw-r--r--tensorflow/python/kernel_tests/dynamic_stitch_op_test.py51
3 files changed, 74 insertions, 31 deletions
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index 2209ecf1de..8e24ea70cb 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -133,17 +133,23 @@ num_partitions: The number of partitions to output.
namespace {
Status DynamicStitchShapeFunction(InferenceContext* c) {
- int64 num_partitions;
+ int32 num_partitions;
TF_RETURN_IF_ERROR(c->GetAttr("N", &num_partitions));
+ bool all_indices_constant = true;
+ int32 max_index = 0;
ShapeHandle extra_shape = c->UnknownShape();
- for (int64 i = 0; i < num_partitions; ++i) {
+ for (int i = 0; i < num_partitions; ++i) {
+ const Tensor* indices_t = c->input_tensor(i);
+ if (indices_t == nullptr) {
+ all_indices_constant = false;
+ }
+
ShapeHandle indices_shape = c->input(i);
ShapeHandle data_shape = c->input(i + num_partitions);
if (!c->RankKnown(indices_shape)) {
continue;
}
-
const int64 indices_rank = c->Rank(indices_shape);
// Assert that data_shape starts with indices_shape.
@@ -155,9 +161,21 @@ Status DynamicStitchShapeFunction(InferenceContext* c) {
ShapeHandle rest;
TF_RETURN_IF_ERROR(c->Subshape(data_shape, indices_rank, &rest));
TF_RETURN_IF_ERROR(c->Merge(extra_shape, rest, &extra_shape));
+
+ if (indices_t != nullptr) {
+ // The length is based on the highest index from flattened indices.
+ const int32* indices = indices_t->flat<int32>().data();
+ int64 count = indices_t->NumElements();
+ for (int64 i = 0; i < count; ++i) {
+ if (indices[i] > max_index) {
+ max_index = indices[i];
+ }
+ }
+ }
}
- ShapeHandle output_shape = c->Vector(c->UnknownDim());
+ ShapeHandle output_shape = c->Vector(
+ all_indices_constant ? c->MakeDim(max_index + 1) : c->UnknownDim());
TF_RETURN_IF_ERROR(c->Concatenate(output_shape, extra_shape, &output_shape));
c->set_output(0, output_shape);
return Status::OK();
diff --git a/tensorflow/core/ops/data_flow_ops_test.cc b/tensorflow/core/ops/data_flow_ops_test.cc
index 9c94d9aac9..a071eac453 100644
--- a/tensorflow/core/ops/data_flow_ops_test.cc
+++ b/tensorflow/core/ops/data_flow_ops_test.cc
@@ -126,8 +126,6 @@ TEST(DataFlowOpsTest, DynamicStitch) {
.Attr("N", 2)
.Finalize(&op.node_def));
- INFER_OK(op, "[2,3];[5,6];[2,3,4,5];[5,6,4,5]", "[?,d2_2,d2_3]");
-
// Bad prefix for the second data input.
INFER_ERROR("Dimensions must be equal, but are 10 and 5", op,
"[2,3];[5,6];[2,3,4,5];[10,11,4,5]");
@@ -135,6 +133,32 @@ TEST(DataFlowOpsTest, DynamicStitch) {
// Inconsistent suffix dimensions
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 4 and 13", op,
"[2,3];[5,6];[2,3,4,5];[5,6,13,14]");
+
+ // Good case, but no known input tensors.
+ INFER_OK(op, "[2,3];[5,6];[2,3,4,5];[5,6,4,5]", "[?,d2_2,d2_3]");
+
+ // 1 known input tensors, not enough to change answer.
+ Tensor tensor_2 = test::AsTensor<int32>(
+ std::vector<int32>{2, 4, 6, 0, 10, 11}, TensorShape({2, 3}));
+ Tensor tensor_5 = test::AsTensor<int32>(
+ std::vector<int32>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
+ 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
+ 1000, 21, 22, 23, 24, 25, 26, 27, 28, 29},
+ TensorShape({5, 6}));
+ op.input_tensors.push_back(nullptr);
+ op.input_tensors.push_back(&tensor_5);
+ INFER_OK(op, "[2,3];[5,6];[2,3,4,5];[5,6,4,5]", "[?,d2_2,d2_3]");
+
+ op.input_tensors[0] = &tensor_2;
+ op.input_tensors[1] = nullptr;
+ INFER_OK(op, "[2,3];[5,6];[2,3,4,5];[5,6,4,5]", "[?,d2_2,d2_3]");
+ INFER_OK(op, "[2,3];?;[2,3,4,5];[5,6,4,5]", "[?,d2_2,d2_3]");
+
+ op.input_tensors[1] = &tensor_5;
+ INFER_OK(op, "[2,3];[5,6];[2,3,4,5];[5,6,4,5]", "[1001,d2_2,d2_3]");
+
+ tensor_2.flat<int32>()(3) = 10000;
+ INFER_OK(op, "[2,3];[5,6];[2,3,4,5];[5,6,4,5]", "[10001,d2_2,d2_3]");
}
TEST(DataFlowOpsTest, ParallelDynamicStitch) {
diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
index 9b9aa98b37..cf723f5eec 100644
--- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gradients_impl
import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import
@@ -42,8 +43,18 @@ class DynamicStitchTestBase(object):
stitched_t = self.stitch_op(indices[::step], data)
stitched_val = stitched_t.eval()
self.assertAllEqual([40, 60][::step], stitched_val)
- # Dimension 0 is determined by the max index in indices, so we
- # can only infer that the output is a vector of some unknown
+ # Dimension 0 is max(flatten(indices))+1.
+ self.assertEqual([2], stitched_t.get_shape().as_list())
+
+ def testShapeInferenceForScalarWithNonConstantIndices(self):
+ with self.test_session(use_gpu=True):
+ indices = [array_ops.placeholder(dtype=dtypes.int32),
+ constant_op.constant(1)]
+ data = [constant_op.constant(40), constant_op.constant(60)]
+ for step in -1, 1:
+ stitched_t = self.stitch_op(indices[::step], data)
+ # Dimension 0 is max(flatten(indices))+1, but the first indices input is
+ # not a constant tensor, so we can only infer it as a vector of unknown
# length.
self.assertEqual([None], stitched_t.get_shape().as_list())
@@ -59,10 +70,8 @@ class DynamicStitchTestBase(object):
stitched_t = self.stitch_op(indices, data)
stitched_val = stitched_t.eval()
self.assertAllEqual([0, 10, 20, 30, 40, 50, 60, 70], stitched_val)
- # Dimension 0 is determined by the max index in indices, so we
- # can only infer that the output is a vector of some unknown
- # length.
- self.assertEqual([None], stitched_t.get_shape().as_list())
+ # Dimension 0 is max(flatten(indices))+1.
+ self.assertEqual([8], stitched_t.get_shape().as_list())
def testOneListOneDimensional(self):
with self.test_session(use_gpu=True):
@@ -71,10 +80,8 @@ class DynamicStitchTestBase(object):
stitched_t = self.stitch_op(indices, data)
stitched_val = stitched_t.eval()
self.assertAllEqual([0, 10, 20, 30, 40, 50, 60, 70], stitched_val)
- # Dimension 0 is determined by the max index in indices, so we
- # can only infer that the output is a vector of some unknown
- # length.
- self.assertEqual([None], stitched_t.get_shape().as_list())
+ # Dimension 0 is max(flatten(indices))+1.
+ self.assertEqual([8], stitched_t.get_shape().as_list())
def testSimpleTwoDimensional(self):
with self.test_session(use_gpu=True):
@@ -91,10 +98,8 @@ class DynamicStitchTestBase(object):
stitched_val = stitched_t.eval()
self.assertAllEqual([[0, 1], [10, 11], [20, 21], [30, 31], [40, 41],
[50, 51], [60, 61], [70, 71]], stitched_val)
- # Dimension 0 is determined by the max index in indices, so we
- # can only infer that the output is a matrix with 2 columns and
- # some unknown number of rows.
- self.assertEqual([None, 2], stitched_t.get_shape().as_list())
+ # Dimension 0 is max(flatten(indices))+1.
+ self.assertEqual([8, 2], stitched_t.get_shape().as_list())
def testHigherRank(self):
with self.test_session(use_gpu=True) as sess:
@@ -111,7 +116,7 @@ class DynamicStitchTestBase(object):
stitched_val = stitched_t.eval()
correct = 10 * np.arange(7)[:, None] + [1, 2]
self.assertAllEqual(correct, stitched_val)
- self.assertEqual([None, 2], stitched_t.get_shape().as_list())
+ self.assertEqual([7, 2], stitched_t.get_shape().as_list())
# Test gradients
stitched_grad = 7 * stitched_val
grads = gradients_impl.gradients(stitched_t, indices + data,
@@ -186,10 +191,8 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase):
stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data)
stitched_val = stitched_t.eval()
self.assertAllEqual([40.0, 60.0][::step], stitched_val)
- # Dimension 0 is determined by the max index in indices, so we
- # can only infer that the output is a vector of some unknown
- # length.
- self.assertEqual([None], stitched_t.get_shape().as_list())
+ # Dimension 0 is max(flatten(indices))+1.
+ self.assertEqual([2], stitched_t.get_shape().as_list())
def testHigherRank(self):
with self.test_session(use_gpu=True) as sess:
@@ -208,7 +211,7 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase):
stitched_val = stitched_t.eval()
correct = 10 * np.arange(7)[:, None] + [1.0, 2.0]
self.assertAllEqual(correct, stitched_val)
- self.assertEqual([None, 2], stitched_t.get_shape().as_list())
+ self.assertEqual([7, 2], stitched_t.get_shape().as_list())
# Test gradients
stitched_grad = 7 * stitched_val
grads = gradients_impl.gradients(stitched_t, indices + data,
@@ -226,10 +229,8 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase):
stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data)
stitched_val = stitched_t.eval()
self.assertAllEqual([40.0, 60.0][::step], stitched_val)
- # Dimension 0 is determined by the max index in indices, so we
- # can only infer that the output is a vector of some unknown
- # length.
- self.assertEqual([None], stitched_t.get_shape().as_list())
+ # Dimension 0 is max(flatten(indices))+1.
+ self.assertEqual([2], stitched_t.get_shape().as_list())
def testHigherRankGPU(self):
with self.test_session() as sess:
@@ -246,7 +247,7 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase):
stitched_val = stitched_t.eval()
correct = 10 * np.arange(7)[:, None] + [1.0, 2.0]
self.assertAllEqual(correct, stitched_val)
- self.assertEqual([None, 2], stitched_t.get_shape().as_list())
+ self.assertEqual([7, 2], stitched_t.get_shape().as_list())
# Test gradients
stitched_grad = 7 * stitched_val
grads = gradients_impl.gradients(stitched_t, indices + data,