diff options
author | 2018-01-10 13:44:27 -0800 | |
---|---|---|
committer | 2018-01-10 13:48:49 -0800 | |
commit | d426feef7f92acc70a7e8cd61a17d9eeeae2efc9 (patch) | |
tree | 2d1e5c6ae4e6fc1619f806096e6bea6522656102 | |
parent | a6cbbeadfc368684472e420a0ec42c39e5fa5567 (diff) |
Fix python/framework/subscribe.py and test to work with C API enabled.
PiperOrigin-RevId: 181511142
-rw-r--r-- | tensorflow/python/framework/subscribe.py | 23 | ||||
-rw-r--r-- | tensorflow/python/framework/subscribe_test.py | 7 |
2 files changed, 17 insertions, 13 deletions
diff --git a/tensorflow/python/framework/subscribe.py b/tensorflow/python/framework/subscribe.py index cdcb74e88f..7797d991da 100644 --- a/tensorflow/python/framework/subscribe.py +++ b/tensorflow/python/framework/subscribe.py @@ -137,11 +137,18 @@ def _subscribe_new(tensor, side_effects, control_cache): # are subscribed at the same time, we remove the control dependency from # the original op only once and we add the dependencies to all the # new identities. + if ops._USE_C_API: # pylint: disable=protected-access + new_control_inputs = consumer_op.control_inputs + else: + # Make a copy so we don't modify the actual control inputs (this is fixed + # in the C API). + new_control_inputs = list(consumer_op.control_inputs) + if tensor.op in new_control_inputs: + new_control_inputs.remove(tensor.op) + new_control_inputs.append(out.op) # pylint: disable=protected-access - if tensor.op in consumer_op._control_inputs: - consumer_op._control_inputs.remove(tensor.op) - consumer_op._control_inputs.append(out.op) - consumer_op._recompute_node_def() + consumer_op._remove_all_control_inputs() + consumer_op._add_control_inputs(new_control_inputs) # pylint: enable=protected-access return out @@ -167,12 +174,8 @@ def _subscribe_extend(tensor, side_effects): for s in side_effects: outs += s(source_tensor) - for out in outs: - out_type = type(out) - if out_type is ops.Tensor: - out = out.op - tensor.op._control_inputs.append(out) # pylint: disable=protected-access - tensor.op._recompute_node_def() # pylint: disable=protected-access + out_ops = [out.op if isinstance(out, ops.Tensor) else out for out in outs] + tensor.op._add_control_inputs(out_ops) # pylint: disable=protected-access return tensor diff --git a/tensorflow/python/framework/subscribe_test.py b/tensorflow/python/framework/subscribe_test.py index 01df20241d..8b95b25e82 100644 --- a/tensorflow/python/framework/subscribe_test.py +++ b/tensorflow/python/framework/subscribe_test.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import googletest +@test_util.with_c_api class SubscribeTest(test_util.TensorFlowTestCase): def _ExpectSubscribedIdentities(self, container): @@ -58,12 +59,12 @@ class SubscribeTest(test_util.TensorFlowTestCase): return t c0 = c - self.assertTrue(c0.op in d.op._control_inputs) + self.assertTrue(c0.op in d.op.control_inputs) c = subscribe.subscribe(c, lambda t: script_ops.py_func(sub, [t], [t.dtype])) # Verify that control dependencies are correctly moved to the subscription. - self.assertFalse(c0.op in d.op._control_inputs) - self.assertTrue(c.op in d.op._control_inputs) + self.assertFalse(c0.op in d.op.control_inputs) + self.assertTrue(c.op in d.op.control_inputs) with self.test_session() as sess: c_out = sess.run([c]) |