From ea92856334a74519055a4e22aba7ceee5dd039a2 Mon Sep 17 00:00:00 2001 From: Geoffrey Irving Date: Wed, 10 Feb 2016 13:52:47 -0800 Subject: Add versions to checkpoints Checkpoints now have a version scheme analogous to that for GraphDefs. We have no plans to ever deprecate a checkpoint version, but it's good to have the scheme in place in case we need to. Change: 114364388 --- tensorflow/core/BUILD | 2 + tensorflow/core/framework/versions.cc | 55 +++++++++++++++++++++ tensorflow/core/framework/versions.h | 38 +++++++++++++++ tensorflow/core/graph/graph_constructor.cc | 28 ++--------- tensorflow/core/public/version.h | 14 ++++++ tensorflow/core/util/saved_tensor_slice.proto | 5 ++ tensorflow/core/util/tensor_slice_reader.cc | 6 +++ tensorflow/core/util/tensor_slice_reader_test.cc | 61 ++++++++++++++++++++++++ tensorflow/core/util/tensor_slice_writer.cc | 7 ++- tensorflow/core/util/tensor_slice_writer_test.cc | 6 +++ 10 files changed, 198 insertions(+), 24 deletions(-) create mode 100644 tensorflow/core/framework/versions.cc create mode 100644 tensorflow/core/framework/versions.h diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 4a6f35dcc3..39052af115 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -806,6 +806,7 @@ tf_cuda_library( [ "framework/**/*.h", "framework/**/*.cc", + "public/version.h", "util/**/*.h", "util/**/*.cc", ], @@ -922,6 +923,7 @@ tf_cuda_library( "client/**/*.cc", "common_runtime/*.h", "common_runtime/*.cc", + "framework/versions.h", "graph/**/*.h", "graph/**/*.cc", "public/session.h", diff --git a/tensorflow/core/framework/versions.cc b/tensorflow/core/framework/versions.cc new file mode 100644 index 0000000000..1373d7140f --- /dev/null +++ b/tensorflow/core/framework/versions.cc @@ -0,0 +1,55 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/versions.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +Status CheckVersions(const VersionDef& versions, int consumer, int min_producer, + const char* upper_name, const char* lower_name) { + // Guard against the caller misordering the arguments + if (consumer < min_producer) { + return errors::Internal(upper_name, " version check has consumer ", + consumer, " < min_producer ", min_producer, "."); + } + + // Check versions + if (versions.producer() < min_producer) { + return errors::InvalidArgument( + upper_name, " producer version ", versions.producer(), + " below min producer ", min_producer, " supported by TensorFlow ", + TF_VERSION_STRING, ". Please regenerate your ", lower_name, "."); + } + if (versions.min_consumer() > consumer) { + return errors::InvalidArgument( + upper_name, " min consumer version ", versions.min_consumer(), + " above current version ", consumer, " for TensorFlow ", + TF_VERSION_STRING, ". Please upgrade TensorFlow."); + } + for (const int bad_consumer : versions.bad_consumers()) { + if (bad_consumer == consumer) { + return errors::InvalidArgument( + upper_name, " disallows consumer version ", bad_consumer, + ". Please upgrade TensorFlow: this version is likely buggy."); + } + } + + // All good! + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/framework/versions.h b/tensorflow/core/framework/versions.h new file mode 100644 index 0000000000..757481fdc7 --- /dev/null +++ b/tensorflow/core/framework/versions.h @@ -0,0 +1,38 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_VERSIONS_H_ +#define TENSORFLOW_FRAMEWORK_VERSIONS_H_ + +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Check whether data with the given versions is compatible with the given +// consumer and min producer. upper_name and lower_name are used to form +// error messages upon failure. Example usage: +// +// #include "tensorflow/core/public/version.h" +// +// TF_RETURN_ERROR(CheckVersions(versions, TF_GRAPH_DEF_VERSION, +// TF_GRAPH_DEF_VERSION_MIN_PRODUCER, +// "GraphDef", "graph")); +Status CheckVersions(const VersionDef& versions, int consumer, int min_producer, + const char* upper_name, const char* lower_name); + +} // namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_VERSIONS_H_ diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index e1d80782b6..6c5873d0c1 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/versions.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/optimizer_cse.h" @@ -46,29 +47,10 @@ class GraphConstructor { GraphConstructor(const GraphConstructorOptions& opts, const GraphDef* gdef, Graph* g, Status* status) : opts_(opts), gdef_(gdef), g_(g), status_(status) { - if (gdef->versions().producer() < TF_GRAPH_DEF_VERSION_MIN_PRODUCER) { - *status = errors::InvalidArgument( - "GraphDef producer version ", gdef->versions().producer(), - " below min producer ", TF_GRAPH_DEF_VERSION_MIN_PRODUCER, - " supported by TensorFlow ", TF_VERSION_STRING, - ". Please regenerate your graph."); - return; - } - if (gdef->versions().min_consumer() > TF_GRAPH_DEF_VERSION) { - *status = errors::InvalidArgument( - "GraphDef min consumer version ", gdef->versions().min_consumer(), - " above current version ", TF_GRAPH_DEF_VERSION, " for TensorFlow ", - TF_VERSION_STRING, ". Please upgrade TensorFlow."); - return; - } - for (const int bad_consumer : gdef->versions().bad_consumers()) { - if (bad_consumer == TF_GRAPH_DEF_VERSION) { - *status = errors::InvalidArgument( - "GraphDef disallows consumer version ", bad_consumer, - ". Please upgrade TensorFlow: this version is likely buggy."); - return; - } - } + *status = + CheckVersions(gdef->versions(), TF_GRAPH_DEF_VERSION, + TF_GRAPH_DEF_VERSION_MIN_PRODUCER, "GraphDef", "graph"); + if (!status->ok()) return; g->set_versions(gdef->versions()); BuildNodeIndex(); InitFromEdges(); diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 39cf5b06e9..4c0a9b8d70 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -66,4 +66,18 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 #define TF_GRAPH_DEF_VERSION 8 +// Checkpoint compatibility versions (the versions field in SavedSliceMeta). +// +// The checkpoint versions have the same semantics as GraphDef versions, but the +// numbering scheme is separate. We have no plans to ever deprecate checkpoint +// versions, but it's good to have this in place in case we ever need to. +// +// Version history: +// +// 0. Checkpoints saved before checkpoint versioning. +// 1. First real version (10feb2015). +#define TF_CHECKPOINT_VERSION_MIN_PRODUCER 0 +#define TF_CHECKPOINT_VERSION_MIN_CONSUMER 0 +#define TF_CHECKPOINT_VERSION 1 + #endif // TENSORFLOW_CORE_PUBLIC_VERSION_H_ diff --git a/tensorflow/core/util/saved_tensor_slice.proto b/tensorflow/core/util/saved_tensor_slice.proto index 75d47e0079..d1e4e85edc 100644 --- a/tensorflow/core/util/saved_tensor_slice.proto +++ b/tensorflow/core/util/saved_tensor_slice.proto @@ -27,6 +27,7 @@ import "tensorflow/core/framework/tensor_shape.proto"; import "tensorflow/core/framework/tensor_slice.proto"; import "tensorflow/core/framework/tensor.proto"; import "tensorflow/core/framework/types.proto"; +import "tensorflow/core/framework/versions.proto"; // Metadata describing the set of slices of the same tensor saved in a // checkpoint file. @@ -49,6 +50,10 @@ message SavedSliceMeta { message SavedTensorSliceMeta { // Each SavedSliceMeta describes the slices for one tensor. repeated SavedSliceMeta tensor = 1; + + // Compatibility version of this checkpoint. See core/public/version.h + // for version history. + VersionDef versions = 2; }; // Saved tensor slice: it stores the name of the tensors, the slice, and the diff --git a/tensorflow/core/util/tensor_slice_reader.cc b/tensorflow/core/util/tensor_slice_reader.cc index b8732b4773..480dc64747 100644 --- a/tensorflow/core/util/tensor_slice_reader.cc +++ b/tensorflow/core/util/tensor_slice_reader.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/util/tensor_slice_reader.h" #include +#include "tensorflow/core/framework/versions.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/lib/io/iterator.h" @@ -25,6 +26,7 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/public/version.h" #include "tensorflow/core/util/saved_tensor_slice_util.h" #include "tensorflow/core/util/tensor_slice_util.h" @@ -155,6 +157,10 @@ void TensorSliceReader::LoadShard(int shard) const { fname); return; } + status_ = CheckVersions(sts.meta().versions(), TF_CHECKPOINT_VERSION, + TF_CHECKPOINT_VERSION_MIN_PRODUCER, "Checkpoint", + "checkpoint"); + if (!status_.ok()) return; for (const SavedSliceMeta& ssm : sts.meta().tensor()) { TensorShape ssm_shape(ssm.shape()); for (const TensorSliceProto& tsp : ssm.slice()) { diff --git a/tensorflow/core/util/tensor_slice_reader_test.cc b/tensorflow/core/util/tensor_slice_reader_test.cc index 928ebb8140..62008952d9 100644 --- a/tensorflow/core/util/tensor_slice_reader_test.cc +++ b/tensorflow/core/util/tensor_slice_reader_test.cc @@ -20,10 +20,12 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/version.h" #include "tensorflow/core/util/saved_tensor_slice_util.h" #include "tensorflow/core/util/tensor_slice_reader_cache.h" #include "tensorflow/core/util/tensor_slice_writer.h" @@ -393,6 +395,65 @@ TEST(CachedTensorSliceReaderTest, SimpleFloat) { OpenTableTensorSliceReader); } +static void VersionTest(const VersionDef& versions, const string& error) { + const string path = io::JoinPath(testing::TmpDir(), "checkpoint"); + + { + // Prepare an empty checkpoint with some version information + SavedTensorSlices sts; + sts.mutable_meta()->mutable_versions()->CopyFrom(versions); + string contents; + EXPECT_TRUE(sts.SerializeToString(&contents)); + + // Write it to disk + TensorSliceWriter::Builder* builder; + TF_ASSERT_OK(CreateTableTensorSliceBuilder(path, &builder)); + builder->Add(kSavedTensorSlicesKey, contents); + int64 file_size; + builder->Finish(&file_size); + delete builder; + } + + // Read it back in and verify that we get the expected error + TensorSliceReader reader(path, OpenTableTensorSliceReader); + EXPECT_TRUE(reader.status().code() == error::INVALID_ARGUMENT && + StringPiece(reader.status().error_message()).starts_with(error)) + << "Expected error starting with '" << errors::InvalidArgument(error) + << "', got '" << reader.status() << "'"; +} + +TEST(CheckpointVersionTest, MinConsumer) { + VersionDef versions; + versions.set_producer(TF_CHECKPOINT_VERSION + 1); + versions.set_min_consumer(TF_CHECKPOINT_VERSION + 1); + VersionTest( + versions, + strings::StrCat("Checkpoint min consumer version ", + TF_CHECKPOINT_VERSION + 1, " above current version ", + TF_CHECKPOINT_VERSION, " for TensorFlow")); +} + +TEST(CheckpointVersionTest, MinProducer) { + VersionDef versions; + versions.set_producer(TF_CHECKPOINT_VERSION_MIN_PRODUCER - 1); + VersionTest(versions, strings::StrCat("Checkpoint producer version ", + TF_CHECKPOINT_VERSION_MIN_PRODUCER - 1, + " below min producer ", + TF_CHECKPOINT_VERSION_MIN_PRODUCER, + " supported by TensorFlow")); +} + +TEST(CheckpointVersionTest, BadConsumer) { + VersionDef versions; + versions.set_producer(TF_CHECKPOINT_VERSION + 1); + versions.add_bad_consumers(TF_CHECKPOINT_VERSION); + VersionTest( + versions, + strings::StrCat( + "Checkpoint disallows consumer version ", TF_CHECKPOINT_VERSION, + ". Please upgrade TensorFlow: this version is likely buggy.")); +} + } // namespace } // namespace checkpoint diff --git a/tensorflow/core/util/tensor_slice_writer.cc b/tensorflow/core/util/tensor_slice_writer.cc index 3a5328f2ac..53f9361973 100644 --- a/tensorflow/core/util/tensor_slice_writer.cc +++ b/tensorflow/core/util/tensor_slice_writer.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/version.h" #include "tensorflow/core/util/saved_tensor_slice_util.h" namespace tensorflow { @@ -81,7 +82,11 @@ TensorSliceWriter::TensorSliceWriter(const string& filename, : filename_(filename), create_builder_(create_builder), tmpname_(strings::StrCat(filename, ".tempstate", random::New64())), - slices_(0) {} + slices_(0) { + VersionDef* versions = sts_.mutable_meta()->mutable_versions(); + versions->set_producer(TF_CHECKPOINT_VERSION); + versions->set_min_consumer(TF_CHECKPOINT_VERSION_MIN_CONSUMER); +} Status TensorSliceWriter::Finish() { Builder* b; diff --git a/tensorflow/core/util/tensor_slice_writer_test.cc b/tensorflow/core/util/tensor_slice_writer_test.cc index 625ad1b212..b7996b223c 100644 --- a/tensorflow/core/util/tensor_slice_writer_test.cc +++ b/tensorflow/core/util/tensor_slice_writer_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/version.h" #include "tensorflow/core/util/saved_tensor_slice_util.h" #include "tensorflow/core/util/tensor_slice_reader.h" @@ -148,6 +149,11 @@ void TensorSliceWriteTestHelper::CheckEntries(const string& fname) { // We also expect two entries for the tensors EXPECT_TRUE(sts.has_meta()); EXPECT_EQ(4, sts.meta().tensor_size()); + // We should have written nontrivial version information + EXPECT_LT(0, TF_CHECKPOINT_VERSION); + EXPECT_EQ(TF_CHECKPOINT_VERSION, sts.meta().versions().producer()); + EXPECT_EQ(TF_CHECKPOINT_VERSION_MIN_CONSUMER, + sts.meta().versions().min_consumer()); // We don't expect any data in the first block. EXPECT_FALSE(sts.has_data()); // The two tensors should be stored in the same order as they are first -- cgit v1.2.3