diff options
author | 2017-01-13 16:17:49 -0800 | |
---|---|---|
committer | 2017-01-13 16:28:22 -0800 | |
commit | e4a235a0358d7b6d5c0830536b786076121fb766 (patch) | |
tree | 6ef4a4ab436645b4a8e16353bf8672aedbab7ddf /tensorflow/core/kernels/inplace_ops.cc | |
parent | 9438ace9743a4607827e9b5e9c131f4e1a1dd2c6 (diff) |
Internal change.
Change: 144497247
Diffstat (limited to 'tensorflow/core/kernels/inplace_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/inplace_ops.cc | 84 |
1 files changed, 43 insertions, 41 deletions
diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc index 01af1b56e2..5f1f5b652c 100644 --- a/tensorflow/core/kernels/inplace_ops.cc +++ b/tensorflow/core/kernels/inplace_ops.cc @@ -24,49 +24,8 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" namespace tensorflow { - typedef Eigen::ThreadPoolDevice CPUDevice; -// TODO(apassos): validate the shapes better. -class InplaceOpBase : public OpKernel { - public: - explicit InplaceOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - auto value = ctx->input(0); - auto loc = ctx->input(1); - auto update = ctx->input(2); - - OP_REQUIRES(ctx, TensorShapeUtils::IsVector(loc.shape()), - errors::InvalidArgument("loc must be a vector. ", - loc.shape().DebugString())); - OP_REQUIRES( - ctx, value.dims() == update.dims(), - errors::InvalidArgument("value and update shape doesn't match: ", - value.shape().DebugString(), " vs. ", - update.shape().DebugString())); - for (int i = 1; i < value.dims(); ++i) { - OP_REQUIRES( - ctx, value.dim_size(i) == update.dim_size(i), - errors::InvalidArgument("value and update shape doesn't match ", - value.shape().DebugString(), " vs. ", - update.shape().DebugString())); - } - OP_REQUIRES(ctx, loc.dim_size(0) == update.dim_size(0), - errors::InvalidArgument("loc and update shape doesn't match: ", - loc.shape().DebugString(), " vs. ", - update.shape().DebugString())); - - Tensor output = value; // This creates an alias intentionally. - OP_REQUIRES_OK(ctx, DoCompute(ctx, update, loc, &output)); - ctx->set_output(0, output); - } - - protected: - virtual Status DoCompute(OpKernelContext* ctx, const Tensor& value, - const Tensor& loc, Tensor* output) = 0; -}; - namespace functor { template <typename T> @@ -112,6 +71,48 @@ Status DoInplace(const CPUDevice& d, InplaceOpType op, const Tensor& value, } // end namespace functor +namespace { + +// TODO(apassos): validate the shapes better. +class InplaceOpBase : public OpKernel { + public: + explicit InplaceOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + auto value = ctx->input(0); + auto loc = ctx->input(1); + auto update = ctx->input(2); + + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(loc.shape()), + errors::InvalidArgument("loc must be a vector. ", + loc.shape().DebugString())); + OP_REQUIRES( + ctx, value.dims() == update.dims(), + errors::InvalidArgument("value and update shape doesn't match: ", + value.shape().DebugString(), " vs. ", + update.shape().DebugString())); + for (int i = 1; i < value.dims(); ++i) { + OP_REQUIRES( + ctx, value.dim_size(i) == update.dim_size(i), + errors::InvalidArgument("value and update shape doesn't match ", + value.shape().DebugString(), " vs. ", + update.shape().DebugString())); + } + OP_REQUIRES(ctx, loc.dim_size(0) == update.dim_size(0), + errors::InvalidArgument("loc and update shape doesn't match: ", + loc.shape().DebugString(), " vs. ", + update.shape().DebugString())); + + Tensor output = value; // This creates an alias intentionally. + OP_REQUIRES_OK(ctx, DoCompute(ctx, update, loc, &output)); + ctx->set_output(0, output); + } + + protected: + virtual Status DoCompute(OpKernelContext* ctx, const Tensor& value, + const Tensor& loc, Tensor* output) = 0; +}; + template <typename Device, functor::InplaceOpType op> class InplaceOp : public InplaceOpBase { public: @@ -237,4 +238,5 @@ REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") InplaceOp<CPUDevice, functor::I_UPDATE>); #endif +} // end namespace } // end namespace tensorflow |