1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
|
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
BinaryOpShared::BinaryOpShared(OpKernelConstruction* ctx, DataType out,
DataType in)
: OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->MatchSignature({in, in}, {out}));
}
void BinaryOpShared::SetUnimplementedError(OpKernelContext* ctx) {
ctx->SetStatus(errors::Unimplemented(
"Broadcast between ", ctx->input(0).shape().ShortDebugString(), " and ",
ctx->input(1).shape().ShortDebugString(), " is not supported yet."));
}
static BCast::Vec FromShape(const TensorShape& shape) {
BCast::Vec ret;
for (int i = 0; i < shape.dims(); ++i) ret.push_back(shape.dim_size(i));
return ret;
}
static TensorShape ToShape(const BCast::Vec& vec) {
TensorShape shape;
for (auto elem : vec) shape.AddDim(elem);
return shape;
}
BinaryOpShared::BinaryOpState::BinaryOpState(OpKernelContext* ctx)
: bcast(FromShape(ctx->input(0).shape()),
FromShape(ctx->input(1).shape())) {
if (!bcast.IsValid()) {
ctx->SetStatus(errors::InvalidArgument(
"Incompatible shapes: ", ctx->input(0).shape().ShortDebugString(),
" vs. ", ctx->input(1).shape().ShortDebugString()));
return;
}
OP_REQUIRES_OK(ctx,
ctx->allocate_output(0, ToShape(bcast.output_shape()), &out));
}
} // namespace tensorflow
|