aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-06-15 16:34:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-15 16:38:58 -0700
commit95e31cbb09c40322acc467902080c1cfd9ec88b1 (patch)
treec45cb3659ccf932ef6cb996e9f983d505ddca652
parentc5436fc0c1047719e4f86942729db4dbc9c76113 (diff)
Add checkpoint-utils to the tf.train module.
PiperOrigin-RevId: 159171746
-rw-r--r--tensorflow/python/BUILD22
-rw-r--r--tensorflow/python/training/checkpoint_utils.py5
-rw-r--r--tensorflow/python/training/training.py9
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.pbtxt16
4 files changed, 49 insertions, 3 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 4a395290e0..64b052839f 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3148,7 +3148,6 @@ cuda_py_tests(
"training/adagrad_da_test.py",
"training/adagrad_test.py",
"training/basic_loops_test.py",
- "training/checkpoint_utils_test.py",
"training/coordinator_test.py",
"training/device_setter_test.py",
"training/ftrl_test.py",
@@ -3353,6 +3352,27 @@ py_test(
)
py_test(
+ name = "checkpoint_utils_test",
+ size = "small",
+ srcs = ["training/checkpoint_utils_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ ":client",
+ ":client_testlib",
+ ":framework_for_generated_wrappers",
+ ":io_ops",
+ ":partitioned_variables",
+ ":platform",
+ ":pywrap_tensorflow",
+ ":state_ops",
+ ":training",
+ ":variable_scope",
+ ":variables",
+ ],
+)
+
+py_test(
name = "monitored_session_test",
size = "small",
srcs = ["training/monitored_session_test.py"],
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index d52cf9a436..ddf04e21e6 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import six
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope as vs
@@ -27,7 +28,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver
-from tensorflow.python.training import training as train
+
__all__ = [
"load_checkpoint", "load_variable", "list_variables", "init_from_checkpoint"
@@ -55,7 +56,7 @@ def load_checkpoint(ckpt_dir_or_file):
if filename is None:
raise ValueError("Couldn't find 'checkpoint' file or checkpoints in "
"given directory %s" % ckpt_dir_or_file)
- return train.NewCheckpointReader(filename)
+ return pywrap_tensorflow.NewCheckpointReader(filename)
def load_variable(ckpt_dir_or_file, name):
diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py
index f4ac3c9758..e2a7b28e2b 100644
--- a/tensorflow/python/training/training.py
+++ b/tensorflow/python/training/training.py
@@ -85,6 +85,10 @@ See the @{$python/train} guide.
@@create_global_step
@@assert_global_step
@@write_graph
+@@load_checkpoint
+@@load_variable
+@@list_variables
+@@init_from_checkpoint
"""
# Optimizers.
@@ -142,6 +146,11 @@ from tensorflow.python.training.basic_session_run_hooks import GlobalStepWaiterH
from tensorflow.python.training.basic_session_run_hooks import FinalOpsHook
from tensorflow.python.training.basic_session_run_hooks import FeedFnHook
from tensorflow.python.training.basic_loops import basic_train_loop
+from tensorflow.python.training.checkpoint_utils import init_from_checkpoint
+from tensorflow.python.training.checkpoint_utils import list_variables
+from tensorflow.python.training.checkpoint_utils import load_checkpoint
+from tensorflow.python.training.checkpoint_utils import load_variable
+
from tensorflow.python.training.device_setter import replica_device_setter
from tensorflow.python.training.monitored_session import Scaffold
from tensorflow.python.training.monitored_session import MonitoredTrainingSession
diff --git a/tensorflow/tools/api/golden/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.pbtxt
index 58fd5760c1..c295537965 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.pbtxt
@@ -305,6 +305,10 @@ tf_module {
argspec: "args=[\'meta_graph_or_file\', \'clear_devices\', \'import_scope\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'None\'], "
}
member_method {
+ name: "init_from_checkpoint"
+ argspec: "args=[\'ckpt_dir_or_file\', \'assignment_map\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "input_producer"
argspec: "args=[\'input_tensor\', \'element_shape\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'summary_name\', \'name\', \'cancel_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'32\', \'None\', \'None\', \'None\', \'None\'], "
}
@@ -321,6 +325,18 @@ tf_module {
argspec: "args=[\'tensor\', \'num_epochs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
+ name: "list_variables"
+ argspec: "args=[\'ckpt_dir_or_file\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "load_checkpoint"
+ argspec: "args=[\'ckpt_dir_or_file\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "load_variable"
+ argspec: "args=[\'ckpt_dir_or_file\', \'name\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "match_filenames_once"
argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}