diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-10-25 06:38:41 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-10-25 07:49:52 -0700 |
commit | 2826f62516e68f37d1fad06e02f2a914ddd3b10f (patch) | |
tree | d764e2fa7a8ab4fac1e3209e31df59e071ff3574 /tensorflow/contrib/util | |
parent | 3662acf8247dda84dedda4f97fcbd07c6c1a4e10 (diff) |
Disabling conversion to memmapped format for constant types that can't be
mapped.
Change: 137155441
Diffstat (limited to 'tensorflow/contrib/util')
-rw-r--r-- | tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc | 18 | ||||
-rw-r--r-- | tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc | 53 |
2 files changed, 70 insertions, 1 deletions
diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc index 68cb20d0b5..1f079027ef 100644 --- a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc +++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc @@ -16,8 +16,10 @@ limitations under the License. #include <unordered_set> #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/immutable_constant_op.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -45,13 +47,27 @@ class NodeConverter { const DataType tensor_data_type = tensor_proto.dtype(); const TensorShapeProto tensor_shape = tensor_proto.tensor_shape(); + // Check that the tensor type is POD, only these types are supported for + // memmapping. + // DataType enum is explicitly converted to int to avoid errors with passing + // enum type are a parameter type to std::unordered_set. + static std::unordered_set<int> supported_types{ +#define TYPE_FOR_SET(type) static_cast<int>(DataTypeToEnum<type>::value), + TF_CALL_POD_TYPES(TYPE_FOR_SET) +#undef ADD_TYPE + }; + + if (supported_types.count(static_cast<int>(tensor_data_type)) == 0) { + return Status::OK(); + } + // Create Tensor from value and write it in memmapped format. Tensor parsed(tensor_proto.dtype()); if (!parsed.FromProto(cpu_allocator(), tensor_proto)) { return errors::InvalidArgument("Cannot parse tensor from proto: ", tensor_proto.DebugString()); } - if (parsed.TotalBytes() < min_conversion_size_bytes) { + if (parsed.TotalBytes() < static_cast<size_t>(min_conversion_size_bytes)) { return Status::OK(); } diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc index d64dca7b63..cb1e7577cf 100644 --- a/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc +++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc @@ -26,6 +26,15 @@ limitations under the License. namespace tensorflow { namespace { +bool GraphHasImmutableConstNodes(const GraphDef& graph_def) { + for (const auto& node : graph_def.node()) { + if (node.op() == "ImmutableConst") { + return true; + } + } + return false; +} + TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) { const string dir = testing::TmpDir(); const string filename_pb = io::JoinPath(dir, "graphdef.pb"); @@ -69,6 +78,7 @@ TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) { TF_ASSERT_OK(ReadBinaryProto( &memmapped_env, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, &loaded_graph_def)); + ASSERT_TRUE(GraphHasImmutableConstNodes(loaded_graph_def)); TF_ASSERT_OK(session->Create(loaded_graph_def)) << "Can't create test graph"; std::vector<Tensor> outputs; @@ -79,5 +89,48 @@ TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) { EXPECT_EQ(outputs.front().flat<float>()(2), 2.0f * 3.0f * kTensorHeight); } +TEST(ConvertGraphdefMemmappedFormatTest, NotSupportedTypesConvert) { + // Create a graph with strings. + const string dir = testing::TmpDir(); + const string filename_pb = io::JoinPath(dir, "string_graphdef.pb"); + + constexpr int kTensorWidth = 4000; + constexpr int kTensorHeight = 100; + const TensorShape kTestTensorShape({kTensorWidth, kTensorHeight}); + Tensor test_tensor1(DT_STRING, kTestTensorShape); + test::FillFn<string>(&test_tensor1, [](int) -> string { return "ABC"; }); + + Tensor test_tensor2(DT_STRING, kTestTensorShape); + test::FillFn<string>(&test_tensor2, [](int) -> string { return "XYZ"; }); + auto root = Scope::NewRootScope().ExitOnError(); + ops::Output m = ops::Add(root, test_tensor1, test_tensor2); + const string result_name = m.node()->name(); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + string graph_def_serialized; + graph_def.SerializeToString(&graph_def_serialized); + TF_ASSERT_OK( + WriteStringToFile(Env::Default(), filename_pb, graph_def_serialized)); + + const string filename_mmap = io::JoinPath(dir, "string_graphdef.mmap"); + TF_ASSERT_OK(ConvertConstantsToImmutable(filename_pb, filename_mmap, 1000)); + + // Create and initialize MemmappedEnv from the converted file. + MemmappedEnv memmapped_env(Env::Default()); + TF_ASSERT_OK(memmapped_env.InitializeFromFile(filename_mmap)); + + // Load the graph and run calculations. + SessionOptions session_options; + session_options.env = &memmapped_env; + std::unique_ptr<Session> session(NewSession(session_options)); + ASSERT_TRUE(session != nullptr) << "Failed to create session"; + GraphDef loaded_graph_def; + TF_ASSERT_OK(ReadBinaryProto( + &memmapped_env, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, + &loaded_graph_def)); + ASSERT_FALSE(GraphHasImmutableConstNodes(loaded_graph_def)); +} + } // namespace } // namespace tensorflow |