aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sherry Moore <sherrym@google.com>2016-09-13 09:08:10 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-13 10:17:12 -0700
commite91ea774cfce323ef5812274f92fa2f13d2f3637 (patch)
tree03aafc080c1ac79e284577cb4a484f096da984c6
parentdb37a626fa81fe613aeaa7539e7730b5caaacce6 (diff)
Added to_proto and from_proto for CondContext and WhileContext.
Change: 133013495
-rw-r--r--tensorflow/core/protobuf/control_flow.proto66
-rw-r--r--tensorflow/python/framework/ops.py4
-rw-r--r--tensorflow/python/ops/control_flow_ops.py205
-rw-r--r--tensorflow/python/ops/control_flow_ops_test.py31
-rw-r--r--tensorflow/python/training/saver.py12
-rw-r--r--tensorflow/python/training/saver_test.py55
6 files changed, 350 insertions, 23 deletions
diff --git a/tensorflow/core/protobuf/control_flow.proto b/tensorflow/core/protobuf/control_flow.proto
new file mode 100644
index 0000000000..24f42322c0
--- /dev/null
+++ b/tensorflow/core/protobuf/control_flow.proto
@@ -0,0 +1,66 @@
+syntax = "proto3";
+
+package tensorflow;
+option cc_enable_arenas = true;
+option java_outer_classname = "ControlFlowProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.framework";
+
+// Control flow context related protocol buffers.
+
+// Protocol buffer representing the values in ControlFlowContext.
+message ValuesDef {
+ // Value names that have been seen in this context.
+ repeated string values = 1;
+
+ // Value names referenced by but external to this context.
+ map<string, string> external_values = 2;
+}
+
+// Protocol buffer representing a CondContext object.
+message CondContextDef {
+ // Name of the context.
+ string context_name = 1;
+
+ // Name of the pred tensor.
+ string pred_name = 2;
+
+ // Name of the pivot tensor.
+ string pivot_name = 3;
+
+ // Branch prediction. 0 or 1.
+ int32 branch = 4;
+
+ // Values and external values in control flow context.
+ ValuesDef values_def = 5;
+}
+
+// Protocol buffer representing a WhileContext object.
+message WhileContextDef {
+ // Name of the context.
+ string context_name = 1;
+
+ // The number of iterations allowed to run in parallel.
+ int32 parallel_iterations = 2;
+
+ // Whether backprop is enabled for this while loop.
+ bool back_prop = 3;
+
+ // Whether GPU-CPU memory swap is enabled for this loop.
+ bool swap_memory = 4;
+
+ // Name of the pivot tensor.
+ string pivot_name = 5;
+
+ // Name of the pivot_for_pred tensor.
+ string pivot_for_pred_name = 6;
+
+ // Name of the pivot_for_body tensor.
+ string pivot_for_body_name = 7;
+
+ // List of names for exit tensors.
+ repeated string loop_exit_names = 8;
+
+ // Values and external values in control flow context.
+ ValuesDef values_def = 9;
+}
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 8cd5b0fef1..caa9695c8a 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -3948,6 +3948,10 @@ class GraphKeys(object):
SUMMARY_OP = "summary_op"
GLOBAL_STEP = "global_step"
+ # Key for control flow context.
+ COND_CONTEXT = "cond_context"
+ WHILE_CONTEXT = "while_context"
+
def add_to_collection(name, value):
"""Wrapper for `Graph.add_to_collection()` using the default graph.
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index f1dc8b3515..6b29f18398 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -75,6 +75,7 @@ import collections
import six
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.core.protobuf import control_flow_pb2
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -1333,13 +1334,35 @@ class ControlFlowContext(object):
Pushed and popped by ctxt.Enter() and ctxt.Exit()
"""
- def __init__(self):
+ def __init__(self, values_def=None):
self._outer_context = ops.get_default_graph()._get_control_flow_context()
self._context_stack = []
- # Values that have been already seen in this context.
- self._values = set()
- # Values referenced by but external to this context.
+ if values_def:
+ self._init_values_from_proto(values_def)
+ else:
+ # Values that have been already seen in this context.
+ self._values = set()
+ # Values referenced by but external to this context.
+ self._external_values = {}
+
+ def _init_values_from_proto(self, values_def):
+ """Initializes values and external_values from `ValuesDef` protocol buffer.
+
+ Args:
+ values_def: `ValuesDef` protocol buffer.
+ """
+ assert isinstance(values_def, control_flow_pb2.ValuesDef)
+ self._values = set(values_def.values)
+ g = ops.get_default_graph()
self._external_values = {}
+ for k, v in values_def.external_values.items():
+ self._external_values[k] = g.as_graph_element(v)
+ op_names = set([op.split(":")[0]
+ for op in self._values - set(self._external_values)])
+ for op in op_names:
+ # pylint: disable=protected-access
+ g.as_graph_element(op)._set_control_flow_context(self)
+ # pylint: enable=protected-access
@property
def outer_context(self):
@@ -1354,6 +1377,23 @@ class ControlFlowContext(object):
def back_prop(self):
raise NotImplementedError("Abstract method")
+ def _to_proto(self):
+ """Converts the values to a `ValuesDef` protocol buffer.
+
+ Returns:
+ A `ValuesDef` protocol buffer.
+ """
+ values_def = control_flow_pb2.ValuesDef()
+ values_def.values.extend([v for v in sorted(self._values)])
+ for k, v in self._external_values.items():
+ values_def.external_values[k] = v.name
+ return values_def
+
+ @staticmethod
+ def _from_proto(values_def):
+ """Returns a `ControlFlowContext` created from `values_def`."""
+ return ControlFlowContext(values_def=values_def)
+
def AddName(self, name):
self._values.add(name)
@@ -1428,15 +1468,41 @@ class ControlFlowContext(object):
class CondContext(ControlFlowContext):
"""The context for the conditional construct."""
- def __init__(self, pred, pivot, branch):
- ControlFlowContext.__init__(self)
- self._pred = pred # The boolean tensor for the cond predicate
- self._pivot = pivot # The predicate tensor in this branch
- self._branch = branch # 0 or 1 representing this branch
+ def __init__(self, pred=None, pivot=None, branch=None,
+ name="cond_text", context_def=None):
+ self._name = ops.get_default_graph().unique_name(name)
- # Values considered to have been already seen in this context.
- self._values.add(pred.name)
- self._values.add(pivot.name)
+ if context_def:
+ self._init_from_proto(context_def)
+ else:
+ # Initializes the default fields.
+ ControlFlowContext.__init__(self)
+ self._pred = pred # The boolean tensor for the cond predicate
+ self._pivot = pivot # The predicate tensor in this branch
+ self._branch = branch # 0 or 1 representing this branch
+
+ # Values considered to have been already seen in this context.
+ self._values.add(pred.name)
+ self._values.add(pivot.name)
+
+ def _init_from_proto(self, context_def):
+ """Creates a new `CondContext` from protocol buffer.
+
+ Args:
+ context_def: `CondContextDef` protocol buffer.
+ """
+ assert isinstance(context_def, control_flow_pb2.CondContextDef)
+ # Create from context_def.
+ g = ops.get_default_graph()
+ self._name = context_def.context_name
+ self._pred = g.as_graph_element(context_def.pred_name)
+ self._pivot = g.as_graph_element(context_def.pivot_name)
+ self._branch = context_def.branch
+ super(CondContext, self).__init__(values_def=context_def.values_def)
+
+ @property
+ def name(self):
+ return self._name
@property
def pred(self):
@@ -1465,6 +1531,26 @@ class CondContext(ControlFlowContext):
def GetControlPivot(self):
return self._pivot
+ def to_proto(self):
+ """Converts a `CondContext` to a `CondContextDef` protocol buffer.
+
+ Returns:
+ A `CondContextDef` protocol buffer.
+ """
+ context_def = control_flow_pb2.CondContextDef()
+ context_def.context_name = self.name
+ context_def.pred_name = self._pred.name
+ context_def.pivot_name = self._pivot.name
+ context_def.branch = self._branch
+ context_def.values_def.MergeFrom(super(CondContext, self)._to_proto())
+
+ return context_def
+
+ @staticmethod
+ def from_proto(context_def):
+ """Returns a `CondContext` object created from `context_def`."""
+ return CondContext(context_def=context_def)
+
def AddValue(self, val):
"""Add `val` to the current context and its outer context recursively."""
if val.name in self._values:
@@ -1650,6 +1736,11 @@ def cond(pred, fn1, fn2, name=None):
"%s, %s" % (val_x.dtype.name, val_y.dtype.name))
merges = [merge([x[0], x[1]])[0] for x in zip(res_f, res_t)]
merges = _convert_flows_to_tensorarrays(orig_res, merges)
+
+ # Add to collections
+ ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
+ ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)
+
return merges[0] if len(merges) == 1 else merges
@@ -1659,9 +1750,30 @@ def cond(pred, fn1, fn2, name=None):
class WhileContext(ControlFlowContext):
"""The context for the loop construct."""
- def __init__(self, parallel_iterations, back_prop, swap_memory, name,
- grad_state=None):
- ControlFlowContext.__init__(self)
+ def __init__(self, parallel_iterations=10, back_prop=True, swap_memory=False,
+ name="while_context", grad_state=None, context_def=None):
+ if context_def:
+ self._init_from_proto(context_def)
+ else:
+ ControlFlowContext.__init__(self)
+ self._init_from_args(parallel_iterations, back_prop, swap_memory,
+ name)
+ # The gradient loop state.
+ self._grad_state = grad_state
+
+ def _init_from_args(self, parallel_iterations, back_prop, swap_memory,
+ name):
+ """Creates a new `WhileContext` from arguments.
+
+ Args:
+ parallel_iterations: The number of iterations allowed to run in parallel.
+ back_prop: Whether backprop is enabled for this while loop.
+ swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
+ name: Optional name prefix for the returned tensors.
+
+ Raises:
+ ValueError: If `parallel_iterations` has invalid value.
+ """
if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0):
raise ValueError("`parallel_iterations` must be a positive integer: "
"%s" % parallel_iterations)
@@ -1678,8 +1790,30 @@ class WhileContext(ControlFlowContext):
self._pivot = None
# The list of exit tensors for loop variables.
self._loop_exits = None
- # The gradient loop state.
- self._grad_state = grad_state
+
+ def _init_from_proto(self, context_def):
+ """Creates a new `WhileContext` from protocol buffer.
+
+ Args:
+ context_def: `WhileContextDef` protocol buffer.
+ """
+ assert isinstance(context_def, control_flow_pb2.WhileContextDef)
+ # Create from context_def.
+ g = ops.get_default_graph()
+ self._name = context_def.context_name
+ self._parallel_iterations = context_def.parallel_iterations
+ self._back_prop = context_def.back_prop
+ self._swap_memory = context_def.swap_memory
+ self._pivot_for_pred = g.as_graph_element(context_def.pivot_for_pred_name)
+ # We use this node to control constants created by the body lambda.
+ self._pivot_for_body = g.as_graph_element(context_def.pivot_for_body_name)
+ # The boolean tensor for loop termination condition. Used in code
+ # generation for gradient computation.
+ self._pivot = g.as_graph_element(context_def.pivot_name)
+ # The list of exit tensors for loop variables.
+ self._loop_exits = [g.as_graph_element(exit_name)
+ for exit_name in context_def.loop_exit_names]
+ super(WhileContext, self).__init__(values_def=context_def.values_def)
@property
def name(self):
@@ -1715,6 +1849,31 @@ class WhileContext(ControlFlowContext):
"""The gradient loop state."""
return self._grad_state
+ def to_proto(self):
+ """Converts a `WhileContext` to a `WhileContextDef` protocol buffer.
+
+ Returns:
+ A `WhileContextDef` protocol buffer.
+ """
+ context_def = control_flow_pb2.WhileContextDef()
+ context_def.context_name = self.name
+ context_def.parallel_iterations = self._parallel_iterations
+ context_def.back_prop = self._back_prop
+ context_def.swap_memory = self._swap_memory
+ context_def.pivot_for_pred_name = self._pivot_for_pred.name
+ context_def.pivot_for_body_name = self._pivot_for_body.name
+ context_def.pivot_name = self._pivot.name
+ if self._loop_exits:
+ context_def.loop_exit_names.extend([l.name for l in self._loop_exits])
+ context_def.values_def.MergeFrom(super(WhileContext, self)._to_proto())
+
+ return context_def
+
+ @staticmethod
+ def from_proto(context_def):
+ """Returns a `WhileContext` object created from `context_def`."""
+ return WhileContext(context_def=context_def)
+
def GetWhileContext(self):
return self
@@ -2357,6 +2516,7 @@ def while_loop(cond, body, loop_vars, shape_invariants=None,
nest.assert_same_structure(loop_vars, shape_invariants)
context = WhileContext(parallel_iterations, back_prop, swap_memory, name)
+ ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, context)
result = context.BuildLoop(cond, body, loop_vars, shape_invariants)
return result
@@ -2792,3 +2952,14 @@ ops.RegisterShape("RefMerge")(_MergeShape)
ops.RegisterShape("RefSelect")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("RefSwitch")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("Switch")(common_shapes.call_cpp_shape_fn)
+
+
+ops.register_proto_function(ops.GraphKeys.COND_CONTEXT,
+ proto_type=control_flow_pb2.CondContextDef,
+ to_proto=CondContext.to_proto,
+ from_proto=CondContext.from_proto)
+
+ops.register_proto_function(ops.GraphKeys.WHILE_CONTEXT,
+ proto_type=control_flow_pb2.WhileContextDef,
+ to_proto=WhileContext.to_proto,
+ from_proto=WhileContext.from_proto)
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index f4352fcdb5..e66f28062a 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import standard_ops as tf
from tensorflow.python.platform import googletest
from tensorflow.python.training import momentum
+from tensorflow.python.util.protobuf import compare
class GroupTestCase(TensorFlowTestCase):
@@ -226,5 +227,35 @@ class SwitchTestCase(TensorFlowTestCase):
self.assertAllEqual(grad, [1] * 3)
+class ContextTest(TensorFlowTestCase):
+
+ def testCondContext(self):
+ with self.test_session() as sess:
+ x = tf.constant(2)
+ y = tf.constant(5)
+ control_flow_ops.cond(tf.less(x, y),
+ lambda: tf.mul(x, 17),
+ lambda: tf.add(y, 23))
+ for op in sess.graph.get_operations():
+ c = op._get_control_flow_context()
+ if c:
+ compare.ProtoEq(
+ c.to_proto(),
+ control_flow_ops.CondContext.from_proto(c.to_proto()).to_proto())
+
+ def testWhileContext(self):
+ with self.test_session() as sess:
+ i = tf.constant(0)
+ c = lambda i: tf.less(i, 10)
+ b = lambda i: tf.add(i, 1)
+ tf.while_loop(c, b, [i])
+ for op in sess.graph.get_operations():
+ c = op._get_control_flow_context()
+ if c:
+ compare.ProtoEq(
+ c.to_proto(),
+ control_flow_ops.WhileContext.from_proto(c.to_proto()).to_proto())
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 2afd879b0c..d5d8daf3b8 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -619,7 +619,17 @@ class BaseSaverBuilder(object):
restore_op = self._AddRestoreOps(filename_tensor, saveables,
restore_sequentially, reshape)
- assert restore_op.name.endswith("restore_all"), restore_op.name
+ # In the following use case, it's possible to have restore_ops be called
+ # something else:
+ # - Build inference graph and export a meta_graph.
+ # - Import the inference meta_graph
+ # - Extend the inference graph to a train graph.
+ # - Export a new meta_graph.
+ # Now the second restore_op will be called "restore_all_1".
+ # As such, comment out the assert for now until we know whether supporting
+ # such usage model makes sense.
+ #
+ # assert restore_op.name.endswith("restore_all"), restore_op.name
return saver_pb2.SaverDef(
filename_tensor_name=filename_tensor.name,
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 27147096fb..e3bbf88bfc 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -22,6 +22,7 @@ import math
import os.path
import time
import contextlib
+import random
import shutil
import tempfile
@@ -37,6 +38,7 @@ from tensorflow.core.protobuf import queue_runner_pb2
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.ops import gen_data_flow_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import gfile
from tensorflow.python.training import saver as saver_module
from tensorflow.python.util import compat
@@ -1211,7 +1213,13 @@ class MetaGraphTest(tf.test.TestCase):
filename = os.path.join(test_dir, "metafile")
with self.test_session():
# Creates a graph.
- v0 = tf.Variable(10.0, name="v0")
+ v0 = tf.Variable(1.0, name="v0")
+ control_flow_ops.cond(tf.less(v0, 10),
+ lambda: tf.add(v0, 1),
+ lambda: tf.sub(v0, 1))
+ control_flow_ops.while_loop(lambda i: tf.less(i, 10),
+ lambda i: tf.add(i, 1),
+ [v0])
var = tf.Variable(tf.constant(0, dtype=tf.int64))
count_up_to = var.count_up_to(3)
input_queue = tf.FIFOQueue(30, tf.float32, shared_name="collection_queue")
@@ -1240,7 +1248,7 @@ class MetaGraphTest(tf.test.TestCase):
self.assertTrue(meta_graph_def.HasField("saver_def"))
self.assertTrue(meta_graph_def.HasField("graph_def"))
collection_def = meta_graph_def.collection_def
- self.assertEqual(len(collection_def), 10)
+ self.assertEqual(len(collection_def), 12)
with tf.Graph().as_default():
# Restores from MetaGraphDef.
@@ -1418,7 +1426,13 @@ class MetaGraphTest(tf.test.TestCase):
tf.truncated_normal([28, 128],
stddev=1.0 / math.sqrt(float(28))),
name="weights")
- biases = tf.Variable(tf.zeros([128]),
+ # The use of control_flow_ops.cond here is purely for adding test coverage
+ # the save and restore of control flow context (which doesn't make any
+ # sense here from a machine learning perspective). The typical biases is
+ # a simple Variable without the conditions.
+ biases = tf.Variable(control_flow_ops.cond(tf.less(random.random(), 0.5),
+ lambda: tf.ones([128]),
+ lambda: tf.zeros([128])),
name="biases")
hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
# Hidden 2
@@ -1427,8 +1441,19 @@ class MetaGraphTest(tf.test.TestCase):
tf.truncated_normal([128, 32],
stddev=1.0 / math.sqrt(float(128))),
name="weights")
- biases = tf.Variable(tf.zeros([32]),
- name="biases")
+
+ # The use of control_flow_ops.while_loop here is purely for adding test
+ # coverage the save and restore of control flow context (which doesn't
+ # make any sense here from a machine learning perspective). The typical
+ # biases is a simple Variable without the conditions.
+ def loop_cond(it, _):
+ return it < 2
+ def loop_body(it, biases):
+ biases += tf.constant(0.1, shape=[32])
+ return it + 1, biases
+ _, biases = control_flow_ops.while_loop(
+ loop_cond, loop_body,
+ [tf.constant(0), tf.Variable(tf.zeros([32]))])
hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
# Linear
with tf.name_scope("softmax_linear"):
@@ -1456,6 +1481,7 @@ class MetaGraphTest(tf.test.TestCase):
def _testGraphExtensionRestore(self):
test_dir = os.path.join(self.get_temp_dir(), "graph_extension")
filename = os.path.join(test_dir, "metafile")
+ train_filename = os.path.join(test_dir, "train_metafile")
saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
with self.test_session(graph=tf.Graph()) as sess:
# Restores from MetaGraphDef.
@@ -1484,11 +1510,30 @@ class MetaGraphTest(tf.test.TestCase):
# Runs train_op.
train_op = optimizer.minimize(loss)
+ tf.add_to_collection("train_op", train_op)
+
+ # Runs train_op.
+ sess.run(train_op)
+
+ # Generates MetaGraphDef.
+ tf.train.export_meta_graph(train_filename)
+
+ def _testRestoreFromTrainGraphWithControlContext(self):
+ test_dir = os.path.join(self.get_temp_dir(), "graph_extension")
+ train_filename = os.path.join(test_dir, "train_metafile")
+ saver0_ckpt = os.path.join(test_dir, "saver0.ckpt")
+ with self.test_session(graph=tf.Graph()) as sess:
+ # Restores from MetaGraphDef.
+ new_saver = tf.train.import_meta_graph(train_filename)
+ # Restores from checkpoint.
+ new_saver.restore(sess, saver0_ckpt)
+ train_op = tf.get_collection("train_op")[0]
sess.run(train_op)
def testGraphExtension(self):
self._testGraphExtensionSave()
self._testGraphExtensionRestore()
+ self._testRestoreFromTrainGraphWithControlContext()
def testStrippedOpListDef(self):
with self.test_session():