aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cwise_ops_common.cc
blob: f86d2ddd9a9cf36046f6127cddbcb229e44fc9d2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include "tensorflow/core/kernels/cwise_ops_common.h"

namespace tensorflow {

BinaryOpShared::BinaryOpShared(OpKernelConstruction* ctx, DataType out,
                               DataType in)
    : OpKernel(ctx) {
  OP_REQUIRES_OK(ctx, ctx->MatchSignature({in, in}, {out}));
}

void BinaryOpShared::SetUnimplementedError(OpKernelContext* ctx) {
  ctx->SetStatus(errors::Unimplemented(
      "Broadcast between ", ctx->input(0).shape().ShortDebugString(), " and ",
      ctx->input(1).shape().ShortDebugString(), " is not supported yet."));
}

static BCast::Vec FromShape(const TensorShape& shape) {
  BCast::Vec ret;
  for (int i = 0; i < shape.dims(); ++i) ret.push_back(shape.dim_size(i));
  return ret;
}

static TensorShape ToShape(const BCast::Vec& vec) {
  TensorShape shape;
  for (auto elem : vec) shape.AddDim(elem);
  return shape;
}

BinaryOpShared::BinaryOpState::BinaryOpState(OpKernelContext* ctx)
    : bcast(FromShape(ctx->input(0).shape()),
            FromShape(ctx->input(1).shape())) {
  if (!bcast.IsValid()) {
    ctx->SetStatus(errors::InvalidArgument(
        "Incompatible shapes: ", ctx->input(0).shape().ShortDebugString(),
        " vs. ", ctx->input(1).shape().ShortDebugString()));
    return;
  }
  OP_REQUIRES_OK(ctx,
                 ctx->allocate_output(0, ToShape(bcast.output_shape()), &out));
}

}  // namespace tensorflow