aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-01 16:20:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-01 16:23:11 -0700
commitfb8f040f2a927c6df149238da7c4278cf781d081 (patch)
treef6201f6c66fbaf35661d708773b9a1d4f43c303b
parent210abebd3febdd2c44ab5021bcebf8f1f5d451c4 (diff)
Allow `warm_start_from` argument to be a SavedModel path.
PiperOrigin-RevId: 195015356
-rw-r--r--tensorflow/python/estimator/estimator.py28
-rw-r--r--tensorflow/python/estimator/estimator_test.py35
2 files changed, 54 insertions, 9 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 0970f00124..3691c99dda 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -155,12 +155,12 @@ class Estimator(object):
config: Configuration object.
params: `dict` of hyper parameters that will be passed into `model_fn`.
Keys are names of parameters, values are basic python types.
- warm_start_from: Optional string filepath to a checkpoint to warm-start
- from, or a `tf.estimator.WarmStartSettings` object to
- fully configure warm-starting. If the string filepath is
- provided instead of a `WarmStartSettings`, then all
- variables are warm-started, and it is assumed that
- vocabularies and Tensor names are unchanged.
+ warm_start_from: Optional string filepath to a checkpoint or SavedModel to
+ warm-start from, or a `tf.estimator.WarmStartSettings`
+ object to fully configure warm-starting. If the string
+ filepath is provided instead of a `WarmStartSettings`,
+ then all variables are warm-started, and it is assumed
+ that vocabularies and Tensor names are unchanged.
Raises:
ValueError: parameters of `model_fn` don't match `params`.
@@ -1502,7 +1502,7 @@ def _get_default_warm_start_settings(warm_start_from):
Args:
warm_start_from: Either a string representing the filepath of a checkpoint
- to initialize from, or an instance of WarmStartSettings.
+ or SavedModel to initialize from, or an instance of WarmStartSettings.
Returns:
Either None or an instance of WarmStartSettings.
@@ -1513,9 +1513,19 @@ def _get_default_warm_start_settings(warm_start_from):
"""
if warm_start_from is None:
return None
- if isinstance(warm_start_from, six.string_types):
+ if isinstance(warm_start_from, (six.string_types, six.binary_type)):
+ # Infer that this is a SavedModel if export_path +
+ # 'variables/variables.index' exists, and if so, construct the
+ # WarmStartSettings pointing to export_path + 'variables/variables'.
+ if gfile.Exists(os.path.join(compat.as_bytes(warm_start_from),
+ compat.as_bytes('variables/variables.index'))):
+ logging.info('Warm-starting from a SavedModel')
+ return WarmStartSettings(ckpt_to_initialize_from=os.path.join(
+ compat.as_bytes(warm_start_from),
+ compat.as_bytes('variables/variables')))
return WarmStartSettings(ckpt_to_initialize_from=warm_start_from)
elif isinstance(warm_start_from, WarmStartSettings):
return warm_start_from
else:
- raise ValueError('warm_start_from must be a string or a WarmStartSettings')
+ raise ValueError('warm_start_from must be a string or a WarmStartSettings, '
+ 'instead got {}'.format(type(warm_start_from)))
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 74114fab3b..4d958f8b43 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -658,6 +658,41 @@ class EstimatorTrainTest(test.TestCase):
5, estimator._load_global_step_from_checkpoint_dir(
warm_started_est.model_dir))
+ def test_warm_starts_from_savedmodel(self):
+ def _make_model_fn(x):
+ def _variable_creating_and_export_model_fn(features, labels, mode):
+ _, _ = features, labels
+ variable_scope.get_variable('x', initializer=x)
+ global_step = training.get_global_step()
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ predictions={'y': constant_op.constant(1.0)},
+ loss=constant_op.constant(1.),
+ train_op=state_ops.assign_add(global_step, 1),
+ export_outputs={'test': export_output.ClassificationOutput(
+ constant_op.constant([4.2]), constant_op.constant(['label']))})
+ return _variable_creating_and_export_model_fn
+
+ est = estimator.Estimator(model_fn=_make_model_fn(42.))
+ est.train(dummy_input_fn, steps=10)
+ feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
+ 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
+ serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+ tmpdir = tempfile.mkdtemp()
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+ export_dir = est.export_savedmodel(
+ export_dir_base, serving_input_receiver_fn)
+
+ warm_started_est = estimator.Estimator(
+ model_fn=_make_model_fn(36.),
+ warm_start_from=export_dir)
+ warm_started_est.train(dummy_input_fn, steps=5)
+ # warm_start is called after the model_fn, so x should have the value
+ # from the SavedModel.
+ self.assertEqual(42., warm_started_est.get_variable_value('x'))
+
def test_max_step(self):
est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
est.train(dummy_input_fn, max_steps=5)