From 42e65468aaa4a5238ee899b2f8400fdd76bff0eb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Dec 2015 14:47:57 -0800 Subject: Add support to scatter_update to allow string and bool tensors. Change: 110286577 --- tensorflow/core/kernels/scatter_op.cc | 6 ++-- tensorflow/core/kernels/scatter_op_test.cc | 50 +++++++++++++++++++++++------- 2 files changed, 40 insertions(+), 16 deletions(-) diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc index 84dd625a9f..20fe1ed67a 100644 --- a/tensorflow/core/kernels/scatter_op.cc +++ b/tensorflow/core/kernels/scatter_op.cc @@ -147,10 +147,8 @@ class ScatterUpdateOp : public OpKernel { #define REGISTER_SCATTER_UPDATE_INT32(type) REGISTER_SCATTER_UPDATE(type, int32) #define REGISTER_SCATTER_UPDATE_INT64(type) REGISTER_SCATTER_UPDATE(type, int64) -TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_INT32); -TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_INT64); -REGISTER_SCATTER_UPDATE_INT32(bool) -REGISTER_SCATTER_UPDATE_INT64(bool) +TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_INT32); +TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_INT64); #undef REGISTER_SCATTER_UPDATE_INT64 #undef REGISTER_SCATTER_UPDATE_INT32 diff --git a/tensorflow/core/kernels/scatter_op_test.cc b/tensorflow/core/kernels/scatter_op_test.cc index 5d245bf9bb..d751570acb 100644 --- a/tensorflow/core/kernels/scatter_op_test.cc +++ b/tensorflow/core/kernels/scatter_op_test.cc @@ -38,19 +38,45 @@ namespace { class ScatterUpdateOpTest : public OpsTestBase { protected: - void MakeOp(DataType index_type) { + void MakeOp(DataType variable_ref_type, DataType index_type) { RequireDefaultOps(); ASSERT_OK(NodeDefBuilder("myop", "ScatterUpdate") - .Input(FakeInput(DT_FLOAT_REF)) + .Input(FakeInput(variable_ref_type)) .Input(FakeInput(index_type)) - .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(RemoveRefType(variable_ref_type))) .Finalize(node_def())); ASSERT_OK(InitOp()); } }; +TEST_F(ScatterUpdateOpTest, Simple_StringType) { + MakeOp(DT_STRING_REF, DT_INT32); + AddInputFromArray(TensorShape({1}), {"Brain"}); + AddInputFromArray(TensorShape({1}), {0}); + AddInputFromArray(TensorShape({1}), {"TensorFlow"}); + ASSERT_OK(RunOpKernel()); + // Check the new state of the input + Tensor params_tensor = *mutable_input(0).tensor; + Tensor expected(allocator(), DT_STRING, TensorShape({1})); + test::FillValues(&expected, {"TensorFlow"}); + test::ExpectTensorEqual(expected, params_tensor); +} + +TEST_F(ScatterUpdateOpTest, Simple_BoolType) { + MakeOp(DT_BOOL_REF, DT_INT32); + AddInputFromArray(TensorShape({1}), {false}); + AddInputFromArray(TensorShape({1}), {0}); + AddInputFromArray(TensorShape({1}), {true}); + ASSERT_OK(RunOpKernel()); + // Check the new state of the input + Tensor params_tensor = *mutable_input(0).tensor; + Tensor expected(allocator(), DT_BOOL, TensorShape({1})); + test::FillValues(&expected, {true}); + test::ExpectTensorEqual(expected, params_tensor); +} + TEST_F(ScatterUpdateOpTest, Simple_TwoD32) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT_REF, DT_INT32); // Feed and run AddInputFromArray(TensorShape({5, 3}), @@ -69,7 +95,7 @@ TEST_F(ScatterUpdateOpTest, Simple_TwoD32) { } TEST_F(ScatterUpdateOpTest, Simple_Two64) { - MakeOp(DT_INT64); + MakeOp(DT_FLOAT_REF, DT_INT64); // Feed and run AddInputFromArray(TensorShape({5, 3}), @@ -88,7 +114,7 @@ TEST_F(ScatterUpdateOpTest, Simple_Two64) { } TEST_F(ScatterUpdateOpTest, Simple_ZeroD) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT_REF, DT_INT32); // Feed and run AddInputFromArray(TensorShape({5}), {0, 0, 0, 0, 0}); @@ -104,7 +130,7 @@ TEST_F(ScatterUpdateOpTest, Simple_ZeroD) { } TEST_F(ScatterUpdateOpTest, Simple_OneD) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT_REF, DT_INT32); // Feed and run AddInputFromArray(TensorShape({5}), {0, 0, 0, 0, 0}); @@ -120,7 +146,7 @@ TEST_F(ScatterUpdateOpTest, Simple_OneD) { } TEST_F(ScatterUpdateOpTest, HigherRank) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT_REF, DT_INT32); // Feed and run AddInputFromArray(TensorShape({8}), {0, 0, 0, 0, 0, 0, 0, 0}); @@ -136,7 +162,7 @@ TEST_F(ScatterUpdateOpTest, HigherRank) { } TEST_F(ScatterUpdateOpTest, Error_IndexOutOfRange) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT_REF, DT_INT32); // Feed and run AddInputFromArray(TensorShape({5, 3}), @@ -151,7 +177,7 @@ TEST_F(ScatterUpdateOpTest, Error_IndexOutOfRange) { } TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT_REF, DT_INT32); // Feed and run AddInputFromArray(TensorShape({2, 3}), {0, 0, 0, 0, 0, 0}); @@ -166,7 +192,7 @@ TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) { } TEST_F(ScatterUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT_REF, DT_INT32); // Feed and run AddInputFromArray(TensorShape({5, 3}), @@ -184,7 +210,7 @@ TEST_F(ScatterUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) { } TEST_F(ScatterUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) { - MakeOp(DT_INT32); + MakeOp(DT_FLOAT_REF, DT_INT32); // Feed and run AddInputFromArray(TensorShape({5, 3}), -- cgit v1.2.3