aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-18 01:44:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-18 01:48:49 -0700
commit8a90e8487d39cedc8fdd2cbab5a6c237c2f3a1cb (patch)
tree9c58651fb3873cc9b14789303560783660c9a634
parent28d9223a55d3c02e91781e02ff8b3f6a31bdd66a (diff)
Fixes in control_flow_ops.
PiperOrigin-RevId: 162325768
-rw-r--r--tensorflow/python/ops/control_flow_ops.py12
-rw-r--r--tensorflow/python/ops/control_flow_ops_test.py7
2 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index f2c3dae03c..44d6c7e275 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -1372,8 +1372,10 @@ class ControlFlowContext(object):
k = ops.prepend_name_scope(k, import_scope)
self._external_values[k] = g.as_graph_element(
ops.prepend_name_scope(v, import_scope))
- op_names = set([op.split(":")[0]
- for op in self._values - set(self._external_values)])
+ op_names = set([
+ op.split(":")[0]
+ for op in self._values - set(self._external_values.keys())
+ ])
for op in op_names:
# pylint: disable=protected-access
g.as_graph_element(op)._set_control_flow_context(self)
@@ -1801,7 +1803,7 @@ def cond(pred, true_fn=None, false_fn=None, strict=False, name=None,
if not callable(false_fn):
raise TypeError("false_fn must be callable.")
- with ops.name_scope(name, "cond", [pred]) as name:
+ with ops.name_scope(name, "cond", [pred]):
# Add the Switch to the graph.
if isinstance(pred, bool):
raise TypeError("pred must not be a Python bool")
@@ -2757,7 +2759,7 @@ def while_loop(cond, body, loop_vars, shape_invariants=None,
```
"""
- with ops.name_scope(name, "while", loop_vars) as name:
+ with ops.name_scope(name, "while", loop_vars):
if not loop_vars:
raise ValueError("No loop variables provided")
if not callable(cond):
@@ -2770,7 +2772,7 @@ def while_loop(cond, body, loop_vars, shape_invariants=None,
if shape_invariants is not None:
nest.assert_same_structure(loop_vars, shape_invariants)
- context = WhileContext(parallel_iterations, back_prop, swap_memory, name)
+ context = WhileContext(parallel_iterations, back_prop, swap_memory)
ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, context)
result = context.BuildLoop(cond, body, loop_vars, shape_invariants)
return result
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index 07c47e05d2..c23957443f 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -43,7 +43,6 @@ import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.platform import googletest
from tensorflow.python.training import momentum
from tensorflow.python.util import nest
-from tensorflow.python.util.protobuf import compare
TestTuple = collections.namedtuple("TestTuple", "a b")
@@ -399,7 +398,7 @@ class ContextTest(TensorFlowTestCase):
for op in sess.graph.get_operations():
c = op._get_control_flow_context()
if c:
- compare.ProtoEq(
+ self.assertProtoEquals(
c.to_proto(),
control_flow_ops.CondContext.from_proto(c.to_proto()).to_proto())
@@ -412,7 +411,7 @@ class ContextTest(TensorFlowTestCase):
for op in sess.graph.get_operations():
c = op._get_control_flow_context()
if c:
- compare.ProtoEq(
+ self.assertProtoEquals(
c.to_proto(),
control_flow_ops.WhileContext.from_proto(c.to_proto()).to_proto())
@@ -437,7 +436,7 @@ class ContextTest(TensorFlowTestCase):
c_with_scope._external_values, {"test_scope/a": b2})
# Calling _to_proto() with export_scope should remove "test_scope".
- compare.ProtoEq(
+ self.assertProtoEquals(
c._to_proto(),
c_with_scope._to_proto(export_scope="test_scope"))