diff options
author | Mustafa Ispir <ispir@google.com> | 2017-06-15 16:34:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-15 16:38:58 -0700 |
commit | 95e31cbb09c40322acc467902080c1cfd9ec88b1 (patch) | |
tree | c45cb3659ccf932ef6cb996e9f983d505ddca652 | |
parent | c5436fc0c1047719e4f86942729db4dbc9c76113 (diff) |
Add checkpoint-utils to the tf.train module.
PiperOrigin-RevId: 159171746
-rw-r--r-- | tensorflow/python/BUILD | 22 | ||||
-rw-r--r-- | tensorflow/python/training/checkpoint_utils.py | 5 | ||||
-rw-r--r-- | tensorflow/python/training/training.py | 9 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/tensorflow.train.pbtxt | 16 |
4 files changed, 49 insertions, 3 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 4a395290e0..64b052839f 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3148,7 +3148,6 @@ cuda_py_tests( "training/adagrad_da_test.py", "training/adagrad_test.py", "training/basic_loops_test.py", - "training/checkpoint_utils_test.py", "training/coordinator_test.py", "training/device_setter_test.py", "training/ftrl_test.py", @@ -3353,6 +3352,27 @@ py_test( ) py_test( + name = "checkpoint_utils_test", + size = "small", + srcs = ["training/checkpoint_utils_test.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + ":client", + ":client_testlib", + ":framework_for_generated_wrappers", + ":io_ops", + ":partitioned_variables", + ":platform", + ":pywrap_tensorflow", + ":state_ops", + ":training", + ":variable_scope", + ":variables", + ], +) + +py_test( name = "monitored_session_test", size = "small", srcs = ["training/monitored_session_test.py"], diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py index d52cf9a436..ddf04e21e6 100644 --- a/tensorflow/python/training/checkpoint_utils.py +++ b/tensorflow/python/training/checkpoint_utils.py @@ -20,6 +20,7 @@ from __future__ import print_function import six +from tensorflow.python import pywrap_tensorflow from tensorflow.python.ops import io_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs @@ -27,7 +28,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver -from tensorflow.python.training import training as train + __all__ = [ "load_checkpoint", "load_variable", "list_variables", "init_from_checkpoint" @@ -55,7 +56,7 @@ def load_checkpoint(ckpt_dir_or_file): if filename is None: raise ValueError("Couldn't find 'checkpoint' file or checkpoints in " "given directory %s" % ckpt_dir_or_file) - return train.NewCheckpointReader(filename) + return pywrap_tensorflow.NewCheckpointReader(filename) def load_variable(ckpt_dir_or_file, name): diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py index f4ac3c9758..e2a7b28e2b 100644 --- a/tensorflow/python/training/training.py +++ b/tensorflow/python/training/training.py @@ -85,6 +85,10 @@ See the @{$python/train} guide. @@create_global_step @@assert_global_step @@write_graph +@@load_checkpoint +@@load_variable +@@list_variables +@@init_from_checkpoint """ # Optimizers. @@ -142,6 +146,11 @@ from tensorflow.python.training.basic_session_run_hooks import GlobalStepWaiterH from tensorflow.python.training.basic_session_run_hooks import FinalOpsHook from tensorflow.python.training.basic_session_run_hooks import FeedFnHook from tensorflow.python.training.basic_loops import basic_train_loop +from tensorflow.python.training.checkpoint_utils import init_from_checkpoint +from tensorflow.python.training.checkpoint_utils import list_variables +from tensorflow.python.training.checkpoint_utils import load_checkpoint +from tensorflow.python.training.checkpoint_utils import load_variable + from tensorflow.python.training.device_setter import replica_device_setter from tensorflow.python.training.monitored_session import Scaffold from tensorflow.python.training.monitored_session import MonitoredTrainingSession diff --git a/tensorflow/tools/api/golden/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.pbtxt index 58fd5760c1..c295537965 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.pbtxt @@ -305,6 +305,10 @@ tf_module { argspec: "args=[\'meta_graph_or_file\', \'clear_devices\', \'import_scope\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'None\'], " } member_method { + name: "init_from_checkpoint" + argspec: "args=[\'ckpt_dir_or_file\', \'assignment_map\'], varargs=None, keywords=None, defaults=None" + } + member_method { name: "input_producer" argspec: "args=[\'input_tensor\', \'element_shape\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'summary_name\', \'name\', \'cancel_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'32\', \'None\', \'None\', \'None\', \'None\'], " } @@ -321,6 +325,18 @@ tf_module { argspec: "args=[\'tensor\', \'num_epochs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { + name: "list_variables" + argspec: "args=[\'ckpt_dir_or_file\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_checkpoint" + argspec: "args=[\'ckpt_dir_or_file\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_variable" + argspec: "args=[\'ckpt_dir_or_file\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { name: "match_filenames_once" argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } |