aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
blob: 56f237d5887fd9c88bb74bafcc5e44470f8807bf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <vector>

#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/concat_lib.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {
namespace {

// TODO(phawkins): implement double-sized windowed reductions in XLA and remove
// the type constraint.
constexpr std::array<DataType, 3> kScanOpTypes = {
    {DT_HALF, DT_BFLOAT16, DT_FLOAT}};

class ScanOp : public XlaOpKernel {
 public:
  ScanOp(OpKernelConstruction* ctx, bool sum) : XlaOpKernel(ctx), sum_(sum) {
    OP_REQUIRES_OK(ctx, ctx->GetAttr("reverse", &reverse_));
    OP_REQUIRES_OK(ctx, ctx->GetAttr("exclusive", &exclusive_));
  }

  void Compile(XlaOpKernelContext* ctx) override {
    const TensorShape input_shape = ctx->InputShape(0);
    const TensorShape tensor_axis_shape = ctx->InputShape(1);

    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_axis_shape),
                errors::InvalidArgument("ScanOp: axis must be a scalar, not ",
                                        tensor_axis_shape.DebugString()));

    int64 axis;
    OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &axis));
    if (axis < 0) {
      axis += input_shape.dims();
    }
    OP_REQUIRES(
        ctx, FastBoundsCheck(axis, input_shape.dims()),
        errors::InvalidArgument("ScanOp: Expected scan axis in the range [",
                                -input_shape.dims(), ", ", input_shape.dims(),
                                "), but got ", axis));

    DataType dtype = XlaHelpers::SumAccumulationType(ctx->input_type(0));

    if (input_shape.num_elements() == 0) {
      // Exit early if there is nothing to compute.
      ctx->SetOutput(0, ctx->Input(0));
      return;
    }

    xla::XlaBuilder* builder = ctx->builder();

    std::vector<int64> window_strides(input_shape.dims(), 1);
    std::vector<int64> window_dims(input_shape.dims(), 1);
    window_dims[axis] = input_shape.dim_size(axis);

    std::vector<std::pair<int64, int64>> padding(input_shape.dims(), {0, 0});
    padding[axis].first = input_shape.dim_size(axis) - 1;
    // In exclusive mode, add an extra padding element so there is a complete
    // window of padding before the data starts.
    if (exclusive_) {
      ++padding[axis].first;
    }
    if (reverse_) {
      std::swap(padding[axis].first, padding[axis].second);
    }

    xla::XlaOp init;
    const xla::XlaComputation* reducer;
    if (sum_) {
      init = XlaHelpers::Zero(builder, dtype);
      reducer = ctx->GetOrCreateAdd(dtype);
    } else {
      init = XlaHelpers::One(builder, dtype);
      reducer = ctx->GetOrCreateMul(dtype);
    }
    auto output = xla::ReduceWindowWithGeneralPadding(
        XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init,
        *reducer, window_dims, window_strides, padding);
    output =
        XlaHelpers::ConvertElementType(builder, output, ctx->input_type(0));

    // In exclusive mode, we have computed an extra element containing the sum
    // of all the input elements. Slice off this extra "last" element.
    if (exclusive_) {
      if (reverse_) {
        output =
            xla::SliceInDim(output, 1, input_shape.dim_size(axis) + 1, 1, axis);

      } else {
        output =
            xla::SliceInDim(output, 0, input_shape.dim_size(axis), 1, axis);
      }
    }
    ctx->SetOutput(0, output);
  }

 private:
  const bool sum_;  // True=cumulative sum. False=cumulative product.
  bool reverse_;
  bool exclusive_;
};

class CumsumOp : public ScanOp {
 public:
  explicit CumsumOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/true) {}
};
REGISTER_XLA_OP(Name("Cumsum")
                    .TypeConstraint("T", kScanOpTypes)
                    .CompileTimeConstInput("axis"),
                CumsumOp);

class CumprodOp : public ScanOp {
 public:
  explicit CumprodOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/false) {}
};
REGISTER_XLA_OP(Name("Cumprod")
                    .TypeConstraint("T", kScanOpTypes)
                    .CompileTimeConstInput("axis"),
                CumprodOp);

}  // anonymous namespace
}  // namespace tensorflow