diff options
author | 2016-11-07 16:01:57 -0800 | |
---|---|---|
committer | 2016-11-08 16:12:19 -0800 | |
commit | fd05b5ebc56316eb6ac9fcb74234979fee2fc5f9 (patch) | |
tree | 051f2d5145673d8bbebe0646434860c888991815 /tensorflow/core/kernels/scatter_nd_op.cc | |
parent | aac685b7209b03ffd356ea6860366467b335d402 (diff) |
Changes to scatter_nd ops
* Rewrite CPU impl to be single-threaded and use vectorization; avoids race conditions. Removes use of the generator.
* Remove scatter_nd_mul and scatter_nd_div to reduce binary size until
we figure out a better way to reduce the templating pain
* Modify scatter_nd to add for repeated indices as opposed to update
(this is the appropriate gradient for gather_nd, for example)
* Clean up docstrings.
Change: 138452341
Diffstat (limited to 'tensorflow/core/kernels/scatter_nd_op.cc')
-rw-r--r-- | tensorflow/core/kernels/scatter_nd_op.cc | 154 |
1 files changed, 61 insertions, 93 deletions
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index 83b38d7338..5aeb3d2c0e 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -146,43 +146,48 @@ class ScatterNdOp : public OpKernel { &num_updates, &slice_size); if (!c->status().ok()) return; - Tensor scratch; - OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch)); - - auto scratch_scalar = scratch.scalar<Index>(); auto indices_flat = indices.flat_inner_dims<Index>(); auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size}); + Tensor* out = nullptr; + OP_REQUIRES_OK(c, c->allocate_output(0, shape, &out)); + functor::SetZeroFunctor<Device, T> fill; + fill(c->eigen_device<Device>(), out->flat<T>()); + auto output_matrix = out->template shaped<T, 2>( + {shape.num_elements() / slice_size, slice_size}); + Index bad_i = -1; - switch (indices_nd) { -#define PARAMS_CASE(IXDIM) \ - case IXDIM: { \ - Tensor* out = nullptr; \ - OP_REQUIRES_OK(c, c->allocate_output(0, shape, &out)); \ - functor::SetZeroFunctor<Device, T> fill; \ - fill(c->eigen_device<Device>(), out->flat<T>()); \ - if (shape.num_elements() > 0) { \ - auto output_flat = out->flat_outer_dims<T, (IXDIM) + 1>(); \ - functor::ScatterNdFunctor<Device, T, Index, \ - scatter_nd_op::UpdateOp::ASSIGN, (IXDIM)> \ - functor; \ - bad_i = functor(c->eigen_device<Device>(), slice_size, scratch_scalar, \ - output_flat, indices_flat, updates_flat, output_flat); \ - } \ + + if (shape.num_elements() > 0) { + switch (indices_nd) { +#define PARAMS_CASE(IXDIM) \ + case IXDIM: { \ + typename Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix; \ + for (int i = 0; i < IXDIM; ++i) { \ + output_shape_prefix[i] = shape.dim_size(i); \ + } \ + functor::ScatterNdFunctor<Device, T, Index, scatter_nd_op::UpdateOp::ADD, \ + IXDIM> \ + functor; \ + bad_i = \ + functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \ + output_matrix, indices_flat, updates_flat, output_matrix); \ } break - PARAMS_CASE(0); - PARAMS_CASE(1); - PARAMS_CASE(2); - PARAMS_CASE(3); - PARAMS_CASE(4); - PARAMS_CASE(5); + // TODO(simister): Re-enable this once binary size is under control. + // PARAMS_CASE(0); + PARAMS_CASE(1); + PARAMS_CASE(2); + PARAMS_CASE(3); + PARAMS_CASE(4); + PARAMS_CASE(5); #undef PARAMS_CASE - default: - OP_REQUIRES(c, false, - errors::InvalidArgument( - "Only indices.shape[-1] values between 0 and 5 " - "are currently supported. Requested rank: ", - indices_nd)); + default: + OP_REQUIRES(c, false, + errors::InvalidArgument( + "Only indices.shape[-1] values between 1 and 5 " + "are currently supported. Requested rank: ", + indices_nd)); + } } OP_REQUIRES( c, bad_i < 0, @@ -236,24 +241,27 @@ class ScatterNdUpdateOp : public OpKernel { &indices_nd, &num_updates, &slice_size); if (!c->status().ok()) return; - Tensor scratch; - OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape(), &scratch)); - - auto scratch_scalar = scratch.scalar<Index>(); auto indices_flat = indices.flat_inner_dims<Index>(); auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size}); - + auto params_matrix = params.template shaped<T, 2>( + {params_shape.num_elements() / slice_size, slice_size}); Index bad_i = -1; c->forward_ref_input_to_ref_output(0, 0); + switch (indices_nd) { -#define PARAMS_CASE(IXDIM) \ - case IXDIM: { \ - auto params_flat = params.flat_outer_dims<T, (IXDIM) + 1>(); \ - functor::ScatterNdFunctor<Device, T, Index, op, IXDIM> functor; \ - bad_i = functor(c->eigen_device<Device>(), slice_size, scratch_scalar, \ - params_flat, indices_flat, updates_flat, params_flat); \ +#define PARAMS_CASE(IXDIM) \ + case IXDIM: { \ + typename Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix; \ + for (int i = 0; i < IXDIM; ++i) { \ + output_shape_prefix[i] = params_shape.dim_size(i); \ + } \ + functor::ScatterNdFunctor<Device, T, Index, op, IXDIM> functor; \ + bad_i = \ + functor(c->eigen_device<Device>(), slice_size, output_shape_prefix, \ + params_matrix, indices_flat, updates_flat, params_matrix); \ } break - PARAMS_CASE(0); + // TODO(simister): Re-enable this once binary size is under control. + // PARAMS_CASE(0); PARAMS_CASE(1); PARAMS_CASE(2); PARAMS_CASE(3); @@ -306,11 +314,13 @@ class ScatterNdUpdateOp : public OpKernel { REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \ scatter_nd_op::UpdateOp::ADD); \ REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \ - scatter_nd_op::UpdateOp::SUB); \ - REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMul", \ - scatter_nd_op::UpdateOp::MUL); \ - REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdDiv", \ - scatter_nd_op::UpdateOp::DIV); + scatter_nd_op::UpdateOp::SUB); +// TODO(simister): Find a way to reduce amount of templated generated code +// to reduce build size, then re-enable these additional operations. +// REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdMul", \ +// scatter_nd_op::UpdateOp::MUL); \ +// REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdDiv", \ +// scatter_nd_op::UpdateOp::DIV); #define REGISTER_SCATTER_ND(type, dev) \ REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd"); @@ -329,8 +339,9 @@ class ScatterNdUpdateOp : public OpKernel { #define REGISTER_SCATTER_ND_CPU(type) REGISTER_SCATTER_ND(type, CPU); TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU); -TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU); -TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_CPU); +// TODO(simister): Re-enable all types after binary size is under control. +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU); +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU); // Registers GPU kernels. #if GOOGLE_CUDA @@ -356,47 +367,4 @@ TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_CPU); #undef REGISTER_SCATTER_ND_KERNEL #undef REGISTER_SCATTER_ND_KERNEL_INDEX -#if GOOGLE_CUDA -// Forward declarations of the functor specializations for GPU. -namespace functor { - -#define DECLARE_GPU_SPECS_OP(T, Index, op, NDIM) \ - template <> \ - Index ScatterNdFunctor<GPUDevice, T, Index, op, NDIM>::operator()( \ - OpKernelContext* c, const GPUDevice& d, \ - typename TTypes<T, IXDIM>::Tensor params, \ - typename TTypes<Index, 2>::ConstTensor indices, \ - typename TTypes<T, 2>::ConstTensor updates); \ - extern template struct ScatterNdFunctor<GPUDevice, T, Index, op>; - -#define DECLARE_GPU_SPECS_OPS(T, Index, op) \ - DECLARE_GPU_SPECS_OP(T, Index, op, 0); \ - DECLARE_GPU_SPECS_OP(T, Index, op, 1); \ - DECLARE_GPU_SPECS_OP(T, Index, op, 2); \ - DECLARE_GPU_SPECS_OP(T, Index, op, 3); \ - DECLARE_GPU_SPECS_OP(T, Index, op, 4); \ - DECLARE_GPU_SPECS_OP(T, Index, op, 5) - -#define DECLARE_GPU_SPECS_INDEX(T, Index) \ - DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ASSIGN); \ - DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::ADD); \ - DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::SUB); \ - DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::MUL); \ - DECLARE_GPU_SPECS_OPS(T, Index, scatter_nd_op::UpdateOp::DIV); - -#define DECLARE_GPU_SPECS(T) \ - DECLARE_GPU_SPECS_INDEX(T, int32); \ - DECLARE_GPU_SPECS_INDEX(T, int64); - -// TODO(simister): Re-enable when GPU support is working. -// TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS); - -#undef DECLARE_GPU_SPECS -#undef DECLARE_GPU_SPECS_INDEX -#undef DECLARE_GPU_SPECS_OPS -#undef DECLARE_GPU_SPECS_OP - -} // namespace functor -#endif // GOOGLE_CUDA - } // namespace tensorflow |