aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2015-12-15 14:47:57 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-12-15 14:47:57 -0800
commit42e65468aaa4a5238ee899b2f8400fdd76bff0eb (patch)
treed835db432b515929fcc07660342fdf6e918ef2f2
parent21feee989ae36b079b5274071064f8af9d3018df (diff)
Add support to scatter_update to allow string and bool tensors.
Change: 110286577
-rw-r--r--tensorflow/core/kernels/scatter_op.cc6
-rw-r--r--tensorflow/core/kernels/scatter_op_test.cc50
2 files changed, 40 insertions, 16 deletions
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc
index 84dd625a9f..20fe1ed67a 100644
--- a/tensorflow/core/kernels/scatter_op.cc
+++ b/tensorflow/core/kernels/scatter_op.cc
@@ -147,10 +147,8 @@ class ScatterUpdateOp : public OpKernel {
#define REGISTER_SCATTER_UPDATE_INT32(type) REGISTER_SCATTER_UPDATE(type, int32)
#define REGISTER_SCATTER_UPDATE_INT64(type) REGISTER_SCATTER_UPDATE(type, int64)
-TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_INT32);
-TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_INT64);
-REGISTER_SCATTER_UPDATE_INT32(bool)
-REGISTER_SCATTER_UPDATE_INT64(bool)
+TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_INT32);
+TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_INT64);
#undef REGISTER_SCATTER_UPDATE_INT64
#undef REGISTER_SCATTER_UPDATE_INT32
diff --git a/tensorflow/core/kernels/scatter_op_test.cc b/tensorflow/core/kernels/scatter_op_test.cc
index 5d245bf9bb..d751570acb 100644
--- a/tensorflow/core/kernels/scatter_op_test.cc
+++ b/tensorflow/core/kernels/scatter_op_test.cc
@@ -38,19 +38,45 @@ namespace {
class ScatterUpdateOpTest : public OpsTestBase {
protected:
- void MakeOp(DataType index_type) {
+ void MakeOp(DataType variable_ref_type, DataType index_type) {
RequireDefaultOps();
ASSERT_OK(NodeDefBuilder("myop", "ScatterUpdate")
- .Input(FakeInput(DT_FLOAT_REF))
+ .Input(FakeInput(variable_ref_type))
.Input(FakeInput(index_type))
- .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(RemoveRefType(variable_ref_type)))
.Finalize(node_def()));
ASSERT_OK(InitOp());
}
};
+TEST_F(ScatterUpdateOpTest, Simple_StringType) {
+ MakeOp(DT_STRING_REF, DT_INT32);
+ AddInputFromArray<string>(TensorShape({1}), {"Brain"});
+ AddInputFromArray<int32>(TensorShape({1}), {0});
+ AddInputFromArray<string>(TensorShape({1}), {"TensorFlow"});
+ ASSERT_OK(RunOpKernel());
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_STRING, TensorShape({1}));
+ test::FillValues<string>(&expected, {"TensorFlow"});
+ test::ExpectTensorEqual<string>(expected, params_tensor);
+}
+
+TEST_F(ScatterUpdateOpTest, Simple_BoolType) {
+ MakeOp(DT_BOOL_REF, DT_INT32);
+ AddInputFromArray<bool>(TensorShape({1}), {false});
+ AddInputFromArray<int32>(TensorShape({1}), {0});
+ AddInputFromArray<bool>(TensorShape({1}), {true});
+ ASSERT_OK(RunOpKernel());
+ // Check the new state of the input
+ Tensor params_tensor = *mutable_input(0).tensor;
+ Tensor expected(allocator(), DT_BOOL, TensorShape({1}));
+ test::FillValues<bool>(&expected, {true});
+ test::ExpectTensorEqual<bool>(expected, params_tensor);
+}
+
TEST_F(ScatterUpdateOpTest, Simple_TwoD32) {
- MakeOp(DT_INT32);
+ MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
@@ -69,7 +95,7 @@ TEST_F(ScatterUpdateOpTest, Simple_TwoD32) {
}
TEST_F(ScatterUpdateOpTest, Simple_Two64) {
- MakeOp(DT_INT64);
+ MakeOp(DT_FLOAT_REF, DT_INT64);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
@@ -88,7 +114,7 @@ TEST_F(ScatterUpdateOpTest, Simple_Two64) {
}
TEST_F(ScatterUpdateOpTest, Simple_ZeroD) {
- MakeOp(DT_INT32);
+ MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
@@ -104,7 +130,7 @@ TEST_F(ScatterUpdateOpTest, Simple_ZeroD) {
}
TEST_F(ScatterUpdateOpTest, Simple_OneD) {
- MakeOp(DT_INT32);
+ MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
@@ -120,7 +146,7 @@ TEST_F(ScatterUpdateOpTest, Simple_OneD) {
}
TEST_F(ScatterUpdateOpTest, HigherRank) {
- MakeOp(DT_INT32);
+ MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({8}), {0, 0, 0, 0, 0, 0, 0, 0});
@@ -136,7 +162,7 @@ TEST_F(ScatterUpdateOpTest, HigherRank) {
}
TEST_F(ScatterUpdateOpTest, Error_IndexOutOfRange) {
- MakeOp(DT_INT32);
+ MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
@@ -151,7 +177,7 @@ TEST_F(ScatterUpdateOpTest, Error_IndexOutOfRange) {
}
TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) {
- MakeOp(DT_INT32);
+ MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({2, 3}), {0, 0, 0, 0, 0, 0});
@@ -166,7 +192,7 @@ TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) {
}
TEST_F(ScatterUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
- MakeOp(DT_INT32);
+ MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
@@ -184,7 +210,7 @@ TEST_F(ScatterUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
}
TEST_F(ScatterUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
- MakeOp(DT_INT32);
+ MakeOp(DT_FLOAT_REF, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),