diff options
-rw-r--r-- | tensorflow/python/framework/subscribe.py | 62 | ||||
-rw-r--r-- | tensorflow/python/framework/subscribe_test.py | 30 |
2 files changed, 76 insertions, 16 deletions
diff --git a/tensorflow/python/framework/subscribe.py b/tensorflow/python/framework/subscribe.py index 53d299a976..1f3bca71b2 100644 --- a/tensorflow/python/framework/subscribe.py +++ b/tensorflow/python/framework/subscribe.py @@ -57,25 +57,51 @@ def _recursive_apply(tensors, apply_fn): (tensors, tensors_type)) -def _control_outputs(op): - """Returns the control_input consumers for the supplied `Operation`. - - Args: - op: The `Operation` to find consumers of. - Yields: - A list of ops that have op as a control dependency. - """ - for o in op.graph.get_operations(): - if op in o.control_inputs: - yield o - - -def _subscribe(tensor, side_effects): +class _ControlOutputCache(object): + """Helper class to manage calculating and caching control_outputs in graph.""" + + def __init__(self): + self.cache = {} + + def calc_control_outputs(self, graph): + """Returns the map of control_outputs for a given graph. + + Args: + graph: The graph to parse. + Returns: + A map of the control outputs. + """ + control_outputs = {} + for op in graph.get_operations(): + for control_input in op.control_inputs: + if control_input not in control_outputs: + control_outputs[control_input] = set() + control_outputs[control_input].add(op) + return control_outputs + + def get_control_outputs(self, op): + """Return the control outputs for a given op. + + Args: + op: The op to fetch control outputs for. + Returns: + Iterable of control output ops. + """ + if op.graph not in self.cache: + control_outputs = self.calc_control_outputs(op.graph) + self.cache[op.graph] = control_outputs + else: + control_outputs = self.cache[op.graph] + return control_outputs.get(op, []) + + +def _subscribe(tensor, side_effects, control_cache): """Helper method that subscribes a single tensor to a list of side_effects. Args: tensor: `tf.Tensor` side_effects: List of side_effect functions see subscribe for details. + control_cache: `_ControlOutputCache` helper to get control_outputs faster. Returns: The modified replacement to the passed in tensor which triggers the side effects. @@ -84,7 +110,7 @@ def _subscribe(tensor, side_effects): for consumer_op in list(tensor.consumers()): # explicit copy update_input.append((consumer_op, list(consumer_op.inputs).index(tensor))) - update_control_input = list(_control_outputs(tensor.op)) + update_control_input = control_cache.get_control_outputs(tensor.op) # Trailing slash on name scope to replace the scope. name_scope = tensor.op.name + '/subscription/' @@ -141,4 +167,8 @@ def subscribe(tensors, side_effects): """ if not hasattr(side_effects, '__iter__'): side_effects = [side_effects] - return _recursive_apply(tensors, lambda t: _subscribe(t, side_effects)) + + control_outputs = _ControlOutputCache() + result = _recursive_apply( + tensors, lambda t: _subscribe(t, side_effects, control_outputs)) + return result diff --git a/tensorflow/python/framework/subscribe_test.py b/tensorflow/python/framework/subscribe_test.py index 8371c2cfc4..ed56f80d22 100644 --- a/tensorflow/python/framework/subscribe_test.py +++ b/tensorflow/python/framework/subscribe_test.py @@ -54,6 +54,36 @@ class SubscribeTest(test_util.TensorFlowTestCase): self.assertEquals(d_out, [42]) self.assertEquals(shared, [2, 2, 2]) + def testCaching(self): + """Confirm caching of control output is recacluated between calls.""" + a = tf.constant(1) + b = tf.constant(2) + with tf.control_dependencies([a]): + c = tf.constant(42) + + shared = {} + + def sub(t): + shared[t] = shared.get(t, 0) + 1 + return t + + a = subscribe.subscribe(a, lambda t: tf.py_func(sub, [t], [t.dtype])) + + with tf.control_dependencies([b]): + d = tf.constant(11) + + # If it was using outdated cached control_outputs then + # evaling would not trigger the new subscription. + b = subscribe.subscribe(b, lambda t: tf.py_func(sub, [t], [t.dtype])) + + with self.test_session() as sess: + c_out = sess.run([c]) + d_out = sess.run([d]) + + self.assertEquals(c_out, [42]) + self.assertEquals(d_out, [11]) + self.assertEquals(shared, {2: 1, 1: 1}) + if __name__ == '__main__': googletest.main() |