aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/BUILD
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-08-02 15:47:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-02 15:51:17 -0700
commit1bf206bc82f600886f1e19c9860f09f18984346b (patch)
treefbd6ee10df16e491142017e96120181b81a72ec5 /tensorflow/python/BUILD
parent6fbbad97e293cc39bde32495e92614c69a9a7896 (diff)
Split checkpoint management utility functions out of saver.py
Pure refactor, in preparation for adding a higher level checkpoint management utility. This utility will also need to work with the Checkpoint proto, and globbing it on to saver.py seems dirty. PiperOrigin-RevId: 207179646
Diffstat (limited to 'tensorflow/python/BUILD')
-rw-r--r--tensorflow/python/BUILD70
1 files changed, 61 insertions, 9 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 2b8110a999..7cf8ddb1d9 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3216,6 +3216,7 @@ py_library(
"training/checkpointable/**/*.py",
# The following targets have their own build rules (same name as the
# file):
+ "training/checkpoint_management.py",
"training/saveable_object.py",
"training/saver.py",
"training/training_util.py",
@@ -3223,8 +3224,10 @@ py_library(
),
srcs_version = "PY2AND3",
deps = [
+ "saver",
":array_ops",
":array_ops_gen",
+ ":checkpoint_management",
":checkpoint_ops_gen",
":client",
":control_flow_ops",
@@ -3236,25 +3239,20 @@ py_library(
":framework_ops",
":gradients",
":init_ops",
- ":distribute",
":io_ops",
- ":io_ops_gen",
":layers_base",
- ":lib",
":lookup_ops",
":math_ops",
":platform",
- ":protos_all_py",
":pywrap_tensorflow",
":random_ops",
":resource_variable_ops",
":resources",
- "saver",
- ":saveable_object",
":sdca_ops",
+ ":session",
":sparse_ops",
+ ":sparse_tensor",
":state_ops",
- ":string_ops",
":summary",
":training_ops_gen",
":training_util",
@@ -3264,6 +3262,7 @@ py_library(
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
# `layers` dependency only exists due to the use of a small utility.
@@ -3281,11 +3280,25 @@ py_library(
)
py_library(
+ name = "checkpoint_management",
+ srcs = ["training/checkpoint_management.py"],
+ deps = [
+ ":errors",
+ ":lib",
+ ":platform",
+ ":protos_all_py",
+ ":util",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+py_library(
name = "saver",
srcs = ["training/saver.py"],
srcs_version = "PY2AND3",
deps = [
":array_ops",
+ ":checkpoint_management",
":constant_op",
":control_flow_ops",
":device",
@@ -3294,9 +3307,7 @@ py_library(
":framework_ops",
":io_ops",
":io_ops_gen",
- ":lib",
":platform",
- ":protos_all_py",
":pywrap_tensorflow",
":resource_variable_ops",
":saveable_object",
@@ -4423,6 +4434,42 @@ cuda_py_test(
tags = ["multi_gpu"],
)
+cuda_py_test(
+ name = "checkpoint_management_test",
+ size = "small",
+ srcs = [
+ "training/checkpoint_management_test.py",
+ ],
+ additional_deps = [
+ ":array_ops",
+ ":client_testlib",
+ ":control_flow_ops",
+ ":data_flow_ops",
+ ":errors",
+ ":gradients",
+ ":math_ops",
+ ":nn_grad",
+ ":nn_ops",
+ ":saver_test_utils",
+ ":partitioned_variables",
+ ":platform",
+ ":platform_test",
+ ":pywrap_tensorflow",
+ ":random_ops",
+ ":resource_variable_ops",
+ ":sparse_ops",
+ ":summary",
+ ":training",
+ ":util",
+ ":variable_scope",
+ ":variables",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
py_test(
name = "saver_large_variable_test",
size = "medium",
@@ -4489,6 +4536,7 @@ tf_py_test(
srcs = ["training/supervisor_test.py"],
additional_deps = [
":array_ops",
+ ":checkpoint_management",
":client_testlib",
":errors",
":framework",
@@ -4496,6 +4544,7 @@ tf_py_test(
":io_ops",
":parsing_ops",
":platform",
+ ":saver",
":summary",
":training",
":variables",
@@ -4609,10 +4658,13 @@ py_test(
tags = ["notsan"], # b/67945581
deps = [
":array_ops",
+ ":checkpoint_management",
":client_testlib",
":control_flow_ops",
":errors",
":framework_for_generated_wrappers",
+ ":resource_variable_ops",
+ ":saver",
":session",
":state_ops",
":summary",