aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/util
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-04-18 08:08:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-18 09:12:35 -0700
commit3c280f6fa0e0fcaa3d2cee5d2d8bb7ab3e25319f (patch)
treefc67b33f56cc465486453b49789ea6a4d97b639d /tensorflow/contrib/util
parent517d3af445d85e2f6945fcdfc4fed4e46b1e0e35 (diff)
Added a format for saving an inference graph that can be memmapped and an utility to convert a freezed graph into this format.
Change: 120128412
Diffstat (limited to 'tensorflow/contrib/util')
-rw-r--r--tensorflow/contrib/util/BUILD41
-rw-r--r--tensorflow/contrib/util/convert_graphdef_memmapped_format.cc88
-rw-r--r--tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc156
-rw-r--r--tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h34
-rw-r--r--tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc84
5 files changed, 403 insertions, 0 deletions
diff --git a/tensorflow/contrib/util/BUILD b/tensorflow/contrib/util/BUILD
index c0be2b9c14..80495c9b8a 100644
--- a/tensorflow/contrib/util/BUILD
+++ b/tensorflow/contrib/util/BUILD
@@ -7,6 +7,47 @@ exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
+# Convertor of a frozen graph definition into the memmapped format.
+cc_library(
+ name = "convert_graphdef_memmapped_format_lib",
+ srcs = ["convert_graphdef_memmapped_format_lib.cc"],
+ hdrs = ["convert_graphdef_memmapped_format_lib.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core/kernels:immutable_constant_op",
+ ],
+)
+
+cc_binary(
+ name = "convert_graphdef_memmapped_format",
+ srcs = ["convert_graphdef_memmapped_format.cc"],
+ deps = [
+ ":convert_graphdef_memmapped_format_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "convert_graphdef_memmapped_format_test",
+ srcs = ["convert_graphdef_memmapped_format_test.cc"],
+ deps = [
+ ":convert_graphdef_memmapped_format_lib",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:tensor_testutil",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
cc_binary(
name = "inspect_checkpoint",
srcs = ["inspect_checkpoint.cc"],
diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc
new file mode 100644
index 0000000000..811761efd6
--- /dev/null
+++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format.cc
@@ -0,0 +1,88 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Utility that converts a "frozen" inference graph (output from the
+// freeze_graph utility) into a format in which large Const ops are converted to
+// ImmutableConst ops which are memmapped when the graph is executed by
+// TensorFlow.
+//
+// tensorflow/contrib/util/convert_graphdef_memmapped_format
+// --in_graph=frozen.model --out_graph=memmapped.mmodel
+//
+// Parameters:
+// in_graph - name of a file with a frozen GraphDef proto in binary format
+// out_graph - name of the output file, where the graph in memmapped format will
+// be saved.
+// min_conversion_size_bytes - tensors with fewer than this many bytes of data
+// will not be converted to ImmutableConst format, and kept in the graph.
+
+#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace {
+
+int ParseFlagsAndConvertGraph(int argc, char* argv[]) {
+ string in_graph = "";
+ string out_graph = "";
+ int min_conversion_tensor_size = 10000;
+ const bool parse_result = ParseFlags(
+ &argc, argv,
+ {// input graph
+ Flag("in_graph", &in_graph),
+ // output graph
+ Flag("out_graph", &out_graph),
+ // constants with tensors that have less than this number elements won't
+ // be converted into ImmutableConst (be memmapped).
+ Flag("min_conversion_tensor_size", &min_conversion_tensor_size)});
+ // We need to call this to set up global state for TensorFlow.
+ port::InitMain(argv[0], &argc, &argv);
+ if (!parse_result) {
+ LOG(ERROR) << "Error parsing command-line flags.";
+ return -1;
+ }
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1];
+ return -1;
+ }
+ if (in_graph.empty()) {
+ LOG(ERROR) << "in_graph graph can't be empty";
+ return -1;
+ }
+ if (out_graph.empty()) {
+ LOG(ERROR) << "out_graph graph can't be empty";
+ return -1;
+ }
+ if (min_conversion_tensor_size <= 0) {
+ LOG(ERROR) << "min_conversion_tensor_size must be > 0";
+ return -1;
+ }
+ const auto result = ConvertConstantsToImmutable(in_graph, out_graph,
+ min_conversion_tensor_size);
+ if (!result.ok()) {
+ LOG(ERROR) << "Conversion failed " << result.error_message();
+ return -1;
+ }
+ return 0;
+}
+
+} // namespace
+} // namespace tensorflow
+
+int main(int argc, char* argv[]) {
+ return tensorflow::ParseFlagsAndConvertGraph(argc, argv);
+}
diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc
new file mode 100644
index 0000000000..7697a7f3d2
--- /dev/null
+++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc
@@ -0,0 +1,156 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h"
+
+#include <unordered_set>
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/kernels/immutable_constant_op.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/memmapped_file_system_writer.h"
+
+namespace tensorflow {
+namespace {
+class NodeConverter {
+ public:
+ // Converts one node. In-place updates node_def, writes the tensor in
+ // memmapped
+ // format, using writer. If the conversion has been done, convert_counter is
+ // increased.
+ Status ConvertConstantsToImmutable(NodeDef* node_def,
+ MemmappedFileSystemWriter* writer,
+ int* convert_counter,
+ int min_conversion_size_bytes) {
+ // Check the size.
+ const AttrValue& value = node_def->attr().at("value");
+ const TensorProto& tensor_proto = value.tensor();
+
+ // Create copies of tensor datatype and shape, to put into the operator
+ // after
+ // the tensor is destroyed.
+ const DataType tensor_data_type = tensor_proto.dtype();
+ const TensorShapeProto tensor_shape = tensor_proto.tensor_shape();
+
+ // Create Tensor from value and write it in memmapped format.
+ Tensor parsed(tensor_proto.dtype());
+ if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ tensor_proto.DebugString());
+ }
+ if (parsed.TotalBytes() < min_conversion_size_bytes) {
+ return Status::OK();
+ }
+
+ const string memmapped_region_name =
+ MemmappedFileSystem::kMemmappedPackagePrefix +
+ ConvertVariableNameToUniqueRegionName(node_def->name());
+
+ TF_RETURN_IF_ERROR(writer->SaveTensor(parsed, memmapped_region_name));
+
+ node_def->set_op("ImmutableConst");
+
+ // Erase all attributes and leave only attributes that can be understood by
+ // ImmutableConst.
+ auto* mutable_attr = node_def->mutable_attr();
+ mutable_attr->clear();
+
+ {
+ AttrValue attr_value;
+ attr_value.set_type(tensor_data_type);
+ mutable_attr->insert({ImmutableConstantOp::kDTypeAttr, attr_value});
+ }
+ {
+ AttrValue attr_value;
+ *(attr_value.mutable_shape()) = tensor_shape;
+ mutable_attr->insert({ImmutableConstantOp::kShapeAttr, attr_value});
+ }
+ {
+ AttrValue attr_value;
+ attr_value.set_s(memmapped_region_name);
+ mutable_attr->insert(
+ {ImmutableConstantOp::kMemoryRegionNameAttr, attr_value});
+ }
+ ++*convert_counter;
+ return Status::OK();
+ }
+
+ private:
+ string ConvertVariableNameToUniqueRegionName(const string& variable_name) {
+ string region_name = SanitizeVariableName(variable_name);
+ while (!used_names_.insert(region_name).second) {
+ region_name += '_';
+ }
+ return region_name;
+ }
+
+ static string SanitizeVariableName(const string& variable_name) {
+ string result;
+ for (char c : variable_name) {
+ if ((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') ||
+ (c >= '0' && c <= '9') || c == '_' || c == '.') {
+ result += c;
+ } else {
+ result += '_';
+ }
+ }
+ return result;
+ }
+ std::unordered_set<string> used_names_;
+};
+
+} // namespace
+
+// Loads the graph, replaces operators, and writes it out.
+Status ConvertConstantsToImmutable(const string& in_graph_filename,
+ const string& out_graph_filename,
+ int min_conversion_size_bytes) {
+ Env* default_env = Env::Default();
+ GraphDef graph_def;
+ const auto load_graph_status =
+ ReadBinaryProto(default_env, in_graph_filename, &graph_def);
+ if (!load_graph_status.ok()) {
+ return tensorflow::errors::NotFound("Failed to load graph at '",
+ in_graph_filename, "' : ",
+ load_graph_status.error_message());
+ }
+
+ NodeConverter node_converter;
+
+ // Create output writer.
+ MemmappedFileSystemWriter writer;
+ TF_RETURN_IF_ERROR(writer.InitializeToFile(default_env, out_graph_filename));
+
+ // Iterate over graph nodes, looking for Const and replacing it with
+ // ImmutableConst.
+ int convert_counter = 0;
+ for (int i = 0; i < graph_def.node_size(); ++i) {
+ const NodeDef& node = graph_def.node(i);
+ if (node.op() == "Const") {
+ // Try to convert to ImmutableConst
+ TF_RETURN_IF_ERROR(node_converter.ConvertConstantsToImmutable(
+ graph_def.mutable_node(i), &writer, &convert_counter,
+ min_conversion_size_bytes));
+ }
+ }
+ TF_RETURN_IF_ERROR(writer.SaveProtobuf(
+ graph_def, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef));
+ TF_RETURN_IF_ERROR(writer.FlushAndClose());
+ LOG(INFO) << "Converted " << convert_counter << " nodes";
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h
new file mode 100644
index 0000000000..e6fd1bb132
--- /dev/null
+++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h
@@ -0,0 +1,34 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_
+
+#include <string>
+
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// Converts a "frozen" inference graph (output from the freeze_graph utility)
+// into a format in which large Const ops are converted to ImmutableConst ops
+// which are memmapped when the graph is executed by TensorFlow.
+Status ConvertConstantsToImmutable(const string& in_graph_filename,
+ const string& out_graph_filename,
+ int min_conversion_size_bytes);
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_UTIL_CONVERT_GRAPHDEF_MEMMAPPED_FORMAT_LIB_H_
diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc
new file mode 100644
index 0000000000..7710fc38ef
--- /dev/null
+++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc
@@ -0,0 +1,84 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/util/memmapped_file_system.h"
+
+namespace tensorflow {
+namespace {
+
+TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) {
+ const string dir = testing::TmpDir();
+ const string filename_pb = io::JoinPath(dir, "graphdef.pb");
+
+ // Create a simple graph and write it to filename_pb.
+ constexpr int kTensorWidth = 4000;
+ constexpr int kTensorHeight = 100;
+ const TensorShape kTestTensorShape({kTensorWidth, kTensorHeight});
+ const TensorShape kTestTensorShapeT({kTensorHeight, kTensorWidth});
+
+ Tensor test_tensor1(DT_FLOAT, kTestTensorShape);
+ test::FillFn<float>(&test_tensor1, [](int) -> float { return 2.0; });
+
+ Tensor test_tensor2(DT_FLOAT, kTestTensorShapeT);
+ test::FillFn<float>(&test_tensor2, [](int) -> float { return 3.0; });
+
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* node1 = ops::Const(test_tensor1, b.opts());
+ Node* node2 = ops::Const(test_tensor2, b.opts());
+ const string result_name = ops::MatMul(node1, node2, b.opts())->name();
+
+ GraphDef graph_def;
+ TF_ASSERT_OK(b.ToGraphDef(&graph_def));
+ string graph_def_serialized;
+ graph_def.SerializeToString(&graph_def_serialized);
+ TF_ASSERT_OK(
+ WriteStringToFile(Env::Default(), filename_pb, graph_def_serialized));
+
+ const string filename_mmap = io::JoinPath(dir, "graphdef.mmap");
+ TF_ASSERT_OK(ConvertConstantsToImmutable(filename_pb, filename_mmap, 10000));
+
+ // Create and initialize MemmappedEnv from the converted file.
+ MemmappedEnv memmapped_env(Env::Default());
+ TF_ASSERT_OK(memmapped_env.InitializeFromFile(filename_mmap));
+
+ // Load the graph and run calculations.
+ SessionOptions session_options;
+ session_options.env = &memmapped_env;
+ std::unique_ptr<Session> session(NewSession(session_options));
+ ASSERT_TRUE(session != nullptr) << "Failed to create session";
+ GraphDef loaded_graph_def;
+ TF_ASSERT_OK(ReadBinaryProto(
+ &memmapped_env, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
+ &loaded_graph_def));
+
+ TF_ASSERT_OK(session->Create(loaded_graph_def)) << "Can't create test graph";
+ std::vector<Tensor> outputs;
+ TF_ASSERT_OK(session->Run({}, {result_name + ":0"}, {}, &outputs));
+ ASSERT_EQ(outputs.size(), 1);
+ EXPECT_EQ(outputs.front().flat<float>()(0), 2.0f * 3.0f * kTensorHeight);
+ EXPECT_EQ(outputs.front().flat<float>()(1), 2.0f * 3.0f * kTensorHeight);
+ EXPECT_EQ(outputs.front().flat<float>()(2), 2.0f * 3.0f * kTensorHeight);
+}
+
+} // namespace
+} // namespace tensorflow