aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/ops.py')
-rw-r--r--tensorflow/python/framework/ops.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 98a1802490..5527f52860 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -4860,6 +4860,18 @@ class Graph(object):
else:
self._graph_control_dependencies_stack = control_dependencies
+ @property
+ def _distribution_strategy_stack(self):
+ """A stack to maintain distribution strategy context for each thread."""
+ if not hasattr(self._thread_local, "_distribution_strategy_stack"):
+ self._thread_local._distribution_strategy_stack = [] # pylint: disable=protected-access
+ return self._thread_local._distribution_strategy_stack # pylint: disable=protected-access
+
+ @_distribution_strategy_stack.setter
+ def _distribution_strategy_stack(self, _distribution_strategy_stack):
+ self._thread_local._distribution_strategy_stack = ( # pylint: disable=protected-access
+ _distribution_strategy_stack)
+
def _mutation_lock(self):
"""Returns a lock to guard code that creates & mutates ops.