aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-01-10 13:44:27 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-10 13:48:49 -0800
commitd426feef7f92acc70a7e8cd61a17d9eeeae2efc9 (patch)
tree2d1e5c6ae4e6fc1619f806096e6bea6522656102
parenta6cbbeadfc368684472e420a0ec42c39e5fa5567 (diff)
Fix python/framework/subscribe.py and test to work with C API enabled.
PiperOrigin-RevId: 181511142
-rw-r--r--tensorflow/python/framework/subscribe.py23
-rw-r--r--tensorflow/python/framework/subscribe_test.py7
2 files changed, 17 insertions, 13 deletions
diff --git a/tensorflow/python/framework/subscribe.py b/tensorflow/python/framework/subscribe.py
index cdcb74e88f..7797d991da 100644
--- a/tensorflow/python/framework/subscribe.py
+++ b/tensorflow/python/framework/subscribe.py
@@ -137,11 +137,18 @@ def _subscribe_new(tensor, side_effects, control_cache):
# are subscribed at the same time, we remove the control dependency from
# the original op only once and we add the dependencies to all the
# new identities.
+ if ops._USE_C_API: # pylint: disable=protected-access
+ new_control_inputs = consumer_op.control_inputs
+ else:
+ # Make a copy so we don't modify the actual control inputs (this is fixed
+ # in the C API).
+ new_control_inputs = list(consumer_op.control_inputs)
+ if tensor.op in new_control_inputs:
+ new_control_inputs.remove(tensor.op)
+ new_control_inputs.append(out.op)
# pylint: disable=protected-access
- if tensor.op in consumer_op._control_inputs:
- consumer_op._control_inputs.remove(tensor.op)
- consumer_op._control_inputs.append(out.op)
- consumer_op._recompute_node_def()
+ consumer_op._remove_all_control_inputs()
+ consumer_op._add_control_inputs(new_control_inputs)
# pylint: enable=protected-access
return out
@@ -167,12 +174,8 @@ def _subscribe_extend(tensor, side_effects):
for s in side_effects:
outs += s(source_tensor)
- for out in outs:
- out_type = type(out)
- if out_type is ops.Tensor:
- out = out.op
- tensor.op._control_inputs.append(out) # pylint: disable=protected-access
- tensor.op._recompute_node_def() # pylint: disable=protected-access
+ out_ops = [out.op if isinstance(out, ops.Tensor) else out for out in outs]
+ tensor.op._add_control_inputs(out_ops) # pylint: disable=protected-access
return tensor
diff --git a/tensorflow/python/framework/subscribe_test.py b/tensorflow/python/framework/subscribe_test.py
index 01df20241d..8b95b25e82 100644
--- a/tensorflow/python/framework/subscribe_test.py
+++ b/tensorflow/python/framework/subscribe_test.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
+@test_util.with_c_api
class SubscribeTest(test_util.TensorFlowTestCase):
def _ExpectSubscribedIdentities(self, container):
@@ -58,12 +59,12 @@ class SubscribeTest(test_util.TensorFlowTestCase):
return t
c0 = c
- self.assertTrue(c0.op in d.op._control_inputs)
+ self.assertTrue(c0.op in d.op.control_inputs)
c = subscribe.subscribe(c,
lambda t: script_ops.py_func(sub, [t], [t.dtype]))
# Verify that control dependencies are correctly moved to the subscription.
- self.assertFalse(c0.op in d.op._control_inputs)
- self.assertTrue(c.op in d.op._control_inputs)
+ self.assertFalse(c0.op in d.op.control_inputs)
+ self.assertTrue(c.op in d.op.control_inputs)
with self.test_session() as sess:
c_out = sess.run([c])