aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/backend.py
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-09-04 21:38:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 21:44:02 -0700
commit220a546cfae7459abf7d0e4c50bb9848fa69ff53 (patch)
tree56061252d26ad16c6426d4f616033b893a64f8b2 /tensorflow/python/keras/backend.py
parentc8be0ea9bb3a86f9bf7b1636246ecef1b9869924 (diff)
Allow configuring session options in keras when running with distribution strategy.
PiperOrigin-RevId: 211576839
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r--tensorflow/python/keras/backend.py18
1 files changed, 11 insertions, 7 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index b52ab7f05c..7768caeaf0 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -443,13 +443,7 @@ def get_session():
session = default_session
else:
if _SESSION is None:
- if not os.environ.get('OMP_NUM_THREADS'):
- config = config_pb2.ConfigProto(allow_soft_placement=True)
- else:
- num_thread = int(os.environ.get('OMP_NUM_THREADS'))
- config = config_pb2.ConfigProto(
- intra_op_parallelism_threads=num_thread, allow_soft_placement=True)
- _SESSION = session_module.Session(config=config)
+ _SESSION = session_module.Session(config=get_default_session_config())
session = _SESSION
if not _MANUAL_VAR_INIT:
with session.graph.as_default():
@@ -468,6 +462,16 @@ def set_session(session):
_SESSION = session
+def get_default_session_config():
+ if not os.environ.get('OMP_NUM_THREADS'):
+ config = config_pb2.ConfigProto(allow_soft_placement=True)
+ else:
+ num_thread = int(os.environ.get('OMP_NUM_THREADS'))
+ config = config_pb2.ConfigProto(
+ intra_op_parallelism_threads=num_thread, allow_soft_placement=True)
+ return config
+
+
# DEVICE MANIPULATION