1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
|
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
// Given shapes of two tensors, computes the reduction indices for the
// gradient computation.
//
// TODO(zhifengc):
// 1. Adds support for n-ary (n >= 2).
class BCastGradArgsOp : public OpKernel {
public:
explicit BCastGradArgsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(
ctx, ctx->MatchSignature({DT_INT32, DT_INT32}, {DT_INT32, DT_INT32}));
}
void Compute(OpKernelContext* ctx) override {
OP_REQUIRES(
ctx, ctx->num_inputs() == 2,
errors::Unimplemented("Broadcast for n-ary operations (n > 2)"));
gtl::InlinedVector<BCast::Vec, 4> shapes;
for (int i = 0; i < ctx->num_inputs(); ++i) {
const Tensor& in = ctx->input(i);
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()),
errors::InvalidArgument("In[", i, "] must be a vector.",
in.shape().ShortDebugString()));
BCast::Vec vec;
for (int64 i = 0; i < in.NumElements(); ++i) {
vec.push_back(in.vec<int32>()(i));
}
shapes.push_back(vec);
}
BCast bcast(shapes[0], shapes[1]);
OP_REQUIRES(ctx, bcast.IsValid(),
errors::InvalidArgument(
"Incompatible shapes: [", str_util::Join(shapes[0], ","),
"] vs. [", str_util::Join(shapes[1], ","), "]"));
Output(ctx, 0, bcast.grad_x_reduce_idx());
Output(ctx, 1, bcast.grad_y_reduce_idx());
}
private:
void Output(OpKernelContext* ctx, int idx, const BCast::Vec& v) {
const int len = v.size();
Tensor* o = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(idx, TensorShape({len}), &o));
for (int i = 0; i < len; ++i) o->flat<int32>()(i) = v[i];
}
TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp);
};
REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
.Device(DEVICE_CPU)
.HostMemory("s0")
.HostMemory("s1")
.HostMemory("r0")
.HostMemory("r1"),
BCastGradArgsOp);
REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
.Device(DEVICE_GPU)
.HostMemory("s0")
.HostMemory("s1")
.HostMemory("r0")
.HostMemory("r1"),
BCastGradArgsOp);
} // end namespace tensorflow
|