1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops to manipulate lists of tensors."""
# pylint: disable=g-bad-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_list_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_list_ops import *
# pylint: enable=wildcard-import
ops.NotDifferentiable("TensorListConcat")
ops.NotDifferentiable("TensorListPushBackBatch")
@ops.RegisterGradient("TensorListPushBack")
def _PushBackGrad(op, dresult):
return gen_list_ops.tensor_list_pop_back(
dresult, element_dtype=op.get_attr("element_dtype"))
@ops.RegisterGradient("TensorListPopBack")
def _PopBackGrad(op, dlist, delement):
if dlist is None:
dlist = gen_list_ops.empty_tensor_list(
element_dtype=delement.dtype,
element_shape=gen_list_ops.tensor_list_element_shape(
op.outputs[0], shape_type=dtypes.int32))
return gen_list_ops.tensor_list_push_back(dlist, delement)
@ops.RegisterGradient("TensorListStack")
def _TensorListStackGrad(unused_op, dtensor):
return gen_list_ops.tensor_list_from_tensor(dtensor,
element_shape=dtensor.shape[1:])
@ops.RegisterGradient("TensorListFromTensor")
def _TensorListFromTensorGrad(op, dlist):
"""Gradient for TensorListFromTensor."""
if op.inputs[0].shape[0].value is not None:
num_elements = op.inputs[0].shape[0].value
else:
num_elements = None
if dlist is None:
dlist = gen_list_ops.empty_tensor_list(
element_dtype=op.inputs[0].dtype,
element_shape=gen_list_ops.tensor_list_element_shape(
op.outputs[0], shape_type=dtypes.int32))
tensor_grad = gen_list_ops.tensor_list_stack(
dlist, element_dtype=op.inputs[0].dtype, num_elements=num_elements)
shape_grad = None
return tensor_grad, shape_grad
@ops.RegisterGradient("TensorListGetItem")
def _TensorListGetItemGrad(op, ditem):
"""Gradient for TensorListGetItem."""
list_size = gen_list_ops.tensor_list_length(op.inputs[0])
list_grad = gen_list_ops.tensor_list_set_item(
gen_list_ops.tensor_list_reserve(
gen_list_ops.tensor_list_element_shape(op.inputs[0],
shape_type=dtypes.int32),
list_size, element_dtype=ditem.dtype),
index=op.inputs[1],
item=ditem)
index_grad = None
return list_grad, index_grad
@ops.RegisterGradient("TensorListSetItem")
def _TensorListSetItemGrad(op, dlist):
_, index, item = op.inputs
list_grad = gen_list_ops.tensor_list_set_item(
dlist, index=index, item=array_ops.zeros_like(item))
index_grad = None
element_grad = gen_list_ops.tensor_list_get_item(
dlist, index, element_dtype=item.dtype)
return list_grad, index_grad, element_grad
@ops.RegisterGradient("TensorListGather")
def _TensorListGatherGrad(op, dtensor):
_, indices = op.inputs
return gen_list_ops.tensor_list_scatter(
tensor=dtensor, indices=indices,
element_shape=ops.convert_to_tensor(-1, dtype=dtypes.int32)), None
@ops.RegisterGradient("TensorListScatter")
def _TensorListScatterGrad(op, dlist):
t, indices, _ = op.inputs
return gen_list_ops.tensor_list_gather(
dlist, indices, element_dtype=t.dtype), None
|