diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-01 16:20:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-01 16:23:11 -0700 |
commit | fb8f040f2a927c6df149238da7c4278cf781d081 (patch) | |
tree | f6201f6c66fbaf35661d708773b9a1d4f43c303b | |
parent | 210abebd3febdd2c44ab5021bcebf8f1f5d451c4 (diff) |
Allow `warm_start_from` argument to be a SavedModel path.
PiperOrigin-RevId: 195015356
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 28 | ||||
-rw-r--r-- | tensorflow/python/estimator/estimator_test.py | 35 |
2 files changed, 54 insertions, 9 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 0970f00124..3691c99dda 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -155,12 +155,12 @@ class Estimator(object): config: Configuration object. params: `dict` of hyper parameters that will be passed into `model_fn`. Keys are names of parameters, values are basic python types. - warm_start_from: Optional string filepath to a checkpoint to warm-start - from, or a `tf.estimator.WarmStartSettings` object to - fully configure warm-starting. If the string filepath is - provided instead of a `WarmStartSettings`, then all - variables are warm-started, and it is assumed that - vocabularies and Tensor names are unchanged. + warm_start_from: Optional string filepath to a checkpoint or SavedModel to + warm-start from, or a `tf.estimator.WarmStartSettings` + object to fully configure warm-starting. If the string + filepath is provided instead of a `WarmStartSettings`, + then all variables are warm-started, and it is assumed + that vocabularies and Tensor names are unchanged. Raises: ValueError: parameters of `model_fn` don't match `params`. @@ -1502,7 +1502,7 @@ def _get_default_warm_start_settings(warm_start_from): Args: warm_start_from: Either a string representing the filepath of a checkpoint - to initialize from, or an instance of WarmStartSettings. + or SavedModel to initialize from, or an instance of WarmStartSettings. Returns: Either None or an instance of WarmStartSettings. @@ -1513,9 +1513,19 @@ def _get_default_warm_start_settings(warm_start_from): """ if warm_start_from is None: return None - if isinstance(warm_start_from, six.string_types): + if isinstance(warm_start_from, (six.string_types, six.binary_type)): + # Infer that this is a SavedModel if export_path + + # 'variables/variables.index' exists, and if so, construct the + # WarmStartSettings pointing to export_path + 'variables/variables'. + if gfile.Exists(os.path.join(compat.as_bytes(warm_start_from), + compat.as_bytes('variables/variables.index'))): + logging.info('Warm-starting from a SavedModel') + return WarmStartSettings(ckpt_to_initialize_from=os.path.join( + compat.as_bytes(warm_start_from), + compat.as_bytes('variables/variables'))) return WarmStartSettings(ckpt_to_initialize_from=warm_start_from) elif isinstance(warm_start_from, WarmStartSettings): return warm_start_from else: - raise ValueError('warm_start_from must be a string or a WarmStartSettings') + raise ValueError('warm_start_from must be a string or a WarmStartSettings, ' + 'instead got {}'.format(type(warm_start_from))) diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 74114fab3b..4d958f8b43 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -658,6 +658,41 @@ class EstimatorTrainTest(test.TestCase): 5, estimator._load_global_step_from_checkpoint_dir( warm_started_est.model_dir)) + def test_warm_starts_from_savedmodel(self): + def _make_model_fn(x): + def _variable_creating_and_export_model_fn(features, labels, mode): + _, _ = features, labels + variable_scope.get_variable('x', initializer=x) + global_step = training.get_global_step() + return model_fn_lib.EstimatorSpec( + mode, + predictions={'y': constant_op.constant(1.0)}, + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(global_step, 1), + export_outputs={'test': export_output.ClassificationOutput( + constant_op.constant([4.2]), constant_op.constant(['label']))}) + return _variable_creating_and_export_model_fn + + est = estimator.Estimator(model_fn=_make_model_fn(42.)) + est.train(dummy_input_fn, steps=10) + feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), + 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + tmpdir = tempfile.mkdtemp() + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = est.export_savedmodel( + export_dir_base, serving_input_receiver_fn) + + warm_started_est = estimator.Estimator( + model_fn=_make_model_fn(36.), + warm_start_from=export_dir) + warm_started_est.train(dummy_input_fn, steps=5) + # warm_start is called after the model_fn, so x should have the value + # from the SavedModel. + self.assertEqual(42., warm_started_est.get_variable_value('x')) + def test_max_step(self): est = estimator.Estimator(model_fn=model_fn_global_step_incrementer) est.train(dummy_input_fn, max_steps=5) |