diff options
author | 2016-09-27 15:37:53 -0800 | |
---|---|---|
committer | 2016-09-27 16:49:44 -0700 | |
commit | 69cfb3b2e71c92b4837ef265b26680227022b861 (patch) | |
tree | 5e347aa316bd523863acba6f38026b08732c9783 | |
parent | 3d8c93548d933860ab4b1dc2caa84031f4c4878b (diff) |
Use a sharded saver in SavedModel.
Change: 134471468
-rw-r--r-- | tensorflow/cc/saved_model/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/cc/saved_model/constants.h | 4 | ||||
-rw-r--r-- | tensorflow/cc/saved_model/loader.cc | 18 | ||||
-rw-r--r-- | tensorflow/cc/saved_model/loader_test.cc | 14 | ||||
-rw-r--r-- | tensorflow/cc/saved_model/testdata/half_plus_two_sharded/saved_model.pb | bin | 0 -> 4300 bytes | |||
-rw-r--r-- | tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/checkpoint | 2 | ||||
-rw-r--r-- | tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/saved_model_variables-00000-of-00001 | bin | 0 -> 169 bytes | |||
-rw-r--r-- | tensorflow/python/saved_model/builder.py | 4 | ||||
-rw-r--r-- | tensorflow/python/saved_model/constants.py | 1 | ||||
-rw-r--r-- | tensorflow/python/saved_model/loader.py | 2 |
10 files changed, 36 insertions, 10 deletions
diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index f0a72a60d6..cf677ac4e3 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -52,6 +52,7 @@ filegroup( srcs = glob([ "testdata/half_plus_two/**", "testdata/half_plus_two_pbtxt/**", + "testdata/half_plus_two_sharded/**", ]), ) diff --git a/tensorflow/cc/saved_model/constants.h b/tensorflow/cc/saved_model/constants.h index 305a6e1793..0972ca5e60 100644 --- a/tensorflow/cc/saved_model/constants.h +++ b/tensorflow/cc/saved_model/constants.h @@ -30,6 +30,10 @@ constexpr char kSavedModelVariablesDirectory[] = "variables"; // SavedModel variables filename. constexpr char kSavedModelVariablesFilename[] = "saved_model_variables"; +// SavedModel sharded variables filename. +constexpr char kSavedModelVariablesShardedFilename[] = + "saved_model_variables-\?\?\?\?\?-of-\?\?\?\?\?"; + // Commonly used tags. constexpr char kSavedModelTagServe[] = "serve"; constexpr char kSavedModelTagTrain[] = "train"; diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 2863bc99fa..a88cae6c93 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -86,14 +86,18 @@ Status Restore(const RunOptions& run_options, const string& export_dir, const StringPiece restore_op_name, const StringPiece variable_filename_const_op_name, Session* session) { - const string variables_path = io::JoinPath( - export_dir, kSavedModelVariablesDirectory, kSavedModelVariablesFilename); - if (!Env::Default()->FileExists(variables_path)) { - return Status( - error::Code::NOT_FOUND, - "Could not find checkpointed variables at: " + variables_path); + // Find path to variables to be restored in export directory. + string variables_path = + io::JoinPath(export_dir, kSavedModelVariablesDirectory); + const string unsharded_variables_path = + io::JoinPath(variables_path, kSavedModelVariablesFilename); + if (Env::Default()->FileExists(unsharded_variables_path)) { + variables_path = unsharded_variables_path; + } else { + const string sharded_variables_path = + io::JoinPath(variables_path, kSavedModelVariablesShardedFilename); + variables_path = sharded_variables_path; } - // Add variables to the graph. Tensor variables_path_tensor(DT_STRING, TensorShape({})); variables_path_tensor.scalar<string>()() = variables_path; diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index cef28d96f9..b2d55a9ade 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -28,6 +28,8 @@ namespace { constexpr char kTestDataPb[] = "cc/saved_model/testdata/half_plus_two"; constexpr char kTestDataPbTxt[] = "cc/saved_model/testdata/half_plus_two_pbtxt"; +constexpr char kTestDataSharded[] = + "cc/saved_model/testdata/half_plus_two_sharded"; class LoaderTest : public ::testing::Test { protected: @@ -110,6 +112,18 @@ TEST_F(LoaderTest, PbtxtFormat) { CheckSavedModelBundle(bundle); } +TEST_F(LoaderTest, ShardedVariables) { + SavedModelBundle bundle; + SessionOptions session_options; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + TF_ASSERT_OK(LoadSavedModel(export_dir, {kSavedModelTagServe}, + session_options, run_options, &bundle)); + CheckSavedModelBundle(bundle); +} + TEST_F(LoaderTest, InvalidExportPath) { SavedModelBundle bundle; RunOptions run_options; diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/saved_model.pb b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/saved_model.pb Binary files differnew file mode 100644 index 0000000000..0a87f3306f --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/saved_model.pb diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/checkpoint b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/checkpoint new file mode 100644 index 0000000000..b5b6425e93 --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/checkpoint @@ -0,0 +1,2 @@ +model_checkpoint_path: "/tmp/saved_model/half_plus_two/variables/saved_model_variables-?????-of-00001" +all_model_checkpoint_paths: "/tmp/saved_model/half_plus_two/variables/saved_model_variables-?????-of-00001" diff --git a/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/saved_model_variables-00000-of-00001 b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/saved_model_variables-00000-of-00001 Binary files differnew file mode 100644 index 0000000000..e1ac9e900e --- /dev/null +++ b/tensorflow/cc/saved_model/testdata/half_plus_two_sharded/variables/saved_model_variables-00000-of-00001 diff --git a/tensorflow/python/saved_model/builder.py b/tensorflow/python/saved_model/builder.py index 76f99d00d0..23768f1247 100644 --- a/tensorflow/python/saved_model/builder.py +++ b/tensorflow/python/saved_model/builder.py @@ -252,7 +252,7 @@ class SavedModelBuilder(object): # Save asset files, if any. self._save_assets(assets_collection) - saver = tf_saver.Saver(variables.all_variables()) + saver = tf_saver.Saver(variables.all_variables(), sharded=True) meta_graph_def = saver.export_meta_graph() # Tag the meta graph def and add it to the SavedModel. @@ -298,7 +298,7 @@ class SavedModelBuilder(object): compat.as_text(constants.VARIABLES_FILENAME)) # Save the variables and export meta graph def. - saver = tf_saver.Saver(variables.all_variables()) + saver = tf_saver.Saver(variables.all_variables(), sharded=True) saver.save(sess, variables_path, write_meta_graph=False) meta_graph_def = saver.export_meta_graph() diff --git a/tensorflow/python/saved_model/constants.py b/tensorflow/python/saved_model/constants.py index 440edcc512..5c48103c6e 100644 --- a/tensorflow/python/saved_model/constants.py +++ b/tensorflow/python/saved_model/constants.py @@ -31,3 +31,4 @@ TAG_TRAINING = "train" VARIABLES_DIRECTORY = "variables" VARIABLES_FILENAME = "saved_model_variables" +VARIABLES_FILENAME_SHARDED = VARIABLES_FILENAME + "-?????-of-?????" diff --git a/tensorflow/python/saved_model/loader.py b/tensorflow/python/saved_model/loader.py index eb49210078..c262b6380f 100644 --- a/tensorflow/python/saved_model/loader.py +++ b/tensorflow/python/saved_model/loader.py @@ -154,7 +154,7 @@ def load(sess, tags, export_dir): variables_path = os.path.join( compat.as_bytes(export_dir), compat.as_bytes(constants.VARIABLES_DIRECTORY), - compat.as_bytes(constants.VARIABLES_FILENAME)) + compat.as_bytes(constants.VARIABLES_FILENAME_SHARDED)) # Restore the variables using the built saver in the provided session. saver.restore(sess, variables_path) |