diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/tests')
2 files changed, 14 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc index 66cfed4ac2..e2a6f12481 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc @@ -166,7 +166,10 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) { GraphTransformationsSet graph_transformation_set; graph_transformation_set.Add(new toco::ResolveConstantConcatenation); EXPECT_THAT(model.GetArrayMap().size(), 5); - (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + bool modified; + ASSERT_TRUE((*graph_transformation_set.begin()) + ->Run(&model, /*op_index=*/0, &modified) + .ok()); EXPECT_THAT(model.GetArrayMap().size(), 1); auto& concatenated_array = (*model.GetArrayMap().begin()).second; @@ -185,7 +188,10 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis1) { GraphTransformationsSet graph_transformation_set; graph_transformation_set.Add(new toco::ResolveConstantConcatenation); EXPECT_THAT(model.GetArrayMap().size(), 5); - (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + bool modified; + ASSERT_TRUE((*graph_transformation_set.begin()) + ->Run(&model, /*op_index=*/0, &modified) + .ok()); EXPECT_THAT(model.GetArrayMap().size(), 1); auto& concatenated_array = (*model.GetArrayMap().begin()).second; @@ -204,7 +210,10 @@ TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis2) { GraphTransformationsSet graph_transformation_set; graph_transformation_set.Add(new toco::ResolveConstantConcatenation); EXPECT_THAT(model.GetArrayMap().size(), 5); - (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + bool modified; + ASSERT_TRUE((*graph_transformation_set.begin()) + ->Run(&model, /*op_index=*/0, &modified) + .ok()); EXPECT_THAT(model.GetArrayMap().size(), 1); auto& concatenated_array = (*model.GetArrayMap().begin()).second; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc index a53abc9941..57d85a0435 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc @@ -50,7 +50,8 @@ void RunResolveSum(const std::vector<float>& input, sum_op->inputs = {"input0", "input1"}; sum_op->outputs = {"output"}; model.operators.push_back(std::move(sum_op)); - ResolveConstantUnaryOperator().Run(&model, 0); + bool modified; + ASSERT_TRUE(ResolveConstantUnaryOperator().Run(&model, 0, &modified).ok()); EXPECT_EQ(model.GetArray("output").GetBuffer<ArrayDataType::kFloat>().data, expected_output); EXPECT_EQ(model.GetArray("output").shape().dims(), output_shape); |