aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/tensor_array_ops.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-30 12:24:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-30 13:33:09 -0700
commit660d7e5ae51d8d74643874b9ae10f855dced69bb (patch)
treeec62bdb867b4c09a8a458f94911dc96387973e95 /tensorflow/python/ops/tensor_array_ops.py
parent5532f08194c2e29e6056e5339cbd57c52db1907a (diff)
Delegate to C++ shape inference functions for several ops in
image_ops, tensor_array_ops, sparse_ops, and string_ops. Fix bugs in C++ shape inference for image decoding (channel 0 means auto-detect) and string split (some output dims are known). Change: 131747517
Diffstat (limited to 'tensorflow/python/ops/tensor_array_ops.py')
-rw-r--r--tensorflow/python/ops/tensor_array_ops.py93
1 files changed, 11 insertions, 82 deletions
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index ef0573c6fd..3472229aff 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -25,6 +25,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.python.framework import ops
@@ -373,87 +374,15 @@ class TensorArray(object):
handle=self._handle, name=name)
-@ops.RegisterShape("TensorArray")
-def _TensorArrayShape(op):
- # size is a scalar
- op.inputs[0].get_shape().merge_with(tensor_shape.scalar())
- return [tensor_shape.vector(2)]
-
-
-@ops.RegisterShape("TensorArrayRead")
-def _TensorArrayReadShape(op):
- # handle, index, flow_in
- op.inputs[0].get_shape().merge_with(tensor_shape.vector(2))
- op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
- op.inputs[2].get_shape().merge_with(tensor_shape.scalar())
- # value
- return [tensor_shape.unknown_shape()]
-
-
-@ops.RegisterShape("TensorArrayWrite")
-def _TensorArrayWriteShape(op):
- # handle, index, value, flow_in
- op.inputs[0].get_shape().merge_with(tensor_shape.vector(2))
- op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
- op.inputs[3].get_shape().merge_with(tensor_shape.scalar())
- # flow_out
- return [tensor_shape.scalar()]
-
-
-@ops.RegisterShape("TensorArraySize")
-def _TensorArraySizeShape(op):
- op.inputs[0].get_shape().merge_with(tensor_shape.vector(2))
- return [tensor_shape.scalar()]
-
-
-@ops.RegisterShape("TensorArrayClose")
-def _TensorArrayCloseShape(op):
- """Shape function for ops that take a scalar and produce no outputs."""
- op.inputs[0].get_shape().merge_with(tensor_shape.vector(2))
- return []
-
-
-@ops.RegisterShape("TensorArrayGrad")
-def _TensorArrayGradShape(op):
- """Shape function for ops that take a scalar and produce no outputs."""
- op.inputs[0].get_shape().merge_with(tensor_shape.vector(2))
- return [tensor_shape.vector(2)]
-
-
-@ops.RegisterShape("TensorArrayPack")
-def _TensorArrayPackShape(op):
- # handle, flow_in
- op.inputs[0].get_shape().merge_with(tensor_shape.vector(2))
- op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
- # value
- return [tensor_shape.unknown_shape()]
-
-
-@ops.RegisterShape("TensorArrayConcat")
-def _TensorArrayConcatShape(op):
- # handle, flow_in
- op.inputs[0].get_shape().merge_with(tensor_shape.vector(2))
- op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
- # value, lengths
- return [tensor_shape.unknown_shape(), tensor_shape.vector(None)]
-
-
-@ops.RegisterShape("TensorArraySplit")
-def _TensorArraySplitShape(op):
- # handle, value, lengths, flow_in
- op.inputs[0].get_shape().merge_with(tensor_shape.vector(2))
- op.inputs[2].get_shape().merge_with(tensor_shape.vector(None))
- op.inputs[3].get_shape().merge_with(tensor_shape.scalar())
- # flow_out
- return [tensor_shape.scalar()]
-
-
-@ops.RegisterShape("TensorArrayUnpack")
-def _TensorArrayUnpackShape(op):
- # handle, value, flow_in
- op.inputs[0].get_shape().merge_with(tensor_shape.vector(2))
- op.inputs[2].get_shape().merge_with(tensor_shape.scalar())
- # flow_out
- return [tensor_shape.scalar()]
+ops.RegisterShape("TensorArray")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("TensorArrayRead")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("TensorArrayWrite")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("TensorArraySize")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("TensorArrayClose")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("TensorArrayGrad")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("TensorArrayPack")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("TensorArrayConcat")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("TensorArraySplit")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("TensorArrayUnpack")(common_shapes.call_cpp_shape_fn)
# pylint: enable=protected-access