diff options
author | 2017-03-15 13:34:20 -0800 | |
---|---|---|
committer | 2017-03-15 14:56:00 -0700 | |
commit | 19ef82151be358698f9ab8702ed5575afc94f110 (patch) | |
tree | 42cc430ca40355d103543b1f59c3f6b4ada020e1 | |
parent | b05e0840d11ee30c3a66d45daeeea2495b9808e5 (diff) |
Make SavedModel exports include all the SAVEABLE objects and not just global variables.
Change: 150243023
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/estimator.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/estimator_test.py | 87 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 13 | ||||
-rw-r--r-- | tensorflow/python/estimator/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 2 | ||||
-rw-r--r-- | tensorflow/python/estimator/estimator_test.py | 59 | ||||
-rw-r--r-- | tensorflow/python/saved_model/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/saved_model/builder_impl.py | 8 | ||||
-rw-r--r-- | tensorflow/python/saved_model/saved_model_test.py | 30 | ||||
-rw-r--r-- | tensorflow/python/training/saver_test.py | 86 | ||||
-rw-r--r-- | tensorflow/python/training/saver_test_utils.py | 86 |
11 files changed, 303 insertions, 77 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 6f9c44dff6..4c575bb62d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -58,6 +58,7 @@ from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging @@ -1254,13 +1255,17 @@ class Estimator(BaseEstimator): with tf_session.Session('') as session: variables.initialize_local_variables() data_flow_ops.tables_initializer() + resources.initialize_resources(resources.shared_resources()) saver_for_restore = saver.Saver( - variables.global_variables(), + # pylint: disable=protected-access + variables._all_saveable_objects(), + # pylint: enable=protected-access sharded=True) saver_for_restore.restore(session, checkpoint_path) init_op = control_flow_ops.group( variables.local_variables_initializer(), + resources.initialize_resources(resources.shared_resources()), data_flow_ops.tables_initializer()) # Perform the export diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index 659266690f..b27e99b0dc 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -50,6 +50,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops @@ -225,6 +226,49 @@ def _build_estimator_for_export_tests(tmpdir): return est, serving_input_fn_with_asset +def _build_estimator_for_resource_export_test(): + + def _input_fn(): + iris = base.load_iris() + return { + 'feature': constant_op.constant(iris.data, dtype=dtypes.float32) + }, constant_op.constant( + iris.target, shape=[150], dtype=dtypes.int32) + + feature_columns = [ + feature_column_lib.real_valued_column('feature', dimension=4) + ] + + def resource_constant_model_fn(unused_features, unused_labels, mode): + """A model_fn that loads a constant from a resource and serves it.""" + assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL, + model_fn.ModeKeys.INFER) + + const = constant_op.constant(-1, dtype=dtypes.int64) + table = lookup.MutableHashTable( + dtypes.string, dtypes.int64, const, name='LookupTableModel') + if mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL): + key = constant_op.constant(['key']) + value = constant_op.constant([42], dtype=dtypes.int64) + train_op_1 = table.insert(key, value) + training_state = lookup.MutableHashTable( + dtypes.string, dtypes.int64, const, name='LookupTableTrainingState') + training_op_2 = training_state.insert(key, value) + return const, const, control_flow_ops.group(train_op_1, training_op_2) + if mode == model_fn.ModeKeys.INFER: + key = constant_op.constant(['key']) + prediction = table.lookup(key) + return prediction, const, control_flow_ops.no_op() + + est = estimator.Estimator(model_fn=resource_constant_model_fn) + est.fit(input_fn=_input_fn, steps=1) + + feature_spec = feature_column_lib.create_feature_spec_for_parsing( + feature_columns) + serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec) + return est, serving_input_fn + + class CheckCallsMonitor(monitors_lib.BaseMonitor): def __init__(self, expect_calls): @@ -753,6 +797,49 @@ class EstimatorTest(test.TestCase): # cleanup gfile.DeleteRecursively(tmpdir) + def test_export_savedmodel_with_resource(self): + tmpdir = tempfile.mkdtemp() + est, serving_input_fn = _build_estimator_for_resource_export_test() + + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = est.export_savedmodel(export_dir_base, serving_input_fn) + + self.assertTrue(gfile.Exists(export_dir_base)) + self.assertTrue(gfile.Exists(export_dir)) + self.assertTrue( + gfile.Exists( + os.path.join( + compat.as_bytes(export_dir), compat.as_bytes( + 'saved_model.pb')))) + self.assertTrue( + gfile.Exists( + os.path.join( + compat.as_bytes(export_dir), compat.as_bytes('variables')))) + self.assertTrue( + gfile.Exists( + os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.index')))) + self.assertTrue( + gfile.Exists( + os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.data-00000-of-00001')))) + + # Restore, to validate that the export was well-formed. + with ops.Graph().as_default() as graph: + with session_lib.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertTrue('LookupTableModel' in graph_ops) + self.assertFalse('LookupTableTrainingState' in graph_ops) + + # cleanup + gfile.DeleteRecursively(tmpdir) + class InferRealValuedColumnsTest(test.TestCase): diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 7f416dc609..5c3e87cfd6 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -74,6 +74,7 @@ py_library( ":tensor_array_ops", ":training", ":ops", + ":saver_test_utils", ":test_ops", # TODO: Break testing code out into separate rule. ":util", ":weights_broadcast_ops", @@ -2935,6 +2936,16 @@ cuda_py_tests( ], ) +py_library( + name = "saver_test_utils", + srcs = ["training/saver_test_utils.py"], + srcs_version = "PY2AND3", + deps = [ + ":data_flow_ops_gen", + ":training", + ], +) + cuda_py_test( name = "saver_test", size = "medium", @@ -2946,12 +2957,12 @@ cuda_py_test( ":client_testlib", ":control_flow_ops", ":data_flow_ops", - ":data_flow_ops_gen", ":errors", ":gradients", ":math_ops", ":nn_grad", ":nn_ops", + ":saver_test_utils", ":partitioned_variables", ":platform", ":platform_test", diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index e6d4af2f95..bcbf06deab 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -136,6 +136,7 @@ py_test( "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:layers", + "//tensorflow/python:saver_test_utils", "//tensorflow/python:session", "//tensorflow/python:state_ops", "//tensorflow/python:training", diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index d30f9093b6..677c840ff0 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -411,7 +411,7 @@ class Estimator(object): with tf_session.Session() as session: saver_for_restore = estimator_spec.scaffold.saver or saver.Saver( - variables.global_variables(), + variables._all_saveable_objects(), # pylint: disable=protected-access sharded=True) saver_for_restore.restore(session, checkpoint_path) diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 061a1226fb..5b0b044ab2 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -48,6 +48,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import saver +from tensorflow.python.training import saver_test_utils from tensorflow.python.training import session_run_hook from tensorflow.python.training import training from tensorflow.python.util import compat @@ -814,6 +815,20 @@ def _model_fn_for_export_tests(features, labels, mode): 'test': export_output.ClassificationOutput(scores, classes)}) +def _model_fn_with_saveables_for_export_tests(features, labels, mode): + _, _ = features, labels + table = saver_test_utils.CheckpointedOp(name='v2') + train_op = table.insert('k1', 30.0) + prediction = table.lookup('k1', 0.0) + return model_fn_lib.EstimatorSpec( + mode, + predictions=prediction, + loss=constant_op.constant(1.), + train_op=train_op, + export_outputs={ + 'test': export_output.PredictOutput({'prediction': prediction})}) + + _VOCAB_FILE_CONTENT = 'emerson\nlake\npalmer\n' _EXTRA_FILE_CONTENT = 'kermit\npiggy\nralph\n' @@ -863,6 +878,50 @@ class EstimatorExportTest(test.TestCase): # Clean up. gfile.DeleteRecursively(tmpdir) + def test_export_savedmodel_with_saveables_proto_roundtrip(self): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator( + model_fn=_model_fn_with_saveables_for_export_tests) + est.train(input_fn=dummy_input_fn, steps=1) + feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), + 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = est.export_savedmodel( + export_dir_base, serving_input_receiver_fn) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + self.assertTrue(gfile.Exists(export_dir)) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('saved_model.pb')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.index')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.data-00000-of-00001')))) + + # Restore, to validate that the export was well-formed. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertTrue('save/LookupTableImport' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + def test_export_savedmodel_assets(self): tmpdir = tempfile.mkdtemp() est = estimator.Estimator(model_fn=_model_fn_for_export_tests) diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 79399c11c4..8301a73e87 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -122,6 +122,7 @@ py_test( "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:lib", "//tensorflow/python:math_ops", + "//tensorflow/python:saver_test_utils", "//tensorflow/python:state_ops", "//tensorflow/python:util", "//tensorflow/python:variables", diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index 4fd87c04ec..7b4fabad95 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -352,10 +352,10 @@ class SavedModelBuilder(object): else: self._add_main_op(main_op) - # Initialize a saver to generate a sharded output for all variables in the + # Initialize a saver to generate a sharded output for all saveables in the # current scope. saver = tf_saver.Saver( - variables.global_variables(), + variables._all_saveable_objects(), # pylint: disable=protected-access sharded=True, write_version=saver_pb2.SaverDef.V2, allow_empty=True) @@ -423,10 +423,10 @@ class SavedModelBuilder(object): else: self._add_main_op(main_op) - # Initialize a saver to generate a sharded output for all variables in the + # Initialize a saver to generate a sharded output for all saveables in the # current scope. saver = tf_saver.Saver( - variables.global_variables(), + variables._all_saveable_objects(), # pylint: disable=protected-access sharded=True, write_version=saver_pb2.SaverDef.V2, allow_empty=True) diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 03c5901bd4..a81f744175 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -39,6 +39,7 @@ from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import main_op from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants +from tensorflow.python.training import saver_test_utils from tensorflow.python.util import compat SAVED_MODEL_PATH = ("cc/saved_model/testdata/half_plus_two/00000123") @@ -734,6 +735,35 @@ class SavedModelTest(test.TestCase): ops.get_collection("init_op")[0].run() self.assertEqual(3, ops.get_collection("v")[2].eval()) + def testCustomSaveable(self): + export_dir = os.path.join(test.get_temp_dir(), "custom_saveable") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with session.Session( + graph=ops.Graph(), + config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: + # CheckpointedOp is a key-value table that can be saved across sessions. + # The table register itself in SAVEABLE_OBJECTS collection. + v1 = saver_test_utils.CheckpointedOp(name="v1") + variables.global_variables_initializer().run() + v1.insert("k1", 3.0).run() + # Once the table is restored, we can access it through this reference. + ops.add_to_collection("table_ref", v1.table_ref) + builder.add_meta_graph_and_variables(sess, ["foo"]) + + # Save the SavedModel to disk. + builder.save() + + with session.Session( + graph=ops.Graph(), + config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: + loader.load(sess, ["foo"], export_dir) + # Instantiate a wrapper object from the checkpointed reference. + v1 = saver_test_utils.CheckpointedOp( + name="v1", table_ref=ops.get_collection("table_ref")[0]) + self.assertEqual(b"k1", v1.keys().eval()) + self.assertEqual(3.0, v1.values().eval()) + def testClearDevices(self): export_dir = os.path.join(test.get_temp_dir(), "test_clear_devices") builder = saved_model_builder.SavedModelBuilder(export_dir) diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index 4ca12cb24e..abb1f89e86 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -48,7 +48,6 @@ from tensorflow.python.framework import ops as ops_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops -from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import partitioned_variables @@ -65,63 +64,10 @@ from tensorflow.python.training import adam from tensorflow.python.training import gradient_descent from tensorflow.python.training import queue_runner_impl from tensorflow.python.training import saver as saver_module +from tensorflow.python.training import saver_test_utils from tensorflow.python.util import compat -class CheckpointedOp(object): - """Op with a custom checkpointing implementation. - - Defined as part of the test because the MutableHashTable Python code is - currently in contrib. - """ - - def __init__(self, name): - self._table_ref = gen_data_flow_ops._mutable_hash_table( - key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name) - self._name = name - self._saveable = CheckpointedOp.CustomSaveable(self, name) - ops_lib.add_to_collection(ops_lib.GraphKeys.SAVEABLE_OBJECTS, - self._saveable) - - @property - def name(self): - return self._name - - @property - def saveable(self): - return self._saveable - - def insert(self, keys, values): - return gen_data_flow_ops._lookup_table_insert(self._table_ref, keys, values) - - def keys(self): - return self._export()[0] - - def values(self): - return self._export()[1] - - def _export(self): - return gen_data_flow_ops._lookup_table_export(self._table_ref, - dtypes.string, dtypes.float32) - - class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject): - - def __init__(self, table, name): - tensors = table._export() - specs = [ - saver_module.BaseSaverBuilder.SaveSpec(tensors[0], "", - name + "-keys"), - saver_module.BaseSaverBuilder.SaveSpec(tensors[1], "", - name + "-values") - ] - super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name) - - def restore(self, restore_tensors, shapes): - return gen_data_flow_ops._lookup_table_import(self.op._table_ref, - restore_tensors[0], - restore_tensors[1]) - - class SaverTest(test.TestCase): def basicSaveRestore(self, variable_op): @@ -131,7 +77,7 @@ class SaverTest(test.TestCase): # Restore nodes for them. v0 = variable_op(10.0, name="v0") v1 = variable_op(20.0, name="v1") - v2 = CheckpointedOp(name="v2") + v2 = saver_test_utils.CheckpointedOp(name="v2") v2_init = v2.insert("k1", 30.0) save = saver_module.Saver( { @@ -161,7 +107,7 @@ class SaverTest(test.TestCase): with self.test_session() as sess: v0 = variable_op(-1.0, name="v0") v1 = variable_op(-1.0, name="v1") - v2 = CheckpointedOp(name="v2") + v2 = saver_test_utils.CheckpointedOp(name="v2") save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable}) # Assert that the variables are not initialized. @@ -183,7 +129,7 @@ class SaverTest(test.TestCase): with self.test_session() as sess: v0_2 = variable_op(1000.0, name="v0") v1_2 = variable_op(2000.0, name="v1") - v2_2 = CheckpointedOp(name="v2") + v2_2 = saver_test_utils.CheckpointedOp(name="v2") save2 = saver_module.Saver({"v0": v0_2, "v1": v1_2, "v2": v2_2.saveable}) v2_2.insert("k1000", 3000.0).run() variables.global_variables_initializer().run() @@ -276,7 +222,7 @@ class SaverTest(test.TestCase): def testSameName(self): with ops_lib.Graph().as_default(): v0 = variables.Variable([10.0], name="v0") - v2 = CheckpointedOp(name="v2") + v2 = saver_test_utils.CheckpointedOp(name="v2") # Saving one variable under two names raises an error. with self.assertRaisesRegexp( @@ -299,7 +245,7 @@ class SaverTest(test.TestCase): # Restore nodes for them. v0 = variables.Variable(10.0, name="v0") v1 = variables.Variable(20.0, name="v1") - v2 = CheckpointedOp(name="v2") + v2 = saver_test_utils.CheckpointedOp(name="v2") v2_init = v2.insert("k1", 30.0) save = saver_module.Saver([v0, v1, v2.saveable]) variables.global_variables_initializer().run() @@ -321,7 +267,7 @@ class SaverTest(test.TestCase): with self.test_session(graph=ops_lib.Graph()) as sess: v0 = variables.Variable(-1.0, name="v0") v1 = variables.Variable(-1.0, name="v1") - v2 = CheckpointedOp(name="v2") + v2 = saver_test_utils.CheckpointedOp(name="v2") save = saver_module.Saver([v0, v1, v2.saveable]) with self.assertRaisesWithPredicateMatch( @@ -346,7 +292,7 @@ class SaverTest(test.TestCase): with self.test_session(graph=ops_lib.Graph()) as sess: v0_2 = variables.Variable(1000.0, name="v0") v1_2 = variables.Variable(2000.0, name="v1") - v2_2 = CheckpointedOp(name="v2") + v2_2 = saver_test_utils.CheckpointedOp(name="v2") save2 = saver_module.Saver([v0_2, v1_2, v2_2.saveable]) v2_2.insert("k1000", 3000.0).run() variables.global_variables_initializer().run() @@ -418,7 +364,7 @@ class SaverTest(test.TestCase): with session.Session("", graph=ops_lib.Graph()) as sess: one = variables.Variable(1.0) twos = variables.Variable([2.0, 2.0, 2.0]) - v2 = CheckpointedOp(name="v2") + v2 = saver_test_utils.CheckpointedOp(name="v2") init = variables.global_variables_initializer() save = saver_module.Saver() init.run() @@ -428,7 +374,7 @@ class SaverTest(test.TestCase): with session.Session("", graph=ops_lib.Graph()) as sess: one = variables.Variable(0.0) twos = variables.Variable([0.0, 0.0, 0.0]) - v2 = CheckpointedOp(name="v2") + v2 = saver_test_utils.CheckpointedOp(name="v2") # Saver with no arg, defaults to 'all variables'. save = saver_module.Saver() save.restore(sess, save_path) @@ -593,10 +539,10 @@ class SaveRestoreShardedTest(test.TestCase): config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: with sess.graph.device("/cpu:0"): v0 = variables.Variable(10, name="v0") - t0 = CheckpointedOp(name="t0") + t0 = saver_test_utils.CheckpointedOp(name="t0") with sess.graph.device("/cpu:1"): v1 = variables.Variable(20, name="v1") - t1 = CheckpointedOp(name="t1") + t1 = saver_test_utils.CheckpointedOp(name="t1") save = saver_module.Saver( { "v0": v0, @@ -623,7 +569,7 @@ class SaveRestoreShardedTest(test.TestCase): config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: with sess.graph.device("/cpu:0"): v0 = variables.Variable(111, name="v0") - t0 = CheckpointedOp(name="t0") + t0 = saver_test_utils.CheckpointedOp(name="t0") save = saver_module.Saver({"v0": v0, "t0": t0.saveable}, sharded=True) variables.global_variables_initializer().run() t0.insert("k11", 33.0).run() @@ -641,7 +587,7 @@ class SaveRestoreShardedTest(test.TestCase): config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: with sess.graph.device("/cpu:0"): v1 = variables.Variable(222) - t1 = CheckpointedOp(name="t1") + t1 = saver_test_utils.CheckpointedOp(name="t1") save = saver_module.Saver({"v1": v1, "t1": t1.saveable}, sharded=True) variables.global_variables_initializer().run() t1.insert("k22", 44.0).run() @@ -659,10 +605,10 @@ class SaveRestoreShardedTest(test.TestCase): config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: with sess.graph.device("/cpu:0"): v0 = variables.Variable(111, name="v0") - t0 = CheckpointedOp(name="t0") + t0 = saver_test_utils.CheckpointedOp(name="t0") with sess.graph.device("/cpu:1"): v1 = variables.Variable(222, name="v1") - t1 = CheckpointedOp(name="t1") + t1 = saver_test_utils.CheckpointedOp(name="t1") save = saver_module.Saver( { "v0": v0, diff --git a/tensorflow/python/training/saver_test_utils.py b/tensorflow/python/training/saver_test_utils.py new file mode 100644 index 0000000000..5f31e2aa53 --- /dev/null +++ b/tensorflow/python/training/saver_test_utils.py @@ -0,0 +1,86 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Utility classes for testing checkpointing.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.ops import gen_data_flow_ops +from tensorflow.python.training import saver as saver_module + + +class CheckpointedOp(object): + """Op with a custom checkpointing implementation. + + Defined as part of the test because the MutableHashTable Python code is + currently in contrib. + """ + + # pylint: disable=protected-access + def __init__(self, name, table_ref=None): + if table_ref is None: + self.table_ref = gen_data_flow_ops._mutable_hash_table( + key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name) + else: + self.table_ref = table_ref + self._name = name + self._saveable = CheckpointedOp.CustomSaveable(self, name) + ops_lib.add_to_collection(ops_lib.GraphKeys.SAVEABLE_OBJECTS, + self._saveable) + + @property + def name(self): + return self._name + + @property + def saveable(self): + return self._saveable + + def insert(self, keys, values): + return gen_data_flow_ops._lookup_table_insert(self.table_ref, keys, values) + + def lookup(self, keys, default): + return gen_data_flow_ops._lookup_table_find(self.table_ref, keys, default) + + def keys(self): + return self._export()[0] + + def values(self): + return self._export()[1] + + def _export(self): + return gen_data_flow_ops._lookup_table_export(self.table_ref, dtypes.string, + dtypes.float32) + + class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject): + """A custom saveable for CheckpointedOp.""" + + def __init__(self, table, name): + tensors = table._export() + specs = [ + saver_module.BaseSaverBuilder.SaveSpec(tensors[0], "", + name + "-keys"), + saver_module.BaseSaverBuilder.SaveSpec(tensors[1], "", + name + "-values") + ] + super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name) + + def restore(self, restore_tensors, shapes): + return gen_data_flow_ops._lookup_table_import( + self.op.table_ref, restore_tensors[0], restore_tensors[1]) + # pylint: enable=protected-access |