aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sukriti Ramesh <sukritiramesh@google.com>2016-09-27 15:37:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-27 16:49:44 -0700
commit69cfb3b2e71c92b4837ef265b26680227022b861 (patch)
tree5e347aa316bd523863acba6f38026b08732c9783
parent3d8c93548d933860ab4b1dc2caa84031f4c4878b (diff)
Use a sharded saver in SavedModel.
Change: 134471468
-rw-r--r--tensorflow/cc/saved_model/BUILD1
-rw-r--r--tensorflow/cc/saved_model/constants.h4
-rw-r--r--tensorflow/cc/saved_model/loader.cc18
-rw-r--r--tensorflow/cc/saved_model/loader_test.cc14
-rw-r--r--tensorflow/cc/saved_model/testdata/half_plus_two_sharded/saved_model.pbbin0 -> 4300 bytes
-rw-r--r--tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/checkpoint2
-rw-r--r--tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/saved_model_variables-00000-of-00001bin0 -> 169 bytes
-rw-r--r--tensorflow/python/saved_model/builder.py4
-rw-r--r--tensorflow/python/saved_model/constants.py1
-rw-r--r--tensorflow/python/saved_model/loader.py2
10 files changed, 36 insertions, 10 deletions
diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD
index f0a72a60d6..cf677ac4e3 100644
--- a/tensorflow/cc/saved_model/BUILD
+++ b/tensorflow/cc/saved_model/BUILD
@@ -52,6 +52,7 @@ filegroup(
srcs = glob([
"testdata/half_plus_two/**",
"testdata/half_plus_two_pbtxt/**",
+ "testdata/half_plus_two_sharded/**",
]),
)
diff --git a/tensorflow/cc/saved_model/constants.h b/tensorflow/cc/saved_model/constants.h
index 305a6e1793..0972ca5e60 100644
--- a/tensorflow/cc/saved_model/constants.h
+++ b/tensorflow/cc/saved_model/constants.h
@@ -30,6 +30,10 @@ constexpr char kSavedModelVariablesDirectory[] = "variables";
// SavedModel variables filename.
constexpr char kSavedModelVariablesFilename[] = "saved_model_variables";
+// SavedModel sharded variables filename.
+constexpr char kSavedModelVariablesShardedFilename[] =
+ "saved_model_variables-\?\?\?\?\?-of-\?\?\?\?\?";
+
// Commonly used tags.
constexpr char kSavedModelTagServe[] = "serve";
constexpr char kSavedModelTagTrain[] = "train";
diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc
index 2863bc99fa..a88cae6c93 100644
--- a/tensorflow/cc/saved_model/loader.cc
+++ b/tensorflow/cc/saved_model/loader.cc
@@ -86,14 +86,18 @@ Status Restore(const RunOptions& run_options, const string& export_dir,
const StringPiece restore_op_name,
const StringPiece variable_filename_const_op_name,
Session* session) {
- const string variables_path = io::JoinPath(
- export_dir, kSavedModelVariablesDirectory, kSavedModelVariablesFilename);
- if (!Env::Default()->FileExists(variables_path)) {
- return Status(
- error::Code::NOT_FOUND,
- "Could not find checkpointed variables at: " + variables_path);
+ // Find path to variables to be restored in export directory.
+ string variables_path =
+ io::JoinPath(export_dir, kSavedModelVariablesDirectory);
+ const string unsharded_variables_path =
+ io::JoinPath(variables_path, kSavedModelVariablesFilename);
+ if (Env::Default()->FileExists(unsharded_variables_path)) {
+ variables_path = unsharded_variables_path;
+ } else {
+ const string sharded_variables_path =
+ io::JoinPath(variables_path, kSavedModelVariablesShardedFilename);
+ variables_path = sharded_variables_path;
}
-
// Add variables to the graph.
Tensor variables_path_tensor(DT_STRING, TensorShape({}));
variables_path_tensor.scalar<string>()() = variables_path;
diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc
index cef28d96f9..b2d55a9ade 100644
--- a/tensorflow/cc/saved_model/loader_test.cc
+++ b/tensorflow/cc/saved_model/loader_test.cc
@@ -28,6 +28,8 @@ namespace {
constexpr char kTestDataPb[] = "cc/saved_model/testdata/half_plus_two";
constexpr char kTestDataPbTxt[] = "cc/saved_model/testdata/half_plus_two_pbtxt";
+constexpr char kTestDataSharded[] =
+ "cc/saved_model/testdata/half_plus_two_sharded";
class LoaderTest : public ::testing::Test {
protected:
@@ -110,6 +112,18 @@ TEST_F(LoaderTest, PbtxtFormat) {
CheckSavedModelBundle(bundle);
}
+TEST_F(LoaderTest, ShardedVariables) {
+ SavedModelBundle bundle;
+ SessionOptions session_options;
+ RunOptions run_options;
+
+ const string export_dir =
+ io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
+ TF_ASSERT_OK(LoadSavedModel(export_dir, {kSavedModelTagServe},
+ session_options, run_options, &bundle));
+ CheckSavedModelBundle(bundle);
+}
+
TEST_F(LoaderTest, InvalidExportPath) {
SavedModelBundle bundle;
RunOptions run_options;
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/saved_model.pb b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/saved_model.pb
new file mode 100644
index 0000000000..0a87f3306f
--- /dev/null
+++ b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/saved_model.pb
Binary files differ
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/checkpoint b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/checkpoint
new file mode 100644
index 0000000000..b5b6425e93
--- /dev/null
+++ b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/checkpoint
@@ -0,0 +1,2 @@
+model_checkpoint_path: "/tmp/saved_model/half_plus_two/variables/saved_model_variables-?????-of-00001"
+all_model_checkpoint_paths: "/tmp/saved_model/half_plus_two/variables/saved_model_variables-?????-of-00001"
diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/saved_model_variables-00000-of-00001 b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/saved_model_variables-00000-of-00001
new file mode 100644
index 0000000000..e1ac9e900e
--- /dev/null
+++ b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/saved_model_variables-00000-of-00001
Binary files differ
diff --git a/tensorflow/python/saved_model/builder.py b/tensorflow/python/saved_model/builder.py
index 76f99d00d0..23768f1247 100644
--- a/tensorflow/python/saved_model/builder.py
+++ b/tensorflow/python/saved_model/builder.py
@@ -252,7 +252,7 @@ class SavedModelBuilder(object):
# Save asset files, if any.
self._save_assets(assets_collection)
- saver = tf_saver.Saver(variables.all_variables())
+ saver = tf_saver.Saver(variables.all_variables(), sharded=True)
meta_graph_def = saver.export_meta_graph()
# Tag the meta graph def and add it to the SavedModel.
@@ -298,7 +298,7 @@ class SavedModelBuilder(object):
compat.as_text(constants.VARIABLES_FILENAME))
# Save the variables and export meta graph def.
- saver = tf_saver.Saver(variables.all_variables())
+ saver = tf_saver.Saver(variables.all_variables(), sharded=True)
saver.save(sess, variables_path, write_meta_graph=False)
meta_graph_def = saver.export_meta_graph()
diff --git a/tensorflow/python/saved_model/constants.py b/tensorflow/python/saved_model/constants.py
index 440edcc512..5c48103c6e 100644
--- a/tensorflow/python/saved_model/constants.py
+++ b/tensorflow/python/saved_model/constants.py
@@ -31,3 +31,4 @@ TAG_TRAINING = "train"
VARIABLES_DIRECTORY = "variables"
VARIABLES_FILENAME = "saved_model_variables"
+VARIABLES_FILENAME_SHARDED = VARIABLES_FILENAME + "-?????-of-?????"
diff --git a/tensorflow/python/saved_model/loader.py b/tensorflow/python/saved_model/loader.py
index eb49210078..c262b6380f 100644
--- a/tensorflow/python/saved_model/loader.py
+++ b/tensorflow/python/saved_model/loader.py
@@ -154,7 +154,7 @@ def load(sess, tags, export_dir):
variables_path = os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes(constants.VARIABLES_DIRECTORY),
- compat.as_bytes(constants.VARIABLES_FILENAME))
+ compat.as_bytes(constants.VARIABLES_FILENAME_SHARDED))
# Restore the variables using the built saver in the provided session.
saver.restore(sess, variables_path)