diff options
author | 2017-10-10 10:29:43 -0700 | |
---|---|---|
committer | 2017-10-10 10:33:51 -0700 | |
commit | 091504af57f70df13ebf1db9946dc59482e1190a (patch) | |
tree | 0f6874c8047a94f4526e7d028a7e49f020153fac /tensorflow/python/ops/tensor_array_ops.py | |
parent | cf3cddc2089d310360f2332ac4df2b14344f6cde (diff) |
Fix gradient behavior of fully dynamic tensor arrays + stop_gradients on tf.scan.
Added a test checking that this fixes a bug with tf.stop_gradient of tf.scan output.
PiperOrigin-RevId: 171697920
Diffstat (limited to 'tensorflow/python/ops/tensor_array_ops.py')
-rw-r--r-- | tensorflow/python/ops/tensor_array_ops.py | 28 |
1 files changed, 11 insertions, 17 deletions
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index 08325ba771..37b4b3bcf9 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -301,6 +301,8 @@ class TensorArray(object): """ with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]): value = ops.convert_to_tensor(value, name="value") + if self._infer_shape: + self._merge_element_shape(value.shape) with self._maybe_colocate_with(value): flow_out = gen_data_flow_ops._tensor_array_write_v3( handle=self._handle, @@ -314,8 +316,6 @@ class TensorArray(object): ta._infer_shape = self._infer_shape ta._element_shape = self._element_shape ta._colocate_with = self._colocate_with - if ta._infer_shape: - ta._merge_element_shape(value.get_shape()) return ta def stack(self, name=None): @@ -433,6 +433,8 @@ class TensorArray(object): with ops.name_scope(name, "TensorArrayScatter", [self._handle, value, indices]): value = ops.convert_to_tensor(value, name="value") + if self._infer_shape and context.in_graph_mode(): + self._merge_element_shape(value.shape[1:]) with self._maybe_colocate_with(value): flow_out = gen_data_flow_ops._tensor_array_scatter_v3( handle=self._handle, @@ -446,12 +448,6 @@ class TensorArray(object): ta._infer_shape = self._infer_shape ta._element_shape = self._element_shape ta._colocate_with = self._colocate_with - if ta._infer_shape and context.in_graph_mode(): - val_shape = flow_out.op.inputs[2].get_shape() - element_shape = tensor_shape.unknown_shape() - if val_shape.dims is not None: - element_shape = tensor_shape.TensorShape(val_shape.dims[1:]) - ta._merge_element_shape(element_shape) return ta @tf_should_use.should_use_result @@ -476,6 +472,13 @@ class TensorArray(object): value = ops.convert_to_tensor(value, name="value") with self._maybe_colocate_with(value): lengths_64 = math_ops.to_int64(lengths) + if self._infer_shape and context.in_graph_mode(): + clengths = tensor_util.constant_value(lengths_64) + if value.shape.dims is not None: + if clengths is not None and clengths.max() == clengths.min(): + self._merge_element_shape( + tensor_shape.TensorShape([clengths[0]]).concatenate( + value.shape[1:])) flow_out = gen_data_flow_ops._tensor_array_split_v3( handle=self._handle, value=value, @@ -488,15 +491,6 @@ class TensorArray(object): ta._infer_shape = self._infer_shape ta._element_shape = self._element_shape ta._colocate_with = self._colocate_with - if ta._infer_shape and context.in_graph_mode(): - val_shape = flow_out.op.inputs[1].get_shape() - clengths = tensor_util.constant_value(flow_out.op.inputs[2]) - element_shape = tensor_shape.unknown_shape() - if val_shape.dims is not None: - if clengths is not None and clengths.max() == clengths.min(): - element_shape = tensor_shape.TensorShape([clengths[0]] + - val_shape.dims[1:]) - ta._merge_element_shape(element_shape) return ta def size(self, name=None): |