aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
blob: 3cd0b39c871c53f807d5736ae15ff3d108efc69e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// XLA-specific base classes for Unary and Binary Ops.

#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"

#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/util/bcast.h"

namespace tensorflow {

void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
  const TensorShape lhs_shape = ctx->InputShape(0);
  const TensorShape rhs_shape = ctx->InputShape(1);

  // By TensorFlow conventions the inputs may not have the same
  // shapes, in which case they will be automatically broadcast if
  // possible before mapping. Use the standard TensorFlow helper to
  // compute valid broadcast shapes, but rely below on XLA to
  // automatically perform the broadcast assuming its valid shapes are
  // a superset of TensorFlow's valid shapes.
  BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape));
  if (!bcast.IsValid()) {
    ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ",
                                           lhs_shape.DebugString(), " vs. ",
                                           rhs_shape.DebugString()));
    return;
  }
  TensorShape bcast_shape = BCast::ToShape(bcast.output_shape());

  // Fetch the expressions containing the input tensors.
  auto lhs_handle = ctx->Input(0);
  auto rhs_handle = ctx->Input(1);

  // If the ranks of the inputs don't match, TensorFlow automatically
  // reshapes the smaller by padding with dimensions of size 1 as a
  // prefix. In other words to pad a 5-vector to a 3-dimensional
  // tensor it is reshaped to have shape [1,1,5]. XLA's automatic
  // broadcast code is able to broadcast from lower to higher rank,
  // but doesn't assume you want to pad as a prefix of the dimensions,
  // and instead needs to be told which dimensions of the higher rank
  // tensor to match to the lower rank tensor. In this example it
  // would be dimensions [2]. If we were matching a matrix against a
  // 4-D tensor the dimensions to match would be [2,3],
  // etc. extend_dimension encodes the general case.
  std::vector<int64> extend_dimension;
  int max_rank = std::max(lhs_shape.dims(), rhs_shape.dims());
  int min_rank = std::min(lhs_shape.dims(), rhs_shape.dims());
  if (min_rank != max_rank) {
    for (int i = 0; i < min_rank; ++i) {
      // Match the lower rank tensor along the larger-numbered
      // dimensions of the higher rank tensor.
      extend_dimension.push_back(max_rank - min_rank + i);
    }
  }

  // Call virtual method to emit the computation.
  xla::ComputationDataHandle output =
      Computation(ctx, lhs_handle, lhs_shape.dim_sizes(), rhs_handle,
                  rhs_shape.dim_sizes(), bcast, extend_dimension);

  // The TensorFlow helper computed the post-broadcast shape in
  // output_shape: we rely on subclassed Computations to implement the
  // same broadcast semantics.
  ctx->SetOutput(0, output);
}

/* static */ std::pair<xla::ComputationDataHandle, xla::ComputationDataHandle>
XlaBinaryOp::Broadcast(xla::ComputationBuilder* builder,
                       const xla::ComputationDataHandle& lhs,
                       const xla::ComputationDataHandle& rhs,
                       const BCast& broadcast_helper) {
  // Manually construct the broadcasting since MapN does not do
  // automatic broadcasting. The bcast helper ensures that
  // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and
  // rhs.reshape(bcast.y_reshape()).broadcast(bcast.y_bcast()) have
  // the same shape, so can be operated on by MapN.

  // First reshape the inputs, which should be a metadata-only
  // operation since we are flattening the dimensions in order.
  auto lhs_shaped = builder->Reshape(lhs, broadcast_helper.x_reshape());
  auto rhs_shaped = builder->Reshape(rhs, broadcast_helper.y_reshape());

  // Next broadcast the necessary input dimensions. We rely on the
  // XLA optimizer to be smart about the fact that we are asking
  // it to broadcast size 1 on some of these dimensions, to avoid
  // adding complexity to this code.
  auto lhs_broadcast =
      builder->Broadcast(lhs_shaped, broadcast_helper.x_bcast());
  int lhs_size = broadcast_helper.x_bcast().size();
  auto rhs_broadcast =
      builder->Broadcast(rhs_shaped, broadcast_helper.y_bcast());
  int rhs_size = broadcast_helper.y_bcast().size();

  // Now reshape them to the correct output shape. After the
  // broadcast each side is twice as wide as it should be, since the
  // broadcast dimensions were prepended to the shape. Reshape
  // flattening each original dimension with the prepended broadcast
  // dimension. E.g. if we started out with lhs_shaped with shape
  // [5,2,3] and x_bcast was [2,1,7] then lhs_broadcast would have
  // shape [2,1,7,5,2,3] and we want to reshape it to [10,2,21].
  std::vector<int64> lhs_reorder;
  for (int i = 0; i < lhs_size; ++i) {
    lhs_reorder.push_back(i);
    lhs_reorder.push_back(i + lhs_size);
  }
  auto lhs_output = builder->Reshape(lhs_broadcast, lhs_reorder,
                                     broadcast_helper.output_shape());
  std::vector<int64> rhs_reorder;
  for (int i = 0; i < rhs_size; ++i) {
    rhs_reorder.push_back(i);
    rhs_reorder.push_back(i + rhs_size);
  }
  auto rhs_output = builder->Reshape(rhs_broadcast, rhs_reorder,
                                     broadcast_helper.output_shape());

  return {lhs_output, rhs_output};
}

xla::ComputationDataHandle XlaBinaryMapOp::Computation(
    XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs,
    const gtl::ArraySlice<int64>& lhs_shape,
    const xla::ComputationDataHandle& rhs,
    const gtl::ArraySlice<int64>& rhs_shape, const BCast& broadcast_helper,
    const std::vector<int64>& extend_dimensions) {
  xla::ComputationBuilder* builder = ctx->builder();

  // Construct the builder for the lambda computation.
  xla::ComputationBuilder l(builder->client(), ctx->op_kernel().name());
  xla::PrimitiveType type;
  TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type));

  // Make two scalar parameters of the desired type for the lambda.
  xla::ComputationDataHandle x =
      l.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x");
  xla::ComputationDataHandle y =
      l.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y");

  // Call virtual method to build the lambda.
  BuildMapLambda(&l, x, y);
  xla::Computation computation = l.Build().ConsumeValueOrDie();

  xla::ComputationDataHandle lhs_broadcast = lhs;
  xla::ComputationDataHandle rhs_broadcast = rhs;
  if (lhs_shape == rhs_shape) {
    // There's no broadcasting to do.
    CHECK_EQ(0, extend_dimensions.size());
    return builder->Map({lhs, rhs}, computation);
  } else {
    std::tie(lhs_broadcast, rhs_broadcast) =
        Broadcast(builder, lhs, rhs, broadcast_helper);
  }
  // Now the two sides are broadcast to the final shape we can do the map.
  return builder->Map({lhs_broadcast, rhs_broadcast}, computation);
}

}  // namespace tensorflow