aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-06-25 15:17:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-25 15:23:14 -0700
commit55b3cac99dba6c5b882ecca88263a93e60b2c0f9 (patch)
tree620b13c92a4ed6010a26cd6db9685966d6a4b3a1 /tensorflow/python/client
parent79d11c035c83968b91afc6291d8b3d35a6991d47 (diff)
Guard ops modification and Session.run with a group lock. This lock allows multiple ops modifications to happen at the same time, but no Session.run can happen until the modifications are done. And vice-versa.
PiperOrigin-RevId: 202028326
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r--tensorflow/python/client/session.py2
-rw-r--r--tensorflow/python/client/session_test.py69
2 files changed, 52 insertions, 19 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 35aa37ac6d..f3b788f931 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -1291,7 +1291,7 @@ class BaseSession(SessionInterface):
raise type(e)(node_def, op, message)
def _extend_graph(self):
- with self._graph._lock: # pylint: disable=protected-access
+ with self._graph._session_run_lock(): # pylint: disable=protected-access
tf_session.ExtendSession(self._session)
# The threshold to run garbage collection to delete dead tensors.
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index e49d067105..b72e029d1c 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import random
import os
import sys
import threading
@@ -1040,40 +1041,72 @@ class SessionTest(test_util.TensorFlowTestCase):
for t in threads:
t.join()
- def testParallelRunAndBuild(self):
+ @staticmethod
+ def _build_graph():
+ time.sleep(random.random() * 0.1)
+ # Do some graph construction. Try to exercise non-trivial paths.
+ graph = ops.get_default_graph()
+ gdef = None
+ for _ in range(10):
+ x = array_ops.placeholder(dtype=dtypes.float32)
+ with ops.colocate_with(x):
+ y = array_ops.placeholder(dtype=dtypes.float32)
+ with ops.device('/cpu:0'):
+ z = control_flow_ops.while_loop(
+ lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y])
+ with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}):
+ gradients_impl.gradients(z, [x, y])
+ if gdef is None:
+ gdef = graph.as_graph_def()
+ else:
+ importer.import_graph_def(gdef, name='import')
+
+ def testParallelRunAndSingleBuild(self):
with session.Session() as sess:
c = constant_op.constant(5.0)
stop = threading.Event()
def run_loop():
while not stop.is_set():
+ time.sleep(random.random() * 0.1)
self.assertEqual(sess.run(c), 5.0)
- threads = [self.checkedThread(target=run_loop) for _ in range(100)]
+ threads = [self.checkedThread(target=run_loop) for _ in range(10)]
for t in threads:
t.start()
- # Do some graph construction. Try to exercise non-trivial paths.
- graph = ops.get_default_graph()
- gdef = None
- for _ in range(10):
- x = array_ops.placeholder(dtype=dtypes.float32)
- with ops.colocate_with(x):
- y = array_ops.placeholder(dtype=dtypes.float32)
- with ops.device('/cpu:0'):
- z = control_flow_ops.while_loop(
- lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y])
- with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}):
- gradients_impl.gradients(z, [x, y])
- if gdef is None:
- gdef = graph.as_graph_def()
- else:
- importer.import_graph_def(gdef, name='import')
+ SessionTest._build_graph()
stop.set()
for t in threads:
t.join()
+ def testParallelRunAndParallelBuild(self):
+ with session.Session() as sess:
+ c = constant_op.constant(5.0)
+ stop = threading.Event()
+
+ def run_loop():
+ while not stop.is_set():
+ time.sleep(random.random() * 0.1)
+ self.assertEqual(sess.run(c), 5.0)
+
+ run_threads = [self.checkedThread(target=run_loop) for _ in range(10)]
+ for t in run_threads:
+ t.start()
+
+ build_threads = [self.checkedThread(target=SessionTest._build_graph)
+ for _ in range(10)]
+ for t in build_threads:
+ t.start()
+ for t in build_threads:
+ t.join()
+
+ # Let the run_threads run until the build threads are finished.
+ stop.set()
+ for t in run_threads:
+ t.join()
+
def testRunFeedDict(self):
with session.Session() as s:
x = array_ops.zeros([2])