aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-02-26 14:38:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-26 14:44:21 -0800
commit153e10a037c5e348834108ff46d9dccdf0cfb9a9 (patch)
tree411a202a638a10e08092b243cadf05a6293705be
parenta80896d3b3a2358f324dc4cd429409ea9acc8a09 (diff)
Enable de/serialization of nested control flow.
This is a follow-up to the previous commit (https://github.com/tensorflow/tensorflow/commit/23851760b7b099214bdd4f1b88156d7ac2bdd2a2). It adds the new proto schemas, enables the behavior for reading and writing the new protos, and adds a test for de/serializing nested while loops. There's still a bug preventing deserializing conds, which will be addressed in another change. PiperOrigin-RevId: 187082713
-rw-r--r--tensorflow/core/protobuf/control_flow.proto17
-rw-r--r--tensorflow/python/ops/control_flow_ops.py54
-rw-r--r--tensorflow/python/training/saver_test.py56
3 files changed, 88 insertions, 39 deletions
diff --git a/tensorflow/core/protobuf/control_flow.proto b/tensorflow/core/protobuf/control_flow.proto
index 2c9476a08a..3c05b4f0e2 100644
--- a/tensorflow/core/protobuf/control_flow.proto
+++ b/tensorflow/core/protobuf/control_flow.proto
@@ -17,6 +17,15 @@ message ValuesDef {
map<string, string> external_values = 2;
}
+// Container for any kind of control flow context. Any other control flow
+// contexts that are added below should also be added here.
+message ControlFlowContextDef {
+ oneof ctxt {
+ CondContextDef cond_ctxt = 1;
+ WhileContextDef while_ctxt = 2;
+ }
+}
+
// Protocol buffer representing a CondContext object.
message CondContextDef {
// Name of the context.
@@ -33,6 +42,9 @@ message CondContextDef {
// Values and external values in control flow context.
ValuesDef values_def = 5;
+
+ // Contexts contained inside this context (e.g. nested conds).
+ repeated ControlFlowContextDef nested_contexts = 6;
}
// Protocol buffer representing a WhileContext object.
@@ -70,5 +82,8 @@ message WhileContextDef {
// Optional name of the maximum_iterations tensor.
string maximum_iterations_name = 11;
- // Next available id: 12.
+ // Contexts contained inside this context (e.g. nested whiles).
+ repeated ControlFlowContextDef nested_contexts = 12;
+
+ // Next available id: 13.
}
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 8d5ab72670..85944efbe8 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -1767,13 +1767,9 @@ class CondContext(ControlFlowContext):
context_def.branch = self._branch
context_def.values_def.MergeFrom(super(CondContext, self)._to_values_def(
export_scope))
- # TODO(b/72868227): enable this once the corresponding control_flow.proto
- # changes have been checked in (they aren't checked in and this is
- # disabled for now to ensure forwards compatibility).
- if False: # pylint: disable=using-constant-test
- for nested in self._nested_contexts:
- nested_def = context_def.nested_contexts.add()
- nested.to_control_flow_context_def(nested_def)
+ for nested in self._nested_contexts:
+ nested_def = context_def.nested_contexts.add()
+ nested.to_control_flow_context_def(nested_def)
return context_def
else:
@@ -1785,14 +1781,10 @@ class CondContext(ControlFlowContext):
ret = CondContext(context_def=context_def,
import_scope=import_scope)
- # TODO(b/72868227): remove "if hasattr(...)" once the corresponding
- # control_flow.proto changes have been checked in (they aren't checked in
- # and this is here for now to ensure forwards compatibility).
- if hasattr(context_def, "nested_contexts"):
- ret.Enter()
- for nested_def in context_def.nested_contexts:
- from_control_flow_context_def(nested_def)
- ret.Exit()
+ ret.Enter()
+ for nested_def in context_def.nested_contexts:
+ from_control_flow_context_def(nested_def)
+ ret.Exit()
return ret
def to_control_flow_context_def(self, context_def, export_scope=None):
@@ -2110,10 +2102,7 @@ def cond(pred,
# Only add non-nested conds to the collection. Any nested control flow will
# be encapsulated in the root context.
assert context_t.outer_context == context_f.outer_context
- # TODO(b/72868227): remove "if True..." once the corresponding
- # control_flow.proto changes have been checked in (they aren't checked in
- # and this is disabled for now to ensure forwards compatibility).
- if True or context_t.outer_context is None:
+ if context_t.outer_context is None:
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)
@@ -2336,13 +2325,9 @@ class WhileContext(ControlFlowContext):
context_def.values_def.MergeFrom(
super(WhileContext, self)._to_values_def(
export_scope=export_scope))
- # TODO(b/72868227): remove "if True..." once the corresponding
- # control_flow.proto changes have been checked in (they aren't checked in
- # and this is disabled for now to ensure forwards compatibility).
- if False: # pylint: disable=using-constant-test
- for nested in self._nested_contexts:
- nested_def = context_def.nested_contexts.add()
- nested.to_control_flow_context_def(nested_def)
+ for nested in self._nested_contexts:
+ nested_def = context_def.nested_contexts.add()
+ nested.to_control_flow_context_def(nested_def)
return context_def
else:
@@ -2364,14 +2349,10 @@ class WhileContext(ControlFlowContext):
"""
ret = WhileContext(context_def=context_def,
import_scope=import_scope)
- # TODO(b/72868227): remove "if hasattr(...)" once the corresponding
- # control_flow.proto changes have been checked in (they aren't checked in
- # and this is disabled for now to ensure forwards compatibility).
- if hasattr(context_def, "nested_contexts"):
- ret.Enter()
- for nested_def in context_def.nested_contexts:
- from_control_flow_context_def(nested_def, import_scope=import_scope)
- ret.Exit()
+ ret.Enter()
+ for nested_def in context_def.nested_contexts:
+ from_control_flow_context_def(nested_def, import_scope=import_scope)
+ ret.Exit()
return ret
def GetWhileContext(self):
@@ -3216,10 +3197,7 @@ def while_loop(cond,
swap_memory=swap_memory)
# Only add non-nested loops to the collection. Any nested control flow will
# be encapsulated in the root context.
- # TODO(b/72868227): enable condition once the corresponding
- # control_flow.proto changes have been checked in (they aren't checked in
- # and this is disabled for now to ensure forwards compatibility).
- if True or loop_context.outer_context is None:
+ if loop_context.outer_context is None:
ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
if maximum_iterations is not None:
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index f00f98db00..b366ed30f3 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -53,6 +53,7 @@ from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables
@@ -2040,6 +2041,61 @@ class MetaGraphTest(test.TestCase):
self._testGraphExtensionRestore(test_dir)
self._testRestoreFromTrainGraphWithControlContext(test_dir)
+ def testNestedWhileLoops(self):
+ test_dir = self._get_test_dir("nested_whiles")
+ filename = os.path.join(test_dir, "metafile")
+ saver_ckpt = os.path.join(test_dir, "saver.ckpt")
+
+ # Create two simple nested while loops.
+ with ops_lib.Graph().as_default():
+ def body(i, x):
+ _, r = control_flow_ops.while_loop(lambda j, y: j < 3,
+ lambda j, y: (j + 1, y + x),
+ [0, 0])
+ return i + 1, x + r
+
+ var = variables.Variable(0)
+ var_name = var.name
+
+ _, output = control_flow_ops.while_loop(lambda i, x: i < 5, body,
+ [0, var])
+ output_name = output.name
+
+ init_op = variables.global_variables_initializer()
+
+ # Generate a MetaGraphDef containing the nested loops.
+ with session.Session() as sess:
+ sess.run(init_op)
+ sess.run(output)
+ saver = saver_module.Saver()
+ saver.save(sess, saver_ckpt)
+ saver.export_meta_graph(filename)
+
+ # Build and run the gradients of the nested while loop. We use this below
+ # to verify that the gradients are correct with an imported MetaGraphDef.
+ grad = gradients_impl.gradients([output], [var])
+ with session.Session() as sess:
+ sess.run(init_op)
+ expected_grad_value = sess.run(grad)
+
+ # Restore the MetaGraphDef into a new Graph.
+ with ops_lib.Graph().as_default():
+ with session.Session() as sess:
+ saver = saver_module.import_meta_graph(filename)
+ saver.restore(sess, saver_ckpt)
+
+ # Make sure we can still build gradients and get the same result.
+ var = ops_lib.get_default_graph().get_tensor_by_name(var_name)
+ output = ops_lib.get_default_graph().get_tensor_by_name(output_name)
+ grad = gradients_impl.gradients([output], [var])
+
+ init_op = variables.global_variables_initializer()
+
+ with session.Session() as sess:
+ sess.run(init_op)
+ actual_grad_value = sess.run(grad)
+ self.assertEqual(expected_grad_value, actual_grad_value)
+
def testStrippedOpListDef(self):
with self.test_session():
# Creates a graph.