diff options
author | 2018-02-02 11:24:11 -0800 | |
---|---|---|
committer | 2018-02-02 11:28:07 -0800 | |
commit | 51792c887abd425693a3c36f16ea221b949f7277 (patch) | |
tree | f7d1fb6dda7a3a9486221a381b041d8421539d49 | |
parent | 4b83dea191761967e5d4c705caac6078f8360a1e (diff) |
Register resource_scatter_update for string types.
PiperOrigin-RevId: 184309674
-rw-r--r-- | tensorflow/core/kernels/resource_variable_ops.cc | 3 | ||||
-rw-r--r-- | tensorflow/core/ops/resource_variable_ops.cc | 2 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/resource_variable_ops_test.py | 12 |
3 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index 9cc8e03e3a..6ce53e725f 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -635,6 +635,9 @@ class ResourceScatterUpdateOp : public OpKernel { TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ARITHEMTIC_CPU); +REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate", + scatter_op::UpdateOp::ASSIGN); + // Registers GPU kernels. #if GOOGLE_CUDA #define REGISTER_SCATTER_ARITHEMTIC_GPU(type) \ diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index f6cfbf873a..8dae7e1ff5 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -193,7 +193,7 @@ REGISTER_OP("ResourceScatterUpdate") .Input("resource: resource") .Input("indices: Tindices") .Input("updates: dtype") - .Attr("dtype: numbertype") + .Attr("dtype: type") .Attr("Tindices: {int32, int64}") .SetShapeFn([](InferenceContext* c) { ShapeAndType handle_shape_and_type; diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index b4b555591d..cd94579688 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.util import compat @test_util.with_c_api @@ -170,6 +171,17 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[3]]) + def testScatterUpdateString(self): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.string, shape=[1, 1]) + self.evaluate(resource_variable_ops.assign_variable_op( + handle, constant_op.constant([["a"]], dtype=dtypes.string))) + self.evaluate(resource_variable_ops.resource_scatter_update( + handle, [0], constant_op.constant([["b"]], dtype=dtypes.string))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.string) + self.assertEqual(compat.as_bytes(self.evaluate(read)[0][0]), + compat.as_bytes("b")) + # TODO(alive): get this to work in Eager mode. def testGPU(self): with self.test_session(use_gpu=True): |