diff options
Diffstat (limited to 'tensorflow/core/kernels/bcast_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/bcast_ops.cc | 71 |
1 files changed, 71 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/bcast_ops.cc b/tensorflow/core/kernels/bcast_ops.cc new file mode 100644 index 0000000000..bb1492e5b4 --- /dev/null +++ b/tensorflow/core/kernels/bcast_ops.cc @@ -0,0 +1,71 @@ +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/util/bcast.h" + +namespace tensorflow { + +// Given shapes of two tensors, computes the reduction indices for the +// gradient computation. +// +// TODO(zhifengc): +// 1. Adds support for n-ary (n >= 2). +class BCastGradArgsOp : public OpKernel { + public: + explicit BCastGradArgsOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK( + ctx, ctx->MatchSignature({DT_INT32, DT_INT32}, {DT_INT32, DT_INT32})); + } + + void Compute(OpKernelContext* ctx) override { + OP_REQUIRES( + ctx, ctx->num_inputs() == 2, + errors::Unimplemented("Broadcast for n-ary operations (n > 2)")); + gtl::InlinedVector<BCast::Vec, 4> shapes; + for (int i = 0; i < ctx->num_inputs(); ++i) { + const Tensor& in = ctx->input(i); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()), + errors::InvalidArgument("In[", i, "] must be a vector.", + in.shape().ShortDebugString())); + BCast::Vec vec; + for (int64 i = 0; i < in.NumElements(); ++i) { + vec.push_back(in.vec<int32>()(i)); + } + shapes.push_back(vec); + } + BCast bcast(shapes[0], shapes[1]); + OP_REQUIRES(ctx, bcast.IsValid(), + errors::InvalidArgument( + "Incompatible shapes: [", str_util::Join(shapes[0], ","), + "] vs. [", str_util::Join(shapes[1], ","), "]")); + Output(ctx, 0, bcast.grad_x_reduce_idx()); + Output(ctx, 1, bcast.grad_y_reduce_idx()); + } + + private: + void Output(OpKernelContext* ctx, int idx, const BCast::Vec& v) { + const int len = v.size(); + Tensor* o = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(idx, TensorShape({len}), &o)); + for (int i = 0; i < len; ++i) o->flat<int32>()(i) = v[i]; + } + + TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp); +}; + +REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs") + .Device(DEVICE_CPU) + .HostMemory("s0") + .HostMemory("s1") + .HostMemory("r0") + .HostMemory("r1"), + BCastGradArgsOp); +REGISTER_KERNEL_BUILDER(Name("BroadcastGradientArgs") + .Device(DEVICE_GPU) + .HostMemory("s0") + .HostMemory("s1") + .HostMemory("r0") + .HostMemory("r1"), + BCastGradArgsOp); + +} // end namespace tensorflow |