aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/where_op.cc
blob: 42d1365e64592c6609c6daf83678f7dbd056a23f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// See docs in ../ops/array_ops.cc.

#define EIGEN_USE_THREADS

#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif  // GOOGLE_CUDA

#include "tensorflow/core/kernels/where_op.h"

#include <memory>
#include <numeric>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"

#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/platform/cuda.h"

using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
#endif  // GOOGLE_CUDA

namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;

namespace functor {

namespace {
template <typename T>
int64 CountAccumulator(const T* begin, const T* end) {
  return std::accumulate(begin, end, 0L, [](int64 accum, const T& val) {
    return accum + (val != T(0));
  });
}

template <>
int64 CountAccumulator<bool>(const bool* begin, const bool* end) {
  return std::accumulate(begin, end, 0L);
}

}  // namespace

template <typename T>
struct NumTrue<CPUDevice, T, int64> {
  static Status Compute(OpKernelContext* ctx, const CPUDevice& d,
                        typename TTypes<T>::ConstFlat input,
                        TTypes<int64>::Scalar num_true) {
    num_true() = CountAccumulator<T>(input.data(), input.data() + input.size());
    return Status::OK();
  }
};

template <int DIMS, typename T, typename TIndex>
struct Where<CPUDevice, DIMS, T, TIndex> {
  EIGEN_ALWAYS_INLINE static void WriteIndexRowMajor(
      typename TTypes<int64>::Matrix output,
      const typename Eigen::DSizes<TIndex, DIMS>& strides, TIndex true_n,
      TIndex index) {
    for (int i = 0; i < DIMS; ++i) {
      output(true_n, i) = index / strides[i];
      index -= output(true_n, i) * strides[i];
    }
  }

  EIGEN_ALWAYS_INLINE static Status Compute(
      OpKernelContext* ctx, const CPUDevice& d,
      typename TTypes<T, DIMS>::ConstTensor input,
      typename TTypes<int64>::Matrix output, TIndex* found_true) {
    Eigen::DSizes<Eigen::DenseIndex, DIMS> dims = input.dimensions();
    Eigen::DSizes<TIndex, DIMS> strides;

    EIGEN_STATIC_ASSERT((static_cast<int>(decltype(input)::Layout) ==
                         static_cast<int>(Eigen::RowMajor)),
                        INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR);

    strides[DIMS - 1] = 1;
    for (int i = DIMS - 2; i >= 0; --i) {
      strides[i] = strides[i + 1] * dims[i + 1];
    }

    Eigen::DenseIndex output_size = output.dimension(0);
    for (Eigen::DenseIndex n = 0; n < input.size(); ++n) {
      if (input.data()[n] != T(0)) {
        if (FastBoundsCheck(*found_true, output_size)) {
          WriteIndexRowMajor(output, strides, *found_true, n);
        }
        ++*found_true;
      }
    }
    return Status::OK();
  }
};

}  // namespace functor

template <typename T>
class WhereCPUOp : public OpKernel {
 public:
  explicit WhereCPUOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    const Tensor& input = context->input(0);

    OP_REQUIRES(
        context, input.dtype() != DT_HALF,
        errors::Unimplemented("No WhereOp available for float16/half type on "
                              "GPU; dying in CPU WhereOp to avoid silently "
                              "creating costly copies from device."));

    const int input_dims = input.dims();

    Tensor num_true;
    OP_REQUIRES_OK(
        context, context->allocate_temp(DT_INT64, TensorShape({}), &num_true));
    auto num_true_t = num_true.scalar<int64>();

    Status s = functor::NumTrue<CPUDevice, T, int64>::Compute(
        context, context->eigen_device<CPUDevice>(), input.flat<T>(),
        num_true_t);
    OP_REQUIRES_OK(context, s);
    TensorShape output_shape({num_true_t(), input_dims});
    Tensor* output = nullptr;
    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));

    // TODO(ebrevdo): Replace single-threaded copy with a
    // multithreaded block copy by getting block counts above instead
    // of a global NumTrue, then having each block filled in in
    // separate threads below.
    int64 found_true = 0;

