diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-04-18 08:08:23 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-04-18 09:12:35 -0700 |
commit | 3c280f6fa0e0fcaa3d2cee5d2d8bb7ab3e25319f (patch) | |
tree | fc67b33f56cc465486453b49789ea6a4d97b639d /tensorflow/contrib/util | |
parent | 517d3af445d85e2f6945fcdfc4fed4e46b1e0e35 (diff) |
Added a format for saving an inference graph that can be memmapped and an utility to convert a freezed graph into this format.
Change: 120128412
Diffstat (limited to 'tensorflow/contrib/util')
5 files changed, 403 insertions, 0 deletions
diff --git a/tensorflow/contrib/util/BUILD b/tensorflow/contrib/util/BUILD index c0be2b9c14..80495c9b8a 100644 --- a/tensorflow/contrib/util/BUILD +++ b/tensorflow/contrib/util/BUILD @@ -7,6 +7,47 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) +# Convertor of a frozen graph definition into the memmapped format. +cc_library( + name = "convert_graphdef_memmapped_format_lib", + srcs = ["convert_graphdef_memmapped_format_lib.cc"], + hdrs = ["convert_graphdef_memmapped_format_lib.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core/kernels:immutable_constant_op", + ], +) + +cc_binary( + name = "convert_graphdef_memmapped_format", + srcs = ["convert_graphdef_memmapped_format.cc"], + deps = [ + ":convert_graphdef_memmapped_format_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "convert_graphdef_memmapped_format_test", + srcs = ["convert_graphdef_memmapped_format_test.cc"], + deps = [ + ":convert_graphdef_memmapped_format_lib", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:tensor_testutil", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_binary( name = "inspect_checkpoint", srcs = ["inspect_checkpoint.cc"], diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc new file mode 100644 index 0000000000..811761efd6 --- /dev/null +++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc @@ -0,0 +1,88 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utility that converts a "frozen" inference graph (output from the +// freeze_graph utility) into a format in which large Const ops are converted to +// ImmutableConst ops which are memmapped when the graph is executed by +// TensorFlow. +// +// tensorflow/contrib/util/convert_graphdef_memmapped_format +// --in_graph=frozen.model --out_graph=memmapped.mmodel +// +// Parameters: +// in_graph - name of a file with a frozen GraphDef proto in binary format +// out_graph - name of the output file, where the graph in memmapped format will +// be saved. +// min_conversion_size_bytes - tensors with fewer than this many bytes of data +// will not be converted to ImmutableConst format, and kept in the graph. + +#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace { + +int ParseFlagsAndConvertGraph(int argc, char* argv[]) { + string in_graph = ""; + string out_graph = ""; + int min_conversion_tensor_size = 10000; + const bool parse_result = ParseFlags( + &argc, argv, + {// input graph + Flag("in_graph", &in_graph), + // output graph + Flag("out_graph", &out_graph), + // constants with tensors that have less than this number elements won't + // be converted into ImmutableConst (be memmapped). + Flag("min_conversion_tensor_size", &min_conversion_tensor_size)}); + // We need to call this to set up global state for TensorFlow. + port::InitMain(argv[0], &argc, &argv); + if (!parse_result) { + LOG(ERROR) << "Error parsing command-line flags."; + return -1; + } + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1]; + return -1; + } + if (in_graph.empty()) { + LOG(ERROR) << "in_graph graph can't be empty"; + return -1; + } + if (out_graph.empty()) { + LOG(ERROR) << "out_graph graph can't be empty"; + return -1; + } + if (min_conversion_tensor_size <= 0) { + LOG(ERROR) << "min_conversion_tensor_size must be > 0"; + return -1; + } + const auto result = ConvertConstantsToImmutable(in_graph, out_graph, + min_conversion_tensor_size); + if (!result.ok()) { + LOG(ERROR) << "Conversion failed " << result.error_message(); + return -1; + } + return 0; +} + +} // namespace +} // namespace tensorflow + +int main(int argc, char* argv[]) { + return tensorflow::ParseFlagsAndConvertGraph(argc, argv); +} diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc new file mode 100644 index 0000000000..7697a7f3d2 --- /dev/null +++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc @@ -0,0 +1,156 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h" + +#include <unordered_set> +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/kernels/immutable_constant_op.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/memmapped_file_system_writer.h" + +namespace tensorflow { +namespace { +class NodeConverter { + public: + // Converts one node. In-place updates node_def, writes the tensor in + // memmapped + // format, using writer. If the conversion has been done, convert_counter is + // increased. + Status ConvertConstantsToImmutable(NodeDef* node_def, + MemmappedFileSystemWriter* writer, + int* convert_counter, + int min_conversion_size_bytes) { + // Check the size. + const AttrValue& value = node_def->attr().at("value"); + const TensorProto& tensor_proto = value.tensor(); + + // Create copies of tensor datatype and shape, to put into the operator + // after + // the tensor is destroyed. + const DataType tensor_data_type = tensor_proto.dtype(); + const TensorShapeProto tensor_shape = tensor_proto.tensor_shape(); + + // Create Tensor from value and write it in memmapped format. + Tensor parsed(tensor_proto.dtype()); + if (!parsed.FromProto(cpu_allocator(), tensor_proto)) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + tensor_proto.DebugString()); + } + if (parsed.TotalBytes() < min_conversion_size_bytes) { + return Status::OK(); + } + + const string memmapped_region_name = + MemmappedFileSystem::kMemmappedPackagePrefix + + ConvertVariableNameToUniqueRegionName(node_def->name()); + + TF_RETURN_IF_ERROR(writer->SaveTensor(parsed, memmapped_region_name)); + + node_def->set_op("ImmutableConst"); + + // Erase all attributes and leave only attributes that can be understood by + // ImmutableConst. + auto* mutable_attr = node_def->mutable_attr(); + mutable_attr->clear(); + + { + AttrValue attr_value; + attr_value.set_type(tensor_data_type); + mutable_attr->insert({ImmutableConstantOp::kDTypeAttr, attr_value}); + } + { + AttrValue attr_value; + *(attr_value.mutable_shape()) = tensor_shape; + mutable_attr->insert({ImmutableConstantOp::kShapeAttr, attr_value}); + } + { + AttrValue attr_value; + attr_value.set_s(memmapped_region_name); + mutable_attr->insert( + {ImmutableConstantOp::kMemoryRegionNameAttr, attr_value}); + } + ++*convert_counter; + return Status::OK(); + } + + private: + string ConvertVariableNameToUniqueRegionName(const string& variable_name) { + string region_name = SanitizeVariableName(variable_name); + while (!used_names_.insert(region_name).second) { + region_name += '_'; + } + return region_name; + } + + static string SanitizeVariableName(const string& variable_name) { + string result; + for (char c : variable_name) { + if ((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || + (c >= '0' && c <= '9') || c == '_' || c == '.') { + result += c; + } else { + result += '_'; + } + } + return result; + } + std::unordered_set<string> used_names_; +}; + +} // namespace + +// Loads the graph, replaces operators, and writes it out. +Status ConvertConstantsToImmutable(const string& in_graph_filename, + const string& out_graph_filename, + int min_conversion_size_bytes) { + Env* default_env = Env::Default(); + GraphDef graph_def; + const auto load_graph_status = + ReadBinaryProto(default_env, in_graph_filename, &graph_def); + if (!load_graph_status.ok()) { + return tensorflow::errors::NotFound("Failed to load graph at '", + in_graph_filename, "' : ", + load_graph_status.error_message()); + } + + NodeConverter node_converter; + + // Create output writer. + MemmappedFileSystemWriter writer; + TF_RETURN_IF_ERROR(writer.InitializeToFile(default_env, out_graph_filename)); + + // Iterate over graph nodes, looking for Const and replacing it with + // ImmutableConst. + int convert_counter = 0; + for (int i = 0; i < graph_def.node_size(); ++i) { + const NodeDef& node = graph_def.node(i); + if (node.op() == "Const") { + // Try to convert to ImmutableConst + TF_RETURN_IF_ERROR(node_converter.ConvertConstantsToImmutable( + graph_def.mutable_node(i), &writer, &convert_counter, + min_conversion_size_bytes)); + } + } + TF_RETURN_IF_ERROR(writer.SaveProtobuf( + graph_def, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef)); + TF_RETURN_IF_ERROR(writer.FlushAndClose()); + LOG(INFO) << "Converted " << convert_counter << " nodes"; + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h new file mode 100644 index 0000000000..e6fd1bb132 --- /dev/null +++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h @@ -0,0 +1,34 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_ + +#include <string> + +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Converts a "frozen" inference graph (output from the freeze_graph utility) +// into a format in which large Const ops are converted to ImmutableConst ops +// which are memmapped when the graph is executed by TensorFlow. +Status ConvertConstantsToImmutable(const string& in_graph_filename, + const string& out_graph_filename, + int min_conversion_size_bytes); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_ diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc new file mode 100644 index 0000000000..7710fc38ef --- /dev/null +++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc @@ -0,0 +1,84 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/memmapped_file_system.h" + +namespace tensorflow { +namespace { + +TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) { + const string dir = testing::TmpDir(); + const string filename_pb = io::JoinPath(dir, "graphdef.pb"); + + // Create a simple graph and write it to filename_pb. + constexpr int kTensorWidth = 4000; + constexpr int kTensorHeight = 100; + const TensorShape kTestTensorShape({kTensorWidth, kTensorHeight}); + const TensorShape kTestTensorShapeT({kTensorHeight, kTensorWidth}); + + Tensor test_tensor1(DT_FLOAT, kTestTensorShape); + test::FillFn<float>(&test_tensor1, [](int) -> float { return 2.0; }); + + Tensor test_tensor2(DT_FLOAT, kTestTensorShapeT); + test::FillFn<float>(&test_tensor2, [](int) -> float { return 3.0; }); + + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* node1 = ops::Const(test_tensor1, b.opts()); + Node* node2 = ops::Const(test_tensor2, b.opts()); + const string result_name = ops::MatMul(node1, node2, b.opts())->name(); + + GraphDef graph_def; + TF_ASSERT_OK(b.ToGraphDef(&graph_def)); + string graph_def_serialized; + graph_def.SerializeToString(&graph_def_serialized); + TF_ASSERT_OK( + WriteStringToFile(Env::Default(), filename_pb, graph_def_serialized)); + + const string filename_mmap = io::JoinPath(dir, "graphdef.mmap"); + TF_ASSERT_OK(ConvertConstantsToImmutable(filename_pb, filename_mmap, 10000)); + + // Create and initialize MemmappedEnv from the converted file. + MemmappedEnv memmapped_env(Env::Default()); + TF_ASSERT_OK(memmapped_env.InitializeFromFile(filename_mmap)); + + // Load the graph and run calculations. + SessionOptions session_options; + session_options.env = &memmapped_env; + std::unique_ptr<Session> session(NewSession(session_options)); + ASSERT_TRUE(session != nullptr) << "Failed to create session"; + GraphDef loaded_graph_def; + TF_ASSERT_OK(ReadBinaryProto( + &memmapped_env, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, + &loaded_graph_def)); + + TF_ASSERT_OK(session->Create(loaded_graph_def)) << "Can't create test graph"; + std::vector<Tensor> outputs; + TF_ASSERT_OK(session->Run({}, {result_name + ":0"}, {}, &outputs)); + ASSERT_EQ(outputs.size(), 1); + EXPECT_EQ(outputs.front().flat<float>()(0), 2.0f * 3.0f * kTensorHeight); + EXPECT_EQ(outputs.front().flat<float>()(1), 2.0f * 3.0f * kTensorHeight); + EXPECT_EQ(outputs.front().flat<float>()(2), 2.0f * 3.0f * kTensorHeight); +} + +} // namespace +} // namespace tensorflow |