aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
blob: 591e61b4c82836bc1995cd11c4c0314c9d854e50 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// XLA Stack operators.

#include <limits>
#include <vector>

#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/concat_lib.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {
namespace {

Status GetStackShape(xla::XlaBuilder* builder, XlaResource* resource,
                     TensorShape* stack_shape) {
  auto shape_or_status = builder->GetShape(resource->value());
  if (!shape_or_status.ok()) {
    return shape_or_status.status();
  }
  xla::Shape shape = shape_or_status.ValueOrDie();
  TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape));
  return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0),
                               stack_shape);
}

// Since the element shape is not provided to the Stack operator,
// we lazily initialize the Stack at the time of the first write.
//
// If a Stack `resource` has not been initialized, constructs storage for the
// Stack with elements of `elem_shape`. For both initialized and
// uninitialized Stacks, checks that the tensor has a type compatible with
// 'dtype' and shape compatible with 'elem_shape'.
//
// TODO(phawkins): consider changing the API of the stack operators to
// allow an optional element shape at stack construction time.
Status MaybeInitializeStack(xla::XlaBuilder* builder, XlaResource* resource,
                            DataType dtype, const TensorShape& elem_shape) {
  if (resource->type() != dtype) {
    return errors::InvalidArgument(
        "Stack dtype is ", DataTypeString(resource->type()),
        " but op has dtype ", DataTypeString(dtype), ".");
  }

  TensorShape stack_shape;
  stack_shape.AddDim(resource->tensor_array_size());
  stack_shape.AppendShape(elem_shape);

  if (!resource->initialized()) {
    // Stack has not been initialized.
    TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
    TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
  } else {
    // Checks the expected shape matches the actual shape.
    TensorShape actual_shape;
    TF_RETURN_IF_ERROR(GetStackShape(builder, resource, &actual_shape));
    if (stack_shape != actual_shape) {
      return errors::InvalidArgument(
          "Mismatched Stack shapes: ", stack_shape.DebugString(), " vs ",
          actual_shape.DebugString());
    }
  }
  return Status::OK();
}

class StackOp : public XlaOpKernel {
 public:
  explicit StackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    OP_REQUIRES_OK(ctx, ctx->GetAttr("elem_type", &dtype_));
    OP_REQUIRES_OK(ctx, ctx->GetAttr("stack_name", &stack_name_));
  }

  void Compile(XlaOpKernelContext* ctx) override {
    int64 size;
    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &size));
    OP_REQUIRES(
        ctx, size >= 0,
        errors::InvalidArgument(
            "XLA compilation requires a fixed stack size upper bound. If "
            "you are using tf.while_loop, set the maximum_iterations parameter "
            "to fix this issue."));

    // We defer initializing the Stack resource until we see the first push.
    // Otherwise we do not know the shape of the stack elements.
    xla::XlaOp value;
    XlaContext& xc = XlaContext::Get(ctx);
    XlaResource* resource;
    string name = strings::StrCat("Stack: ", stack_name_);
    OP_REQUIRES_OK(
        ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_,
                               TensorShape(), value, /*tensor_array_size=*/size,
                               /*tensor_array_gradients=*/{}, &resource));
    ctx->SetResourceOutput(0, resource);
  }

 private:
  DataType dtype_;
  string stack_name_;

  TF_DISALLOW_COPY_AND_ASSIGN(StackOp);
};

REGISTER_XLA_OP(Name("StackV2").CompileTimeConstInput("max_size"), StackOp);

class StackPushOp : public XlaOpKernel {
 public:
  explicit StackPushOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
  }

  void Compile(XlaOpKernelContext* ctx) override {
    xla::XlaBuilder* b = ctx->builder();
    TensorShape elem_shape = ctx->InputShape(1);

    XlaResource* resource;
    OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));

    // Initializes the Stack, if the element shape was not already known.
    OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape));

    xla::XlaOp ta = xla::GetTupleElement(resource->value(), 0);
    xla::XlaOp index = xla::GetTupleElement(resource->value(), 1);
    xla::XlaOp value = ctx->Input(1);

    // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
    auto start_indices =
        xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
                 xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));

    TensorShape slice_shape = elem_shape;
    slice_shape.InsertDim(0, 1LL);
    auto update = xla::Reshape(value, slice_shape.dim_sizes());

    // TODO(phawkins): We don't check the index is in bounds --- there is no
    // error mechanism in XLA.
    OP_REQUIRES_OK(ctx,
                   resource->SetValue(xla::Tuple(
                       b, {xla::DynamicUpdateSlice(ta, update, start_indices),
                           xla::Add(index, xla::ConstantR0<int32>(b, 1))})));

    ctx->SetOutput(0, value);
  }

 private:
  DataType dtype_;

  TF_DISALLOW_COPY_AND_ASSIGN(StackPushOp);
};

REGISTER_XLA_OP(Name("StackPushV2"), StackPushOp);

class StackPopOp : public XlaOpKernel {
 public:
  explicit StackPopOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    OP_REQUIRES_OK(ctx, ctx->GetAttr("elem_type", &dtype_));
  }

  void Compile(XlaOpKernelContext* ctx) override {
    xla::XlaBuilder* b = ctx->builder();

    XlaResource* resource;
    OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));

    // There is a somewhat subtle issue here: here "uninitialized" means we have
    // not yet seen a pop in the order that we compile operators, not the order
    // that we run them. However, in practice the two orders should be the same
    // for the sole user of the stack operators (loop gradients).
    OP_REQUIRES(ctx, resource->initialized(),
                errors::InvalidArgument("Stack pop on uninitialized stack"));

    TensorShape stack_shape;
    OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape));

    xla::XlaOp state = resource->value();
    xla::XlaOp ta = xla::GetTupleElement(state, 0);
    xla::XlaOp index = xla::GetTupleElement(state, 1);

    index = Sub(index, xla::ConstantR0<int32>(b, 1));
    OP_REQUIRES_OK(ctx, resource->SetValue(xla::Tuple(b, {ta, index})));

    // start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
    auto start_indices =
        xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
                 xla::MakeEdgePaddingConfig({{0, stack_shape.dims() - 1}}));

    auto slice_shape = stack_shape.dim_sizes();
    slice_shape[0] = 1LL;

    // TODO(phawkins): We don't check the index is in bounds --- there is no
    // error mechanism in XLA.
    xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape);

    // Remove the leading '1' dimension.
    std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
    ctx->SetOutput(0, xla::Reshape(read, value_shape));
  }

 private:
  DataType dtype_;

  TF_DISALLOW_COPY_AND_ASSIGN(StackPopOp);
};

REGISTER_XLA_OP(Name("StackPopV2"), StackPopOp);

class StackCloseOp : public XlaOpKernel {
 public:
  explicit StackCloseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}

  void Compile(XlaOpKernelContext* ctx) override {
    // Do nothing.
  }

 private:
  TF_DISALLOW_COPY_AND_ASSIGN(StackCloseOp);
};

REGISTER_XLA_OP(Name("StackCloseV2"), StackCloseOp);

}  // anonymous namespace
}  // namespace tensorflow