aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/scatter_nd_op.cc
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2016-11-07 16:01:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-08 16:12:19 -0800
commitfd05b5ebc56316eb6ac9fcb74234979fee2fc5f9 (patch)
tree051f2d5145673d8bbebe0646434860c888991815 /tensorflow/core/kernels/scatter_nd_op.cc
parentaac685b7209b03ffd356ea6860366467b335d402 (diff)
Changes to scatter_nd ops
* Rewrite CPU impl to be single-threaded and use vectorization; avoids race conditions. Removes use of the generator. * Remove scatter_nd_mul and scatter_nd_div to reduce binary size until we figure out a better way to reduce the templating pain * Modify scatter_nd to add for repeated indices as opposed to update (this is the appropriate gradient for gather_nd, for example) * Clean up docstrings. Change: 138452341
Diffstat (limited to 'tensorflow/core/kernels/scatter_nd_op.cc')
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc154
1 files changed, 61 insertions, 93 deletions
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index 83b38d7338..5aeb3d2c0e 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -146,43 +146,48 @@ class ScatterNdOp : public OpKernel {
&num_updates, &slice_size);
if (!c->status().ok()) return;
- Tensor scratch;
- OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch));
-
- auto scratch_scalar = scratch.scalar<Index>();
auto indices_flat = indices.flat_inner_dims<Index>();
auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(c, c->allocate_output(0, shape, &out));
+ functor::SetZeroFunctor<Device, T> fill;
+ fill(c->eigen_device<Device>(), out->flat<T>());
+ auto output_matrix = out->template shaped<T, 2>(
+ {shape.num_elements() / slice_size, slice_size});
+
Index bad_i = -1;
- switch (indices_nd) {
-#define PARAMS_CASE(IXDIM) \
- case IXDIM: { \
- Tensor* out = nullptr; \
- OP_REQUIRES_OK(c, c->allocate_output(0, shape, &out)); \
- functor::SetZeroFunctor<Device, T> fill; \
- fill(c->eigen_device<Device>(), out->flat<T>()); \
- if (shape.num_elements() > 0) { \
- auto output_flat = out->flat_outer_dims<T, (IXDIM) + 1>(); \
- functor::ScatterNdFunctor<Device, T, Index, \
- scatter_nd_op::UpdateOp::ASSIGN, (IXDIM)> \
- functor; \
- bad_i = functor(c->eigen_device<Device>(), slice_size, scratch_scalar, \
- output_flat, indices_flat, updates_flat, output_flat); \
- } \
+
+ if (shape.num_elements() > 0) {
+ switch (indices_nd) {
+#define PARAMS_CASE(IXDIM) \
+ case IXDIM: { \
+ typename Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix; \
+ for (int i = 0; i < IXDIM; ++i) { \
+ output_shape_prefix[i] = shape.dim_size(i); \
+ } \
+ functor::ScatterNdFunctor<Device, T, Index, scatter_nd_op::UpdateOp::ADD, \
+ IXDIM> \
+ functor; \
+ bad_i = \
+ functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \
+ output_matrix, indices_flat, updates_flat, output_matrix); \
} break
- PARAMS_CASE(0);
- PARAMS_CASE(1);
- PARAMS_CASE(2);
- PARAMS_CASE(3);
- PARAMS_CASE(4);
- PARAMS_CASE(5);
+ // TODO(simister): Re-enable this once binary size is under control.
+ // PARAMS_CASE(0);
+ PARAMS_CASE(1);
+ PARAMS_CASE(2);
+ PARAMS_CASE(3);
+ PARAMS_CASE(4);
+ PARAMS_CASE(5);
#undef PARAMS_CASE
- default:
- OP_REQUIRES(c, false,
- errors::InvalidArgument(
- "Only indices.shape[-1] values between 0 and 5 "
- "are currently supported. Requested rank: ",
- indices_nd));
+ default:
+ OP_REQUIRES(c, false,
+ errors::InvalidArgument(
+ "Only indices.shape[-1] values between 1 and 5 "
+ "are currently supported. Requested rank: ",
+ indices_nd));
+ }
}
OP_REQUIRES(
c, bad_i < 0,
@@ -236,24 +241,27 @@ class ScatterNdUpdateOp : public OpKernel {
&indices_nd, &num_updates, &slice_size);
if (!c->status().ok()) return;
- Tensor scratch;
- OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch));
-
- auto scratch_scalar = scratch.scalar<Index>();
auto indices_flat = indices.flat_inner_dims<Index>();
auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});
-
+ auto params_matrix = params.template shaped<T, 2>(
+ {params_shape.num_elements() / slice_size, slice_size});
Index bad_i = -1;
c->forward_ref_input_to_ref_output(0, 0);
+
switch (indices_nd) {
-#define PARAMS_CASE(IXDIM) \
- case IXDIM: { \
- auto params_flat = params.flat_outer_dims<T, (IXDIM) + 1>(); \
- functor::ScatterNdFunctor<Device, T, Index, op, IXDIM> functor; \
- bad_i = functor(c->eigen_device<Device>(), slice_size, scratch_scalar, \
- params_flat, indices_flat, updates_flat, params_flat); \
+#define PARAMS_CASE(IXDIM) \
+ case IXDIM: { \
+ typename Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix; \
+ for (int i = 0; i < IXDIM; ++i) { \
+ output_shape_prefix[i] = params_shape.dim_size(i); \
+ } \
+ functor::ScatterNdFunctor<Device, T, Index, op, IXDIM> functor; \
+ bad_i = \
+ functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \
+ params_matrix, indices_flat, updates_flat, params_matrix); \
} break
- PARAMS_CASE(0);
+ // TODO(simister): Re-enable this once binary size is under control.
+ // PARAMS_CASE(0);
PARAMS_CASE(1);
PARAMS_CASE(2);
PARAMS_CASE(3);
@@ -306,11 +314,13 @@ class ScatterNdUpdateOp : public OpKernel {
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \
scatter_nd_op::UpdateOp::ADD); \
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \
- scatter_nd_op::UpdateOp::SUB); \
- REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMul", \
- scatter_nd_op::UpdateOp::MUL); \
- REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdDiv", \
- scatter_nd_op::UpdateOp::DIV);
+ scatter_nd_op::UpdateOp::SUB);
+// TODO(simister): Find a way to reduce amount of templated generated code
+// to reduce build size, then re-enable these additional operations.
+// REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMul", \
+// scatter_nd_op::UpdateOp::MUL); \
+// REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdDiv", \
+// scatter_nd_op::UpdateOp::DIV);
#define REGISTER_SCATTER_ND(type, dev) \
REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd");
@@ -329,8 +339,9 @@ class ScatterNdUpdateOp : public OpKernel {
#define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU);
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU);
-TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
-TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_CPU);
+// TODO(simister): Re-enable all types after binary size is under control.
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU);
// Registers GPU kernels.
#if GOOGLE_CUDA
@@ -356,47 +367,4 @@ TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_CPU);
#undef REGISTER_SCATTER_ND_KERNEL
#undef REGISTER_SCATTER_ND_KERNEL_INDEX
-#if GOOGLE_CUDA
-// Forward declarations of the functor specializations for GPU.
-namespace functor {
-
-#define DECLARE_GPU_SPECS_OP(T, Index, op, NDIM) \
- template <> \
- Index ScatterNdFunctor<GPUDevice, T, Index, op, NDIM>::operator()( \
- OpKernelContext* c, const GPUDevice& d, \
- typename TTypes<T, IXDIM>::Tensor params, \
- typename TTypes<Index, 2>::ConstTensor indices, \
- typename TTypes<T, 2>::ConstTensor updates); \
- extern template struct ScatterNdFunctor<GPUDevice, T, Index, op>;
-
-#define DECLARE_GPU_SPECS_OPS(T, Index, op) \
- DECLARE_GPU_SPECS_OP(T, Index, op, 0); \
- DECLARE_GPU_SPECS_OP(T, Index, op, 1); \
- DECLARE_GPU_SPECS_OP(T, Index, op, 2); \
- DECLARE_GPU_SPECS_OP(T, Index, op, 3); \
- DECLARE_GPU_SPECS_OP(T, Index, op, 4); \
- DECLARE_GPU_SPECS_OP(T, Index, op, 5)
-
-#define DECLARE_GPU_SPECS_INDEX(T, Index) \
- DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \
- DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ADD); \
- DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::SUB); \
- DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::MUL); \
- DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::DIV);
-
-#define DECLARE_GPU_SPECS(T) \
- DECLARE_GPU_SPECS_INDEX(T, int32); \
- DECLARE_GPU_SPECS_INDEX(T, int64);
-
-// TODO(simister): Re-enable when GPU support is working.
-// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS);
-
-#undef DECLARE_GPU_SPECS
-#undef DECLARE_GPU_SPECS_INDEX
-#undef DECLARE_GPU_SPECS_OPS
-#undef DECLARE_GPU_SPECS_OP
-
-} // namespace functor
-#endif // GOOGLE_CUDA
-
} // namespace tensorflow