#define HANDLE_DIM(NDIM)                                                      \
  case NDIM: {                                                                \
    Status s = functor::Where<CPUDevice, NDIM, T, int64>::Compute(            \
        context, context->eigen_device<CPUDevice>(), input.tensor<T, NDIM>(), \
        output->matrix<int64>(), &found_true);                                \
    OP_REQUIRES_OK(context, s);                                               \
  } break;

    switch (input_dims) {
      HANDLE_DIM(1);
      HANDLE_DIM(2);
      HANDLE_DIM(3);
      HANDLE_DIM(4);
      HANDLE_DIM(5);

      default:
        OP_REQUIRES(context, false,
                    errors::InvalidArgument(
                        "WhereOp : Unhandled input dimensions: ", input_dims));
    }
#undef HANDLE_DIM

    OP_REQUIRES(
        context, found_true == num_true_t(),
        errors::InvalidArgument(
            "WhereOp: Race condition between counting the number of true "
            "elements and writing them.  When counting, saw ",
            num_true_t(), " elements; but when writing their indices, saw ",
            found_true, " elements."));
  }

 private:
  TF_DISALLOW_COPY_AND_ASSIGN(WhereCPUOp);
};

#define REGISTER_WHERE_OP(T) \
  REGISTER_KERNEL_BUILDER(   \
      Name("Where").Device(DEVICE_CPU).TypeConstraint<T>("T"), WhereCPUOp<T>);

TF_CALL_NUMBER_TYPES(REGISTER_WHERE_OP);
TF_CALL_bool(REGISTER_WHERE_OP);

#undef REGISTER_WHERE_OP

#if GOOGLE_CUDA

namespace functor {

#define DECLARE_GPU_NUMTRUE(T, Tindex)                                      \
  template <>                                                               \
  Status NumTrue<GPUDevice, T, Tindex>::Compute(                            \
      OpKernelContext* ctx, const GPUDevice& d, TTypes<T>::ConstFlat input, \
      TTypes<Tindex>::Scalar num_true);                                     \
  extern template struct NumTrue<GPUDevice, T, Tindex>

#define DECLARE_GPU_NUMTRUE_TYPE(T) \
  DECLARE_GPU_NUMTRUE(T, int32);    \
  DECLARE_GPU_NUMTRUE(T, int64);

TF_CALL_NUMBER_TYPES(DECLARE_GPU_NUMTRUE_TYPE);
TF_CALL_bool(DECLARE_GPU_NUMTRUE_TYPE);

#undef DECLARE_GPU_NUMTRUE_TYPE
#undef DECLARE_GPU_NUMTRUE

#define DECLARE_GPU_WHERE_INDEX(Dims, T, Tindex)                  \
  template <>                                                     \
  Status Where<GPUDevice, Dims, T, Tindex>::Compute(              \
      OpKernelContext* ctx, const GPUDevice& d,                   \
      typename TTypes<T, Dims>::ConstTensor input,                \
      typename TTypes<int64>::Matrix output, Tindex* found_true); \
  extern template struct Where<GPUDevice, Dims, T, Tindex>;
#define DECLARE_GPU_WHERE(Dims, T)         \
  DECLARE_GPU_WHERE_INDEX(Dims, T, int32); \
  DECLARE_GPU_WHERE_INDEX(Dims, T, int64);

#define DECLARE_GPU_WHERE_TYPES(T) \
  DECLARE_GPU_WHERE(1, T);         \
  DECLARE_GPU_WHERE(2, T);         \
  DECLARE_GPU_WHERE(3, T);         \
  DECLARE_GPU_WHERE(4, T);         \
  DECLARE_GPU_WHERE(5, T);

TF_CALL_WHERE_GPU_TYPES(DECLARE_GPU_WHERE_TYPES);

#undef DECLARE_GPU_WHERE_TYPES
#undef DECLARE_GPU_WHERE
#undef DECLARE_GPU_WHERE_INDEX

}  // namespace functor

template <typename T>
class WhereGPUOp : public AsyncOpKernel {
 public:
  explicit WhereGPUOp(OpKernelConstruction* context) : AsyncOpKernel(context) {}

  void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
    const Tensor& input = context->input(0);
    const int input_dims = input.dims();

