aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/checkpoint
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-05-23 10:43:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-23 10:45:45 -0700
commit7b78417a00e6805557d530c1f1fcc8b2a44d6e2e (patch)
tree400694f98f7a31fb946a92384dac0ce201c9d0f9 /tensorflow/contrib/checkpoint
parentc78d4e8e7e032986789b0755b399b6c9ad274b5d (diff)
Add a checkpointable list data structure
Allows tracking of Layers and other checkpointable objects by number. Fixes #19250. PiperOrigin-RevId: 197749961
Diffstat (limited to 'tensorflow/contrib/checkpoint')
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py6
-rw-r--r--tensorflow/contrib/checkpoint/python/BUILD16
2 files changed, 17 insertions, 5 deletions
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index af8df72618..bd0bc9e56b 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -18,11 +18,14 @@ Visualization and inspection:
@@dot_graph_from_checkpoint
@@object_metadata
-Creating and managing dependencies:
+Managing dependencies:
@@Checkpointable
@@CheckpointableObjectGraph
@@NoDependency
@@split_dependency
+
+Checkpointable data structures:
+@@List
@@UniqueNameTracker
"""
@@ -36,6 +39,7 @@ from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkp
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
from tensorflow.python.training.checkpointable.base import Checkpointable
from tensorflow.python.training.checkpointable.base import NoDependency
+from tensorflow.python.training.checkpointable.data_structures import List
from tensorflow.python.training.checkpointable.util import object_metadata
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD
index 53f4e97f99..0b67619c11 100644
--- a/tensorflow/contrib/checkpoint/python/BUILD
+++ b/tensorflow/contrib/checkpoint/python/BUILD
@@ -11,6 +11,7 @@ py_library(
":containers",
":split_dependency",
":visualize",
+ "//tensorflow/python/training/checkpointable:data_structures",
],
)
@@ -30,8 +31,8 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:training",
"//tensorflow/python/training/checkpointable:base",
+ "//tensorflow/python/training/checkpointable:util",
"@six_archive//:six",
],
)
@@ -44,6 +45,7 @@ py_library(
deps = [
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:training",
+ "//tensorflow/python/training/checkpointable:base",
],
)
@@ -55,8 +57,9 @@ py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:training",
"//tensorflow/python/eager:test",
+ "//tensorflow/python/training/checkpointable:base",
+ "//tensorflow/python/training/checkpointable:util",
],
)
@@ -67,6 +70,8 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:pywrap_tensorflow",
+ "//tensorflow/python/training/checkpointable:base",
+ "//tensorflow/python/training/checkpointable:util",
],
)
@@ -75,10 +80,13 @@ py_test(
srcs = ["visualize_test.py"],
deps = [
":visualize",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:constant_op",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
+ "//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
+ "//tensorflow/python/keras:engine",
+ "//tensorflow/python/keras:layers",
+ "//tensorflow/python/training/checkpointable:util",
],
)