aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-02-02 11:24:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-02 11:28:07 -0800
commit51792c887abd425693a3c36f16ea221b949f7277 (patch)
treef7d1fb6dda7a3a9486221a381b041d8421539d49
parent4b83dea191761967e5d4c705caac6078f8360a1e (diff)
Register resource_scatter_update for string types.
PiperOrigin-RevId: 184309674
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc3
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc2
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py12
3 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index 9cc8e03e3a..6ce53e725f 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -635,6 +635,9 @@ class ResourceScatterUpdateOp : public OpKernel {
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU);
+REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
+ scatter_op::UpdateOp::ASSIGN);
+
// Registers GPU kernels.
#if GOOGLE_CUDA
#define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index f6cfbf873a..8dae7e1ff5 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -193,7 +193,7 @@ REGISTER_OP("ResourceScatterUpdate")
.Input("resource: resource")
.Input("indices: Tindices")
.Input("updates: dtype")
- .Attr("dtype: numbertype")
+ .Attr("dtype: type")
.Attr("Tindices: {int32, int64}")
.SetShapeFn([](InferenceContext* c) {
ShapeAndType handle_shape_and_type;
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index b4b555591d..cd94579688 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.util import compat
@test_util.with_c_api
@@ -170,6 +171,17 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[3]])
+ def testScatterUpdateString(self):
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.string, shape=[1, 1])
+ self.evaluate(resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([["a"]], dtype=dtypes.string)))
+ self.evaluate(resource_variable_ops.resource_scatter_update(
+ handle, [0], constant_op.constant([["b"]], dtype=dtypes.string)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string)
+ self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]),
+ compat.as_bytes("b"))
+
# TODO(alive): get this to work in Eager mode.
def testGPU(self):
with self.test_session(use_gpu=True):