aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-30 12:24:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-30 13:33:09 -0700
commit660d7e5ae51d8d74643874b9ae10f855dced69bb (patch)
treeec62bdb867b4c09a8a458f94911dc96387973e95 /tensorflow
parent5532f08194c2e29e6056e5339cbd57c52db1907a (diff)
Delegate to C++ shape inference functions for several ops in
image_ops, tensor_array_ops, sparse_ops, and string_ops. Fix bugs in C++ shape inference for image decoding (channel 0 means auto-detect) and string split (some output dims are known). Change: 131747517
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/ops/image_ops.cc8
-rw-r--r--tensorflow/core/ops/image_ops_test.cc5
-rw-r--r--tensorflow/core/ops/string_ops.cc11
-rw-r--r--tensorflow/python/ops/image_ops.py46
-rw-r--r--tensorflow/python/ops/sparse_ops.py156
-rw-r--r--tensorflow/python/ops/string_ops.py29
-rw-r--r--tensorflow/python/ops/tensor_array_ops.py93
7 files changed, 51 insertions, 297 deletions
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 7eb380798c..c0a60aab0a 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -62,15 +62,15 @@ Status DecodeImageShapeFn(InferenceContext* c) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
DimensionHandle channels_dim;
int32 channels;
- Status s = c->GetAttr("channels", &channels);
- if (s.ok()) {
+ TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels));
+ if (channels == 0) {
+ channels_dim = c->UnknownDim();
+ } else {
if (channels < 0) {
return errors::InvalidArgument("channels must be non-negative, got ",
channels);
}
channels_dim = c->MakeDim(channels);
- } else {
- channels_dim = c->UnknownDim();
}
c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
diff --git a/tensorflow/core/ops/image_ops_test.cc b/tensorflow/core/ops/image_ops_test.cc
index 1b5db82ab8..d9f1400a2c 100644
--- a/tensorflow/core/ops/image_ops_test.cc
+++ b/tensorflow/core/ops/image_ops_test.cc
@@ -58,7 +58,10 @@ TEST(ImageOpsTest, DecodeImage_ShapeFn) {
// Rank check.
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1]");
- // Channel not set - output is unknown.
+ // Set the channel to zero - output is not known.
+ TF_ASSERT_OK(NodeDefBuilder("test", op_name)
+ .Input({"a", 0, DT_STRING})
+ .Finalize(&op.node_def));
INFER_OK(op, "[]", "[?,?,?]");
// Set the channel and so that part of output shape is known.
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index dd4cb12f5d..248fb7cbbb 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -206,14 +206,13 @@ REGISTER_OP("StringSplit")
.Output("values: string")
.Output("shape: int64")
.SetShapeFn([](InferenceContext* c) {
- ShapeHandle unsed_shape;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unsed_shape));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unsed_shape));
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
- c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
- InferenceContext::kUnknownDim));
+ c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2));
c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
- c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
+ c->set_output(2, c->Vector(2));
return Status::OK();
})
.Doc(R"doc(
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index adabf4f5e1..8291474e6c 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -1007,12 +1007,7 @@ ops.RegisterShape('DrawBoundingBoxes')(
common_shapes.unchanged_shape_with_rank_at_least(3))
-@ops.RegisterShape('SampleDistortedBoundingBox')
-def _SampleDistortedBoundingBoxShape(unused_op): # pylint: disable=invalid-name
- """Shape function for the sample distorted bounding box."""
- return [tensor_shape.TensorShape([3]),
- tensor_shape.TensorShape([3]),
- tensor_shape.TensorShape([1, 1, 4])]
+ops.RegisterShape('SampleDistortedBoundingBox')(common_shapes.call_cpp_shape_fn)
@ops.RegisterShape('ResizeBilinear')
@@ -1034,28 +1029,16 @@ def _ResizeShape(op):
[input_shape[0], height, width, input_shape[3]])]
@ops.RegisterShape('DecodeGif')
-def _ImageDecodeShape(op):
+def _DecodeGifShape(op):
"""Shape function for decode gif."""
unused_input_shape = op.inputs[0].get_shape().merge_with(
tensor_shape.scalar())
return [tensor_shape.TensorShape([None, None, None, 3])]
-@ops.RegisterShape('DecodeJpeg')
-@ops.RegisterShape('DecodePng')
-def _ImageDecodeShape(op):
- """Shape function for image decoding ops."""
- unused_input_shape = op.inputs[0].get_shape().merge_with(
- tensor_shape.scalar())
- channels = op.get_attr('channels') or None
- return [tensor_shape.TensorShape([None, None, channels])]
-
-
-@ops.RegisterShape('EncodeJpeg')
-@ops.RegisterShape('EncodePng')
-def _ImageEncodeShape(op):
- """Shape function for image encoding ops."""
- unused_input_shape = op.inputs[0].get_shape().with_rank(3)
- return [tensor_shape.scalar()]
+ops.RegisterShape('DecodeJpeg')(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape('DecodePng')(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape('EncodeJpeg')(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape('EncodePng')(common_shapes.call_cpp_shape_fn)
def convert_image_dtype(image, dtype, saturate=False, name=None):
@@ -1194,16 +1177,8 @@ def grayscale_to_rgb(images, name=None):
# pylint: disable=invalid-name
-@ops.RegisterShape('HSVToRGB')
-@ops.RegisterShape('RGBToHSV')
-def _ColorspaceShape(op):
- """Shape function for colorspace ops."""
- input_shape = op.inputs[0].get_shape().with_rank_at_least(1)
- input_rank = input_shape.ndims
- if input_rank is not None:
- input_shape = input_shape.merge_with([None] * (input_rank - 1) + [3])
- return [input_shape]
-# pylint: enable=invalid-name
+ops.RegisterShape('HSVToRGB')(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape('RGBToHSV')(common_shapes.call_cpp_shape_fn)
def random_hue(image, max_delta, seed=None):
@@ -1413,10 +1388,7 @@ def _crop_and_resize_shape(op):
[box_shape[0], crop_height, crop_width, image_shape[3]])]
-@ops.RegisterShape('NonMaxSuppression')
-def _non_max_suppression_shape(_):
- """Shape function for the NonMaxSuppression op."""
- return [tensor_shape.TensorShape([None])]
+ops.RegisterShape('NonMaxSuppression')(common_shapes.call_cpp_shape_fn)
__all__ = make_all(__name__)
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index 198ccab021..e14324614e 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -59,6 +59,7 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -272,14 +273,7 @@ def sparse_add(a, b, thresh=0):
a.indices, a.values, a.shape, b)
-@ops.RegisterShape("SparseAdd")
-def _SparseAddShape(op): # pylint: disable=invalid-name
- input_shape_shape = op.inputs[2].get_shape()
- input_shape_shape.assert_has_rank(1)
- return [
- tensor_shape.TensorShape([None, input_shape_shape[0]]),
- tensor_shape.unknown_shape(1), input_shape_shape
- ]
+ops.RegisterShape("SparseAdd")(common_shapes.call_cpp_shape_fn)
def sparse_dense_cwise_add(sp_t, dense_t):
@@ -307,47 +301,9 @@ def sparse_dense_cwise_add(sp_t, dense_t):
return ops.SparseTensor(sp_t.indices, result, sp_t.shape)
-@ops.RegisterShape("SparseTensorDenseAdd")
-def _SparseTensorDenseAddShape(op): # pylint: disable=invalid-name
- return [op.inputs[3].get_shape()]
-
-
-@ops.RegisterShape("SparseAddGrad")
-def _SparseAddGradShape(op): # pylint: disable=invalid-name
- # shapes for (a_val_grad, b_val_grad)
- a_nnz = op.inputs[1].get_shape()[0]
- b_nnz = op.inputs[2].get_shape()[0]
- return [tensor_shape.TensorShape([a_nnz]), tensor_shape.TensorShape([b_nnz])]
-
-
-@ops.RegisterShape("SparseConcat")
-def _SparseConcatShape(op):
- """Shape function for SparseConcat op."""
- num_inputs = int(op.get_attr("N"))
-
- # TF flattens and concatenates all list inputs, so reconstruct the lists here.
- ind_shapes = [ind.get_shape().with_rank(2) for ind in op.inputs[0:num_inputs]]
- val_shapes = [val.get_shape().with_rank(1)
- for val in op.inputs[num_inputs:2 * num_inputs]]
- shape_shapes = [shape.get_shape().with_rank(1)
- for shape in op.inputs[2 * num_inputs:]]
-
- output_ind_rows = tensor_shape.Dimension(0)
- output_ind_cols = tensor_shape.Dimension(None)
- output_val_elems = tensor_shape.Dimension(0)
- output_shape_shape = tensor_shape.TensorShape(None)
-
- for i in xrange(num_inputs):
- num_elems_i = ind_shapes[i][0].merge_with(val_shapes[i][0])
- output_ind_rows += num_elems_i
- output_ind_cols = output_ind_cols.merge_with(ind_shapes[i][1])
- output_val_elems += num_elems_i
- output_shape_shape = output_shape_shape.merge_with(shape_shapes[i])
-
- output_ind_shape = tensor_shape.matrix(output_ind_rows, output_ind_cols)
- output_val_shape = tensor_shape.vector(output_val_elems)
-
- return [output_ind_shape, output_val_shape, output_shape_shape]
+ops.RegisterShape("SparseTensorDenseAdd")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("SparseAddGrad")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("SparseConcat")(common_shapes.call_cpp_shape_fn)
def sparse_reorder(sp_input, name=None):
@@ -398,14 +354,7 @@ def sparse_reorder(sp_input, name=None):
array_ops.identity(sp_input.shape))
-@ops.RegisterShape("SparseReorder")
-def _SparseReorderShape(op):
- """Shape function for SparseReorder op."""
- input_indices_shape = op.inputs[0].get_shape().with_rank(2)
- input_values_shape = op.inputs[1].get_shape().with_rank(1)
- unused_shape_shape = op.inputs[2].get_shape().with_rank(1)
-
- return [input_indices_shape, input_values_shape]
+ops.RegisterShape("SparseReorder")(common_shapes.call_cpp_shape_fn)
def sparse_reshape(sp_input, shape, name=None):
@@ -464,17 +413,7 @@ def sparse_reshape(sp_input, shape, name=None):
reshaped_shape)
-@ops.RegisterShape("SparseReshape")
-def _SparseReshapeShape(op): # pylint: disable=invalid-name
- """Shape function for SparseReshape op."""
- input_indices_shape = op.inputs[0].get_shape().with_rank(2)
- unused_input_shape_shape = op.inputs[1].get_shape().with_rank(1)
- new_shape_shape = op.inputs[2].get_shape().with_rank(1)
-
- new_indices_shape = tensor_shape.matrix(input_indices_shape[0],
- new_shape_shape[0])
-
- return [new_indices_shape, new_shape_shape]
+ops.RegisterShape("SparseReshape")(common_shapes.call_cpp_shape_fn)
def sparse_split(split_dim, num_split, sp_input, name=None):
@@ -528,20 +467,7 @@ def sparse_split(split_dim, num_split, sp_input, name=None):
return sparse_tensors
-# pylint: disable=invalid-name
-@ops.RegisterShape("SparseSplit")
-def _SparseSplitShape(op):
- """Shape function for SparseSplit op."""
- num_split = int(op.get_attr("num_split"))
- input_shape_shape = op.inputs[3].get_shape()
- dim = input_shape_shape.num_elements()
- output_indices_shape = tensor_shape.TensorShape([None, dim])
- output_values_shape = tensor_shape.unknown_shape(1)
- output_indices_shape = [output_indices_shape] * num_split
- output_values_shape = [output_values_shape] * num_split
- output_shape_shape = [input_shape_shape] * num_split
- return output_indices_shape + output_values_shape + output_shape_shape
-# pylint: enable=invalid-name
+ops.RegisterShape("SparseSplit")(common_shapes.call_cpp_shape_fn)
@ops.RegisterShape("SparseToDense")
@@ -656,9 +582,7 @@ def sparse_reduce_sum(sp_input, reduction_axes=None, keep_dims=False):
keep_dims)
-@ops.RegisterShape("SparseReduceSum")
-def _SparseReduceSumShape(unused_op): # pylint: disable=invalid-name
- return [tensor_shape.unknown_shape()]
+ops.RegisterShape("SparseReduceSum")(common_shapes.call_cpp_shape_fn)
def sparse_tensor_to_dense(sp_input,
@@ -1116,14 +1040,7 @@ def serialize_sparse(sp_input, name=None):
name=name)
-@ops.RegisterShape("SerializeSparse")
-def _SerializeSparseShape(op): # pylint: disable=invalid-name
- """Shape function for SerializeSparse op."""
- op.inputs[0].get_shape().with_rank(2)
- op.inputs[1].get_shape().with_rank(1)
- op.inputs[2].get_shape().with_rank(1)
-
- return [tensor_shape.vector(3)]
+ops.RegisterShape("SerializeSparse")(common_shapes.call_cpp_shape_fn)
def serialize_many_sparse(sp_input, name=None):
@@ -1159,14 +1076,7 @@ def serialize_many_sparse(sp_input, name=None):
name=name)
-@ops.RegisterShape("SerializeManySparse")
-def _SerializeManySparseShape(op): # pylint: disable=invalid-name
- """Shape function for SerializeSparse op."""
- op.inputs[0].get_shape().with_rank(2)
- op.inputs[1].get_shape().with_rank(1)
- op.inputs[2].get_shape().with_rank(1)
-
- return [tensor_shape.matrix(None, 3)]
+ops.RegisterShape("SerializeManySparse")(common_shapes.call_cpp_shape_fn)
def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None):
@@ -1238,16 +1148,7 @@ def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None):
return ops.SparseTensor(output_indices, output_values, output_shape)
-@ops.RegisterShape("DeserializeManySparse")
-def _DeserializeSparseShape(op): # pylint: disable=invalid-name
- """Shape function for DeserializeManySparse op."""
- serialized_sparse_shape = op.inputs[0].get_shape().with_rank(2)
- serialized_sparse_shape.merge_with(
- tensor_shape.TensorShape([None, 3]))
-
- return [tensor_shape.matrix(None, None),
- tensor_shape.vector(None),
- tensor_shape.vector(None)]
+ops.RegisterShape("DeserializeManySparse")(common_shapes.call_cpp_shape_fn)
def sparse_tensor_dense_matmul(sp_a, b, adjoint_a=False, adjoint_b=False,
@@ -1426,16 +1327,7 @@ def sparse_tensor_dense_matmul(sp_a, b, adjoint_a=False, adjoint_b=False,
adjoint_b=adjoint_b)
-@ops.RegisterShape("SparseTensorDenseMatMul")
-def _SparseTensorDenseMatMulShape(op): # pylint: disable=invalid-name
- """Shape function for SparseTensorDenseMatMul op."""
- adjoint_b = op.get_attr("adjoint_b")
- op.inputs[0].get_shape().assert_has_rank(2) # a_indices
- op.inputs[1].get_shape().assert_has_rank(1) # a_values
- op.inputs[2].get_shape().merge_with(tensor_shape.vector(2)) # a_shape
- b_shape = op.inputs[3].get_shape().with_rank(2)
- output_shape_right = b_shape[0] if adjoint_b else b_shape[1]
- return [tensor_shape.matrix(None, output_shape_right)]
+ops.RegisterShape("SparseTensorDenseMatMul")(common_shapes.call_cpp_shape_fn)
def sparse_softmax(sp_input, name=None):
@@ -1492,14 +1384,7 @@ def sparse_softmax(sp_input, name=None):
return ops.SparseTensor(sp_input.indices, out_vals, sp_input.shape)
-@ops.RegisterShape("SparseSoftmax")
-def _SparseSoftmaxShape(op): # pylint: disable=invalid-name
- """Shape function for SparseSoftmax op."""
- unused_indices_shape = op.inputs[0].get_shape().with_rank(2)
- values_shape = op.inputs[1].get_shape().with_rank(1)
- unused_shape_shape = op.inputs[2].get_shape().with_rank(1)
- nnz = values_shape[0]
- return [tensor_shape.vector(nnz)]
+ops.RegisterShape("SparseSoftmax")(common_shapes.call_cpp_shape_fn)
def sparse_maximum(sp_a, sp_b, name=None):
@@ -1572,17 +1457,8 @@ def sparse_minimum(sp_a, sp_b, name=None):
return ops.SparseTensor(out_indices, out_values, sp_a.shape)
-@ops.RegisterShape("SparseSparseMaximum")
-@ops.RegisterShape("SparseSparseMinimum")
-def _SparseSparseMaximumMinimumShape(op): # pylint: disable=invalid-name
- """Shape function for SparseSparseMaximum and SparseSparseMinimum."""
- op.inputs[0].get_shape().assert_has_rank(2) # a_indices
- op.inputs[1].get_shape().assert_has_rank(1) # a_values
- op.inputs[2].get_shape().assert_has_rank(1) # a_shape
- op.inputs[3].get_shape().assert_has_rank(2) # b_indices
- op.inputs[4].get_shape().assert_has_rank(1) # b_values
- op.inputs[5].get_shape().assert_has_rank(1) # b_shape
- return [tensor_shape.unknown_shape(2), tensor_shape.unknown_shape(1)]
+ops.RegisterShape("SparseSparseMaximum")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("SparseSparseMinimum")(common_shapes.call_cpp_shape_fn)
def sparse_transpose(sp_input, perm=None, name=None):
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index 7ffa521ac6..8a547eeac5 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -176,30 +176,5 @@ def _ReduceJoinShape(op):
return [tensor_shape.TensorShape(returned_dims)]
-@ops.RegisterShape("StringJoin")
-def _StringJoinShape(op):
- """Shape function for the StringJoin op."""
- input_shapes = [x.get_shape() for x in op.inputs]
-
- # First check if all inputs are scalars. In the next section
- # we may have *some* scalars and we will be broadcasting them
- if all([s.ndims == 0 for s in input_shapes]):
- return [tensor_shape.scalar()]
-
- base_shape = tensor_shape.unknown_shape()
- for shape in input_shapes:
- if shape.ndims != 0:
- base_shape = base_shape.merge_with(shape)
- return [base_shape]
-
-
-@ops.RegisterShape("StringSplit")
-def _StringSplitShape(op):
- """Shape function for string_ops.string_split."""
- unused_sfs_shape = op.inputs[0].get_shape().with_rank(1)
- unused_sfs_shape = op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
-
- indices_shape = tensor_shape.TensorShape([None, 2])
- values_shape = tensor_shape.TensorShape([None])
- shape_shape = tensor_shape.TensorShape([2])
- return [indices_shape, values_shape, shape_shape]
+ops.RegisterShape("StringJoin")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("StringSplit")(common_shapes.call_cpp_shape_fn)
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index ef0573c6fd..3472229aff 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -25,6 +25,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.python.framework import ops
@@ -373,87 +374,15 @@ class TensorArray(object):
handle=self._handle, name=name)
-@ops.RegisterShape("TensorArray")
-def _TensorArrayShape(op):
- # size is a scalar
- op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
- return [tensor_shape.vector(2)]
-
-
-@ops.RegisterShape("TensorArrayRead")
-def _TensorArrayReadShape(op):
- # handle, index, flow_in
- op.inputs[0].get_shape().merge_with(tensor_shape.vector(2))
- op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
- op.inputs[2].get_shape().merge_with(tensor_shape.scalar())
- # value
- return [tensor_shape.unknown_shape()]
-
-
-@ops.RegisterShape("TensorArrayWrite")
-def _TensorArrayWriteShape(op):
- # handle, index, value, flow_in
- op.inputs[0].get_shape().merge_with(tensor_shape.vector(2))
- op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
- op.inputs[3].get_shape().merge_with(tensor_shape.scalar())
- # flow_out
- return [tensor_shape.scalar()]
-
-
-@ops.RegisterShape("TensorArraySize")
-def _TensorArraySizeShape(op):
- op.inputs[0].get_shape().merge_with(tensor_shape.vector(2))
- return [tensor_shape.scalar()]
-
-
-@ops.RegisterShape("TensorArrayClose")
-def _TensorArrayCloseShape(op):
- """Shape function for ops that take a scalar and produce no outputs."""
- op.inputs[0].get_shape().merge_with(tensor_shape.vector(2))
- return []
-
-
-@ops.RegisterShape("TensorArrayGrad")
-def _TensorArrayGradShape(op):
- """Shape function for ops that take a scalar and produce no outputs."""
- op.inputs[0].get_shape().merge_with(tensor_shape.vector(2))
- return [tensor_shape.vector(2)]
-
-
-@ops.RegisterShape("TensorArrayPack")
-def _TensorArrayPackShape(op):
- # handle, flow_in
- op.inputs[0].get_shape().merge_with(tensor_shape.vector(2))
- op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
- # value
- return [tensor_shape.unknown_shape()]
-
-
-@ops.RegisterShape("TensorArrayConcat")
-def _TensorArrayConcatShape(op):
- # handle, flow_in
- op.inputs[0].get_shape().merge_with(tensor_shape.vector(2))
- op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
- # value, lengths
- return [tensor_shape.unknown_shape(), tensor_shape.vector(None)]
-
-
-@ops.RegisterShape("TensorArraySplit")
-def _TensorArraySplitShape(op):
- # handle, value, lengths, flow_in
- op.inputs[0].get_shape().merge_with(tensor_shape.vector(2))
- op.inputs[2].get_shape().merge_with(tensor_shape.vector(None))
- op.inputs[3].get_shape().merge_with(tensor_shape.scalar())
- # flow_out
- return [tensor_shape.scalar()]
-
-
-@ops.RegisterShape("TensorArrayUnpack")
-def _TensorArrayUnpackShape(op):
- # handle, value, flow_in
- op.inputs[0].get_shape().merge_with(tensor_shape.vector(2))
- op.inputs[2].get_shape().merge_with(tensor_shape.scalar())
- # flow_out
- return [tensor_shape.scalar()]
+ops.RegisterShape("TensorArray")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("TensorArrayRead")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("TensorArrayWrite")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("TensorArraySize")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("TensorArrayClose")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("TensorArrayGrad")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("TensorArrayPack")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("TensorArrayConcat")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("TensorArraySplit")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("TensorArrayUnpack")(common_shapes.call_cpp_shape_fn)
# pylint: enable=protected-access