aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jack Rae <jwrae@google.com>2016-09-08 03:05:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-08 04:18:22 -0700
commit88dbc53dd7d51a6d1fd3dd4a0ba41bce950426b0 (patch)
tree72796e944c5b027c1b923888697b1681b8959c64
parent68d90864e7156c29dbf72697979bdd5d3174ac2d (diff)
Add scatter_mul and scatter_div state ops for CPU and GPU.
Note: no fp16 support on the GPU. Also added AtomicMul and AtomicDiv to cuda_kernel_helper.h (excluding fp16 support). Change: 132539643
-rw-r--r--tensorflow/core/kernels/scatter_op.cc38
-rw-r--r--tensorflow/core/kernels/scatter_op.h2
-rw-r--r--tensorflow/core/kernels/scatter_op_gpu.cu.cc18
-rw-r--r--tensorflow/core/kernels/scatter_op_test.cc20
-rw-r--r--tensorflow/core/ops/state_ops.cc80
-rw-r--r--tensorflow/core/util/cuda_kernel_helper.h106
-rw-r--r--tensorflow/python/kernel_tests/BUILD12
-rw-r--r--tensorflow/python/kernel_tests/scatter_ops_test.py93
-rw-r--r--tensorflow/python/ops/standard_ops.py2
-rw-r--r--tensorflow/python/ops/state_grad.py6
-rw-r--r--tensorflow/python/ops/state_ops.py4
11 files changed, 336 insertions, 45 deletions
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc
index 30e81737f8..1516455cc6 100644
--- a/tensorflow/core/kernels/scatter_op.cc
+++ b/tensorflow/core/kernels/scatter_op.cc
@@ -55,6 +55,20 @@ struct Assign<scatter_op::UpdateOp::SUB> {
p -= u;
}
};
+template <>
+struct Assign<scatter_op::UpdateOp::MUL> {
+ template <typename Params, typename Update>
+ static void Run(Params p, Update u) {
+ p *= u;
+ }
+};
+template <>
+struct Assign<scatter_op::UpdateOp::DIV> {
+ template <typename Params, typename Update>
+ static void Run(Params p, Update u) {
+ p /= u;
+ }
+};
} // namespace
@@ -195,8 +209,10 @@ struct ScatterFunctor<CPUDevice, T, Index, op> {
REGISTER_SCATTER_KERNEL_INDEX(type, int32, dev, name, op); \
REGISTER_SCATTER_KERNEL_INDEX(type, int64, dev, name, op);
-#define REGISTER_SCATTER_ADD_SUB(type, dev) \
+#define REGISTER_SCATTER_ARITHEMTIC(type, dev) \
REGISTER_SCATTER_KERNEL(type, dev, "ScatterAdd", scatter_op::UpdateOp::ADD); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ScatterDiv", scatter_op::UpdateOp::DIV); \
+ REGISTER_SCATTER_KERNEL(type, dev, "ScatterMul", scatter_op::UpdateOp::MUL); \
REGISTER_SCATTER_KERNEL(type, dev, "ScatterSub", scatter_op::UpdateOp::SUB);
#define REGISTER_SCATTER_UPDATE(type, dev) \
@@ -204,28 +220,30 @@ struct ScatterFunctor<CPUDevice, T, Index, op> {
scatter_op::UpdateOp::ASSIGN);
// Registers CPU kernels.
-#define REGISTER_SCATTER_ADD_SUB_CPU(type) REGISTER_SCATTER_ADD_SUB(type, CPU);
+#define REGISTER_SCATTER_ARITHEMTIC_CPU(type) \
+ REGISTER_SCATTER_ARITHEMTIC(type, CPU);
#define REGISTER_SCATTER_UPDATE_CPU(type) REGISTER_SCATTER_UPDATE(type, CPU);
-TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ADD_SUB_CPU);
+TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU);
TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU);
// Registers GPU kernels.
#if GOOGLE_CUDA
-#define REGISTER_SCATTER_ADD_SUB_GPU(type) REGISTER_SCATTER_ADD_SUB(type, GPU);
+#define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \
+ REGISTER_SCATTER_ARITHEMTIC(type, GPU);
#define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ADD_SUB_GPU);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_GPU);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU);
#endif // GOOGLE_CUDA
#undef REGISTER_SCATTER_ADD
-#undef REGISTER_SCATTER_ADD_SUB
-#undef REGISTER_SCATTER_ADD_SUB_CPU
-#undef REGISTER_SCATTER_ADD_SUB_GPU
+#undef REGISTER_SCATTER_ARITHEMTIC
+#undef REGISTER_SCATTER_ARITHEMTIC_CPU
+#undef REGISTER_SCATTER_ARITHEMTIC_GPU
#undef REGISTER_SCATTER_UPDATE
#undef REGISTER_SCATTER_UPDATE_CPU
#undef REGISTER_SCATTER_UPDATE_GPU
@@ -248,7 +266,9 @@ namespace functor {
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \
- DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB);
+ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \
+ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \
+ DECLARE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV);
#define DECLARE_GPU_SPECS(T) \
DECLARE_GPU_SPECS_INDEX(T, int32); \
diff --git a/tensorflow/core/kernels/scatter_op.h b/tensorflow/core/kernels/scatter_op.h
index ae5646cd20..6b35555b83 100644
--- a/tensorflow/core/kernels/scatter_op.h
+++ b/tensorflow/core/kernels/scatter_op.h
@@ -27,7 +27,7 @@ class OpKernelContext;
namespace scatter_op {
-enum class UpdateOp { ASSIGN, ADD, SUB };
+enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV };
} // namespace scatter_op
diff --git a/tensorflow/core/kernels/scatter_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_op_gpu.cu.cc
index 213c62402a..e51579f032 100644
--- a/tensorflow/core/kernels/scatter_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/scatter_op_gpu.cu.cc
@@ -54,6 +54,14 @@ __global__ void ScatterOpCustomKernel(
CudaAtomicSub(params + params_i, ldg(updates + updates_i));
break;
}
+ case scatter_op::UpdateOp::MUL: {
+ CudaAtomicMul(params + params_i, ldg(updates + updates_i));
+ break;
+ }
+ case scatter_op::UpdateOp::DIV: {
+ CudaAtomicDiv(params + params_i, ldg(updates + updates_i));
+ break;
+ }
}
}
}
@@ -86,10 +94,12 @@ struct ScatterFunctor<GPUDevice, T, Index, op> {
#define DEFINE_GPU_SPECS_OP(T, Index, op) \
template struct functor::ScatterFunctor<GPUDevice, T, Index, op>;
-#define DEFINE_GPU_SPECS_INDEX(T, Index) \
- DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
- DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \
- DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB);
+#define DEFINE_GPU_SPECS_INDEX(T, Index) \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ASSIGN); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::ADD); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::SUB); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::MUL); \
+ DEFINE_GPU_SPECS_OP(T, Index, scatter_op::UpdateOp::DIV);
#define DEFINE_GPU_SPECS(T) \
DEFINE_GPU_SPECS_INDEX(T, int32); \
diff --git a/tensorflow/core/kernels/scatter_op_test.cc b/tensorflow/core/kernels/scatter_op_test.cc
index dca0dc56a0..50ccbf2317 100644
--- a/tensorflow/core/kernels/scatter_op_test.cc
+++ b/tensorflow/core/kernels/scatter_op_test.cc
@@ -287,11 +287,31 @@ static void BM_ScatterAddInt64(int iters, int embedding_size) {
BM_ScatterHelper<int64>(iters, embedding_size, "ScatterAdd");
}
+static void BM_ScatterMulInt32(int iters, int embedding_size) {
+ BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMul");
+}
+static void BM_ScatterMulInt64(int iters, int embedding_size) {
+ BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMul");
+}
+
+static void BM_ScatterDivInt32(int iters, int embedding_size) {
+ BM_ScatterHelper<int32>(iters, embedding_size, "ScatterDiv");
+}
+static void BM_ScatterDivInt64(int iters, int embedding_size) {
+ BM_ScatterHelper<int64>(iters, embedding_size, "ScatterDiv");
+}
+
BENCHMARK(BM_ScatterUpdateInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
BENCHMARK(BM_ScatterUpdateInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
BENCHMARK(BM_ScatterAddInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
BENCHMARK(BM_ScatterAddInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+BENCHMARK(BM_ScatterMulInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+BENCHMARK(BM_ScatterMulInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+
+BENCHMARK(BM_ScatterDivInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+BENCHMARK(BM_ScatterDivInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc
index cc0c652107..6de3f97548 100644
--- a/tensorflow/core/ops/state_ops.cc
+++ b/tensorflow/core/ops/state_ops.cc
@@ -348,6 +348,86 @@ use_locking: If True, the subtraction will be protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention.
)doc");
+REGISTER_OP("ScatterMul")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .SetShapeFn(ScatterUpdateShape)
+ .Doc(R"doc(
+Multiplies sparse updates into a variable reference.
+
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] *= updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] *= updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...]
+
+This operation outputs `ref` after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions multiply.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]`.
+
+ref: Should be from a `Variable` node.
+indices: A tensor of indices into the first dimension of `ref`.
+updates: A tensor of updated values to multiply to `ref`.
+output_ref:= Same as `ref`. Returned as a convenience for operations that want
+ to use the updated values after the update is done.
+use_locking: If True, the operation will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
+REGISTER_OP("ScatterDiv")
+ .Input("ref: Ref(T)")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Output("output_ref: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .SetShapeFn(ScatterUpdateShape)
+ .Doc(R"doc(
+Divides a variable reference by sparse updates.
+
+This operation computes
+
+ # Scalar indices
+ ref[indices, ...] /= updates[...]
+
+ # Vector indices (for each i)
+ ref[indices[i], ...] /= updates[i, ...]
+
+ # High rank indices (for each i, ..., j)
+ ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...]
+
+This operation outputs `ref` after the update is done.
+This makes it easier to chain operations that need to use the reset value.
+
+Duplicate entries are handled correctly: if multiple `indices` reference
+the same location, their contributions divide.
+
+Requires `updates.shape = indices.shape + ref.shape[1:]`.
+
+ref: Should be from a `Variable` node.
+indices: A tensor of indices into the first dimension of `ref`.
+updates: A tensor of values that `ref` is divided by.
+output_ref:= Same as `ref`. Returned as a convenience for operations that want
+ to use the updated values after the update is done.
+use_locking: If True, the operation will be protected by a lock;
+ otherwise the behavior is undefined, but may exhibit less contention.
+)doc");
+
REGISTER_OP("CountUpTo")
.Input("ref: Ref(T)")
.Output("output: T")
diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h
index 60d9822677..488c28e530 100644
--- a/tensorflow/core/util/cuda_kernel_helper.h
+++ b/tensorflow/core/util/cuda_kernel_helper.h
@@ -79,6 +79,8 @@ __device__ __host__ inline T ldg(const T* address) {
// Reason of guarding: NVCC cannot compile the "::" in "cuda_builtin::atomicOp".
#ifdef __GCUDACC__
+using cuda_builtin::__float_as_int;
+using cuda_builtin::__int_as_float;
#define USE_CUDA_ATOMIC(op, T) \
CUDA_ATOMIC_WRAPPER(op, T) { return cuda_builtin::atomic##op(address, val); }
#else
@@ -101,7 +103,7 @@ USE_CUDA_ATOMIC(Max, uint64);
// The uint64 overload of atomicMax() is only available for __CUDA_ARCH__ >=
// 350. If not satisfied, we provide a custom implementation using atomicCAS().
CUDA_ATOMIC_WRAPPER(Max, uint64) {
- uint64* address_as_ull = (uint64*)address;
+ uint64* address_as_ull = reinterpret_cast<uint64*>(address);
uint64 old = *address_as_ull, assumed;
do {
@@ -116,7 +118,7 @@ CUDA_ATOMIC_WRAPPER(Max, uint64) {
// Custom implementation of atomicAdd for double.
// This implementation is copied from CUDA manual.
CUDA_ATOMIC_WRAPPER(Add, double) {
- uint64* address_as_ull = (uint64*)address;
+ uint64* address_as_ull = reinterpret_cast<uint64*>(address);
uint64 old = *address_as_ull, assumed;
do {
@@ -221,6 +223,106 @@ WRAPPED_ATOMIC_SUB(double);
#undef WRAPPED_ATOMIC_SUB
+// For atomicMul.
+CUDA_ATOMIC_WRAPPER(Mul, int32) {
+ int32 old = *address, assumed;
+ do {
+ assumed = old;
+ old = atomicCAS(address, assumed, val * assumed);
+ } while (assumed != old);
+ return old;
+}
+
+CUDA_ATOMIC_WRAPPER(Mul, uint32) {
+ uint32 old = *address, assumed;
+ do {
+ assumed = old;
+ old = atomicCAS(address, assumed, val * assumed);
+ } while (assumed != old);
+ return old;
+}
+
+CUDA_ATOMIC_WRAPPER(Mul, uint64) {
+ uint64 old = *address, assumed;
+ do {
+ assumed = old;
+ old = atomicCAS(address, assumed, val * assumed);
+ } while (assumed != old);
+ return old;
+}
+
+CUDA_ATOMIC_WRAPPER(Mul, float) {
+ int32* address_as_int = reinterpret_cast<int32*>(address);
+ int32 old = *address_as_int, assumed;
+ do {
+ assumed = old;
+ old = atomicCAS(address_as_int, assumed,
+ __float_as_int(val * __int_as_float(assumed)));
+ } while (assumed != old);
+ return __int_as_float(old);
+}
+
+CUDA_ATOMIC_WRAPPER(Mul, double) {
+ uint64* address_as_ull = reinterpret_cast<uint64*>(address);
+ uint64 old = *address_as_ull, assumed;
+ do {
+ assumed = old;
+ old = atomicCAS(address_as_ull, assumed,
+ __double_as_longlong(val * __longlong_as_double(assumed)));
+ } while (assumed != old);
+ return __longlong_as_double(old);
+}
+
+// For atomicDiv.
+CUDA_ATOMIC_WRAPPER(Div, int32) {
+ int32 old = *address, assumed;
+ do {
+ assumed = old;
+ old = atomicCAS(address, assumed, assumed / val);
+ } while (assumed != old);
+ return old;
+}
+
+CUDA_ATOMIC_WRAPPER(Div, uint32) {
+ uint32 old = *address, assumed;
+ do {
+ assumed = old;
+ old = atomicCAS(address, assumed, assumed / val);
+ } while (assumed != old);
+ return old;
+}
+
+CUDA_ATOMIC_WRAPPER(Div, uint64) {
+ uint64 old = *address, assumed;
+ do {
+ assumed = old;
+ old = atomicCAS(address, assumed, assumed / val);
+ } while (assumed != old);
+ return old;
+}
+
+CUDA_ATOMIC_WRAPPER(Div, float) {
+ int32* address_as_int = reinterpret_cast<int32*>(address);
+ int32 old = *address_as_int, assumed;
+ do {
+ assumed = old;
+ old = atomicCAS(address_as_int, assumed,
+ __float_as_int(__int_as_float(assumed) / val));
+ } while (assumed != old);
+ return __int_as_float(old);
+}
+
+CUDA_ATOMIC_WRAPPER(Div, double) {
+ uint64* address_as_ull = reinterpret_cast<uint64*>(address);
+ uint64 old = *address_as_ull, assumed;
+ do {
+ assumed = old;
+ old = atomicCAS(address_as_ull, assumed,
+ __double_as_longlong(__longlong_as_double(assumed) / val));
+ } while (assumed != old);
+ return __longlong_as_double(old);
+}
+
#undef USE_CUDA_ATOMIC
#undef CUDA_ATOMIC_WRAPPER
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 65b8539e4a..ad257a4339 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -212,7 +212,6 @@ cuda_py_tests(
"pooling_ops_test.py",
"random_gamma_test.py",
"rnn_test.py",
- "scatter_ops_test.py",
"seq2seq_test.py",
"slice_op_test.py",
"sparse_matmul_op_test.py",
@@ -224,6 +223,17 @@ cuda_py_tests(
],
)
+cuda_py_tests(
+ name = "large_kernel_tests",
+ size = "large",
+ srcs = [
+ "scatter_ops_test.py",
+ ],
+ additional_deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
# TODO(gpapan): Revisit the gradient of extract_image_patches_op to resolve
# http://b/31080670.
cuda_py_test(
diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py
index 2303a0650c..5e2bf59989 100644
--- a/tensorflow/python/kernel_tests/scatter_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_ops_test.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Tests for tensorflow.ops.tf.scatter."""
from __future__ import absolute_import
from __future__ import division
@@ -38,9 +37,37 @@ def _NumpySub(ref, indices, updates):
ref[indx] -= updates[i]
+def _NumpyMul(ref, indices, updates):
+ for i, indx in np.ndenumerate(indices):
+ ref[indx] *= updates[i]
+
+
+def _NumpyDiv(ref, indices, updates):
+ for i, indx in np.ndenumerate(indices):
+ ref[indx] /= updates[i]
+
+
+def _NumpyUpdate(ref, indices, updates):
+ for i, indx in np.ndenumerate(indices):
+ ref[indx] = updates[i]
+
+
+_TF_OPS_TO_NUMPY = {
+ tf.scatter_update: _NumpyUpdate,
+ tf.scatter_add: _NumpyAdd,
+ tf.scatter_sub: _NumpySub,
+ tf.scatter_mul: _NumpyMul,
+ tf.scatter_div: _NumpyDiv,
+}
+
+
class ScatterTest(tf.test.TestCase):
- def _VariableRankTest(self, np_scatter, tf_scatter, vtype, itype, use_gpu,
+ def _VariableRankTest(self,
+ tf_scatter,
+ vtype,
+ itype,
+ use_gpu,
repeat_indices=False):
np.random.seed(8)
with self.test_session(use_gpu=use_gpu):
@@ -54,54 +81,64 @@ class ScatterTest(tf.test.TestCase):
indices = indices[:size]
if size > 1 and repeat_indices:
# Add some random repeats.
- indices = indices[:size//2]
- for _ in range(size-size//2):
+ indices = indices[:size // 2]
+ for _ in range(size - size // 2):
# Randomly append some repeats.
- indices = np.append(indices, indices[np.random.randint(size//2)])
+ indices = np.append(indices,
+ indices[np.random.randint(size // 2)])
np.random.shuffle(indices)
indices = indices.reshape(indices_shape)
- updates = _AsType(np.random.randn(*(indices_shape + extra_shape)),
- vtype)
+ updates = _AsType(
+ np.random.randn(*(indices_shape + extra_shape)), vtype)
+ # Clips small values to avoid division by zero.
+ def clip_small_values(x):
+ return 1e-4 * np.sign(x) if np.abs(x) < 1e-4 else x
+ updates = np.vectorize(clip_small_values)(updates)
old = _AsType(np.random.randn(*((first_dim,) + extra_shape)), vtype)
# Scatter via numpy
new = old.copy()
+ np_scatter = _TF_OPS_TO_NUMPY[tf_scatter]
np_scatter(new, indices, updates)
# Scatter via tensorflow
ref = tf.Variable(old)
ref.initializer.run()
tf_scatter(ref, indices, updates).eval()
- # Compare
self.assertAllClose(ref.eval(), new)
- def _VariableRankTests(self, np_scatter, tf_scatter):
+ def _VariableRankTests(self, tf_scatter, repeat_indices=False):
for vtype in (np.float32, np.float64):
for itype in (np.int32, np.int64):
for use_gpu in (False, True):
- self._VariableRankTest(np_scatter, tf_scatter, vtype, itype, use_gpu)
+ self._VariableRankTest(tf_scatter, vtype, itype, use_gpu,
+ repeat_indices)
def testVariableRankUpdate(self):
- def update(ref, indices, updates):
- ref[indices] = updates
- self._VariableRankTests(update, tf.scatter_update)
+ self._VariableRankTests(tf.scatter_update)
def testVariableRankAdd(self):
- self._VariableRankTests(_NumpyAdd, tf.scatter_add)
+ self._VariableRankTests(tf.scatter_add)
def testVariableRankSub(self):
- self._VariableRankTests(_NumpySub, tf.scatter_sub)
+ self._VariableRankTests(tf.scatter_sub)
- def _ScatterRepeatIndicesTest(self, np_scatter, tf_scatter):
- for vtype in (np.float32, np.float64):
- for itype in (np.int32, np.int64):
- for use_gpu in (False, True):
- self._VariableRankTest(np_scatter, tf_scatter, vtype, itype, use_gpu,
- repeat_indices=True)
+ def testVariableRankMul(self):
+ self._VariableRankTests(tf.scatter_mul)
+
+ def testVariableRankDiv(self):
+ self._VariableRankTests(tf.scatter_div)
+
+ def testRepeatIndicesAdd(self):
+ self._VariableRankTests(tf.scatter_add, True)
+
+ def testRepeatIndicesSub(self):
+ self._VariableRankTests(tf.scatter_sub, True)
+
+ def testRepeatIndicesMul(self):
+ self._VariableRankTests(tf.scatter_mul, True)
- def testScatterRepeatIndices(self):
- """This tests scatter_add using indices that repeat."""
- self._ScatterRepeatIndicesTest(_NumpyAdd, tf.scatter_add)
- self._ScatterRepeatIndicesTest(_NumpySub, tf.scatter_sub)
+ def testRepeatIndicesDiv(self):
+ self._VariableRankTests(tf.scatter_div, True)
def testBooleanScatterUpdate(self):
with self.test_session(use_gpu=False) as session:
@@ -115,7 +152,7 @@ class ScatterTest(tf.test.TestCase):
self.assertAllEqual([False, True], var.eval())
def testScatterOutOfRangeCpu(self):
- for op in (tf.scatter_add, tf.scatter_sub, tf.scatter_update):
+ for op, _ in _TF_OPS_TO_NUMPY.items():
params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
updates = np.array([-3, -4, -5]).astype(np.float32)
with self.test_session(use_gpu=False):
@@ -139,7 +176,7 @@ class ScatterTest(tf.test.TestCase):
def _disabledTestScatterOutOfRangeGpu(self):
if not tf.test.IsBuiltWithCuda():
return
- for op in (tf.scatter_add, tf.scatter_sub, tf.scatter_update):
+ for op, _ in _TF_OPS_TO_NUMPY.items():
params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
updates = np.array([-3, -4, -5]).astype(np.float32)
# With GPU, the code ignores indices that are out of range.
@@ -159,5 +196,5 @@ class ScatterTest(tf.test.TestCase):
op(ref, indices, updates).eval()
-if __name__ == "__main__":
+if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index 3e109c7c70..9267b8ef2e 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -63,6 +63,8 @@ from tensorflow.python.ops.state_ops import assign_add
from tensorflow.python.ops.state_ops import assign_sub
from tensorflow.python.ops.state_ops import count_up_to
from tensorflow.python.ops.state_ops import scatter_add
+from tensorflow.python.ops.state_ops import scatter_div
+from tensorflow.python.ops.state_ops import scatter_mul
from tensorflow.python.ops.state_ops import scatter_sub
from tensorflow.python.ops.state_ops import scatter_update
from tensorflow.python.ops.string_ops import *
diff --git a/tensorflow/python/ops/state_grad.py b/tensorflow/python/ops/state_grad.py
index d6fcda3dfb..7d0940bf26 100644
--- a/tensorflow/python/ops/state_grad.py
+++ b/tensorflow/python/ops/state_grad.py
@@ -35,3 +35,9 @@ ops.NoGradient("ScatterAdd")
ops.NoGradient("ScatterSub")
+
+
+ops.NoGradient("ScatterMul")
+
+
+ops.NoGradient("ScatterDiv")
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index 63fae74812..6220d1e532 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -88,6 +88,8 @@ automatically by the optimizers in most cases.
@@scatter_update
@@scatter_add
@@scatter_sub
+@@scatter_mul
+@@scatter_div
@@sparse_mask
@@IndexedSlices
@@ -227,6 +229,8 @@ def _CountUpToShape(op):
@ops.RegisterShape("ScatterAdd")
@ops.RegisterShape("ScatterSub")
+@ops.RegisterShape("ScatterMul")
+@ops.RegisterShape("ScatterDiv")
@ops.RegisterShape("ScatterUpdate")
def _ScatterUpdateShape(op):
"""Shape function for the sparse update ops."""