aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cwise_ops_common.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/cwise_ops_common.cc')
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.cc42
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc
new file mode 100644
index 0000000000..f86d2ddd9a
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_ops_common.cc
@@ -0,0 +1,42 @@
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+
+BinaryOpShared::BinaryOpShared(OpKernelConstruction* ctx, DataType out,
+ DataType in)
+ : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->MatchSignature({in, in}, {out}));
+}
+
+void BinaryOpShared::SetUnimplementedError(OpKernelContext* ctx) {
+ ctx->SetStatus(errors::Unimplemented(
+ "Broadcast between ", ctx->input(0).shape().ShortDebugString(), " and ",
+ ctx->input(1).shape().ShortDebugString(), " is not supported yet."));
+}
+
+static BCast::Vec FromShape(const TensorShape& shape) {
+ BCast::Vec ret;
+ for (int i = 0; i < shape.dims(); ++i) ret.push_back(shape.dim_size(i));
+ return ret;
+}
+
+static TensorShape ToShape(const BCast::Vec& vec) {
+ TensorShape shape;
+ for (auto elem : vec) shape.AddDim(elem);
+ return shape;
+}
+
+BinaryOpShared::BinaryOpState::BinaryOpState(OpKernelContext* ctx)
+ : bcast(FromShape(ctx->input(0).shape()),
+ FromShape(ctx->input(1).shape())) {
+ if (!bcast.IsValid()) {
+ ctx->SetStatus(errors::InvalidArgument(
+ "Incompatible shapes: ", ctx->input(0).shape().ShortDebugString(),
+ " vs. ", ctx->input(1).shape().ShortDebugString()));
+ return;
+ }
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, ToShape(bcast.output_shape()), &out));
+}
+
+} // namespace tensorflow