diff options
author | 2015-12-15 14:47:57 -0800 | |
---|---|---|
committer | 2015-12-15 14:47:57 -0800 | |
commit | 42e65468aaa4a5238ee899b2f8400fdd76bff0eb (patch) | |
tree | d835db432b515929fcc07660342fdf6e918ef2f2 | |
parent | 21feee989ae36b079b5274071064f8af9d3018df (diff) |
Add support to scatter_update to allow string and bool tensors.
Change: 110286577
-rw-r--r-- | tensorflow/core/kernels/scatter_op.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/kernels/scatter_op_test.cc | 50 |
2 files changed, 40 insertions, 16 deletions
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc index 84dd625a9f..20fe1ed67a 100644 --- a/tensorflow/core/kernels/scatter_op.cc +++ b/tensorflow/core/kernels/scatter_op.cc @@ -147,10 +147,8 @@ class ScatterUpdateOp : public OpKernel { #define REGISTER_SCATTER_UPDATE_INT32(type) REGISTER_SCATTER_UPDATE(type, int32) #define REGISTER_SCATTER_UPDATE_INT64(type) REGISTER_SCATTER_UPDATE(type, int64) -TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_INT32); -TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_INT64); -REGISTER_SCATTER_UPDATE_INT32(bool) -REGISTER_SCATTER_UPDATE_INT64(bool) +TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_INT32); +TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_INT64); #undef REGISTER_SCATTER_UPDATE_INT64 #undef REGISTER_SCATTER_UPDATE_INT32 diff --git a/tensorflow/core/kernels/scatter_op_test.cc b/tensorflow/core/kernels/scatter_op_test.cc index 5d245bf9bb..d751570acb 100644 --- a/tensorflow/core/kernels/scatter_op_test.cc +++ b/tensorflow/core/kernels/scatter_op_test.cc @@ -38,19 +38,45 @@ namespace { class ScatterUpdateOpTest : public OpsTestBase { protected: - void MakeOp(DataType index_type) { + void MakeOp(DataType variable_ref_type, DataType index_type) { RequireDefaultOps(); ASSERT_OK(NodeDefBuilder("myop", "ScatterUpdate") - .Input(FakeInput(DT_FLOAT_REF)) + .Input(FakeInput(variable_ref_type)) .Input(FakeInput(index_type)) - .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(RemoveRefType(variable_ref_type))) .Finalize(node_def())); ASSERT_OK(InitOp()); } }; +TEST_F(ScatterUpdateOpTest, Simple_StringType) { + MakeOp(DT_STRING_REF, DT_INT32); + AddInputFromArray<string>(TensorShape({1}), {"Brain"}); + AddInputFromArray<int32>(TensorShape({1}), {0}); + AddInputFromArray<string>(TensorShape({1}), {"TensorFlow"}); + ASSERT_OK(RunOpKernel()); + // Check the new state of the input + Tensor params_tensor = *mutable_input(0).tensor; + Tensor expected(allocator(), DT_STRING, TensorShape({1})); + test::FillValues<string>(&expected, {"TensorFlow"}); + test::ExpectTensorEqual<string>(expected, params_tensor); +} + +TEST_F(ScatterUpdateOpTest, Simple_BoolType) { + MakeOp(DT_BOOL_REF, DT_INT32); + AddInputFromArray<bool>(TensorShape({1}), {false}); + AddInputFromArray<int32>(TensorShape({1}), {0}); + AddInputFromArray<bool>(TensorShape({1}), {true}); + ASSERT_OK(RunOpKernel()); + // Check the new state of the input + Tensor params_tensor = *mutable_input(0).tensor; + Tensor expected(allocator(), DT_BOOL, TensorShape({1})); + test::FillValues<bool>(&expected, {true}); + test::ExpectTensorEqual<bool>(expected, params_tensor); +} + TEST_F(ScatterUpdateOpTest, Simple_TwoD32) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT_REF, DT_INT32); // Feed and run AddInputFromArray<float>(TensorShape({5, 3}), @@ -69,7 +95,7 @@ TEST_F(ScatterUpdateOpTest, Simple_TwoD32) { } TEST_F(ScatterUpdateOpTest, Simple_Two64) { - MakeOp(DT_INT64); + MakeOp(DT_FLOAT_REF, DT_INT64); // Feed and run AddInputFromArray<float>(TensorShape({5, 3}), @@ -88,7 +114,7 @@ TEST_F(ScatterUpdateOpTest, Simple_Two64) { } TEST_F(ScatterUpdateOpTest, Simple_ZeroD) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT_REF, DT_INT32); // Feed and run AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0}); @@ -104,7 +130,7 @@ TEST_F(ScatterUpdateOpTest, Simple_ZeroD) { } TEST_F(ScatterUpdateOpTest, Simple_OneD) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT_REF, DT_INT32); // Feed and run AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0}); @@ -120,7 +146,7 @@ TEST_F(ScatterUpdateOpTest, Simple_OneD) { } TEST_F(ScatterUpdateOpTest, HigherRank) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT_REF, DT_INT32); // Feed and run AddInputFromArray<float>(TensorShape({8}), {0, 0, 0, 0, 0, 0, 0, 0}); @@ -136,7 +162,7 @@ TEST_F(ScatterUpdateOpTest, HigherRank) { } TEST_F(ScatterUpdateOpTest, Error_IndexOutOfRange) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT_REF, DT_INT32); // Feed and run AddInputFromArray<float>(TensorShape({5, 3}), @@ -151,7 +177,7 @@ TEST_F(ScatterUpdateOpTest, Error_IndexOutOfRange) { } TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT_REF, DT_INT32); // Feed and run AddInputFromArray<float>(TensorShape({2, 3}), {0, 0, 0, 0, 0, 0}); @@ -166,7 +192,7 @@ TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) { } TEST_F(ScatterUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT_REF, DT_INT32); // Feed and run AddInputFromArray<float>(TensorShape({5, 3}), @@ -184,7 +210,7 @@ TEST_F(ScatterUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) { } TEST_F(ScatterUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT_REF, DT_INT32); // Feed and run AddInputFromArray<float>(TensorShape({5, 3}), |