diff options
author | 2016-09-08 03:05:08 -0800 | |
---|---|---|
committer | 2016-09-08 04:18:22 -0700 | |
commit | 88dbc53dd7d51a6d1fd3dd4a0ba41bce950426b0 (patch) | |
tree | 72796e944c5b027c1b923888697b1681b8959c64 | |
parent | 68d90864e7156c29dbf72697979bdd5d3174ac2d (diff) |
Add scatter_mul and scatter_div state ops for CPU and GPU.
Note: no fp16 support on the GPU.
Also added AtomicMul and AtomicDiv to cuda_kernel_helper.h (excluding fp16 support).
Change: 132539643
-rw-r--r-- | tensorflow/core/kernels/scatter_op.cc | 38 | ||||
-rw-r--r-- | tensorflow/core/kernels/scatter_op.h | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/scatter_op_gpu.cu.cc | 18 | ||||
-rw-r--r-- | tensorflow/core/kernels/scatter_op_test.cc | 20 | ||||
-rw-r--r-- | tensorflow/core/ops/state_ops.cc | 80 | ||||
-rw-r--r-- | tensorflow/core/util/cuda_kernel_helper.h | 106 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 12 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/scatter_ops_test.py | 93 | ||||
-rw-r--r-- | tensorflow/python/ops/standard_ops.py | 2 | ||||
-rw-r--r-- | tensorflow/python/ops/state_grad.py | 6 | ||||
-rw-r--r-- | tensorflow/python/ops/state_ops.py | 4 |
11 files changed, 336 insertions, 45 deletions
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc index 30e81737f8..1516455cc6 100644 --- a/tensorflow/core/kernels/scatter_op.cc +++ b/tensorflow/core/kernels/scatter_op.cc @@ -55,6 +55,20 @@ struct Assign<scatter_op::UpdateOp::SUB> { p -= u; } }; +template <> +struct Assign<scatter_op::UpdateOp::MUL> { + template <typename Params, typename Update> + static void Run(Params p, Update u) { + p *= u; + } +}; +template <> +struct Assign<scatter_op::UpdateOp::DIV> { + template <typename Params, typename Update> + static void Run(Params p, Update u) { + p /= u; + } +}; } // namespace @@ -195,8 +209,10 @@ struct ScatterFunctor<CPUDevice, T, Index, op> { REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \ REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op); -#define REGISTER_SCATTER_ADD_SUB(type, dev) \ +#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \ REGISTER_SCATTER_KERNEL(type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \ + REGISTER_SCATTER_KERNEL(type, dev, "ScatterDiv", scatter_op::UpdateOp::DIV); \ + REGISTER_SCATTER_KERNEL(type, dev, "ScatterMul", scatter_op::UpdateOp::MUL); \ REGISTER_SCATTER_KERNEL(type, dev, "ScatterSub", scatter_op::UpdateOp::SUB); #define REGISTER_SCATTER_UPDATE(type, dev) \ @@ -204,28 +220,30 @@ struct ScatterFunctor<CPUDevice, T, Index, op> { scatter_op::UpdateOp::ASSIGN); // Registers CPU kernels. -#define REGISTER_SCATTER_ADD_SUB_CPU(type) REGISTER_SCATTER_ADD_SUB(type, CPU); +#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \ + REGISTER_SCATTER_ARITHEMTIC(type, CPU); #define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU); -TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ADD_SUB_CPU); +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU); TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU); // Registers GPU kernels. #if GOOGLE_CUDA -#define REGISTER_SCATTER_ADD_SUB_GPU(type) REGISTER_SCATTER_ADD_SUB(type, GPU); +#define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \ + REGISTER_SCATTER_ARITHEMTIC(type, GPU); #define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU); -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ADD_SUB_GPU); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU); #endif // GOOGLE_CUDA #undef REGISTER_SCATTER_ADD -#undef REGISTER_SCATTER_ADD_SUB -#undef REGISTER_SCATTER_ADD_SUB_CPU -#undef REGISTER_SCATTER_ADD_SUB_GPU +#undef REGISTER_SCATTER_ARITHEMTIC +#undef REGISTER_SCATTER_ARITHEMTIC_CPU +#undef REGISTER_SCATTER_ARITHEMTIC_GPU #undef REGISTER_SCATTER_UPDATE #undef REGISTER_SCATTER_UPDATE_CPU #undef REGISTER_SCATTER_UPDATE_GPU @@ -248,7 +266,9 @@ namespace functor { #define DECLARE_GPU_SPECS_INDEX(T, Index) \ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \ - DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); + DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \ + DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \ + DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); #define DECLARE_GPU_SPECS(T) \ DECLARE_GPU_SPECS_INDEX(T, int32); \ diff --git a/tensorflow/core/kernels/scatter_op.h b/tensorflow/core/kernels/scatter_op.h index ae5646cd20..6b35555b83 100644 --- a/tensorflow/core/kernels/scatter_op.h +++ b/tensorflow/core/kernels/scatter_op.h @@ -27,7 +27,7 @@ class OpKernelContext; namespace scatter_op { -enum class UpdateOp { ASSIGN, ADD, SUB }; +enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV }; } // namespace scatter_op diff --git a/tensorflow/core/kernels/scatter_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_op_gpu.cu.cc index 213c62402a..e51579f032 100644 --- a/tensorflow/core/kernels/scatter_op_gpu.cu.cc +++ b/tensorflow/core/kernels/scatter_op_gpu.cu.cc @@ -54,6 +54,14 @@ __global__ void ScatterOpCustomKernel( CudaAtomicSub(params + params_i, ldg(updates + updates_i)); break; } + case scatter_op::UpdateOp::MUL: { + CudaAtomicMul(params + params_i, ldg(updates + updates_i)); + break; + } + case scatter_op::UpdateOp::DIV: { + CudaAtomicDiv(params + params_i, ldg(updates + updates_i)); + break; + } } } } @@ -86,10 +94,12 @@ struct ScatterFunctor<GPUDevice, T, Index, op> { #define DEFINE_GPU_SPECS_OP(T, Index, op) \ template struct functor::ScatterFunctor<GPUDevice, T, Index, op>; -#define DEFINE_GPU_SPECS_INDEX(T, Index) \ - DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \ - DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \ - DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); +#define DEFINE_GPU_SPECS_INDEX(T, Index) \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \ + DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV); #define DEFINE_GPU_SPECS(T) \ DEFINE_GPU_SPECS_INDEX(T, int32); \ diff --git a/tensorflow/core/kernels/scatter_op_test.cc b/tensorflow/core/kernels/scatter_op_test.cc index dca0dc56a0..50ccbf2317 100644 --- a/tensorflow/core/kernels/scatter_op_test.cc +++ b/tensorflow/core/kernels/scatter_op_test.cc @@ -287,11 +287,31 @@ static void BM_ScatterAddInt64(int iters, int embedding_size) { BM_ScatterHelper<int64>(iters, embedding_size, "ScatterAdd"); } +static void BM_ScatterMulInt32(int iters, int embedding_size) { + BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMul"); +} +static void BM_ScatterMulInt64(int iters, int embedding_size) { + BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMul"); +} + +static void BM_ScatterDivInt32(int iters, int embedding_size) { + BM_ScatterHelper<int32>(iters, embedding_size, "ScatterDiv"); +} +static void BM_ScatterDivInt64(int iters, int embedding_size) { + BM_ScatterHelper<int64>(iters, embedding_size, "ScatterDiv"); +} + BENCHMARK(BM_ScatterUpdateInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); BENCHMARK(BM_ScatterUpdateInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); BENCHMARK(BM_ScatterAddInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); BENCHMARK(BM_ScatterAddInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); +BENCHMARK(BM_ScatterMulInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); +BENCHMARK(BM_ScatterMulInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); + +BENCHMARK(BM_ScatterDivInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); +BENCHMARK(BM_ScatterDivInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc index cc0c652107..6de3f97548 100644 --- a/tensorflow/core/ops/state_ops.cc +++ b/tensorflow/core/ops/state_ops.cc @@ -348,6 +348,86 @@ use_locking: If True, the subtraction will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. )doc"); +REGISTER_OP("ScatterMul") + .Input("ref: Ref(T)") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output_ref: Ref(T)") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .SetShapeFn(ScatterUpdateShape) + .Doc(R"doc( +Multiplies sparse updates into a variable reference. + +This operation computes + + # Scalar indices + ref[indices, ...] *= updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] *= updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...] + +This operation outputs `ref` after the update is done. +This makes it easier to chain operations that need to use the reset value. + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions multiply. + +Requires `updates.shape = indices.shape + ref.shape[1:]`. + +ref: Should be from a `Variable` node. +indices: A tensor of indices into the first dimension of `ref`. +updates: A tensor of updated values to multiply to `ref`. +output_ref:= Same as `ref`. Returned as a convenience for operations that want + to use the updated values after the update is done. +use_locking: If True, the operation will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +)doc"); + +REGISTER_OP("ScatterDiv") + .Input("ref: Ref(T)") + .Input("indices: Tindices") + .Input("updates: T") + .Output("output_ref: Ref(T)") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .SetShapeFn(ScatterUpdateShape) + .Doc(R"doc( +Divides a variable reference by sparse updates. + +This operation computes + + # Scalar indices + ref[indices, ...] /= updates[...] + + # Vector indices (for each i) + ref[indices[i], ...] /= updates[i, ...] + + # High rank indices (for each i, ..., j) + ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] + +This operation outputs `ref` after the update is done. +This makes it easier to chain operations that need to use the reset value. + +Duplicate entries are handled correctly: if multiple `indices` reference +the same location, their contributions divide. + +Requires `updates.shape = indices.shape + ref.shape[1:]`. + +ref: Should be from a `Variable` node. +indices: A tensor of indices into the first dimension of `ref`. +updates: A tensor of values that `ref` is divided by. +output_ref:= Same as `ref`. Returned as a convenience for operations that want + to use the updated values after the update is done. +use_locking: If True, the operation will be protected by a lock; + otherwise the behavior is undefined, but may exhibit less contention. +)doc"); + REGISTER_OP("CountUpTo") .Input("ref: Ref(T)") .Output("output: T") diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h index 60d9822677..488c28e530 100644 --- a/tensorflow/core/util/cuda_kernel_helper.h +++ b/tensorflow/core/util/cuda_kernel_helper.h @@ -79,6 +79,8 @@ __device__ __host__ inline T ldg(const T* address) { // Reason of guarding: NVCC cannot compile the "::" in "cuda_builtin::atomicOp". #ifdef __GCUDACC__ +using cuda_builtin::__float_as_int; +using cuda_builtin::__int_as_float; #define USE_CUDA_ATOMIC(op, T) \ CUDA_ATOMIC_WRAPPER(op, T) { return cuda_builtin::atomic##op(address, val); } #else @@ -101,7 +103,7 @@ USE_CUDA_ATOMIC(Max, uint64); // The uint64 overload of atomicMax() is only available for __CUDA_ARCH__ >= // 350. If not satisfied, we provide a custom implementation using atomicCAS(). CUDA_ATOMIC_WRAPPER(Max, uint64) { - uint64* address_as_ull = (uint64*)address; + uint64* address_as_ull = reinterpret_cast<uint64*>(address); uint64 old = *address_as_ull, assumed; do { @@ -116,7 +118,7 @@ CUDA_ATOMIC_WRAPPER(Max, uint64) { // Custom implementation of atomicAdd for double. // This implementation is copied from CUDA manual. CUDA_ATOMIC_WRAPPER(Add, double) { - uint64* address_as_ull = (uint64*)address; + uint64* address_as_ull = reinterpret_cast<uint64*>(address); uint64 old = *address_as_ull, assumed; do { @@ -221,6 +223,106 @@ WRAPPED_ATOMIC_SUB(double); #undef WRAPPED_ATOMIC_SUB +// For atomicMul. +CUDA_ATOMIC_WRAPPER(Mul, int32) { + int32 old = *address, assumed; + do { + assumed = old; + old = atomicCAS(address, assumed, val * assumed); + } while (assumed != old); + return old; +} + +CUDA_ATOMIC_WRAPPER(Mul, uint32) { + uint32 old = *address, assumed; + do { + assumed = old; + old = atomicCAS(address, assumed, val * assumed); + } while (assumed != old); + return old; +} + +CUDA_ATOMIC_WRAPPER(Mul, uint64) { + uint64 old = *address, assumed; + do { + assumed = old; + old = atomicCAS(address, assumed, val * assumed); + } while (assumed != old); + return old; +} + +CUDA_ATOMIC_WRAPPER(Mul, float) { + int32* address_as_int = reinterpret_cast<int32*>(address); + int32 old = *address_as_int, assumed; + do { + assumed = old; + old = atomicCAS(address_as_int, assumed, + __float_as_int(val * __int_as_float(assumed))); + } while (assumed != old); + return __int_as_float(old); +} + +CUDA_ATOMIC_WRAPPER(Mul, double) { + uint64* address_as_ull = reinterpret_cast<uint64*>(address); + uint64 old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val * __longlong_as_double(assumed))); + } while (assumed != old); + return __longlong_as_double(old); +} + +// For atomicDiv. +CUDA_ATOMIC_WRAPPER(Div, int32) { + int32 old = *address, assumed; + do { + assumed = old; + old = atomicCAS(address, assumed, assumed / val); + } while (assumed != old); + return old; +} + +CUDA_ATOMIC_WRAPPER(Div, uint32) { + uint32 old = *address, assumed; + do { + assumed = old; + old = atomicCAS(address, assumed, assumed / val); + } while (assumed != old); + return old; +} + +CUDA_ATOMIC_WRAPPER(Div, uint64) { + uint64 old = *address, assumed; + do { + assumed = old; + old = atomicCAS(address, assumed, assumed / val); + } while (assumed != old); + return old; +} + +CUDA_ATOMIC_WRAPPER(Div, float) { + int32* address_as_int = reinterpret_cast<int32*>(address); + int32 old = *address_as_int, assumed; + do { + assumed = old; + old = atomicCAS(address_as_int, assumed, + __float_as_int(__int_as_float(assumed) / val)); + } while (assumed != old); + return __int_as_float(old); +} + +CUDA_ATOMIC_WRAPPER(Div, double) { + uint64* address_as_ull = reinterpret_cast<uint64*>(address); + uint64 old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(__longlong_as_double(assumed) / val)); + } while (assumed != old); + return __longlong_as_double(old); +} + #undef USE_CUDA_ATOMIC #undef CUDA_ATOMIC_WRAPPER diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 65b8539e4a..ad257a4339 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -212,7 +212,6 @@ cuda_py_tests( "pooling_ops_test.py", "random_gamma_test.py", "rnn_test.py", - "scatter_ops_test.py", "seq2seq_test.py", "slice_op_test.py", "sparse_matmul_op_test.py", @@ -224,6 +223,17 @@ cuda_py_tests( ], ) +cuda_py_tests( + name = "large_kernel_tests", + size = "large", + srcs = [ + "scatter_ops_test.py", + ], + additional_deps = [ + "//tensorflow:tensorflow_py", + ], +) + # TODO(gpapan): Revisit the gradient of extract_image_patches_op to resolve # http://b/31080670. cuda_py_test( diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py index 2303a0650c..5e2bf59989 100644 --- a/tensorflow/python/kernel_tests/scatter_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_ops_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Tests for tensorflow.ops.tf.scatter.""" from __future__ import absolute_import from __future__ import division @@ -38,9 +37,37 @@ def _NumpySub(ref, indices, updates): ref[indx] -= updates[i] +def _NumpyMul(ref, indices, updates): + for i, indx in np.ndenumerate(indices): + ref[indx] *= updates[i] + + +def _NumpyDiv(ref, indices, updates): + for i, indx in np.ndenumerate(indices): + ref[indx] /= updates[i] + + +def _NumpyUpdate(ref, indices, updates): + for i, indx in np.ndenumerate(indices): + ref[indx] = updates[i] + + +_TF_OPS_TO_NUMPY = { + tf.scatter_update: _NumpyUpdate, + tf.scatter_add: _NumpyAdd, + tf.scatter_sub: _NumpySub, + tf.scatter_mul: _NumpyMul, + tf.scatter_div: _NumpyDiv, +} + + class ScatterTest(tf.test.TestCase): - def _VariableRankTest(self, np_scatter, tf_scatter, vtype, itype, use_gpu, + def _VariableRankTest(self, + tf_scatter, + vtype, + itype, + use_gpu, repeat_indices=False): np.random.seed(8) with self.test_session(use_gpu=use_gpu): @@ -54,54 +81,64 @@ class ScatterTest(tf.test.TestCase): indices = indices[:size] if size > 1 and repeat_indices: # Add some random repeats. - indices = indices[:size//2] - for _ in range(size-size//2): + indices = indices[:size // 2] + for _ in range(size - size // 2): # Randomly append some repeats. - indices = np.append(indices, indices[np.random.randint(size//2)]) + indices = np.append(indices, + indices[np.random.randint(size // 2)]) np.random.shuffle(indices) indices = indices.reshape(indices_shape) - updates = _AsType(np.random.randn(*(indices_shape + extra_shape)), - vtype) + updates = _AsType( + np.random.randn(*(indices_shape + extra_shape)), vtype) + # Clips small values to avoid division by zero. + def clip_small_values(x): + return 1e-4 * np.sign(x) if np.abs(x) < 1e-4 else x + updates = np.vectorize(clip_small_values)(updates) old = _AsType(np.random.randn(*((first_dim,) + extra_shape)), vtype) # Scatter via numpy new = old.copy() + np_scatter = _TF_OPS_TO_NUMPY[tf_scatter] np_scatter(new, indices, updates) # Scatter via tensorflow ref = tf.Variable(old) ref.initializer.run() tf_scatter(ref, indices, updates).eval() - # Compare self.assertAllClose(ref.eval(), new) - def _VariableRankTests(self, np_scatter, tf_scatter): + def _VariableRankTests(self, tf_scatter, repeat_indices=False): for vtype in (np.float32, np.float64): for itype in (np.int32, np.int64): for use_gpu in (False, True): - self._VariableRankTest(np_scatter, tf_scatter, vtype, itype, use_gpu) + self._VariableRankTest(tf_scatter, vtype, itype, use_gpu, + repeat_indices) def testVariableRankUpdate(self): - def update(ref, indices, updates): - ref[indices] = updates - self._VariableRankTests(update, tf.scatter_update) + self._VariableRankTests(tf.scatter_update) def testVariableRankAdd(self): - self._VariableRankTests(_NumpyAdd, tf.scatter_add) + self._VariableRankTests(tf.scatter_add) def testVariableRankSub(self): - self._VariableRankTests(_NumpySub, tf.scatter_sub) + self._VariableRankTests(tf.scatter_sub) - def _ScatterRepeatIndicesTest(self, np_scatter, tf_scatter): - for vtype in (np.float32, np.float64): - for itype in (np.int32, np.int64): - for use_gpu in (False, True): - self._VariableRankTest(np_scatter, tf_scatter, vtype, itype, use_gpu, - repeat_indices=True) + def testVariableRankMul(self): + self._VariableRankTests(tf.scatter_mul) + + def testVariableRankDiv(self): + self._VariableRankTests(tf.scatter_div) + + def testRepeatIndicesAdd(self): + self._VariableRankTests(tf.scatter_add, True) + + def testRepeatIndicesSub(self): + self._VariableRankTests(tf.scatter_sub, True) + + def testRepeatIndicesMul(self): + self._VariableRankTests(tf.scatter_mul, True) - def testScatterRepeatIndices(self): - """This tests scatter_add using indices that repeat.""" - self._ScatterRepeatIndicesTest(_NumpyAdd, tf.scatter_add) - self._ScatterRepeatIndicesTest(_NumpySub, tf.scatter_sub) + def testRepeatIndicesDiv(self): + self._VariableRankTests(tf.scatter_div, True) def testBooleanScatterUpdate(self): with self.test_session(use_gpu=False) as session: @@ -115,7 +152,7 @@ class ScatterTest(tf.test.TestCase): self.assertAllEqual([False, True], var.eval()) def testScatterOutOfRangeCpu(self): - for op in (tf.scatter_add, tf.scatter_sub, tf.scatter_update): + for op, _ in _TF_OPS_TO_NUMPY.items(): params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32) updates = np.array([-3, -4, -5]).astype(np.float32) with self.test_session(use_gpu=False): @@ -139,7 +176,7 @@ class ScatterTest(tf.test.TestCase): def _disabledTestScatterOutOfRangeGpu(self): if not tf.test.IsBuiltWithCuda(): return - for op in (tf.scatter_add, tf.scatter_sub, tf.scatter_update): + for op, _ in _TF_OPS_TO_NUMPY.items(): params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32) updates = np.array([-3, -4, -5]).astype(np.float32) # With GPU, the code ignores indices that are out of range. @@ -159,5 +196,5 @@ class ScatterTest(tf.test.TestCase): op(ref, indices, updates).eval() -if __name__ == "__main__": +if __name__ == '__main__': tf.test.main() diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index 3e109c7c70..9267b8ef2e 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -63,6 +63,8 @@ from tensorflow.python.ops.state_ops import assign_add from tensorflow.python.ops.state_ops import assign_sub from tensorflow.python.ops.state_ops import count_up_to from tensorflow.python.ops.state_ops import scatter_add +from tensorflow.python.ops.state_ops import scatter_div +from tensorflow.python.ops.state_ops import scatter_mul from tensorflow.python.ops.state_ops import scatter_sub from tensorflow.python.ops.state_ops import scatter_update from tensorflow.python.ops.string_ops import * diff --git a/tensorflow/python/ops/state_grad.py b/tensorflow/python/ops/state_grad.py index d6fcda3dfb..7d0940bf26 100644 --- a/tensorflow/python/ops/state_grad.py +++ b/tensorflow/python/ops/state_grad.py @@ -35,3 +35,9 @@ ops.NoGradient("ScatterAdd") ops.NoGradient("ScatterSub") + + +ops.NoGradient("ScatterMul") + + +ops.NoGradient("ScatterDiv") diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index 63fae74812..6220d1e532 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -88,6 +88,8 @@ automatically by the optimizers in most cases. @@scatter_update @@scatter_add @@scatter_sub +@@scatter_mul +@@scatter_div @@sparse_mask @@IndexedSlices @@ -227,6 +229,8 @@ def _CountUpToShape(op): @ops.RegisterShape("ScatterAdd") @ops.RegisterShape("ScatterSub") +@ops.RegisterShape("ScatterMul") +@ops.RegisterShape("ScatterDiv") @ops.RegisterShape("ScatterUpdate") def _ScatterUpdateShape(op): """Shape function for the sparse update ops.""" |