aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/list_ops.py
blob: 145a5f358c1707b333f6167bd02496353e3b2e82 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops to manipulate lists of tensors."""

# pylint: disable=g-bad-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_list_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_list_ops import *
# pylint: enable=wildcard-import


ops.NotDifferentiable("TensorListConcat")
ops.NotDifferentiable("TensorListPushBackBatch")


@ops.RegisterGradient("TensorListPushBack")
def _PushBackGrad(op, dresult):
  return gen_list_ops.tensor_list_pop_back(
      dresult, element_dtype=op.get_attr("element_dtype"))


@ops.RegisterGradient("TensorListPopBack")
def _PopBackGrad(op, dlist, delement):
  if dlist is None:
    dlist = gen_list_ops.empty_tensor_list(
        element_dtype=delement.dtype,
        element_shape=gen_list_ops.tensor_list_element_shape(
            op.outputs[0], shape_type=dtypes.int32))
  return gen_list_ops.tensor_list_push_back(dlist, delement)


@ops.RegisterGradient("TensorListStack")
def _TensorListStackGrad(unused_op, dtensor):
  return gen_list_ops.tensor_list_from_tensor(dtensor,
                                              element_shape=dtensor.shape[1:])


@ops.RegisterGradient("TensorListFromTensor")
def _TensorListFromTensorGrad(op, dlist):
  """Gradient for TensorListFromTensor."""
  if op.inputs[0].shape[0].value is not None:
    num_elements = op.inputs[0].shape[0].value
  else:
    num_elements = None
  if dlist is None:
    dlist = gen_list_ops.empty_tensor_list(
        element_dtype=op.inputs[0].dtype,
        element_shape=gen_list_ops.tensor_list_element_shape(
            op.outputs[0], shape_type=dtypes.int32))
  tensor_grad = gen_list_ops.tensor_list_stack(
      dlist, element_dtype=op.inputs[0].dtype, num_elements=num_elements)
  shape_grad = None
  return tensor_grad, shape_grad


@ops.RegisterGradient("TensorListGetItem")
def _TensorListGetItemGrad(op, ditem):
  """Gradient for TensorListGetItem."""
  list_size = gen_list_ops.tensor_list_length(op.inputs[0])
  list_grad = gen_list_ops.tensor_list_set_item(
      gen_list_ops.tensor_list_reserve(
          gen_list_ops.tensor_list_element_shape(op.inputs[0],
                                                 shape_type=dtypes.int32),
          list_size, element_dtype=ditem.dtype),
      index=op.inputs[1],
      item=ditem)
  index_grad = None
  return list_grad, index_grad


@ops.RegisterGradient("TensorListSetItem")
def _TensorListSetItemGrad(op, dlist):
  _, index, item = op.inputs
  list_grad = gen_list_ops.tensor_list_set_item(
      dlist, index=index, item=array_ops.zeros_like(item))
  index_grad = None
  element_grad = gen_list_ops.tensor_list_get_item(
      dlist, index, element_dtype=item.dtype)
  return list_grad, index_grad, element_grad


@ops.RegisterGradient("TensorListGather")
def _TensorListGatherGrad(op, dtensor):
  _, indices = op.inputs
  return gen_list_ops.tensor_list_scatter(
      tensor=dtensor, indices=indices,
      element_shape=ops.convert_to_tensor(-1, dtype=dtypes.int32)), None


@ops.RegisterGradient("TensorListScatter")
def _TensorListScatterGrad(op, dlist):
  t, indices, _ = op.inputs
  return gen_list_ops.tensor_list_gather(
      dlist, indices, element_dtype=t.dtype), None