diff options
author | 2016-08-30 12:24:42 -0800 | |
---|---|---|
committer | 2016-08-30 13:33:09 -0700 | |
commit | 660d7e5ae51d8d74643874b9ae10f855dced69bb (patch) | |
tree | ec62bdb867b4c09a8a458f94911dc96387973e95 /tensorflow/python/ops/tensor_array_ops.py | |
parent | 5532f08194c2e29e6056e5339cbd57c52db1907a (diff) |
Delegate to C++ shape inference functions for several ops in
image_ops, tensor_array_ops, sparse_ops, and string_ops.
Fix bugs in C++ shape inference for image decoding (channel 0 means
auto-detect) and string split (some output dims are known).
Change: 131747517
Diffstat (limited to 'tensorflow/python/ops/tensor_array_ops.py')
-rw-r--r-- | tensorflow/python/ops/tensor_array_ops.py | 93 |
1 files changed, 11 insertions, 82 deletions
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index ef0573c6fd..3472229aff 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -25,6 +25,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as _dtypes from tensorflow.python.framework import ops @@ -373,87 +374,15 @@ class TensorArray(object): handle=self._handle, name=name) -@ops.RegisterShape("TensorArray") -def _TensorArrayShape(op): - # size is a scalar - op.inputs[0].get_shape().merge_with(tensor_shape.scalar()) - return [tensor_shape.vector(2)] - - -@ops.RegisterShape("TensorArrayRead") -def _TensorArrayReadShape(op): - # handle, index, flow_in - op.inputs[0].get_shape().merge_with(tensor_shape.vector(2)) - op.inputs[1].get_shape().merge_with(tensor_shape.scalar()) - op.inputs[2].get_shape().merge_with(tensor_shape.scalar()) - # value - return [tensor_shape.unknown_shape()] - - -@ops.RegisterShape("TensorArrayWrite") -def _TensorArrayWriteShape(op): - # handle, index, value, flow_in - op.inputs[0].get_shape().merge_with(tensor_shape.vector(2)) - op.inputs[1].get_shape().merge_with(tensor_shape.scalar()) - op.inputs[3].get_shape().merge_with(tensor_shape.scalar()) - # flow_out - return [tensor_shape.scalar()] - - -@ops.RegisterShape("TensorArraySize") -def _TensorArraySizeShape(op): - op.inputs[0].get_shape().merge_with(tensor_shape.vector(2)) - return [tensor_shape.scalar()] - - -@ops.RegisterShape("TensorArrayClose") -def _TensorArrayCloseShape(op): - """Shape function for ops that take a scalar and produce no outputs.""" - op.inputs[0].get_shape().merge_with(tensor_shape.vector(2)) - return [] - - -@ops.RegisterShape("TensorArrayGrad") -def _TensorArrayGradShape(op): - """Shape function for ops that take a scalar and produce no outputs.""" - op.inputs[0].get_shape().merge_with(tensor_shape.vector(2)) - return [tensor_shape.vector(2)] - - -@ops.RegisterShape("TensorArrayPack") -def _TensorArrayPackShape(op): - # handle, flow_in - op.inputs[0].get_shape().merge_with(tensor_shape.vector(2)) - op.inputs[1].get_shape().merge_with(tensor_shape.scalar()) - # value - return [tensor_shape.unknown_shape()] - - -@ops.RegisterShape("TensorArrayConcat") -def _TensorArrayConcatShape(op): - # handle, flow_in - op.inputs[0].get_shape().merge_with(tensor_shape.vector(2)) - op.inputs[1].get_shape().merge_with(tensor_shape.scalar()) - # value, lengths - return [tensor_shape.unknown_shape(), tensor_shape.vector(None)] - - -@ops.RegisterShape("TensorArraySplit") -def _TensorArraySplitShape(op): - # handle, value, lengths, flow_in - op.inputs[0].get_shape().merge_with(tensor_shape.vector(2)) - op.inputs[2].get_shape().merge_with(tensor_shape.vector(None)) - op.inputs[3].get_shape().merge_with(tensor_shape.scalar()) - # flow_out - return [tensor_shape.scalar()] - - -@ops.RegisterShape("TensorArrayUnpack") -def _TensorArrayUnpackShape(op): - # handle, value, flow_in - op.inputs[0].get_shape().merge_with(tensor_shape.vector(2)) - op.inputs[2].get_shape().merge_with(tensor_shape.scalar()) - # flow_out - return [tensor_shape.scalar()] +ops.RegisterShape("TensorArray")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TensorArrayRead")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TensorArrayWrite")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TensorArraySize")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TensorArrayClose")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TensorArrayGrad")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TensorArrayPack")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TensorArrayConcat")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TensorArraySplit")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("TensorArrayUnpack")(common_shapes.call_cpp_shape_fn) # pylint: enable=protected-access |