diff options
author | 2017-04-06 17:21:41 -0800 | |
---|---|---|
committer | 2017-04-06 18:42:10 -0700 | |
commit | 7f28f166092e8f6621bc264e12a7201a22f76997 (patch) | |
tree | 8ad4318c808d872248428b796257f276010d47a6 | |
parent | 0ca0b63fddda9bc81ad3bce00594b06e0c543ea9 (diff) |
Allow to set session ConfigProto in RunConfig and use it in Estimator.
Change: 152454548
5 files changed, 34 insertions, 7 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 107454dca1..29ea692f8f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -362,6 +362,11 @@ class BaseEstimator( self._config = config logging.info('Using config: %s', str(vars(self._config))) + if self._config.session_config is None: + self._session_config = config_pb2.ConfigProto(allow_soft_placement=True) + else: + self._session_config = self._config.session_config + # Model directory. if (model_dir is not None) and (self._config.model_dir is not None): if model_dir != self._config.model_dir: @@ -829,7 +834,7 @@ class BaseEstimator( eval_ops=update_op, final_ops=eval_dict, hooks=hooks, - config=config_pb2.ConfigProto(allow_soft_placement=True)) + config=self._session_config) current_global_step = eval_results[global_step_key] _write_dict_to_summary(eval_dir, eval_results, current_global_step) @@ -864,7 +869,7 @@ class BaseEstimator( session_creator=monitored_session.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint_path, scaffold=infer_ops.scaffold, - config=config_pb2.ConfigProto(allow_soft_placement=True))) + config=self._session_config)) if not as_iterable: with mon_sess: if not mon_sess.should_stop(): @@ -976,7 +981,7 @@ class BaseEstimator( chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks, save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, - config=config_pb2.ConfigProto(allow_soft_placement=True) + config=self._session_config ) as mon_sess: loss = None while not mon_sess.should_stop(): diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py index bc7465bbc2..37ee814b62 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py @@ -214,7 +214,8 @@ class RunConfig(ClusterConfig): keep_checkpoint_max=5, keep_checkpoint_every_n_hours=10000, evaluation_master='', - model_dir=None): + model_dir=None, + session_config=None): """Constructor. Note that the superclass `ClusterConfig` may set properties like @@ -246,6 +247,9 @@ class RunConfig(ClusterConfig): evaluation_master: the master on which to perform evaluation. model_dir: directory where model parameters, graph etc are saved. If `None`, see `Estimator` about where the model will be saved. + session_config: a ConfigProto used to set session parameters, or None. + Note - using this argument, it is easy to provide settings which break + otherwise perfectly good models. Use with care. """ super(RunConfig, self).__init__( master=master, evaluation_master=evaluation_master) @@ -261,6 +265,7 @@ class RunConfig(ClusterConfig): self._tf_random_seed = tf_random_seed self._save_summary_steps = save_summary_steps self._save_checkpoints_secs = save_checkpoints_secs + self._session_config = session_config if save_checkpoints_secs == RunConfig._USE_DEFAULT: if save_checkpoints_steps is None: self._save_checkpoints_secs = 600 @@ -346,6 +351,10 @@ class RunConfig(ClusterConfig): return self._save_checkpoints_steps @property + def session_config(self): + return self._session_config + + @property def keep_checkpoint_max(self): return self._keep_checkpoint_max diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index d92c6526a7..80c5bbf684 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -141,6 +141,11 @@ class Estimator(object): logging.info('Using config: %s', str(vars(self._config))) + if self._config.session_config is None: + self._session_config = config_pb2.ConfigProto(allow_soft_placement=True) + else: + self._session_config = self._config.session_config + self._device_fn = _get_replica_device_setter(self._config) if model_fn is None: @@ -317,7 +322,7 @@ class Estimator(object): session_creator=training.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint_path, scaffold=estimator_spec.scaffold, - config=config_pb2.ConfigProto(allow_soft_placement=True)), + config=self._session_config), hooks=hooks) as mon_sess: while not mon_sess.should_stop(): preds_evaluated = mon_sess.run(predictions) @@ -580,7 +585,7 @@ class Estimator(object): chief_only_hooks=chief_hooks + estimator_spec.training_chief_hooks, save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, - config=config_pb2.ConfigProto(allow_soft_placement=True)) as mon_sess: + config=self._session_config) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss]) @@ -635,7 +640,7 @@ class Estimator(object): eval_ops=update_op, final_ops=eval_dict, hooks=hooks, - config=config_pb2.ConfigProto(allow_soft_placement=True)) + config=self._session_config) _write_dict_to_summary( output_dir=eval_dir, diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py index c6e6c60991..79b55c6853 100644 --- a/tensorflow/python/estimator/run_config.py +++ b/tensorflow/python/estimator/run_config.py @@ -73,6 +73,10 @@ class RunConfig(object): return 600 @property + def session_config(self): + return None + + @property def save_checkpoints_steps(self): return None diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt index 5f3dee5b40..8fd991a317 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt @@ -47,6 +47,10 @@ tf_class { mtype: "<type \'property\'>" } member { + name: "session_config" + mtype: "<type \'property\'>" + } + member { name: "task_id" mtype: "<type \'property\'>" } |