aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/util
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-25 06:38:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-25 07:49:52 -0700
commit2826f62516e68f37d1fad06e02f2a914ddd3b10f (patch)
treed764e2fa7a8ab4fac1e3209e31df59e071ff3574 /tensorflow/contrib/util
parent3662acf8247dda84dedda4f97fcbd07c6c1a4e10 (diff)
Disabling conversion to memmapped format for constant types that can't be
mapped. Change: 137155441
Diffstat (limited to 'tensorflow/contrib/util')
-rw-r--r--tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc18
-rw-r--r--tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc53
2 files changed, 70 insertions, 1 deletions
diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc
index 68cb20d0b5..1f079027ef 100644
--- a/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc
+++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_lib.cc
@@ -16,8 +16,10 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/immutable_constant_op.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -45,13 +47,27 @@ class NodeConverter {
const DataType tensor_data_type = tensor_proto.dtype();
const TensorShapeProto tensor_shape = tensor_proto.tensor_shape();
+ // Check that the tensor type is POD, only these types are supported for
+ // memmapping.
+ // DataType enum is explicitly converted to int to avoid errors with passing
+ // enum type are a parameter type to std::unordered_set.
+ static std::unordered_set<int> supported_types{
+#define TYPE_FOR_SET(type) static_cast<int>(DataTypeToEnum<type>::value),
+ TF_CALL_POD_TYPES(TYPE_FOR_SET)
+#undef ADD_TYPE
+ };
+
+ if (supported_types.count(static_cast<int>(tensor_data_type)) == 0) {
+ return Status::OK();
+ }
+
// Create Tensor from value and write it in memmapped format.
Tensor parsed(tensor_proto.dtype());
if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
tensor_proto.DebugString());
}
- if (parsed.TotalBytes() < min_conversion_size_bytes) {
+ if (parsed.TotalBytes() < static_cast<size_t>(min_conversion_size_bytes)) {
return Status::OK();
}
diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc
index d64dca7b63..cb1e7577cf 100644
--- a/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc
+++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc
@@ -26,6 +26,15 @@ limitations under the License.
namespace tensorflow {
namespace {
+bool GraphHasImmutableConstNodes(const GraphDef& graph_def) {
+ for (const auto& node : graph_def.node()) {
+ if (node.op() == "ImmutableConst") {
+ return true;
+ }
+ }
+ return false;
+}
+
TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) {
const string dir = testing::TmpDir();
const string filename_pb = io::JoinPath(dir, "graphdef.pb");
@@ -69,6 +78,7 @@ TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) {
TF_ASSERT_OK(ReadBinaryProto(
&memmapped_env, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
&loaded_graph_def));
+ ASSERT_TRUE(GraphHasImmutableConstNodes(loaded_graph_def));
TF_ASSERT_OK(session->Create(loaded_graph_def)) << "Can't create test graph";
std::vector<Tensor> outputs;
@@ -79,5 +89,48 @@ TEST(ConvertGraphdefMemmappedFormatTest, ConvertModel) {
EXPECT_EQ(outputs.front().flat<float>()(2), 2.0f * 3.0f * kTensorHeight);
}
+TEST(ConvertGraphdefMemmappedFormatTest, NotSupportedTypesConvert) {
+ // Create a graph with strings.
+ const string dir = testing::TmpDir();
+ const string filename_pb = io::JoinPath(dir, "string_graphdef.pb");
+
+ constexpr int kTensorWidth = 4000;
+ constexpr int kTensorHeight = 100;
+ const TensorShape kTestTensorShape({kTensorWidth, kTensorHeight});
+ Tensor test_tensor1(DT_STRING, kTestTensorShape);
+ test::FillFn<string>(&test_tensor1, [](int) -> string { return "ABC"; });
+
+ Tensor test_tensor2(DT_STRING, kTestTensorShape);
+ test::FillFn<string>(&test_tensor2, [](int) -> string { return "XYZ"; });
+ auto root = Scope::NewRootScope().ExitOnError();
+ ops::Output m = ops::Add(root, test_tensor1, test_tensor2);
+ const string result_name = m.node()->name();
+
+ GraphDef graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+ string graph_def_serialized;
+ graph_def.SerializeToString(&graph_def_serialized);
+ TF_ASSERT_OK(
+ WriteStringToFile(Env::Default(), filename_pb, graph_def_serialized));
+
+ const string filename_mmap = io::JoinPath(dir, "string_graphdef.mmap");
+ TF_ASSERT_OK(ConvertConstantsToImmutable(filename_pb, filename_mmap, 1000));
+
+ // Create and initialize MemmappedEnv from the converted file.
+ MemmappedEnv memmapped_env(Env::Default());
+ TF_ASSERT_OK(memmapped_env.InitializeFromFile(filename_mmap));
+
+ // Load the graph and run calculations.
+ SessionOptions session_options;
+ session_options.env = &memmapped_env;
+ std::unique_ptr<Session> session(NewSession(session_options));
+ ASSERT_TRUE(session != nullptr) << "Failed to create session";
+ GraphDef loaded_graph_def;
+ TF_ASSERT_OK(ReadBinaryProto(
+ &memmapped_env, MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
+ &loaded_graph_def));
+ ASSERT_FALSE(GraphHasImmutableConstNodes(loaded_graph_def));
+}
+
} // namespace
} // namespace tensorflow