aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Lukasz Kaiser <lukaszkaiser@google.com>2017-04-06 17:21:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-06 18:42:10 -0700
commit7f28f166092e8f6621bc264e12a7201a22f76997 (patch)
tree8ad4318c808d872248428b796257f276010d47a6
parent0ca0b63fddda9bc81ad3bce00594b06e0c543ea9 (diff)
Allow to set session ConfigProto in RunConfig and use it in Estimator.
Change: 152454548
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py11
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/run_config.py11
-rw-r--r--tensorflow/python/estimator/estimator.py11
-rw-r--r--tensorflow/python/estimator/run_config.py4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt4
5 files changed, 34 insertions, 7 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 107454dca1..29ea692f8f 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -362,6 +362,11 @@ class BaseEstimator(
self._config = config
logging.info('Using config: %s', str(vars(self._config)))
+ if self._config.session_config is None:
+ self._session_config = config_pb2.ConfigProto(allow_soft_placement=True)
+ else:
+ self._session_config = self._config.session_config
+
# Model directory.
if (model_dir is not None) and (self._config.model_dir is not None):
if model_dir != self._config.model_dir:
@@ -829,7 +834,7 @@ class BaseEstimator(
eval_ops=update_op,
final_ops=eval_dict,
hooks=hooks,
- config=config_pb2.ConfigProto(allow_soft_placement=True))
+ config=self._session_config)
current_global_step = eval_results[global_step_key]
_write_dict_to_summary(eval_dir, eval_results, current_global_step)
@@ -864,7 +869,7 @@ class BaseEstimator(
session_creator=monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint_path,
scaffold=infer_ops.scaffold,
- config=config_pb2.ConfigProto(allow_soft_placement=True)))
+ config=self._session_config))
if not as_iterable:
with mon_sess:
if not mon_sess.should_stop():
@@ -976,7 +981,7 @@ class BaseEstimator(
chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
save_checkpoint_secs=0, # Saving is handled by a hook.
save_summaries_steps=self._config.save_summary_steps,
- config=config_pb2.ConfigProto(allow_soft_placement=True)
+ config=self._session_config
) as mon_sess:
loss = None
while not mon_sess.should_stop():
diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py
index bc7465bbc2..37ee814b62 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py
@@ -214,7 +214,8 @@ class RunConfig(ClusterConfig):
keep_checkpoint_max=5,
keep_checkpoint_every_n_hours=10000,
evaluation_master='',
- model_dir=None):
+ model_dir=None,
+ session_config=None):
"""Constructor.
Note that the superclass `ClusterConfig` may set properties like
@@ -246,6 +247,9 @@ class RunConfig(ClusterConfig):
evaluation_master: the master on which to perform evaluation.
model_dir: directory where model parameters, graph etc are saved. If
`None`, see `Estimator` about where the model will be saved.
+ session_config: a ConfigProto used to set session parameters, or None.
+ Note - using this argument, it is easy to provide settings which break
+ otherwise perfectly good models. Use with care.
"""
super(RunConfig, self).__init__(
master=master, evaluation_master=evaluation_master)
@@ -261,6 +265,7 @@ class RunConfig(ClusterConfig):
self._tf_random_seed = tf_random_seed
self._save_summary_steps = save_summary_steps
self._save_checkpoints_secs = save_checkpoints_secs
+ self._session_config = session_config
if save_checkpoints_secs == RunConfig._USE_DEFAULT:
if save_checkpoints_steps is None:
self._save_checkpoints_secs = 600
@@ -346,6 +351,10 @@ class RunConfig(ClusterConfig):
return self._save_checkpoints_steps
@property
+ def session_config(self):
+ return self._session_config
+
+ @property
def keep_checkpoint_max(self):
return self._keep_checkpoint_max
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index d92c6526a7..80c5bbf684 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -141,6 +141,11 @@ class Estimator(object):
logging.info('Using config: %s', str(vars(self._config)))
+ if self._config.session_config is None:
+ self._session_config = config_pb2.ConfigProto(allow_soft_placement=True)
+ else:
+ self._session_config = self._config.session_config
+
self._device_fn = _get_replica_device_setter(self._config)
if model_fn is None:
@@ -317,7 +322,7 @@ class Estimator(object):
session_creator=training.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint_path,
scaffold=estimator_spec.scaffold,
- config=config_pb2.ConfigProto(allow_soft_placement=True)),
+ config=self._session_config),
hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
preds_evaluated = mon_sess.run(predictions)
@@ -580,7 +585,7 @@ class Estimator(object):
chief_only_hooks=chief_hooks + estimator_spec.training_chief_hooks,
save_checkpoint_secs=0, # Saving is handled by a hook.
save_summaries_steps=self._config.save_summary_steps,
- config=config_pb2.ConfigProto(allow_soft_placement=True)) as mon_sess:
+ config=self._session_config) as mon_sess:
loss = None
while not mon_sess.should_stop():
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
@@ -635,7 +640,7 @@ class Estimator(object):
eval_ops=update_op,
final_ops=eval_dict,
hooks=hooks,
- config=config_pb2.ConfigProto(allow_soft_placement=True))
+ config=self._session_config)
_write_dict_to_summary(
output_dir=eval_dir,
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index c6e6c60991..79b55c6853 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -73,6 +73,10 @@ class RunConfig(object):
return 600
@property
+ def session_config(self):
+ return None
+
+ @property
def save_checkpoints_steps(self):
return None
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt
index 5f3dee5b40..8fd991a317 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt
@@ -47,6 +47,10 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "session_config"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "task_id"
mtype: "<type \'property\'>"
}