/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include #include "tensorflow/core/util/tensor_slice_reader.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/saved_tensor_slice_util.h" #include "tensorflow/core/util/tensor_slice_reader_cache.h" #include "tensorflow/core/util/tensor_slice_writer.h" namespace tensorflow { namespace checkpoint { namespace { // A simple test where we write a few tensor slices with a number of tensor // slice writers and then read them back from a tensor slice reader. // // We have a 2-d tensor of shape 4 X 5 that looks like this: // // 0 1 2 3 4 // 5 6 7 8 9 // 10 11 12 13 14 // 15 16 17 18 19 // // We assume this is a row-major matrix. void SimpleFloatHelper( const TensorSliceWriter::CreateBuilderFunction& create_function, TensorSliceReader::OpenTableFunction open_function) { const string fname_base = io::JoinPath(testing::TmpDir(), "float_checkpoint"); TensorShape shape({4, 5}); // File #0 contains a slice that is the top two rows: // // 0 1 2 3 4 // 5 6 7 8 9 // . . . . . // . . . . . { const string fname = strings::StrCat(fname_base, "_0"); TensorSliceWriter writer(fname, create_function); const float data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; TensorSlice slice = TensorSlice::ParseOrDie("0,2:-"); TF_CHECK_OK(writer.Add("test", shape, slice, data)); TF_CHECK_OK(writer.Finish()); } // File #1 contains two slices: // // slice #0 is the bottom left corner // . . . . . // . . . . . // 10 11 12 . . // 15 16 17 . . // // slice #1 is the bottom right corner // . . . . . // . . . . . // . . . . . // . . . 18 19 { const string fname = strings::StrCat(fname_base, "_1"); TensorSliceWriter writer(fname, create_function); // slice #0 { const float data[] = {10, 11, 12, 15, 16, 17}; TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3"); TF_CHECK_OK(writer.Add("test", shape, slice, data)); } // slice #1 { const float data[] = {18, 19}; TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2"); TF_CHECK_OK(writer.Add("test", shape, slice, data)); } TF_CHECK_OK(writer.Finish()); } // Notice that we leave a hole in the tensor // . . . . . // . . . . . // . . . (13) (14) // . . . . . // Now we need to read the tensor slices const string filepattern = strings::StrCat(fname_base, "_*"); TensorSliceReader reader(filepattern, std::move(open_function)); TF_EXPECT_OK(reader.status()); EXPECT_EQ(2, reader.num_files()); // We query some of the tensors { TensorShape shape; DataType type; EXPECT_TRUE(reader.HasTensor("test", &shape, &type)); EXPECT_EQ("[4,5]", shape.DebugString()); EXPECT_EQ(DT_FLOAT, type); EXPECT_FALSE(reader.HasTensor("don't exist", nullptr, nullptr)); } // Now we query some slices // // Slice #1 is an exact match // 0 1 2 3 4 // 5 6 7 8 9 // . . . . . // . . . . . { TensorSlice s = TensorSlice::ParseOrDie("0,2:-"); float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; float results[10]; EXPECT_TRUE(reader.CopySliceData("test", s, results)); for (int i = 0; i < 10; ++i) { EXPECT_EQ(expected[i], results[i]); } } // Slice #2 is a subset match // . . . . . // 5 6 7 8 9 // . . . . . // . . . . . { TensorSlice s = TensorSlice::ParseOrDie("1,1:-"); float expected[] = {5, 6, 7, 8, 9}; float results[5]; EXPECT_TRUE(reader.CopySliceData("test", s, results)); for (int i = 0; i < 5; ++i) { EXPECT_EQ(expected[i], results[i]); } } // Slice #4 includes the hole and so there is no match // . . . . . // . . 7 8 9 // . . 12 13 14 // . . . . . { TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3"); float results[6]; EXPECT_FALSE(reader.CopySliceData("test", s, results)); } } TEST(TensorSliceReaderTest, SimpleFloat) { SimpleFloatHelper(CreateTableTensorSliceBuilder, OpenTableTensorSliceReader); } template void SimpleIntXHelper( const TensorSliceWriter::CreateBuilderFunction& create_function, TensorSliceReader::OpenTableFunction open_function, const string& checkpoint_file) { const string fname_base = io::JoinPath(testing::TmpDir(), checkpoint_file); TensorShape shape({4, 5}); // File #0 contains a slice that is the top two rows: // // 0 1 2 3 4 // 5 6 7 8 9 // . . . . . // . . . . . { const string fname = strings::StrCat(fname_base, "_0"); TensorSliceWriter writer(fname, create_function); const T data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; TensorSlice slice = TensorSlice::ParseOrDie("0,2:-"); TF_CHECK_OK(writer.Add("test", shape, slice, data)); TF_CHECK_OK(writer.Finish()); } // File #1 contains two slices: // // slice #0 is the bottom left corner // . . . . . // . . . . . // 10 11 12 . . // 15 16 17 . . // // slice #1 is the bottom right corner // . . . . . // . . . . . // . . . . . // . . . 18 19 { const string fname = strings::StrCat(fname_base, "_1"); TensorSliceWriter writer(fname, create_function); // slice #0 { const T data[] = {10, 11, 12, 15, 16, 17}; TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3"); TF_CHECK_OK(writer.Add("test", shape, slice, data)); } // slice #1 { const T data[] = {18, 19}; TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2"); TF_CHECK_OK(writer.Add("test", shape, slice, data)); } TF_CHECK_OK(writer.Finish()); } // Notice that we leave a hole in the tensor // . . . . . // . . . . . // . . . (13) (14) // . . . . . // Now we need to read the tensor slices const string filepattern = strings::StrCat(fname_base, "_*"); TensorSliceReader reader(filepattern, std::move(open_function)); TF_EXPECT_OK(reader.status()); EXPECT_EQ(2, reader.num_files()); // We query some of the tensors { TensorShape shape; DataType type; EXPECT_TRUE(reader.HasTensor("test", &shape, &type)); EXPECT_EQ("[4,5]", shape.DebugString()); EXPECT_EQ(DataTypeToEnum::v(), type); EXPECT_FALSE(reader.HasTensor("don't exist", nullptr, nullptr)); } // Now we query some slices // // Slice #1 is an exact match // 0 1 2 3 4 // 5 6 7 8 9 // . . . . . // . . . . . { TensorSlice s = TensorSlice::ParseOrDie("0,2:-"); T expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; U results[10]; EXPECT_TRUE(reader.CopySliceData("test", s, results)); for (int i = 0; i < 10; ++i) { EXPECT_EQ(expected[i], results[i]); } } // Slice #2 is a subset match // . . . . . // 5 6 7 8 9 // . . . . . // . . . . . { TensorSlice s = TensorSlice::ParseOrDie("1,1:-"); T expected[] = {5, 6, 7, 8, 9}; U results[5]; EXPECT_TRUE(reader.CopySliceData("test", s, results)); for (int i = 0; i < 5; ++i) { EXPECT_EQ(expected[i], results[i]); } } // Slice #4 includes the hole and so there is no match // . . . . . // . . 7 8 9 // . . 12 13 14 // . . . . . { TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3"); U results[6]; EXPECT_FALSE(reader.CopySliceData("test", s, results)); } } #define TEST_SIMPLE_INT(TYPE, SAVED_TYPE) \ TEST(TensorSliceReaderTest, Simple##TYPE) { \ SimpleIntXHelper(CreateTableTensorSliceBuilder, \ OpenTableTensorSliceReader, \ #TYPE "_checkpoint"); \ } TEST_SIMPLE_INT(int32, int32) TEST_SIMPLE_INT(int64, int64) TEST_SIMPLE_INT(int16, int32) TEST_SIMPLE_INT(int8, int32) TEST_SIMPLE_INT(uint8, int32) void CachedTensorSliceReaderTesterHelper( const TensorSliceWriter::CreateBuilderFunction& create_function, const TensorSliceReader::OpenTableFunction& open_function) { const string fname_base = io::JoinPath(testing::TmpDir(), "float_checkpoint"); TensorShape shape({4, 5}); // File #0 contains a slice that is the top two rows: // // 0 1 2 3 4 // 5 6 7 8 9 // . . . . . // . . . . . { const string fname = strings::StrCat(fname_base, "_0"); TensorSliceWriter writer(fname, create_function); const float data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; TensorSlice slice = TensorSlice::ParseOrDie("0,2:-"); TF_CHECK_OK(writer.Add("test", shape, slice, data)); TF_CHECK_OK(writer.Finish()); } // File #1 contains two slices: // // slice #0 is the bottom left corner // . . . . . // . . . . . // 10 11 12 . . // 15 16 17 . . // // slice #1 is the bottom right corner // . . . . . // . . . . . // . . . . . // . . . 18 19 { const string fname = strings::StrCat(fname_base, "_1"); TensorSliceWriter writer(fname, create_function); // slice #0 { const float data[] = {10, 11, 12, 15, 16, 17}; TensorSlice slice = TensorSlice::ParseOrDie("2,2:0,3"); TF_CHECK_OK(writer.Add("test", shape, slice, data)); } // slice #1 { const float data[] = {18, 19}; TensorSlice slice = TensorSlice::ParseOrDie("3,1:3,2"); TF_CHECK_OK(writer.Add("test", shape, slice, data)); } TF_CHECK_OK(writer.Finish()); } // Notice that we leave a hole in the tensor // . . . . . // . . . . . // . . . (13) (14) // . . . . . // Now we need to read the tensor slices TensorSliceReaderCache cache; const string filepattern = strings::StrCat(fname_base, "_*"); const TensorSliceReader* reader = cache.GetReader( filepattern, open_function, TensorSliceReader::kLoadAllShards); EXPECT_TRUE(reader != nullptr); EXPECT_EQ(2, reader->num_files()); // We query some of the tensors { TensorShape shape; DataType type; EXPECT_TRUE(reader->HasTensor("test", &shape, &type)); EXPECT_EQ("[4,5]", shape.DebugString()); EXPECT_EQ(DT_FLOAT, type); EXPECT_FALSE(reader->HasTensor("don't exist", nullptr, nullptr)); } // Make sure the reader is cached. const TensorSliceReader* reader2 = cache.GetReader( filepattern, open_function, TensorSliceReader::kLoadAllShards); EXPECT_EQ(reader, reader2); reader = cache.GetReader("file_does_not_exist", open_function, TensorSliceReader::kLoadAllShards); EXPECT_TRUE(reader == nullptr); } TEST(CachedTensorSliceReaderTest, SimpleFloat) { CachedTensorSliceReaderTesterHelper(CreateTableTensorSliceBuilder, OpenTableTensorSliceReader); } static void VersionTest(const VersionDef& versions, const string& error) { const string path = io::JoinPath(testing::TmpDir(), "checkpoint"); { // Prepare an empty checkpoint with some version information SavedTensorSlices sts; *sts.mutable_meta()->mutable_versions() = versions; string contents; EXPECT_TRUE(sts.SerializeToString(&contents)); // Write it to disk TensorSliceWriter::Builder* builder; TF_ASSERT_OK(CreateTableTensorSliceBuilder(path, &builder)); builder->Add(kSavedTensorSlicesKey, contents); int64 file_size; TF_EXPECT_OK(builder->Finish(&file_size)); delete builder; } // Read it back in and verify that we get the expected error TensorSliceReader reader(path, OpenTableTensorSliceReader); EXPECT_TRUE(reader.status().code() == error::INVALID_ARGUMENT && str_util::StartsWith(reader.status().error_message(), error)) << "Expected error starting with '" << errors::InvalidArgument(error) << "', got '" << reader.status() << "'"; } TEST(CheckpointVersionTest, MinConsumer) { VersionDef versions; versions.set_producer(TF_CHECKPOINT_VERSION + 1); versions.set_min_consumer(TF_CHECKPOINT_VERSION + 1); VersionTest( versions, strings::StrCat("Checkpoint min consumer version ", TF_CHECKPOINT_VERSION + 1, " above current version ", TF_CHECKPOINT_VERSION, " for TensorFlow")); } TEST(CheckpointVersionTest, MinProducer) { VersionDef versions; versions.set_producer(TF_CHECKPOINT_VERSION_MIN_PRODUCER - 1); VersionTest(versions, strings::StrCat("Checkpoint producer version ", TF_CHECKPOINT_VERSION_MIN_PRODUCER - 1, " below min producer ", TF_CHECKPOINT_VERSION_MIN_PRODUCER, " supported by TensorFlow")); } TEST(CheckpointVersionTest, BadConsumer) { VersionDef versions; versions.set_producer(TF_CHECKPOINT_VERSION + 1); versions.add_bad_consumers(TF_CHECKPOINT_VERSION); VersionTest( versions, strings::StrCat( "Checkpoint disallows consumer version ", TF_CHECKPOINT_VERSION, ". Please upgrade TensorFlow: this version is likely buggy.")); } } // namespace } // namespace checkpoint } // namespace tensorflow