aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-03 10:19:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 10:27:27 -0700
commita5b3cd8b4d28cfcdcb9adb3d3568b168b9b8a088 (patch)
tree97c740eebd5de54244fdb8188c1c658c94a93523 /tensorflow/core
parent022af5300701d457d848e60ea511dd8d05f68738 (diff)
Fix bug in shape function for transpose: If the rank of the input is unknown and the rank derived from the permutation array is 0 or 1, the shape is ambiguous and cannot be determined at graph construction time. In this case, forward the shape of the input.
PiperOrigin-RevId: 215583050
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/ops/array_ops.cc8
-rw-r--r--tensorflow/core/ops/array_ops_test.cc1
2 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index c9f80df5e4..f55562ec99 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -133,6 +133,14 @@ Status TransposeShapeFn(InferenceContext* c) {
} else {
rank = perm->NumElements();
}
+ if (!c->RankKnown(input) && rank < 2) {
+ // A permutation array containing a single element is ambiguous. It could
+ // indicate either a scalar or a 1-dimensional array, both of which the
+ // transpose op returns unchanged.
+ c->set_output(0, input);
+ return Status::OK();
+ }
+
std::vector<DimensionHandle> dims;
dims.resize(rank);
TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index 03dab390a7..1c29cd2491 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -975,6 +975,7 @@ TEST(ArrayOpsTest, Transpose_ShapeFn) {
INFER_OK(op, "?;[2]", "[?,?]");
INFER_OK(op, "[?,?];[2]", "[d0_1,d0_0]");
INFER_OK(op, "[1,?];[2]", "[d0_1,d0_0]");
+ INFER_OK(op, "?;[0]", "in0");
// Invalid arguments.
perm = test::AsTensor<int32>({1, 2});