diff options
author | 2016-12-21 20:50:31 -0800 | |
---|---|---|
committer | 2016-12-21 21:05:54 -0800 | |
commit | 0f0e29e7ba06c50fe4a1a7718e63731b96563a8d (patch) | |
tree | bfbb7dcc4dba793bfef84d0b1681f20ad4eeff2f /tensorflow/tools/graph_transforms/transform_graph_test.cc | |
parent | be60473c88175dbc9359c9d1bbb384518757ee81 (diff) |
Create Graph Transform Tool for rewriting model files.
Change: 142729497
Diffstat (limited to 'tensorflow/tools/graph_transforms/transform_graph_test.cc')
-rw-r--r-- | tensorflow/tools/graph_transforms/transform_graph_test.cc | 228 |
1 files changed, 228 insertions, 0 deletions
diff --git a/tensorflow/tools/graph_transforms/transform_graph_test.cc b/tensorflow/tools/graph_transforms/transform_graph_test.cc new file mode 100644 index 0000000000..12df4051fb --- /dev/null +++ b/tensorflow/tools/graph_transforms/transform_graph_test.cc @@ -0,0 +1,228 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/graph_transforms/transform_graph.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declared here so we don't have to expose it in the public header. +Status ShouldIgnoreErrors(const TransformFuncParameters& transform_params, + bool* ignore_errors); + +namespace { +Status test_empty_graph_transform(const GraphDef& graph_def, + const TransformFuncContext& context, + GraphDef* result) { + result->Clear(); + return Status::OK(); +} +} // namespace + +REGISTER_GRAPH_TRANSFORM("test_empty_graph_transform", + test_empty_graph_transform); + +class TransformGraphTest : public ::testing::Test { + protected: + void TestConstantFolding() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const int width = 100; + + Tensor a_data(DT_FLOAT, TensorShape({width})); + test::FillIota<float>(&a_data, 1.0f); + Output a_const = + Const(root.WithOpName("a_expect_removed"), Input::Initializer(a_data)); + + Tensor b_data(DT_FLOAT, TensorShape({width})); + test::FillIota<float>(&b_data, 1.0f); + Output b_const = + Const(root.WithOpName("b_expect_removed"), Input::Initializer(b_data)); + + Output add = Add(root.WithOpName("add_expect_removed"), a_const, b_const); + + Output placeholder = + Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT); + + Output mul = + Mul(root.WithOpName("output_expect_remains"), add, placeholder); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + string graph_def_serialized; + graph_def.SerializeToString(&graph_def_serialized); + const string dir = testing::TmpDir(); + const string in_filename_pb = io::JoinPath(dir, "in_graphdef.pb"); + const string out_filename_pb = io::JoinPath(dir, "out_graphdef.pb"); + TF_ASSERT_OK(WriteStringToFile(Env::Default(), in_filename_pb, + graph_def_serialized)); + + std::vector<string> args = {"some_binary", + "--in_graph=" + in_filename_pb, + "--out_graph=" + out_filename_pb, + "--inputs=placeholder_expect_remains", + "--outputs=output_expect_remains", + "--transforms=fold_constants"}; + const int argc = 6; + EXPECT_EQ(argc, args.size()); + char* argv[argc]; + std::vector<char*> char_strings; + for (int i = 0; i < argc; ++i) { + string arg = args[i]; + char* char_string = new char[arg.size() + 1]; + std::copy_n(arg.c_str(), arg.size() + 1, char_string); + argv[i] = char_string; + char_strings.push_back(char_string); + } + ParseFlagsAndTransformGraph(argc, argv, false); + for (char* char_string : char_strings) { + delete[] char_string; + } + + GraphDef out_graph_def; + TF_EXPECT_OK( + ReadBinaryProto(Env::Default(), out_filename_pb, &out_graph_def)); + + std::map<string, const NodeDef*> out_node_map; + graph_transforms::MapNamesToNodes(out_graph_def, &out_node_map); + + for (const NodeDef& node : out_graph_def.node()) { + const StringPiece name(node.name()); + const int occurrence_count = out_node_map.count(node.name()); + if (name.ends_with("expect_removed")) { + EXPECT_EQ(0, occurrence_count) << "node.name()=" << node.name(); + } + if (name.ends_with("expect_remains")) { + EXPECT_EQ(1, occurrence_count) << "node.name()=" << node.name(); + } + } + } + + void TestTransformRegistration() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Output placeholder = + Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT); + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + EXPECT_EQ(1, graph_def.node().size()); + TF_ASSERT_OK(TransformGraph({}, {}, {{"test_empty_graph_transform", {}}}, + &graph_def)); + EXPECT_EQ(0, graph_def.node().size()); + + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + Status no_such_status = + TransformGraph({}, {}, {{"test_no_such_transform", {}}}, &graph_def); + EXPECT_TRUE( + StringPiece(no_such_status.ToString()).contains("not recognized")); + } + + void TestParseTransformParameters() { + TransformParameters params_list; + + ParseTransformParameters("foo", ¶ms_list); + EXPECT_EQ(1, params_list.size()); + EXPECT_EQ("foo", params_list[0].first); + EXPECT_TRUE(params_list[0].second.empty()); + + ParseTransformParameters("foo bar", ¶ms_list); + EXPECT_EQ(2, params_list.size()); + EXPECT_EQ("foo", params_list[0].first); + EXPECT_TRUE(params_list[0].second.empty()); + EXPECT_EQ("bar", params_list[1].first); + EXPECT_TRUE(params_list[1].second.empty()); + + ParseTransformParameters("foo() bar()", ¶ms_list); + EXPECT_EQ(2, params_list.size()); + EXPECT_EQ("foo", params_list[0].first); + EXPECT_TRUE(params_list[0].second.empty()); + EXPECT_EQ("bar", params_list[1].first); + EXPECT_TRUE(params_list[1].second.empty()); + + ParseTransformParameters("foo(bob_something=sue)", ¶ms_list); + EXPECT_EQ(1, params_list.size()); + EXPECT_EQ("foo", params_list[0].first); + EXPECT_EQ(1, params_list[0].second.count("bob_something")); + EXPECT_EQ(1, params_list[0].second["bob_something"].size()); + EXPECT_EQ("sue", params_list[0].second["bob_something"][0]); + + ParseTransformParameters("bar(a=1, b=2, a=3)", ¶ms_list); + EXPECT_EQ(1, params_list.size()); + EXPECT_EQ("bar", params_list[0].first); + EXPECT_EQ(1, params_list[0].second.count("a")); + EXPECT_EQ(2, params_list[0].second["a"].size()); + EXPECT_EQ("1", params_list[0].second["a"][0]); + EXPECT_EQ("3", params_list[0].second["a"][1]); + EXPECT_EQ(1, params_list[0].second.count("b")); + EXPECT_EQ(1, params_list[0].second["b"].size()); + EXPECT_EQ("2", params_list[0].second["b"][0]); + + ParseTransformParameters("bar(a=\"1\", b=\"1,2,3\", a=3)", ¶ms_list); + EXPECT_EQ(1, params_list.size()); + EXPECT_EQ("bar", params_list[0].first); + EXPECT_EQ(1, params_list[0].second.count("a")); + EXPECT_EQ(2, params_list[0].second["a"].size()); + EXPECT_EQ("1", params_list[0].second["a"][0]); + EXPECT_EQ("3", params_list[0].second["a"][1]); + EXPECT_EQ(1, params_list[0].second.count("b")); + EXPECT_EQ(1, params_list[0].second["b"].size()); + EXPECT_EQ("1,2,3", params_list[0].second["b"][0]); + } + + void TestShouldIgnoreErrors() { + bool ignore_errors; + TF_EXPECT_OK( + ShouldIgnoreErrors({{"ignore_errors", {"true"}}}, &ignore_errors)); + EXPECT_TRUE(ignore_errors); + + TF_EXPECT_OK( + ShouldIgnoreErrors({{"ignore_errors", {"false"}}}, &ignore_errors)); + EXPECT_FALSE(ignore_errors); + + TF_EXPECT_OK(ShouldIgnoreErrors({}, &ignore_errors)); + EXPECT_FALSE(ignore_errors); + + EXPECT_FALSE( + ShouldIgnoreErrors({{"ignore_errors", {"foo"}}}, &ignore_errors).ok()); + } +}; + +TEST_F(TransformGraphTest, TestConstantFolding) { TestConstantFolding(); } + +TEST_F(TransformGraphTest, TestTransformRegistration) { + TestTransformRegistration(); +} + +TEST_F(TransformGraphTest, TestParseTransformParameters) { + TestParseTransformParameters(); +} + +TEST_F(TransformGraphTest, TestShouldIgnoreErrors) { TestShouldIgnoreErrors(); } + +} // namespace graph_transforms +} // namespace tensorflow |