diff options
author | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
---|---|---|
committer | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
commit | f41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch) | |
tree | ef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/core/kernels/cwise_ops_common.cc |
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation
using data flow graphs.
Base CL: 107276108
Diffstat (limited to 'tensorflow/core/kernels/cwise_ops_common.cc')
-rw-r--r-- | tensorflow/core/kernels/cwise_ops_common.cc | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc new file mode 100644 index 0000000000..f86d2ddd9a --- /dev/null +++ b/tensorflow/core/kernels/cwise_ops_common.cc @@ -0,0 +1,42 @@ +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { + +BinaryOpShared::BinaryOpShared(OpKernelConstruction* ctx, DataType out, + DataType in) + : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->MatchSignature({in, in}, {out})); +} + +void BinaryOpShared::SetUnimplementedError(OpKernelContext* ctx) { + ctx->SetStatus(errors::Unimplemented( + "Broadcast between ", ctx->input(0).shape().ShortDebugString(), " and ", + ctx->input(1).shape().ShortDebugString(), " is not supported yet.")); +} + +static BCast::Vec FromShape(const TensorShape& shape) { + BCast::Vec ret; + for (int i = 0; i < shape.dims(); ++i) ret.push_back(shape.dim_size(i)); + return ret; +} + +static TensorShape ToShape(const BCast::Vec& vec) { + TensorShape shape; + for (auto elem : vec) shape.AddDim(elem); + return shape; +} + +BinaryOpShared::BinaryOpState::BinaryOpState(OpKernelContext* ctx) + : bcast(FromShape(ctx->input(0).shape()), + FromShape(ctx->input(1).shape())) { + if (!bcast.IsValid()) { + ctx->SetStatus(errors::InvalidArgument( + "Incompatible shapes: ", ctx->input(0).shape().ShortDebugString(), + " vs. ", ctx->input(1).shape().ShortDebugString())); + return; + } + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, ToShape(bcast.output_shape()), &out)); +} + +} // namespace tensorflow |