aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-07 20:25:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-08 16:17:03 -0800
commite1dddfb401fa9bdb964f961f6f61c5df8961d0c2 (patch)
tree4cc48cb789ec70be7331059b9c3794debe15b46d
parent10d6df2f95ca3617a4371c2965e5cc98fcd803aa (diff)
Improve performance of tf.subscribe on large graphs by caching control outputs.
Change: 138473617
-rw-r--r--tensorflow/python/framework/subscribe.py62
-rw-r--r--tensorflow/python/framework/subscribe_test.py30
2 files changed, 76 insertions, 16 deletions
diff --git a/tensorflow/python/framework/subscribe.py b/tensorflow/python/framework/subscribe.py
index 53d299a976..1f3bca71b2 100644
--- a/tensorflow/python/framework/subscribe.py
+++ b/tensorflow/python/framework/subscribe.py
@@ -57,25 +57,51 @@ def _recursive_apply(tensors, apply_fn):
(tensors, tensors_type))
-def _control_outputs(op):
- """Returns the control_input consumers for the supplied `Operation`.
-
- Args:
- op: The `Operation` to find consumers of.
- Yields:
- A list of ops that have op as a control dependency.
- """
- for o in op.graph.get_operations():
- if op in o.control_inputs:
- yield o
-
-
-def _subscribe(tensor, side_effects):
+class _ControlOutputCache(object):
+ """Helper class to manage calculating and caching control_outputs in graph."""
+
+ def __init__(self):
+ self.cache = {}
+
+ def calc_control_outputs(self, graph):
+ """Returns the map of control_outputs for a given graph.
+
+ Args:
+ graph: The graph to parse.
+ Returns:
+ A map of the control outputs.
+ """
+ control_outputs = {}
+ for op in graph.get_operations():
+ for control_input in op.control_inputs:
+ if control_input not in control_outputs:
+ control_outputs[control_input] = set()
+ control_outputs[control_input].add(op)
+ return control_outputs
+
+ def get_control_outputs(self, op):
+ """Return the control outputs for a given op.
+
+ Args:
+ op: The op to fetch control outputs for.
+ Returns:
+ Iterable of control output ops.
+ """
+ if op.graph not in self.cache:
+ control_outputs = self.calc_control_outputs(op.graph)
+ self.cache[op.graph] = control_outputs
+ else:
+ control_outputs = self.cache[op.graph]
+ return control_outputs.get(op, [])
+
+
+def _subscribe(tensor, side_effects, control_cache):
"""Helper method that subscribes a single tensor to a list of side_effects.
Args:
tensor: `tf.Tensor`
side_effects: List of side_effect functions see subscribe for details.
+ control_cache: `_ControlOutputCache` helper to get control_outputs faster.
Returns:
The modified replacement to the passed in tensor which triggers the side
effects.
@@ -84,7 +110,7 @@ def _subscribe(tensor, side_effects):
for consumer_op in list(tensor.consumers()): # explicit copy
update_input.append((consumer_op, list(consumer_op.inputs).index(tensor)))
- update_control_input = list(_control_outputs(tensor.op))
+ update_control_input = control_cache.get_control_outputs(tensor.op)
# Trailing slash on name scope to replace the scope.
name_scope = tensor.op.name + '/subscription/'
@@ -141,4 +167,8 @@ def subscribe(tensors, side_effects):
"""
if not hasattr(side_effects, '__iter__'):
side_effects = [side_effects]
- return _recursive_apply(tensors, lambda t: _subscribe(t, side_effects))
+
+ control_outputs = _ControlOutputCache()
+ result = _recursive_apply(
+ tensors, lambda t: _subscribe(t, side_effects, control_outputs))
+ return result
diff --git a/tensorflow/python/framework/subscribe_test.py b/tensorflow/python/framework/subscribe_test.py
index 8371c2cfc4..ed56f80d22 100644
--- a/tensorflow/python/framework/subscribe_test.py
+++ b/tensorflow/python/framework/subscribe_test.py
@@ -54,6 +54,36 @@ class SubscribeTest(test_util.TensorFlowTestCase):
self.assertEquals(d_out, [42])
self.assertEquals(shared, [2, 2, 2])
+ def testCaching(self):
+ """Confirm caching of control output is recacluated between calls."""
+ a = tf.constant(1)
+ b = tf.constant(2)
+ with tf.control_dependencies([a]):
+ c = tf.constant(42)
+
+ shared = {}
+
+ def sub(t):
+ shared[t] = shared.get(t, 0) + 1
+ return t
+
+ a = subscribe.subscribe(a, lambda t: tf.py_func(sub, [t], [t.dtype]))
+
+ with tf.control_dependencies([b]):
+ d = tf.constant(11)
+
+ # If it was using outdated cached control_outputs then
+ # evaling would not trigger the new subscription.
+ b = subscribe.subscribe(b, lambda t: tf.py_func(sub, [t], [t.dtype]))
+
+ with self.test_session() as sess:
+ c_out = sess.run([c])
+ d_out = sess.run([d])
+
+ self.assertEquals(c_out, [42])
+ self.assertEquals(d_out, [11])
+ self.assertEquals(shared, {2: 1, 1: 1})
+
if __name__ == '__main__':
googletest.main()