diff options
author | Priya Gupta <priyag@google.com> | 2018-09-04 21:38:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-04 21:44:02 -0700 |
commit | 220a546cfae7459abf7d0e4c50bb9848fa69ff53 (patch) | |
tree | 56061252d26ad16c6426d4f616033b893a64f8b2 /tensorflow/python/keras/backend.py | |
parent | c8be0ea9bb3a86f9bf7b1636246ecef1b9869924 (diff) |
Allow configuring session options in keras when running with distribution strategy.
PiperOrigin-RevId: 211576839
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r-- | tensorflow/python/keras/backend.py | 18 |
1 files changed, 11 insertions, 7 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index b52ab7f05c..7768caeaf0 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -443,13 +443,7 @@ def get_session(): session = default_session else: if _SESSION is None: - if not os.environ.get('OMP_NUM_THREADS'): - config = config_pb2.ConfigProto(allow_soft_placement=True) - else: - num_thread = int(os.environ.get('OMP_NUM_THREADS')) - config = config_pb2.ConfigProto( - intra_op_parallelism_threads=num_thread, allow_soft_placement=True) - _SESSION = session_module.Session(config=config) + _SESSION = session_module.Session(config=get_default_session_config()) session = _SESSION if not _MANUAL_VAR_INIT: with session.graph.as_default(): @@ -468,6 +462,16 @@ def set_session(session): _SESSION = session +def get_default_session_config(): + if not os.environ.get('OMP_NUM_THREADS'): + config = config_pb2.ConfigProto(allow_soft_placement=True) + else: + num_thread = int(os.environ.get('OMP_NUM_THREADS')) + config = config_pb2.ConfigProto( + intra_op_parallelism_threads=num_thread, allow_soft_placement=True) + return config + + # DEVICE MANIPULATION |