aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py4
-rw-r--r--tensorflow/python/training/input.py2
-rw-r--r--tensorflow/python/training/saver.py12
3 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index 17e07e171a..aae757b99a 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -336,7 +336,7 @@ class CheckpointSaverListener(object):
`CheckpointSaverHook`, as in this example:
```python
- class ExampleCheckpointSaverListerner(CheckpointSaverListener):
+ class ExampleCheckpointSaverListener(CheckpointSaverListener):
def begin(self):
# You can add ops to the graph here.
print('Starting the session.')
@@ -352,7 +352,7 @@ class CheckpointSaverListener(object):
print('Done with the session.')
...
- listener = ExampleCheckpointSaverListerner()
+ listener = ExampleCheckpointSaverListener()
saver_hook = tf.train.CheckpointSaverHook(
checkpoint_dir, listeners=[listener])
with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index 992184ec9e..bd9985a7c5 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -58,6 +58,8 @@ _restore_sparse = sparse_ops._take_many_sparse_from_tensors_map
def match_filenames_once(pattern, name=None):
"""Save the list of files matching pattern, so it is only computed once.
+ NOTE: The order of the files returned can be non-deterministic.
+
Args:
pattern: A file pattern (glob), or 1D tensor of file patterns.
name: A name for the operations (optional).
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 3888e9bba4..0c1c8e664b 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -1597,9 +1597,9 @@ class Saver(object):
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Returns:
- A string: path prefix used for the checkpoint files. If the saver is
- sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
- is the number of shards created.
+ A string: path prefix used for the checkpoint files. If checkpoint
+ format is V1 and the saver is sharded, this string ends with:
+ '-?????-of-nnnnn' where 'nnnnn' is the number of shards created.
If the saver is empty, returns None.
Raises:
@@ -1749,6 +1749,12 @@ class Saver(object):
return
if save_path is None:
raise ValueError("Can't load save_path when it is None.")
+ if (os.path.isfile(save_path) and
+ self._write_version not in (
+ saver_pb2.SaverDef.V1, saver_pb2.SaverDef.LEGACY)):
+ raise ValueError("The specified path: %s is a file."
+ " Please specify only the path prefix"
+ " to the checkpoint files." % save_path)
logging.info("Restoring parameters from %s", save_path)
if context.in_graph_mode():
sess.run(self.saver_def.restore_op_name,