diff options
Diffstat (limited to 'tensorflow/core/kernels/cwise_ops_common.cc')
-rw-r--r-- | tensorflow/core/kernels/cwise_ops_common.cc | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc new file mode 100644 index 0000000000..f86d2ddd9a --- /dev/null +++ b/tensorflow/core/kernels/cwise_ops_common.cc @@ -0,0 +1,42 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { + +BinaryOpShared::BinaryOpShared(OpKernelConstruction* ctx, DataType out, + DataType in) + : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->MatchSignature({in, in}, {out})); +} + +void BinaryOpShared::SetUnimplementedError(OpKernelContext* ctx) { + ctx->SetStatus(errors::Unimplemented( + "Broadcast between ", ctx->input(0).shape().ShortDebugString(), " and ", + ctx->input(1).shape().ShortDebugString(), " is not supported yet.")); +} + +static BCast::Vec FromShape(const TensorShape& shape) { + BCast::Vec ret; + for (int i = 0; i < shape.dims(); ++i) ret.push_back(shape.dim_size(i)); + return ret; +} + +static TensorShape ToShape(const BCast::Vec& vec) { + TensorShape shape; + for (auto elem : vec) shape.AddDim(elem); + return shape; +} + +BinaryOpShared::BinaryOpState::BinaryOpState(OpKernelContext* ctx) + : bcast(FromShape(ctx->input(0).shape()), + FromShape(ctx->input(1).shape())) { + if (!bcast.IsValid()) { + ctx->SetStatus(errors::InvalidArgument( + "Incompatible shapes: ", ctx->input(0).shape().ShortDebugString(), + " vs. ", ctx->input(1).shape().ShortDebugString())); + return; + } + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, ToShape(bcast.output_shape()), &out)); +} + +} // namespace tensorflow |