    if (input.NumElements() < std::numeric_limits<int32>::max()) {
      ComputeAsyncType<int32>(input, input_dims, context, done);
    } else {
      ComputeAsyncType<int64>(input, input_dims, context, done);
    }
  }

  template <typename Tindex>
  void ComputeAsyncType(const Tensor& input, const int input_dims,
                        OpKernelContext* context, DoneCallback done) {
    // Step 0: alloc nnz
    // Step 1: call nnz kernel
    // Step 2: copy nnz to host
    // Step 3: call create_output
    // Step 4: call where kernel
    Tensor num_true;
    OP_REQUIRES_OK_ASYNC(context,
                         context->allocate_temp(DataTypeToEnum<Tindex>::v(),
                                                TensorShape({}), &num_true),
                         done);

    auto num_true_t = num_true.scalar<Tindex>();

    perftools::gputools::DeviceMemoryBase num_true_ptr(
        static_cast<void*>(num_true_t.data()));
    // Push kernel to stream to get number of true elements.
    const GPUDevice& d = context->eigen_device<GPUDevice>();
    Status s = functor::NumTrue<GPUDevice, T, Tindex>::Compute(
        context, d, input.flat<T>(), num_true_t);
    OP_REQUIRES_OK_ASYNC(context, s, done);

    // Copy num_true to host;
    ScratchSpace<Tindex> num_true_host(context, 1, /* on_host */ true);

    auto stream = context->op_device_context()->stream();
    OP_REQUIRES_ASYNC(
        context,
        stream
            ->ThenMemcpy(num_true_host.mutable_data(), num_true_ptr,
                         sizeof(Tindex))
            .ok(),
        errors::Internal("WhereOp: failed to copy num_true from device"), done);

    auto create_and_check_output = [context, &d, &input, input_dims,
                                    num_true_host, done]() {
      // Ensure that within the callback, the proper GPU settings are
      // configured.
      auto stream = context->op_device_context()->stream();
      ScopedActivateExecutorContext scoped_activation{stream->parent()};

      Tindex num_true = *num_true_host.data();

      // TODO(ebrevdo): Properly copy back found_true value to CPU for
      // validation checking.  Currently Where<GPUDevice>::Compute()
      // does not perform this copy back to CPU.
      Tindex found_true = -1;

      // Step 1: Allocate the output and perform the selection/copy.
      Tensor* output;
      OP_REQUIRES_OK_ASYNC(context,
                           context->allocate_output(
                               0, TensorShape({num_true, input_dims}), &output),
                           done);

#define HANDLE_DIM(NDIM)                                              \
  case NDIM: {                                                        \
    Status s = functor::Where<GPUDevice, NDIM, T, Tindex>::Compute(   \
        context, d, input.tensor<T, NDIM>(), output->matrix<int64>(), \
        &found_true);                                                 \
    OP_REQUIRES_OK_ASYNC(context, s, done);                           \
  } break;

      switch (input_dims) {
        HANDLE_DIM(1);
        HANDLE_DIM(2);
        HANDLE_DIM(3);
        HANDLE_DIM(4);
        HANDLE_DIM(5);

        default:
          OP_REQUIRES_ASYNC(
              context, false,
              errors::InvalidArgument("WhereOp: Unhandled input dimensions: ",
                                      input_dims),
              done);
      }
#undef HANDLE_DIM

      // TODO(ebrevdo): Fix the copy back to host.

      // OP_REQUIRES_ASYNC(
      //     context, found_true == num_true,
      //     errors::InvalidArgument(
      //         "WhereOp: Race condition between counting the number of true "
      //         "elements and writing them.  When counting, saw ",
      //         num_true, " elements; but when writing their indices, saw ",
      //         found_true, " elements."),
      //     done);

      done();
    };
    context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
        stream, create_and_check_output);
  }

 private:
  TF_DISALLOW_COPY_AND_ASSIGN(WhereGPUOp);
};

#define REGISTER_GPU_WHERE_OP(T) \
  REGISTER_KERNEL_BUILDER(       \
      Name("Where").Device(DEVICE_GPU).TypeConstraint<T>("T"), WhereGPUOp<T>);

TF_CALL_WHERE_GPU_TYPES(REGISTER_GPU_WHERE_OP);

#undef REGISTER_GPU_WHERE_OP

#endif  // GOOGLE_CUDA

}  // namespace tensorflow