diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-08-30 12:24:42 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-30 13:33:09 -0700 |
commit | 660d7e5ae51d8d74643874b9ae10f855dced69bb (patch) | |
tree | ec62bdb867b4c09a8a458f94911dc96387973e95 /tensorflow | |
parent | 5532f08194c2e29e6056e5339cbd57c52db1907a (diff) |
Delegate to C++ shape inference functions for several ops in
image_ops, tensor_array_ops, sparse_ops, and string_ops.
Fix bugs in C++ shape inference for image decoding (channel 0 means
auto-detect) and string split (some output dims are known).
Change: 131747517
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/ops/image_ops.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/ops/image_ops_test.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/ops/string_ops.cc | 11 | ||||
-rw-r--r-- | tensorflow/python/ops/image_ops.py | 46 | ||||
-rw-r--r-- | tensorflow/python/ops/sparse_ops.py | 156 | ||||
-rw-r--r-- | tensorflow/python/ops/string_ops.py | 29 | ||||
-rw-r--r-- | tensorflow/python/ops/tensor_array_ops.py | 93 |
7 files changed, 51 insertions, 297 deletions
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 7eb380798c..c0a60aab0a 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -62,15 +62,15 @@ Status DecodeImageShapeFn(InferenceContext* c) { TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); DimensionHandle channels_dim; int32 channels; - Status s = c->GetAttr("channels", &channels); - if (s.ok()) { + TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels)); + if (channels == 0) { + channels_dim = c->UnknownDim(); + } else { if (channels < 0) { return errors::InvalidArgument("channels must be non-negative, got ", channels); } channels_dim = c->MakeDim(channels); - } else { - channels_dim = c->UnknownDim(); } c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim, diff --git a/tensorflow/core/ops/image_ops_test.cc b/tensorflow/core/ops/image_ops_test.cc index 1b5db82ab8..d9f1400a2c 100644 --- a/tensorflow/core/ops/image_ops_test.cc +++ b/tensorflow/core/ops/image_ops_test.cc @@ -58,7 +58,10 @@ TEST(ImageOpsTest, DecodeImage_ShapeFn) { // Rank check. INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1]"); - // Channel not set - output is unknown. + // Set the channel to zero - output is not known. + TF_ASSERT_OK(NodeDefBuilder("test", op_name) + .Input({"a", 0, DT_STRING}) + .Finalize(&op.node_def)); INFER_OK(op, "[]", "[?,?,?]"); // Set the channel and so that part of output shape is known. diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index dd4cb12f5d..248fb7cbbb 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -206,14 +206,13 @@ REGISTER_OP("StringSplit") .Output("values: string") .Output("shape: int64") .SetShapeFn([](InferenceContext* c) { - ShapeHandle unsed_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unsed_shape)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unsed_shape)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); - c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, - InferenceContext::kUnknownDim)); + c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2)); c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); - c->set_output(2, c->Vector(InferenceContext::kUnknownDim)); + c->set_output(2, c->Vector(2)); return Status::OK(); }) .Doc(R"doc( diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index adabf4f5e1..8291474e6c 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -1007,12 +1007,7 @@ ops.RegisterShape('DrawBoundingBoxes')( common_shapes.unchanged_shape_with_rank_at_least(3)) -@ops.RegisterShape('SampleDistortedBoundingBox') -def _SampleDistortedBoundingBoxShape(unused_op): # pylint: disable=invalid-name - """Shape function for the sample distorted bounding box.""" - return [tensor_shape.TensorShape([3]), - tensor_shape.TensorShape([3]), - tensor_shape.TensorShape([1, 1, 4])] +ops.RegisterShape('SampleDistortedBoundingBox')(common_shapes.call_cpp_shape_fn) @ops.RegisterShape('ResizeBilinear') @@ -1034,28 +1029,16 @@ def _ResizeShape(op): [input_shape[0], height, width, input_shape[3]])] @ops.RegisterShape('DecodeGif') -def _ImageDecodeShape(op): +def _DecodeGifShape(op): """Shape function for decode gif.""" unused_input_shape = op.inputs[0].get_shape().merge_with( tensor_shape.scalar()) return [tensor_shape.TensorShape([None, None, None, 3])] -@ops.RegisterShape('DecodeJpeg') -@ops.RegisterShape('DecodePng') -def _ImageDecodeShape(op): - """Shape function for image decoding ops.""" - unused_input_shape = op.inputs[0].get_shape().merge_with( - tensor_shape.scalar()) - channels = op.get_attr('channels') or None - return [tensor_shape.TensorShape([None, None, channels])] - - -@ops.RegisterShape('EncodeJpeg') -@ops.RegisterShape('EncodePng') -def _ImageEncodeShape(op): - """Shape function for image encoding ops.""" - unused_input_shape = op.inputs[0].get_shape().with_rank(3) - return [tensor_shape.scalar()] +ops.RegisterShape('DecodeJpeg')(common_shapes.call_cpp_shape_fn) +ops.RegisterShape('DecodePng')(common_shapes.call_cpp_shape_fn) +ops.RegisterShape('EncodeJpeg')(common_shapes.call_cpp_shape_fn) +ops.RegisterShape('EncodePng')(common_shapes.call_cpp_shape_fn) def convert_image_dtype(image, dtype, saturate=False, name=None): @@ -1194,16 +1177,8 @@ def grayscale_to_rgb(images, name=None): # pylint: disable=invalid-name -@ops.RegisterShape('HSVToRGB') -@ops.RegisterShape('RGBToHSV') -def _ColorspaceShape(op): - """Shape function for colorspace ops.""" - input_shape = op.inputs[0].get_shape().with_rank_at_least(1) - input_rank = input_shape.ndims - if input_rank is not None: - input_shape = input_shape.merge_with([None] * (input_rank - 1) + [3]) - return [input_shape] -# pylint: enable=invalid-name +ops.RegisterShape('HSVToRGB')(common_shapes.call_cpp_shape_fn) +ops.RegisterShape('RGBToHSV')(common_shapes.call_cpp_shape_fn) def random_hue(image, max_delta, seed=None): @@ -1413,10 +1388,7 @@ def _crop_and_resize_shape(op): [box_shape[0], crop_height, crop_width, image_shape[3]])] -@ops.RegisterShape('NonMaxSuppression') -def _non_max_suppression_shape(_): - """Shape function for the NonMaxSuppression op.""" - return [tensor_shape.TensorShape([None])] +ops.RegisterShape('NonMaxSuppression')(common_shapes.call_cpp_shape_fn) __all__ = make_all(__name__) diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index 198ccab021..e14324614e 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -59,6 +59,7 @@ from __future__ import print_function import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -272,14 +273,7 @@ def sparse_add(a, b, thresh=0): a.indices, a.values, a.shape, b) -@ops.RegisterShape("SparseAdd") -def _SparseAddShape(op): # pylint: disable=invalid-name - input_shape_shape = op.inputs[2].get_shape() - input_shape_shape.assert_has_rank(1) - return [ - tensor_shape.TensorShape([None, input_shape_shape[0]]), - tensor_shape.unknown_shape(1), input_shape_shape - ] +ops.RegisterShape("SparseAdd")(common_shapes.call_cpp_shape_fn) def sparse_dense_cwise_add(sp_t, dense_t): @@ -307,47 +301,9 @@ def sparse_dense_cwise_add(sp_t, dense_t): return ops.SparseTensor(sp_t.indices, result, sp_t.shape) -@ops.RegisterShape("SparseTensorDenseAdd") -def _SparseTensorDenseAddShape(op): # pylint: disable=invalid-name - return [op.inputs[3].get_shape()] - - -@ops.RegisterShape("SparseAddGrad") -def _SparseAddGradShape(op): # pylint: disable=invalid-name - # shapes for (a_val_grad, b_val_grad) - a_nnz = op.inputs[1].get_shape()[0] - b_nnz = op.inputs[2].get_shape()[0] - return [tensor_shape.TensorShape([a_nnz]), tensor_shape.TensorShape([b_nnz])] - - -@ops.RegisterShape("SparseConcat") -def _SparseConcatShape(op): - """Shape function for SparseConcat op.""" - num_inputs = int(op.get_attr("N")) - - # TF flattens and concatenates all list inputs, so reconstruct the lists here. - ind_shapes = [ind.get_shape().with_rank(2) for ind in op.inputs[0:num_inputs]] - val_shapes = [val.get_shape().with_rank(1) - for val in op.inputs[num_inputs:2 * num_inputs]] - shape_shapes = [shape.get_shape().with_rank(1) - for shape in op.inputs[2 * num_inputs:]] - - output_ind_rows = tensor_shape.Dimension(0) - output_ind_cols = tensor_shape.Dimension(None) - output_val_elems = tensor_shape.Dimension(0) - output_shape_shape = tensor_shape.TensorShape(None) - - for i in xrange(num_inputs): - num_elems_i = ind_shapes[i][0].merge_with(val_shapes[i][0]) - output_ind_rows += num_elems_i - output_ind_cols = output_ind_cols.merge_with(ind_shapes[i][1]) - output_val_elems += num_elems_i - output_shape_shape = output_shape_shape.merge_with(shape_shapes[i]) - - output_ind_shape = tensor_shape.matrix(output_ind_rows, output_ind_cols) - output_val_shape = tensor_shape.vector(output_val_elems) - - return [output_ind_shape, output_val_shape, output_shape_shape] +ops.RegisterShape("SparseTensorDenseAdd")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("SparseAddGrad")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("SparseConcat")(common_shapes.call_cpp_shape_fn) def sparse_reorder(sp_input, name=None): @@ -398,14 +354,7 @@ def sparse_reorder(sp_input, name=None): array_ops.identity(sp_input.shape)) -@ops.RegisterShape("SparseReorder") -def _SparseReorderShape(op): - """Shape function for SparseReorder op.""" - input_indices_shape = op.inputs[0].get_shape().with_rank(2) - input_values_shape = op.inputs[1].get_shape().with_rank(1) - unused_shape_shape = op.inputs[2].get_shape().with_rank(1) - - return [input_indices_shape, input_values_shape] +ops.RegisterShape("SparseReorder")(common_shapes.call_cpp_shape_fn) def sparse_reshape(sp_input, shape, name=None): @@ -464,17 +413,7 @@ def sparse_reshape(sp_input, shape, name=None): reshaped_shape) -@ops.RegisterShape("SparseReshape") -def _SparseReshapeShape(op): # pylint: disable=invalid-name - """Shape function for SparseReshape op.""" - input_indices_shape = op.inputs[0].get_shape().with_rank(2) - unused_input_shape_shape = op.inputs[1].get_shape().with_rank(1) - new_shape_shape = op.inputs[2].get_shape().with_rank(1) - - new_indices_shape = tensor_shape.matrix(input_indices_shape[0], - new_shape_shape[0]) - - return [new_indices_shape, new_shape_shape] +ops.RegisterShape("SparseReshape")(common_shapes.call_cpp_shape_fn) def sparse_split(split_dim, num_split, sp_input, name=None): @@ -528,20 +467,7 @@ def sparse_split(split_dim, num_split, sp_input, name=None): return sparse_tensors -# pylint: disable=invalid-name -@ops.RegisterShape("SparseSplit") -def _SparseSplitShape(op): - """Shape function for SparseSplit op.""" - num_split = int(op.get_attr("num_split")) - input_shape_shape = op.inputs[3].get_shape() - dim = input_shape_shape.num_elements() - output_indices_shape = tensor_shape.TensorShape([None, dim]) - output_values_shape = tensor_shape.unknown_shape(1) - output_indices_shape = [output_indices_shape] * num_split - output_values_shape = [output_values_shape] * num_split - output_shape_shape = [input_shape_shape] * num_split - return output_indices_shape + output_values_shape + output_shape_shape -# pylint: enable=invalid-name +ops.RegisterShape("SparseSplit")(common_shapes.call_cpp_shape_fn) @ops.RegisterShape("SparseToDense") @@ -656,9 +582,7 @@ def sparse_reduce_sum(sp_input, reduction_axes=None, keep_dims=False): keep_dims) -@ops.RegisterShape("SparseReduceSum") -def _SparseReduceSumShape(unused_op): # pylint: disable=invalid-name - return [tensor_shape.unknown_shape()] +ops.RegisterShape("SparseReduceSum")(common_shapes.call_cpp_shape_fn) def sparse_tensor_to_dense(sp_input, @@ -1116,14 +1040,7 @@ def serialize_sparse(sp_input, name=None): name=name) -@ops.RegisterShape("SerializeSparse") -def _SerializeSparseShape(op): # pylint: disable=invalid-name - """Shape function for SerializeSparse op.""" - op.inputs[0].get_shape().with_rank(2) - op.inputs[1].get_shape().with_rank(1) - op.inputs[2].get_shape().with_rank(1) - - return [tensor_shape.vector(3)] +ops.RegisterShape("SerializeSparse")(common_shapes.call_cpp_shape_fn) def serialize_many_sparse(sp_input, name=None): @@ -1159,14 +1076,7 @@ def serialize_many_sparse(sp_input, name=None): name=name) -@ops.RegisterShape("SerializeManySparse") -def _SerializeManySparseShape(op): # pylint: disable=invalid-name - """Shape function for SerializeSparse op.""" - op.inputs[0].get_shape().with_rank(2) - op.inputs[1].get_shape().with_rank(1) - op.inputs[2].get_shape().with_rank(1) - - return [tensor_shape.matrix(None, 3)] +ops.RegisterShape("SerializeManySparse")(common_shapes.call_cpp_shape_fn) def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None): @@ -1238,16 +1148,7 @@ def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None): return ops.SparseTensor(output_indices, output_values, output_shape) -@ops.RegisterShape("DeserializeManySparse") -def _DeserializeSparseShape(op): # pylint: disable=invalid-name - """Shape function for DeserializeManySparse op.""" - serialized_sparse_shape = op.inputs[0].get_shape().with_rank(2) - serialized_sparse_shape.merge_with( - tensor_shape.TensorShape([None, 3])) - - return [tensor_shape.matrix(None, None), - tensor_shape.vector(None), - tensor_shape.vector(None)] +ops.RegisterShape("DeserializeManySparse")(common_shapes.call_cpp_shape_fn) def sparse_tensor_dense_matmul(sp_a, b, adjoint_a=False, adjoint_b=False, @@ -1426,16 +1327,7 @@ def sparse_tensor_dense_matmul(sp_a, b, adjoint_a=False, adjoint_b=False, adjoint_b=adjoint_b) -@ops.RegisterShape("SparseTensorDenseMatMul") -def _SparseTensorDenseMatMulShape(op): # pylint: disable=invalid-name - """Shape function for SparseTensorDenseMatMul op.""" - adjoint_b = op.get_attr("adjoint_b") - op.inputs[0].get_shape().assert_has_rank(2) # a_indices - op.inputs[1].get_shape().assert_has_rank(1) # a_values - op.inputs[2].get_shape().merge_with(tensor_shape.vector(2)) # a_shape - b_shape = op.inputs[3].get_shape().with_rank(2) - output_shape_right = b_shape[0] if adjoint_b else b_shape[1] - return [tensor_shape.matrix(None, output_shape_right)] +ops.RegisterShape("SparseTensorDenseMatMul")(common_shapes.call_cpp_shape_fn) def sparse_softmax(sp_input, name=None): @@ -1492,14 +1384,7 @@ def sparse_softmax(sp_input, name=None): return ops.SparseTensor(sp_input.indices, out_vals, sp_input.shape) -@ops.RegisterShape("SparseSoftmax") -def _SparseSoftmaxShape(op): # pylint: disable=invalid-name - """Shape function for SparseSoftmax op.""" - unused_indices_shape = op.inputs[0].get_shape().with_rank(2) - values_shape = op.inputs[1].get_shape().with_rank(1) - unused_shape_shape = op.inputs[2].get_shape().with_rank(1) - nnz = values_shape[0] - return [tensor_shape.vector(nnz)] +ops.RegisterShape("SparseSoftmax")(common_shapes.call_cpp_shape_fn) def sparse_maximum(sp_a, sp_b, name=None): @@ -1572,17 +1457,8 @@ def sparse_minimum(sp_a, sp_b, name=None): return ops.SparseTensor(out_indices, out_values, sp_a.shape) -@ops.RegisterShape("SparseSparseMaximum") -@ops.RegisterShape("SparseSparseMinimum") -def _SparseSparseMaximumMinimumShape(op): # pylint: disable=invalid-name - """Shape function for SparseSparseMaximum and SparseSparseMinimum.""" - op.inputs[0].get_shape().assert_has_rank(2) # a_indices - op.inputs[1].get_shape().assert_has_rank(1) # a_values - op.inputs[2].get_shape().assert_has_rank(1) # a_shape - op.inputs[3].get_shape().assert_has_rank(2) # b_indices - op.inputs[4].get_shape().assert_has_rank(1) # b_values - op.inputs[5].get_shape().assert_has_rank(1) # b_shape - return [tensor_shape.unknown_shape(2), tensor_shape.unknown_shape(1)] +ops.RegisterShape("SparseSparseMaximum")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("SparseSparseMinimum")(common_shapes.call_cpp_shape_fn) def sparse_transpose(sp_input, perm=None, name=None): diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py index 7ffa521ac6..8a547eeac5 100644 --- a/tensorflow/python/ops/string_ops.py +++ b/tensorflow/python/ops/string_ops.py @@ -176,30 +176,5 @@ def _ReduceJoinShape(op): return [tensor_shape.TensorShape(returned_dims)] -@ops.RegisterShape("StringJoin") -def _StringJoinShape(op): - """Shape function for the StringJoin op.""" - input_shapes = [x.get_shape() for x in op.inputs] - - # First check if all inputs are scalars. In the next section - # we may have *some* scalars and we will be broadcasting them - if all([s.ndims == 0 for s in input_shapes]): - return [tensor_shape.scalar()] - - base_shape = tensor_shape.unknown_shape() - for shape in input_shapes: - if shape.ndims != 0: - base_shape = base_shape.merge_with(shape) - return [base_shape] - - -@ops.RegisterShape("StringSplit") -def _StringSplitShape(op): - """Shape function for string_ops.string_split.""" - unused_sfs_shape = op.inputs[0].get_shape().with_rank(1) - unused_sfs_shape = op.inputs[1].get_shape().merge_with(tensor_shape.scalar()) - - indices_shape = tensor_shape.TensorShape([None, 2]) - values_shape = tensor_shape.TensorShape([None]) - shape_shape = tensor_shape.TensorShape([2]) - return [indices_shape, values_shape, shape_shape] +ops.RegisterShape("StringJoin")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("StringSplit")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index ef0573c6fd..3472229aff 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -25,6 +25,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as _dtypes from tensorflow.python.framework import ops @@ -373,87 +374,15 @@ class TensorArray(object): handle=self._handle, name=name) -@ops.RegisterShape("TensorArray") -def _TensorArrayShape(op): - # size is a scalar - op.inputs[0].get_shape().merge_with(tensor_shape.scalar()) - return [tensor_shape.vector(2)] - - -@ops.RegisterShape("TensorArrayRead") -def _TensorArrayReadShape(op): - # handle, index, flow_in - op.inputs[0].get_shape().merge_with(tensor_shape.vector(2)) - op.inputs[1].get_shape().merge_with(tensor_shape.scalar()) - op.inputs[2].get_shape().merge_with(tensor_shape.scalar()) - # value - return [tensor_shape.unknown_shape()] - - -@ops.RegisterShape("TensorArrayWrite") -def _TensorArrayWriteShape(op): - # handle, index, value, flow_in - op.inputs[0].get_shape().merge_with(tensor_shape.vector(2)) - op.inputs[1].get_shape().merge_with(tensor_shape.scalar()) - op.inputs[3].get_shape().merge_with(tensor_shape.scalar()) - # flow_out - return [tensor_shape.scalar()] - - -@ops.RegisterShape("TensorArraySize") -def _TensorArraySizeShape(op): - op.inputs[0].get_shape().merge_with(tensor_shape.vector(2)) - return [tensor_shape.scalar()] - - -@ops.RegisterShape("TensorArrayClose") -def _TensorArrayCloseShape(op): - """Shape function for ops that take a scalar and produce no outputs.""" - op.inputs[0].get_shape().merge_with(tensor_shape.vector(2)) - return [] - - -@ops.RegisterShape("TensorArrayGrad") -def _TensorArrayGradShape(op): - """Shape function for ops that take a scalar and produce no outputs.""" - op.inputs[0].get_shape().merge_with(tensor_shape.vector(2)) - return [tensor_shape.vector(2)] - - -@ops.RegisterShape("TensorArrayPack") -def _TensorArrayPackShape(op): - # handle, flow_in - op.inputs[0].get_shape().merge_with(tensor_shape.vector(2)) - op.inputs[1].get_shape().merge_with(tensor_shape.scalar()) - # value - return [tensor_shape.unknown_shape()] - - -@ops.RegisterShape("TensorArrayConcat") -def _TensorArrayConcatShape(op): - # handle, flow_in - op.inputs[0].get_shape().merge_with(tensor_shape.vector(2)) - op.inputs[1].get_shape().merge_with(tensor_shape.scalar()) - # value, lengths - return [tensor_shape.unknown_shape(), tensor_shape.vector(None)] - - -@ops.RegisterShape("TensorArraySplit") -def _TensorArraySplitShape(op): - # handle, value, lengths, flow_in - op.inputs[0].get_shape().merge_with(tensor_shape.vector(2)) - op.inputs[2].get_shape().merge_with(tensor_shape.vector(None)) - op.inputs[3].get_shape().merge_with(tensor_shape.scalar()) - # flow_out - return [tensor_shape.scalar()] - - -@ops.RegisterShape("TensorArrayUnpack") -def _TensorArrayUnpackShape(op): - # handle, value, flow_in - op.inputs[0].get_shape().merge_with(tensor_shape.vector(2)) - op.inputs[2].get_shape().merge_with(tensor_shape.scalar()) - # flow_out - return [tensor_shape.scalar()] +ops.RegisterShape("TensorArray")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TensorArrayRead")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TensorArrayWrite")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TensorArraySize")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TensorArrayClose")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TensorArrayGrad")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TensorArrayPack")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TensorArrayConcat")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TensorArraySplit")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TensorArrayUnpack")(common_shapes.call_cpp_shape_fn) # pylint: enable=protected-access |