diff options
author | Geoffrey Irving <geoffreyi@google.com> | 2016-02-10 13:52:47 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-02-10 16:02:39 -0800 |
commit | ea92856334a74519055a4e22aba7ceee5dd039a2 (patch) | |
tree | 5542d52c1be0a7802866ee83be77cdf82d4fc32d /tensorflow/core/util/tensor_slice_reader_test.cc | |
parent | b2757fb5797252bb1ef12bbf04be796cc83000cc (diff) |
Add versions to checkpoints
Checkpoints now have a version scheme analogous to that for GraphDefs. We have
no plans to ever deprecate a checkpoint version, but it's good to have the
scheme in place in case we need to.
Change: 114364388
Diffstat (limited to 'tensorflow/core/util/tensor_slice_reader_test.cc')
-rw-r--r-- | tensorflow/core/util/tensor_slice_reader_test.cc | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/core/util/tensor_slice_reader_test.cc b/tensorflow/core/util/tensor_slice_reader_test.cc index 928ebb8140..62008952d9 100644 --- a/tensorflow/core/util/tensor_slice_reader_test.cc +++ b/tensorflow/core/util/tensor_slice_reader_test.cc @@ -20,10 +20,12 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/version.h" #include "tensorflow/core/util/saved_tensor_slice_util.h" #include "tensorflow/core/util/tensor_slice_reader_cache.h" #include "tensorflow/core/util/tensor_slice_writer.h" @@ -393,6 +395,65 @@ TEST(CachedTensorSliceReaderTest, SimpleFloat) { OpenTableTensorSliceReader); } +static void VersionTest(const VersionDef& versions, const string& error) { + const string path = io::JoinPath(testing::TmpDir(), "checkpoint"); + + { + // Prepare an empty checkpoint with some version information + SavedTensorSlices sts; + sts.mutable_meta()->mutable_versions()->CopyFrom(versions); + string contents; + EXPECT_TRUE(sts.SerializeToString(&contents)); + + // Write it to disk + TensorSliceWriter::Builder* builder; + TF_ASSERT_OK(CreateTableTensorSliceBuilder(path, &builder)); + builder->Add(kSavedTensorSlicesKey, contents); + int64 file_size; + builder->Finish(&file_size); + delete builder; + } + + // Read it back in and verify that we get the expected error + TensorSliceReader reader(path, OpenTableTensorSliceReader); + EXPECT_TRUE(reader.status().code() == error::INVALID_ARGUMENT && + StringPiece(reader.status().error_message()).starts_with(error)) + << "Expected error starting with '" << errors::InvalidArgument(error) + << "', got '" << reader.status() << "'"; +} + +TEST(CheckpointVersionTest, MinConsumer) { + VersionDef versions; + versions.set_producer(TF_CHECKPOINT_VERSION + 1); + versions.set_min_consumer(TF_CHECKPOINT_VERSION + 1); + VersionTest( + versions, + strings::StrCat("Checkpoint min consumer version ", + TF_CHECKPOINT_VERSION + 1, " above current version ", + TF_CHECKPOINT_VERSION, " for TensorFlow")); +} + +TEST(CheckpointVersionTest, MinProducer) { + VersionDef versions; + versions.set_producer(TF_CHECKPOINT_VERSION_MIN_PRODUCER - 1); + VersionTest(versions, strings::StrCat("Checkpoint producer version ", + TF_CHECKPOINT_VERSION_MIN_PRODUCER - 1, + " below min producer ", + TF_CHECKPOINT_VERSION_MIN_PRODUCER, + " supported by TensorFlow")); +} + +TEST(CheckpointVersionTest, BadConsumer) { + VersionDef versions; + versions.set_producer(TF_CHECKPOINT_VERSION + 1); + versions.add_bad_consumers(TF_CHECKPOINT_VERSION); + VersionTest( + versions, + strings::StrCat( + "Checkpoint disallows consumer version ", TF_CHECKPOINT_VERSION, + ". Please upgrade TensorFlow: this version is likely buggy.")); +} + } // namespace } // namespace checkpoint |