aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/inplace_ops.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-01-13 16:17:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-13 16:28:22 -0800
commite4a235a0358d7b6d5c0830536b786076121fb766 (patch)
tree6ef4a4ab436645b4a8e16353bf8672aedbab7ddf /tensorflow/core/kernels/inplace_ops.cc
parent9438ace9743a4607827e9b5e9c131f4e1a1dd2c6 (diff)
Internal change.
Change: 144497247
Diffstat (limited to 'tensorflow/core/kernels/inplace_ops.cc')
-rw-r--r--tensorflow/core/kernels/inplace_ops.cc84
1 files changed, 43 insertions, 41 deletions
diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc
index 01af1b56e2..5f1f5b652c 100644
--- a/tensorflow/core/kernels/inplace_ops.cc
+++ b/tensorflow/core/kernels/inplace_ops.cc
@@ -24,49 +24,8 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
-
typedef Eigen::ThreadPoolDevice CPUDevice;
-// TODO(apassos): validate the shapes better.
-class InplaceOpBase : public OpKernel {
- public:
- explicit InplaceOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- auto value = ctx->input(0);
- auto loc = ctx->input(1);
- auto update = ctx->input(2);
-
- OP_REQUIRES(ctx, TensorShapeUtils::IsVector(loc.shape()),
- errors::InvalidArgument("loc must be a vector. ",
- loc.shape().DebugString()));
- OP_REQUIRES(
- ctx, value.dims() == update.dims(),
- errors::InvalidArgument("value and update shape doesn't match: ",
- value.shape().DebugString(), " vs. ",
- update.shape().DebugString()));
- for (int i = 1; i < value.dims(); ++i) {
- OP_REQUIRES(
- ctx, value.dim_size(i) == update.dim_size(i),
- errors::InvalidArgument("value and update shape doesn't match ",
- value.shape().DebugString(), " vs. ",
- update.shape().DebugString()));
- }
- OP_REQUIRES(ctx, loc.dim_size(0) == update.dim_size(0),
- errors::InvalidArgument("loc and update shape doesn't match: ",
- loc.shape().DebugString(), " vs. ",
- update.shape().DebugString()));
-
- Tensor output = value; // This creates an alias intentionally.
- OP_REQUIRES_OK(ctx, DoCompute(ctx, update, loc, &output));
- ctx->set_output(0, output);
- }
-
- protected:
- virtual Status DoCompute(OpKernelContext* ctx, const Tensor& value,
- const Tensor& loc, Tensor* output) = 0;
-};
-
namespace functor {
template <typename T>
@@ -112,6 +71,48 @@ Status DoInplace(const CPUDevice& d, InplaceOpType op, const Tensor& value,
} // end namespace functor
+namespace {
+
+// TODO(apassos): validate the shapes better.
+class InplaceOpBase : public OpKernel {
+ public:
+ explicit InplaceOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ auto value = ctx->input(0);
+ auto loc = ctx->input(1);
+ auto update = ctx->input(2);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(loc.shape()),
+ errors::InvalidArgument("loc must be a vector. ",
+ loc.shape().DebugString()));
+ OP_REQUIRES(
+ ctx, value.dims() == update.dims(),
+ errors::InvalidArgument("value and update shape doesn't match: ",
+ value.shape().DebugString(), " vs. ",
+ update.shape().DebugString()));
+ for (int i = 1; i < value.dims(); ++i) {
+ OP_REQUIRES(
+ ctx, value.dim_size(i) == update.dim_size(i),
+ errors::InvalidArgument("value and update shape doesn't match ",
+ value.shape().DebugString(), " vs. ",
+ update.shape().DebugString()));
+ }
+ OP_REQUIRES(ctx, loc.dim_size(0) == update.dim_size(0),
+ errors::InvalidArgument("loc and update shape doesn't match: ",
+ loc.shape().DebugString(), " vs. ",
+ update.shape().DebugString()));
+
+ Tensor output = value; // This creates an alias intentionally.
+ OP_REQUIRES_OK(ctx, DoCompute(ctx, update, loc, &output));
+ ctx->set_output(0, output);
+ }
+
+ protected:
+ virtual Status DoCompute(OpKernelContext* ctx, const Tensor& value,
+ const Tensor& loc, Tensor* output) = 0;
+};
+
template <typename Device, functor::InplaceOpType op>
class InplaceOp : public InplaceOpBase {
public:
@@ -237,4 +238,5 @@ REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate")
InplaceOp<CPUDevice, functor::I_UPDATE>);
#endif
+} // end namespace
} // end namespace tensorflow