aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/tensor_array_ops.py
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-10-10 10:29:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-10 10:33:51 -0700
commit091504af57f70df13ebf1db9946dc59482e1190a (patch)
tree0f6874c8047a94f4526e7d028a7e49f020153fac /tensorflow/python/ops/tensor_array_ops.py
parentcf3cddc2089d310360f2332ac4df2b14344f6cde (diff)
Fix gradient behavior of fully dynamic tensor arrays + stop_gradients on tf.scan.
Added a test checking that this fixes a bug with tf.stop_gradient of tf.scan output. PiperOrigin-RevId: 171697920
Diffstat (limited to 'tensorflow/python/ops/tensor_array_ops.py')
-rw-r--r--tensorflow/python/ops/tensor_array_ops.py28
1 files changed, 11 insertions, 17 deletions
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index 08325ba771..37b4b3bcf9 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -301,6 +301,8 @@ class TensorArray(object):
"""
with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]):
value = ops.convert_to_tensor(value, name="value")
+ if self._infer_shape:
+ self._merge_element_shape(value.shape)
with self._maybe_colocate_with(value):
flow_out = gen_data_flow_ops._tensor_array_write_v3(
handle=self._handle,
@@ -314,8 +316,6 @@ class TensorArray(object):
ta._infer_shape = self._infer_shape
ta._element_shape = self._element_shape
ta._colocate_with = self._colocate_with
- if ta._infer_shape:
- ta._merge_element_shape(value.get_shape())
return ta
def stack(self, name=None):
@@ -433,6 +433,8 @@ class TensorArray(object):
with ops.name_scope(name, "TensorArrayScatter",
[self._handle, value, indices]):
value = ops.convert_to_tensor(value, name="value")
+ if self._infer_shape and context.in_graph_mode():
+ self._merge_element_shape(value.shape[1:])
with self._maybe_colocate_with(value):
flow_out = gen_data_flow_ops._tensor_array_scatter_v3(
handle=self._handle,
@@ -446,12 +448,6 @@ class TensorArray(object):
ta._infer_shape = self._infer_shape
ta._element_shape = self._element_shape
ta._colocate_with = self._colocate_with
- if ta._infer_shape and context.in_graph_mode():
- val_shape = flow_out.op.inputs[2].get_shape()
- element_shape = tensor_shape.unknown_shape()
- if val_shape.dims is not None:
- element_shape = tensor_shape.TensorShape(val_shape.dims[1:])
- ta._merge_element_shape(element_shape)
return ta
@tf_should_use.should_use_result
@@ -476,6 +472,13 @@ class TensorArray(object):
value = ops.convert_to_tensor(value, name="value")
with self._maybe_colocate_with(value):
lengths_64 = math_ops.to_int64(lengths)
+ if self._infer_shape and context.in_graph_mode():
+ clengths = tensor_util.constant_value(lengths_64)
+ if value.shape.dims is not None:
+ if clengths is not None and clengths.max() == clengths.min():
+ self._merge_element_shape(
+ tensor_shape.TensorShape([clengths[0]]).concatenate(
+ value.shape[1:]))
flow_out = gen_data_flow_ops._tensor_array_split_v3(
handle=self._handle,
value=value,
@@ -488,15 +491,6 @@ class TensorArray(object):
ta._infer_shape = self._infer_shape
ta._element_shape = self._element_shape
ta._colocate_with = self._colocate_with
- if ta._infer_shape and context.in_graph_mode():
- val_shape = flow_out.op.inputs[1].get_shape()
- clengths = tensor_util.constant_value(flow_out.op.inputs[2])
- element_shape = tensor_shape.unknown_shape()
- if val_shape.dims is not None:
- if clengths is not None and clengths.max() == clengths.min():
- element_shape = tensor_shape.TensorShape([clengths[0]] +
- val_shape.dims[1:])
- ta._merge_element_shape(element_shape)
return ta
def size(self, name=None):