diff options
Diffstat (limited to 'tensorflow/python/saved_model/loader_impl.py')
-rw-r--r-- | tensorflow/python/saved_model/loader_impl.py | 55 |
1 files changed, 19 insertions, 36 deletions
diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py index e5f649fdab..16077f52fa 100644 --- a/tensorflow/python/saved_model/loader_impl.py +++ b/tensorflow/python/saved_model/loader_impl.py @@ -116,11 +116,14 @@ def _get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None): return asset_tensor_dict -def _get_main_op_tensor(meta_graph_def_to_load): +def _get_main_op_tensor( + meta_graph_def_to_load, init_op_key=constants.MAIN_OP_KEY): """Gets the main op tensor, if one exists. Args: meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. + init_op_key: name of collection to check; should be one of MAIN_OP_KEY + or the deprecated LEGACY_INIT_OP_KEY Returns: The main op tensor, if it exists and `None` otherwise. @@ -131,38 +134,15 @@ def _get_main_op_tensor(meta_graph_def_to_load): """ collection_def = meta_graph_def_to_load.collection_def main_op_tensor = None - if constants.MAIN_OP_KEY in collection_def: - main_ops = collection_def[constants.MAIN_OP_KEY].node_list.value + if init_op_key in collection_def: + main_ops = collection_def[init_op_key].node_list.value if len(main_ops) != 1: - raise RuntimeError("Expected exactly one SavedModel main op.") - main_op_tensor = ops.get_collection(constants.MAIN_OP_KEY)[0] + raise RuntimeError("Expected exactly one SavedModel main op. " + "Found: {}".format(main_ops)) + main_op_tensor = ops.get_collection(init_op_key)[0] return main_op_tensor -def _get_legacy_init_op_tensor(meta_graph_def_to_load): - """Gets the legacy init op tensor, if one exists. - - Args: - meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. - - Returns: - The legacy init op tensor, if it exists and `None` otherwise. - - Raises: - RuntimeError: If the collection def corresponding to the legacy init op key - has other than exactly one tensor. - """ - collection_def = meta_graph_def_to_load.collection_def - legacy_init_op_tensor = None - if constants.LEGACY_INIT_OP_KEY in collection_def: - legacy_init_ops = collection_def[ - constants.LEGACY_INIT_OP_KEY].node_list.value - if len(legacy_init_ops) != 1: - raise RuntimeError("Expected exactly one legacy serving init op.") - legacy_init_op_tensor = ops.get_collection(constants.LEGACY_INIT_OP_KEY)[0] - return legacy_init_op_tensor - - @tf_export("saved_model.loader.maybe_saved_model_directory") def maybe_saved_model_directory(export_dir): """Checks whether the provided export directory could contain a SavedModel. @@ -284,12 +264,15 @@ class SavedModelLoader(object): **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. Returns: - Saver defined by the MetaGraph, which can be used to restore the variable - values. + A tuple of + * Saver defined by the MetaGraph, which can be used to restore the + variable values. + * List of `Operation`/`Tensor` objects returned from + `tf.import_graph_def` (may be `None`). """ meta_graph_def = self.get_meta_graph_def_from_tags(tags) with graph.as_default(): - return tf_saver.import_meta_graph( + return tf_saver._import_meta_graph_with_return_elements( # pylint: disable=protected-access meta_graph_def, import_scope=import_scope, **saver_kwargs) def restore_variables(self, sess, saver, import_scope=None): @@ -340,8 +323,8 @@ class SavedModelLoader(object): self._export_dir, meta_graph_def, import_scope=import_scope) main_op_tensor = ( - _get_main_op_tensor(meta_graph_def) or - (_get_legacy_init_op_tensor(meta_graph_def))) + _get_main_op_tensor(meta_graph_def, constants.MAIN_OP_KEY) or + _get_main_op_tensor(meta_graph_def, constants.LEGACY_INIT_OP_KEY)) if main_op_tensor is not None: sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary) @@ -361,8 +344,8 @@ class SavedModelLoader(object): `MetagraphDef` proto of the graph that was loaded. """ with sess.graph.as_default(): - saver = self.load_graph(sess.graph, tags, import_scope, - **saver_kwargs) + saver, _ = self.load_graph(sess.graph, tags, import_scope, + **saver_kwargs) self.restore_variables(sess, saver, import_scope) self.run_init_ops(sess, tags, import_scope) return self.get_meta_graph_def_from_tags(tags) |