aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/bcast_ops.cc
blob: bb1492e5b429511b02192b9849389432c9b38941 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/bcast.h"

namespace tensorflow {

// Given shapes of two tensors, computes the reduction indices for the
// gradient computation.
//
// TODO(zhifengc):
//   1. Adds support for n-ary (n >= 2).
class BCastGradArgsOp : public OpKernel {
 public:
  explicit BCastGradArgsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    OP_REQUIRES_OK(
        ctx, ctx->MatchSignature({DT_INT32, DT_INT32}, {DT_INT32, DT_INT32}));
  }

  void Compute(OpKernelContext* ctx) override {
    OP_REQUIRES(
        ctx, ctx->num_inputs() == 2,
        errors::Unimplemented("Broadcast for n-ary operations (n > 2)"));
    gtl::InlinedVector<BCast::Vec, 4> shapes;
    for (int i = 0; i < ctx->num_inputs(); ++i) {
      const Tensor& in = ctx->input(i);
      OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()),
                  errors::InvalidArgument("In[", i, "] must be a vector.",
                                          in.shape().ShortDebugString()));
      BCast::Vec vec;
      for (int64 i = 0; i < in.NumElements(); ++i) {
        vec.push_back(in.vec<int32>()(i));
      }
      shapes.push_back(vec);
    }
    BCast bcast(shapes[0], shapes[1]);
    OP_REQUIRES(ctx, bcast.IsValid(),
                errors::InvalidArgument(
                    "Incompatible shapes: [", str_util::Join(shapes[0], ","),
                    "] vs. [", str_util::Join(shapes[1], ","), "]"));
    Output(ctx, 0, bcast.grad_x_reduce_idx());
    Output(ctx, 1, bcast.grad_y_reduce_idx());
  }

 private:
  void Output(OpKernelContext* ctx, int idx, const BCast::Vec& v) {
    const int len = v.size();
    Tensor* o = nullptr;
    OP_REQUIRES_OK(ctx, ctx->allocate_output(idx, TensorShape({len}), &o));
    for (int i = 0; i < len; ++i) o->flat<int32>()(i) = v[i];
  }

  TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp);
};

REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
                            .Device(DEVICE_CPU)
                            .HostMemory("s0")
                            .HostMemory("s1")
                            .HostMemory("r0")
                            .HostMemory("r1"),
                        BCastGradArgsOp);
REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
                            .Device(DEVICE_GPU)
                            .HostMemory("s0")
                            .HostMemory("s1")
                            .HostMemory("r0")
                            .HostMemory("r1"),
                        BCastGradArgsOp);

}  // end namespace tensorflow