aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/bcast_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/bcast_ops.cc')
-rw-r--r--tensorflow/core/kernels/bcast_ops.cc71
1 files changed, 71 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/bcast_ops.cc b/tensorflow/core/kernels/bcast_ops.cc
new file mode 100644
index 0000000000..bb1492e5b4
--- /dev/null
+++ b/tensorflow/core/kernels/bcast_ops.cc
@@ -0,0 +1,71 @@
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/util/bcast.h"
+
+namespace tensorflow {
+
+// Given shapes of two tensors, computes the reduction indices for the
+// gradient computation.
+//
+// TODO(zhifengc):
+// 1. Adds support for n-ary (n >= 2).
+class BCastGradArgsOp : public OpKernel {
+ public:
+ explicit BCastGradArgsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(
+ ctx, ctx->MatchSignature({DT_INT32, DT_INT32}, {DT_INT32, DT_INT32}));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ OP_REQUIRES(
+ ctx, ctx->num_inputs() == 2,
+ errors::Unimplemented("Broadcast for n-ary operations (n > 2)"));
+ gtl::InlinedVector<BCast::Vec, 4> shapes;
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ const Tensor& in = ctx->input(i);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()),
+ errors::InvalidArgument("In[", i, "] must be a vector.",
+ in.shape().ShortDebugString()));
+ BCast::Vec vec;
+ for (int64 i = 0; i < in.NumElements(); ++i) {
+ vec.push_back(in.vec<int32>()(i));
+ }
+ shapes.push_back(vec);
+ }
+ BCast bcast(shapes[0], shapes[1]);
+ OP_REQUIRES(ctx, bcast.IsValid(),
+ errors::InvalidArgument(
+ "Incompatible shapes: [", str_util::Join(shapes[0], ","),
+ "] vs. [", str_util::Join(shapes[1], ","), "]"));
+ Output(ctx, 0, bcast.grad_x_reduce_idx());
+ Output(ctx, 1, bcast.grad_y_reduce_idx());
+ }
+
+ private:
+ void Output(OpKernelContext* ctx, int idx, const BCast::Vec& v) {
+ const int len = v.size();
+ Tensor* o = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(idx, TensorShape({len}), &o));
+ for (int i = 0; i < len; ++i) o->flat<int32>()(i) = v[i];
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
+ .Device(DEVICE_CPU)
+ .HostMemory("s0")
+ .HostMemory("s1")
+ .HostMemory("r0")
+ .HostMemory("r1"),
+ BCastGradArgsOp);
+REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs")
+ .Device(DEVICE_GPU)
+ .HostMemory("s0")
+ .HostMemory("s1")
+ .HostMemory("r0")
+ .HostMemory("r1"),
+ BCastGradArgsOp);
+
+} // end namespace tensorflow