aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
blob: f210bfbd886e48b8d7972393ed1899491486646c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Native XLA implementations of indexing ops.

#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/bounds_check.h"

namespace tensorflow {
namespace {

// The logic below uses a custom-call to implement argmax.
//
// Also see b/29507024 for first-class XLA support for indexing ops.
class ArgMaxCustomCallOp : public XlaOpKernel {
 public:
  explicit ArgMaxCustomCallOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}

  void Compile(XlaOpKernelContext* ctx) override {
    const TensorShape input_shape = ctx->InputShape(0);
    const TensorShape dimension_shape = ctx->InputShape(1);
    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(dimension_shape),
                errors::InvalidArgument(
                    "dim must be a scalar, but received tensor of shape: ",
                    dimension_shape.DebugString()));

    // We require that the dimension argument is a constant, since it lets us
    // dispatch to a specialized custom-call function without any run-time
    // overhead, when compiling ahead-of-time.
    xla::Literal literal;
    OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal));
    const int32 dim = literal.Get<int32>({});
    OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0"));
    OP_REQUIRES(
        ctx, dim < input_shape.dims(),
        errors::InvalidArgument("dim must be < input rank (",
                                input_shape.dims(), "), but got: ", dim));
    const int64 dim_size = input_shape.dim_size(dim);
    OP_REQUIRES(ctx, dim_size > 0,
                errors::InvalidArgument(
                    "Reduction axis ", dim,
                    " is empty in shape: ", input_shape.DebugString()));

    // The output shape is the input shape contracted along dim.
    TensorShape output_shape;
    for (int d = 0; d < input_shape.dims() - 1; ++d) {
      output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1));
    }

    // For now we use a custom-call, only for the 1d and 2d cases.
    OP_REQUIRES(ctx, XlaContext::Get(ctx).allow_cpu_custom_calls(),
                errors::InvalidArgument(
                    "ArgMax implementation requires a CustomCall on CPU"));
    xla::XlaBuilder& b = *ctx->builder();

    // XLA passes <out> to the function, so it is not included here.
    std::vector<xla::XlaOp> args;
    args.push_back(ctx->Input(0));
    args.push_back(xla::ConstantLiteral(
        &b, xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
    if (input_shape.dims() > 1) {
      // Don't bother passing the output shape and dim for the 1d case, since
      // the shape is always a scalar and the dim is always 0.
      args.push_back(xla::ConstantLiteral(
          &b, xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
      args.push_back(
          xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(dim)));
    }

    // The argmax function expects row-major layout.
    xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout(
        xla::S64, output_shape.dim_sizes());
    std::vector<xla::Shape> arg_shapes;
    for (const xla::XlaOp& arg : args) {
      auto shape_status = b.GetShape(arg);
      OP_REQUIRES_OK(ctx, shape_status.status());
      xla::Shape arg_shape = shape_status.ConsumeValueOrDie();
      *arg_shape.mutable_layout() = xla::LayoutUtil::MakeDescendingLayout(
          xla::ShapeUtil::Rank(arg_shape));
      arg_shapes.push_back(std::move(arg_shape));
    }

    // Tell XLA to call the custom code, defined in
    // index_ops_kernel_argmax_float_1d.cc.
    xla::XlaOp output;
    switch (input_shape.dims()) {
      case 1:
        output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args,
                                           xla_shape, arg_shapes);
        break;
      case 2:
        output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args,
                                           xla_shape, arg_shapes);
        break;
      default:
        OP_REQUIRES(ctx, false,
                    errors::Unimplemented(
                        "Argmax is only implemented for 1d and 2d tensors"
                        ", but got shape: ",
                        input_shape.DebugString()));
    }
    ctx->SetOutput(0, output);
  }

 private:
  TF_DISALLOW_COPY_AND_ASSIGN(ArgMaxCustomCallOp);
};

REGISTER_XLA_OP(Name("ArgMax")
                    .TypeConstraint("T", DT_FLOAT)
                    .Device(DEVICE_CPU_XLA_JIT)
                    .CompileTimeConstInput("dimension"),
                ArgMaxCustomCallOp);

}  // namespace
}  // namespace tensorflow