aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-15 13:34:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-15 14:56:00 -0700
commit19ef82151be358698f9ab8702ed5575afc94f110 (patch)
tree42cc430ca40355d103543b1f59c3f6b4ada020e1
parentb05e0840d11ee30c3a66d45daeeea2495b9808e5 (diff)
Make SavedModel exports include all the SAVEABLE objects and not just global variables.
Change: 150243023
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py7
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test.py87
-rw-r--r--tensorflow/python/BUILD13
-rw-r--r--tensorflow/python/estimator/BUILD1
-rw-r--r--tensorflow/python/estimator/estimator.py2
-rw-r--r--tensorflow/python/estimator/estimator_test.py59
-rw-r--r--tensorflow/python/saved_model/BUILD1
-rw-r--r--tensorflow/python/saved_model/builder_impl.py8
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py30
-rw-r--r--tensorflow/python/training/saver_test.py86
-rw-r--r--tensorflow/python/training/saver_test_utils.py86
11 files changed, 303 insertions, 77 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 6f9c44dff6..4c575bb62d 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -58,6 +58,7 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
@@ -1254,13 +1255,17 @@ class Estimator(BaseEstimator):
with tf_session.Session('') as session:
variables.initialize_local_variables()
data_flow_ops.tables_initializer()
+ resources.initialize_resources(resources.shared_resources())
saver_for_restore = saver.Saver(
- variables.global_variables(),
+ # pylint: disable=protected-access
+ variables._all_saveable_objects(),
+ # pylint: enable=protected-access
sharded=True)
saver_for_restore.restore(session, checkpoint_path)
init_op = control_flow_ops.group(
variables.local_variables_initializer(),
+ resources.initialize_resources(resources.shared_resources()),
data_flow_ops.tables_initializer())
# Perform the export
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index 659266690f..b27e99b0dc 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -50,6 +50,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
@@ -225,6 +226,49 @@ def _build_estimator_for_export_tests(tmpdir):
return est, serving_input_fn_with_asset
+def _build_estimator_for_resource_export_test():
+
+ def _input_fn():
+ iris = base.load_iris()
+ return {
+ 'feature': constant_op.constant(iris.data, dtype=dtypes.float32)
+ }, constant_op.constant(
+ iris.target, shape=[150], dtype=dtypes.int32)
+
+ feature_columns = [
+ feature_column_lib.real_valued_column('feature', dimension=4)
+ ]
+
+ def resource_constant_model_fn(unused_features, unused_labels, mode):
+ """A model_fn that loads a constant from a resource and serves it."""
+ assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
+ model_fn.ModeKeys.INFER)
+
+ const = constant_op.constant(-1, dtype=dtypes.int64)
+ table = lookup.MutableHashTable(
+ dtypes.string, dtypes.int64, const, name='LookupTableModel')
+ if mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL):
+ key = constant_op.constant(['key'])
+ value = constant_op.constant([42], dtype=dtypes.int64)
+ train_op_1 = table.insert(key, value)
+ training_state = lookup.MutableHashTable(
+ dtypes.string, dtypes.int64, const, name='LookupTableTrainingState')
+ training_op_2 = training_state.insert(key, value)
+ return const, const, control_flow_ops.group(train_op_1, training_op_2)
+ if mode == model_fn.ModeKeys.INFER:
+ key = constant_op.constant(['key'])
+ prediction = table.lookup(key)
+ return prediction, const, control_flow_ops.no_op()
+
+ est = estimator.Estimator(model_fn=resource_constant_model_fn)
+ est.fit(input_fn=_input_fn, steps=1)
+
+ feature_spec = feature_column_lib.create_feature_spec_for_parsing(
+ feature_columns)
+ serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec)
+ return est, serving_input_fn
+
+
class CheckCallsMonitor(monitors_lib.BaseMonitor):
def __init__(self, expect_calls):
@@ -753,6 +797,49 @@ class EstimatorTest(test.TestCase):
# cleanup
gfile.DeleteRecursively(tmpdir)
+ def test_export_savedmodel_with_resource(self):
+ tmpdir = tempfile.mkdtemp()
+ est, serving_input_fn = _build_estimator_for_resource_export_test()
+
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+ export_dir = est.export_savedmodel(export_dir_base, serving_input_fn)
+
+ self.assertTrue(gfile.Exists(export_dir_base))
+ self.assertTrue(gfile.Exists(export_dir))
+ self.assertTrue(
+ gfile.Exists(
+ os.path.join(
+ compat.as_bytes(export_dir), compat.as_bytes(
+ 'saved_model.pb'))))
+ self.assertTrue(
+ gfile.Exists(
+ os.path.join(
+ compat.as_bytes(export_dir), compat.as_bytes('variables'))))
+ self.assertTrue(
+ gfile.Exists(
+ os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('variables/variables.index'))))
+ self.assertTrue(
+ gfile.Exists(
+ os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('variables/variables.data-00000-of-00001'))))
+
+ # Restore, to validate that the export was well-formed.
+ with ops.Graph().as_default() as graph:
+ with session_lib.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.SERVING], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('input_example_tensor' in graph_ops)
+ self.assertTrue('ParseExample/ParseExample' in graph_ops)
+ self.assertTrue('LookupTableModel' in graph_ops)
+ self.assertFalse('LookupTableTrainingState' in graph_ops)
+
+ # cleanup
+ gfile.DeleteRecursively(tmpdir)
+
class InferRealValuedColumnsTest(test.TestCase):
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 7f416dc609..5c3e87cfd6 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -74,6 +74,7 @@ py_library(
":tensor_array_ops",
":training",
":ops",
+ ":saver_test_utils",
":test_ops", # TODO: Break testing code out into separate rule.
":util",
":weights_broadcast_ops",
@@ -2935,6 +2936,16 @@ cuda_py_tests(
],
)
+py_library(
+ name = "saver_test_utils",
+ srcs = ["training/saver_test_utils.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":data_flow_ops_gen",
+ ":training",
+ ],
+)
+
cuda_py_test(
name = "saver_test",
size = "medium",
@@ -2946,12 +2957,12 @@ cuda_py_test(
":client_testlib",
":control_flow_ops",
":data_flow_ops",
- ":data_flow_ops_gen",
":errors",
":gradients",
":math_ops",
":nn_grad",
":nn_ops",
+ ":saver_test_utils",
":partitioned_variables",
":platform",
":platform_test",
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index e6d4af2f95..bcbf06deab 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -136,6 +136,7 @@ py_test(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:layers",
+ "//tensorflow/python:saver_test_utils",
"//tensorflow/python:session",
"//tensorflow/python:state_ops",
"//tensorflow/python:training",
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index d30f9093b6..677c840ff0 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -411,7 +411,7 @@ class Estimator(object):
with tf_session.Session() as session:
saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(
- variables.global_variables(),
+ variables._all_saveable_objects(), # pylint: disable=protected-access
sharded=True)
saver_for_restore.restore(session, checkpoint_path)
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 061a1226fb..5b0b044ab2 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -48,6 +48,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import saver
+from tensorflow.python.training import saver_test_utils
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training
from tensorflow.python.util import compat
@@ -814,6 +815,20 @@ def _model_fn_for_export_tests(features, labels, mode):
'test': export_output.ClassificationOutput(scores, classes)})
+def _model_fn_with_saveables_for_export_tests(features, labels, mode):
+ _, _ = features, labels
+ table = saver_test_utils.CheckpointedOp(name='v2')
+ train_op = table.insert('k1', 30.0)
+ prediction = table.lookup('k1', 0.0)
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ predictions=prediction,
+ loss=constant_op.constant(1.),
+ train_op=train_op,
+ export_outputs={
+ 'test': export_output.PredictOutput({'prediction': prediction})})
+
+
_VOCAB_FILE_CONTENT = 'emerson\nlake\npalmer\n'
_EXTRA_FILE_CONTENT = 'kermit\npiggy\nralph\n'
@@ -863,6 +878,50 @@ class EstimatorExportTest(test.TestCase):
# Clean up.
gfile.DeleteRecursively(tmpdir)
+ def test_export_savedmodel_with_saveables_proto_roundtrip(self):
+ tmpdir = tempfile.mkdtemp()
+ est = estimator.Estimator(
+ model_fn=_model_fn_with_saveables_for_export_tests)
+ est.train(input_fn=dummy_input_fn, steps=1)
+ feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
+ 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
+ serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+
+ # Perform the export.
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+ export_dir = est.export_savedmodel(
+ export_dir_base, serving_input_receiver_fn)
+
+ # Check that all the files are in the right places.
+ self.assertTrue(gfile.Exists(export_dir_base))
+ self.assertTrue(gfile.Exists(export_dir))
+ self.assertTrue(gfile.Exists(os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('saved_model.pb'))))
+ self.assertTrue(gfile.Exists(os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('variables'))))
+ self.assertTrue(gfile.Exists(os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('variables/variables.index'))))
+ self.assertTrue(gfile.Exists(os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('variables/variables.data-00000-of-00001'))))
+
+ # Restore, to validate that the export was well-formed.
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.SERVING], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('input_example_tensor' in graph_ops)
+ self.assertTrue('ParseExample/ParseExample' in graph_ops)
+ self.assertTrue('save/LookupTableImport' in graph_ops)
+
+ # Clean up.
+ gfile.DeleteRecursively(tmpdir)
+
def test_export_savedmodel_assets(self):
tmpdir = tempfile.mkdtemp()
est = estimator.Estimator(model_fn=_model_fn_for_export_tests)
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index 79399c11c4..8301a73e87 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -122,6 +122,7 @@ py_test(
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:lib",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:saver_test_utils",
"//tensorflow/python:state_ops",
"//tensorflow/python:util",
"//tensorflow/python:variables",
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index 4fd87c04ec..7b4fabad95 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -352,10 +352,10 @@ class SavedModelBuilder(object):
else:
self._add_main_op(main_op)
- # Initialize a saver to generate a sharded output for all variables in the
+ # Initialize a saver to generate a sharded output for all saveables in the
# current scope.
saver = tf_saver.Saver(
- variables.global_variables(),
+ variables._all_saveable_objects(), # pylint: disable=protected-access
sharded=True,
write_version=saver_pb2.SaverDef.V2,
allow_empty=True)
@@ -423,10 +423,10 @@ class SavedModelBuilder(object):
else:
self._add_main_op(main_op)
- # Initialize a saver to generate a sharded output for all variables in the
+ # Initialize a saver to generate a sharded output for all saveables in the
# current scope.
saver = tf_saver.Saver(
- variables.global_variables(),
+ variables._all_saveable_objects(), # pylint: disable=protected-access
sharded=True,
write_version=saver_pb2.SaverDef.V2,
allow_empty=True)
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 03c5901bd4..a81f744175 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -39,6 +39,7 @@ from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import main_op
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
+from tensorflow.python.training import saver_test_utils
from tensorflow.python.util import compat
SAVED_MODEL_PATH = ("cc/saved_model/testdata/half_plus_two/00000123")
@@ -734,6 +735,35 @@ class SavedModelTest(test.TestCase):
ops.get_collection("init_op")[0].run()
self.assertEqual(3, ops.get_collection("v")[2].eval())
+ def testCustomSaveable(self):
+ export_dir = os.path.join(test.get_temp_dir(), "custom_saveable")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with session.Session(
+ graph=ops.Graph(),
+ config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
+ # CheckpointedOp is a key-value table that can be saved across sessions.
+ # The table register itself in SAVEABLE_OBJECTS collection.
+ v1 = saver_test_utils.CheckpointedOp(name="v1")
+ variables.global_variables_initializer().run()
+ v1.insert("k1", 3.0).run()
+ # Once the table is restored, we can access it through this reference.
+ ops.add_to_collection("table_ref", v1.table_ref)
+ builder.add_meta_graph_and_variables(sess, ["foo"])
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with session.Session(
+ graph=ops.Graph(),
+ config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
+ loader.load(sess, ["foo"], export_dir)
+ # Instantiate a wrapper object from the checkpointed reference.
+ v1 = saver_test_utils.CheckpointedOp(
+ name="v1", table_ref=ops.get_collection("table_ref")[0])
+ self.assertEqual(b"k1", v1.keys().eval())
+ self.assertEqual(3.0, v1.values().eval())
+
def testClearDevices(self):
export_dir = os.path.join(test.get_temp_dir(), "test_clear_devices")
builder = saved_model_builder.SavedModelBuilder(export_dir)
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 4ca12cb24e..abb1f89e86 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -48,7 +48,6 @@ from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
-from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables
@@ -65,63 +64,10 @@ from tensorflow.python.training import adam
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import queue_runner_impl
from tensorflow.python.training import saver as saver_module
+from tensorflow.python.training import saver_test_utils
from tensorflow.python.util import compat
-class CheckpointedOp(object):
- """Op with a custom checkpointing implementation.
-
- Defined as part of the test because the MutableHashTable Python code is
- currently in contrib.
- """
-
- def __init__(self, name):
- self._table_ref = gen_data_flow_ops._mutable_hash_table(
- key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name)
- self._name = name
- self._saveable = CheckpointedOp.CustomSaveable(self, name)
- ops_lib.add_to_collection(ops_lib.GraphKeys.SAVEABLE_OBJECTS,
- self._saveable)
-
- @property
- def name(self):
- return self._name
-
- @property
- def saveable(self):
- return self._saveable
-
- def insert(self, keys, values):
- return gen_data_flow_ops._lookup_table_insert(self._table_ref, keys, values)
-
- def keys(self):
- return self._export()[0]
-
- def values(self):
- return self._export()[1]
-
- def _export(self):
- return gen_data_flow_ops._lookup_table_export(self._table_ref,
- dtypes.string, dtypes.float32)
-
- class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject):
-
- def __init__(self, table, name):
- tensors = table._export()
- specs = [
- saver_module.BaseSaverBuilder.SaveSpec(tensors[0], "",
- name + "-keys"),
- saver_module.BaseSaverBuilder.SaveSpec(tensors[1], "",
- name + "-values")
- ]
- super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name)
-
- def restore(self, restore_tensors, shapes):
- return gen_data_flow_ops._lookup_table_import(self.op._table_ref,
- restore_tensors[0],
- restore_tensors[1])
-
-
class SaverTest(test.TestCase):
def basicSaveRestore(self, variable_op):
@@ -131,7 +77,7 @@ class SaverTest(test.TestCase):
# Restore nodes for them.
v0 = variable_op(10.0, name="v0")
v1 = variable_op(20.0, name="v1")
- v2 = CheckpointedOp(name="v2")
+ v2 = saver_test_utils.CheckpointedOp(name="v2")
v2_init = v2.insert("k1", 30.0)
save = saver_module.Saver(
{
@@ -161,7 +107,7 @@ class SaverTest(test.TestCase):
with self.test_session() as sess:
v0 = variable_op(-1.0, name="v0")
v1 = variable_op(-1.0, name="v1")
- v2 = CheckpointedOp(name="v2")
+ v2 = saver_test_utils.CheckpointedOp(name="v2")
save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})
# Assert that the variables are not initialized.
@@ -183,7 +129,7 @@ class SaverTest(test.TestCase):
with self.test_session() as sess:
v0_2 = variable_op(1000.0, name="v0")
v1_2 = variable_op(2000.0, name="v1")
- v2_2 = CheckpointedOp(name="v2")
+ v2_2 = saver_test_utils.CheckpointedOp(name="v2")
save2 = saver_module.Saver({"v0": v0_2, "v1": v1_2, "v2": v2_2.saveable})
v2_2.insert("k1000", 3000.0).run()
variables.global_variables_initializer().run()
@@ -276,7 +222,7 @@ class SaverTest(test.TestCase):
def testSameName(self):
with ops_lib.Graph().as_default():
v0 = variables.Variable([10.0], name="v0")
- v2 = CheckpointedOp(name="v2")
+ v2 = saver_test_utils.CheckpointedOp(name="v2")
# Saving one variable under two names raises an error.
with self.assertRaisesRegexp(
@@ -299,7 +245,7 @@ class SaverTest(test.TestCase):
# Restore nodes for them.
v0 = variables.Variable(10.0, name="v0")
v1 = variables.Variable(20.0, name="v1")
- v2 = CheckpointedOp(name="v2")
+ v2 = saver_test_utils.CheckpointedOp(name="v2")
v2_init = v2.insert("k1", 30.0)
save = saver_module.Saver([v0, v1, v2.saveable])
variables.global_variables_initializer().run()
@@ -321,7 +267,7 @@ class SaverTest(test.TestCase):
with self.test_session(graph=ops_lib.Graph()) as sess:
v0 = variables.Variable(-1.0, name="v0")
v1 = variables.Variable(-1.0, name="v1")
- v2 = CheckpointedOp(name="v2")
+ v2 = saver_test_utils.CheckpointedOp(name="v2")
save = saver_module.Saver([v0, v1, v2.saveable])
with self.assertRaisesWithPredicateMatch(
@@ -346,7 +292,7 @@ class SaverTest(test.TestCase):
with self.test_session(graph=ops_lib.Graph()) as sess:
v0_2 = variables.Variable(1000.0, name="v0")
v1_2 = variables.Variable(2000.0, name="v1")
- v2_2 = CheckpointedOp(name="v2")
+ v2_2 = saver_test_utils.CheckpointedOp(name="v2")
save2 = saver_module.Saver([v0_2, v1_2, v2_2.saveable])
v2_2.insert("k1000", 3000.0).run()
variables.global_variables_initializer().run()
@@ -418,7 +364,7 @@ class SaverTest(test.TestCase):
with session.Session("", graph=ops_lib.Graph()) as sess:
one = variables.Variable(1.0)
twos = variables.Variable([2.0, 2.0, 2.0])
- v2 = CheckpointedOp(name="v2")
+ v2 = saver_test_utils.CheckpointedOp(name="v2")
init = variables.global_variables_initializer()
save = saver_module.Saver()
init.run()
@@ -428,7 +374,7 @@ class SaverTest(test.TestCase):
with session.Session("", graph=ops_lib.Graph()) as sess:
one = variables.Variable(0.0)
twos = variables.Variable([0.0, 0.0, 0.0])
- v2 = CheckpointedOp(name="v2")
+ v2 = saver_test_utils.CheckpointedOp(name="v2")
# Saver with no arg, defaults to 'all variables'.
save = saver_module.Saver()
save.restore(sess, save_path)
@@ -593,10 +539,10 @@ class SaveRestoreShardedTest(test.TestCase):
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
v0 = variables.Variable(10, name="v0")
- t0 = CheckpointedOp(name="t0")
+ t0 = saver_test_utils.CheckpointedOp(name="t0")
with sess.graph.device("/cpu:1"):
v1 = variables.Variable(20, name="v1")
- t1 = CheckpointedOp(name="t1")
+ t1 = saver_test_utils.CheckpointedOp(name="t1")
save = saver_module.Saver(
{
"v0": v0,
@@ -623,7 +569,7 @@ class SaveRestoreShardedTest(test.TestCase):
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
v0 = variables.Variable(111, name="v0")
- t0 = CheckpointedOp(name="t0")
+ t0 = saver_test_utils.CheckpointedOp(name="t0")
save = saver_module.Saver({"v0": v0, "t0": t0.saveable}, sharded=True)
variables.global_variables_initializer().run()
t0.insert("k11", 33.0).run()
@@ -641,7 +587,7 @@ class SaveRestoreShardedTest(test.TestCase):
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
v1 = variables.Variable(222)
- t1 = CheckpointedOp(name="t1")
+ t1 = saver_test_utils.CheckpointedOp(name="t1")
save = saver_module.Saver({"v1": v1, "t1": t1.saveable}, sharded=True)
variables.global_variables_initializer().run()
t1.insert("k22", 44.0).run()
@@ -659,10 +605,10 @@ class SaveRestoreShardedTest(test.TestCase):
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
v0 = variables.Variable(111, name="v0")
- t0 = CheckpointedOp(name="t0")
+ t0 = saver_test_utils.CheckpointedOp(name="t0")
with sess.graph.device("/cpu:1"):
v1 = variables.Variable(222, name="v1")
- t1 = CheckpointedOp(name="t1")
+ t1 = saver_test_utils.CheckpointedOp(name="t1")
save = saver_module.Saver(
{
"v0": v0,
diff --git a/tensorflow/python/training/saver_test_utils.py b/tensorflow/python/training/saver_test_utils.py
new file mode 100644
index 0000000000..5f31e2aa53
--- /dev/null
+++ b/tensorflow/python/training/saver_test_utils.py
@@ -0,0 +1,86 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Utility classes for testing checkpointing."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops as ops_lib
+from tensorflow.python.ops import gen_data_flow_ops
+from tensorflow.python.training import saver as saver_module
+
+
+class CheckpointedOp(object):
+ """Op with a custom checkpointing implementation.
+
+ Defined as part of the test because the MutableHashTable Python code is
+ currently in contrib.
+ """
+
+ # pylint: disable=protected-access
+ def __init__(self, name, table_ref=None):
+ if table_ref is None:
+ self.table_ref = gen_data_flow_ops._mutable_hash_table(
+ key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name)
+ else:
+ self.table_ref = table_ref
+ self._name = name
+ self._saveable = CheckpointedOp.CustomSaveable(self, name)
+ ops_lib.add_to_collection(ops_lib.GraphKeys.SAVEABLE_OBJECTS,
+ self._saveable)
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def saveable(self):
+ return self._saveable
+
+ def insert(self, keys, values):
+ return gen_data_flow_ops._lookup_table_insert(self.table_ref, keys, values)
+
+ def lookup(self, keys, default):
+ return gen_data_flow_ops._lookup_table_find(self.table_ref, keys, default)
+
+ def keys(self):
+ return self._export()[0]
+
+ def values(self):
+ return self._export()[1]
+
+ def _export(self):
+ return gen_data_flow_ops._lookup_table_export(self.table_ref, dtypes.string,
+ dtypes.float32)
+
+ class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject):
+ """A custom saveable for CheckpointedOp."""
+
+ def __init__(self, table, name):
+ tensors = table._export()
+ specs = [
+ saver_module.BaseSaverBuilder.SaveSpec(tensors[0], "",
+ name + "-keys"),
+ saver_module.BaseSaverBuilder.SaveSpec(tensors[1], "",
+ name + "-values")
+ ]
+ super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name)
+
+ def restore(self, restore_tensors, shapes):
+ return gen_data_flow_ops._lookup_table_import(
+ self.op.table_ref, restore_tensors[0], restore_tensors[1])
+ # pylint: enable=protected-access