aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/training.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/training.py')
-rw-r--r--tensorflow/python/estimator/training.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index f789bcbbac..52fb1d39ae 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -677,11 +677,19 @@ class _TrainingExecutor(object):
'RunConfig or set the TF_CONFIG environment variable.')
logging.info('Start Tensorflow server.')
+
+ if config.session_config is None:
+ session_config=config_pb2.ConfigProto(log_device_placement=False)
+ else:
+ session_config=config_pb2.ConfigProto(
+ log_device_placement=False,
+ gpu_options=config.session_config.gpu_options)
+
server = server_lib.Server(
config.cluster_spec,
job_name=config.task_type,
task_index=config.task_id,
- config=config_pb2.ConfigProto(log_device_placement=False),
+ config=session_config,
start=False)
server.start()
return server