diff options
author | 2017-09-06 11:03:35 -0700 | |
---|---|---|
committer | 2017-09-06 11:07:11 -0700 | |
commit | 86b94d8a17783773681b585df474bbda300b62f7 (patch) | |
tree | 18ebf42967bb59a68b885824d6231f1b85558dbe | |
parent | c4cb861fd9b752f2d4ce9205489bc4325b8ead32 (diff) |
contrib summaries work in eager-graph mode (with defun)
As a side effect fix issues related to using eager-defined variables in graph
mode.
PiperOrigin-RevId: 167744121
-rw-r--r-- | tensorflow/contrib/summary/BUILD | 5 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops.py | 98 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops_test.py | 27 | ||||
-rw-r--r-- | tensorflow/core/kernels/summary_kernels.cc | 7 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 7 | ||||
-rw-r--r-- | tensorflow/python/ops/resource_variable_ops.py | 57 | ||||
-rw-r--r-- | tensorflow/python/training/training_util.py | 3 |
7 files changed, 137 insertions, 67 deletions
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index bc30502264..527deab86a 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -22,10 +22,12 @@ py_test( srcs_version = "PY2AND3", deps = [ ":summary_ops", + "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:lib", "//tensorflow/python:platform", "//tensorflow/python:training", - "//tensorflow/python/eager:context", + "//tensorflow/python/eager:function", "//tensorflow/python/eager:test", ], ) @@ -38,6 +40,7 @@ py_library( deps = [ ":gen_summary_ops", "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:summary_op_util", diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index 05e627adf1..ceaf83b70a 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -68,7 +68,8 @@ def never_record_summaries(): def create_summary_file_writer(logdir, max_queue=None, flush_secs=None, - filename_suffix=None): + filename_suffix=None, + name=None): """Creates a summary file writer in the current context.""" if max_queue is None: max_queue = constant_op.constant(10) @@ -76,7 +77,7 @@ def create_summary_file_writer(logdir, flush_secs = constant_op.constant(120) if filename_suffix is None: filename_suffix = constant_op.constant("") - resource = gen_summary_ops.summary_writer() + resource = gen_summary_ops.summary_writer(shared_name=name) gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue, flush_secs, filename_suffix) context.context().summary_writer_resource = resource @@ -84,76 +85,87 @@ def create_summary_file_writer(logdir, def _nothing(): """Convenient else branch for when summaries do not record.""" - return + return False -def generic(name, tensor, metadata, family=None): - """Writes a tensor summary if possible.""" +def summary_writer_function(name, tensor, function, family=None): + """Helper function to write summaries. + Args: + name: name of the summary + tensor: main tensor to form the summary + function: function taking a tag and a scope which writes the summary + family: optional, the summary's family + + Returns: + The result of writing the summary. + """ def record(): with summary_op_util.summary_scope( name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_summary(context.context().summary_writer_resource, - training_util.get_global_step(), tensor, - tag, metadata, name=scope) + function(tag, scope) + return True + return control_flow_ops.cond(should_record_summaries(), record, _nothing) +def generic(name, tensor, metadata, family=None): + """Writes a tensor summary if possible.""" + + def function(tag, scope): + gen_summary_ops.write_summary(context.context().summary_writer_resource, + training_util.get_global_step(), tensor, + tag, metadata, name=scope) + return summary_writer_function(name, tensor, function, family=family) + + def scalar(name, tensor, family=None): """Writes a scalar summary if possible.""" - def record(): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_scalar_summary( - context.context().summary_writer_resource, - training_util.get_global_step(), tag, tensor, name=scope) + def function(tag, scope): + gen_summary_ops.write_scalar_summary( + context.context().summary_writer_resource, + training_util.get_global_step(), tag, tensor, name=scope) - return control_flow_ops.cond(should_record_summaries(), record, _nothing) + return summary_writer_function(name, tensor, function, family=family) def histogram(name, tensor, family=None): """Writes a histogram summary if possible.""" - def record(): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_histogram_summary( - context.context().summary_writer_resource, - training_util.get_global_step(), tag, tensor, name=scope) + def function(tag, scope): + gen_summary_ops.write_histogram_summary( + context.context().summary_writer_resource, + training_util.get_global_step(), tag, tensor, name=scope) - return control_flow_ops.cond(should_record_summaries(), record, _nothing) + return summary_writer_function(name, tensor, function, family=family) def image(name, tensor, bad_color=None, max_images=3, family=None): """Writes an image summary if possible.""" - def record(): + def function(tag, scope): if bad_color is None: bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_image_summary( - context.context().summary_writer_resource, - training_util.get_global_step(), tag, tensor, bad_color_, max_images, - name=scope) + gen_summary_ops.write_image_summary( + context.context().summary_writer_resource, + training_util.get_global_step(), tag, tensor, bad_color_, max_images, + name=scope) - return control_flow_ops.cond(should_record_summaries(), record, _nothing) + return summary_writer_function(name, tensor, function, family=family) def audio(name, tensor, sample_rate, max_outputs, family=None): """Writes an audio summary if possible.""" - def record(): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_audio_summary( - context.context().summary_writer_resource, - training_util.get_global_step(), - tag, - tensor, - sample_rate=sample_rate, - max_outputs=max_outputs, - name=scope) - - return control_flow_ops.cond(should_record_summaries(), record, _nothing) + def function(tag, scope): + gen_summary_ops.write_audio_summary( + context.context().summary_writer_resource, + training_util.get_global_step(), + tag, + tensor, + sample_rate=sample_rate, + max_outputs=max_outputs, + name=scope) + + return summary_writer_function(name, tensor, function, family=family) diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 56c1a16f7f..4b1f60ce4e 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -17,11 +17,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import tempfile from tensorflow.contrib.summary import summary_ops +from tensorflow.core.util import event_pb2 +from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import test_util +from tensorflow.python.lib.io import tf_record from tensorflow.python.platform import gfile from tensorflow.python.training import training_util @@ -36,7 +40,7 @@ class TargetTest(test_util.TensorFlowTestCase): def testSummaryOps(self): training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() - summary_ops.create_summary_file_writer(logdir, max_queue=0) + summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t0') summary_ops.always_record_summaries() summary_ops.generic('tensor', 1, '') summary_ops.scalar('scalar', 2.0) @@ -47,6 +51,27 @@ class TargetTest(test_util.TensorFlowTestCase): # test here that we're calling them correctly. self.assertTrue(gfile.Exists(logdir)) + def testDefunSummarys(self): + training_util.get_or_create_global_step() + logdir = tempfile.mkdtemp() + summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t1') + summary_ops.always_record_summaries() + + @function.defun + def write(): + summary_ops.scalar('scalar', 2.0) + + write() + + self.assertTrue(gfile.Exists(logdir)) + files = gfile.ListDirectory(logdir) + self.assertEqual(len(files), 1) + records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) + self.assertEqual(len(records), 2) + event = event_pb2.Event() + event.ParseFromString(records[1]) + self.assertEqual(event.summary.value[0].simple_value, 2.0) + if __name__ == '__main__': test.main() diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc index d0eca0f1e7..cfa707de71 100644 --- a/tensorflow/core/kernels/summary_kernels.cc +++ b/tensorflow/core/kernels/summary_kernels.cc @@ -40,12 +40,7 @@ class CreateSummaryFileWriterOp : public OpKernel { SummaryWriterInterface* s; OP_REQUIRES_OK(ctx, CreateSummaryWriter(max_queue, flush_millis, logdir, filename_suffix, ctx->env(), &s)); - Status status = CreateResource(ctx, HandleFromInput(ctx, 0), s); - if (!status.ok()) { - s->Unref(); - ctx->SetStatus(status); - return; - } + OP_REQUIRES_OK(ctx, CreateResource(ctx, HandleFromInput(ctx, 0), s)); } }; REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU), diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 980b6c883f..227520eea8 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -373,6 +373,13 @@ def _defun_internal(name, func, args, kwds): """Defines and returns graph-mode version of func.""" with context.graph_mode(): tmp_graph = ops.Graph() + # Copy the graph collections to ensure summaries and other things work. This + # lets the function access (but not mutate) collections of the containing + # graph, such as the global step and the summary writer collections. + curr_graph = ops.get_default_graph() + for collection in curr_graph.collections: + tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection( + collection) with tmp_graph.as_default(): func_inputs = _get_defun_inputs(args) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index b9bb9f3917..fdc8a5843f 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -30,6 +30,7 @@ from tensorflow.python.eager import tensor_node from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_resource_variable_ops @@ -41,6 +42,29 @@ from tensorflow.python.ops.gen_resource_variable_ops import * from tensorflow.python.util import compat +def _eager_safe_variable_handle(shape, dtype, shared_name, name, + container=None): + """Creates a variable handle with information to do shape inference.""" + handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, + shared_name=shared_name, + name=name, + container=container) + if context.in_graph_mode(): + return handle + with context.graph_mode(), ops.Graph().as_default(): + h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, + shared_name=shared_name, + name=name, + container=container) + + # Tensor._handle_data contains information for the shape-inference code to + # know the shape and dtype of the variable pointed to by a handle. Since + # shape inference doesn't run in eager mode we copy this data here for when + # the handle is captured by an eager mode function. + handle._handle_data = h._handle_data # pylint: disable=protected-access + return handle + + class ResourceVariable(variables.Variable): """Variable based on resource handles. @@ -231,7 +255,7 @@ class ResourceVariable(variables.Variable): if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] self._save_slice_info = None - in_graph_mode = context.in_graph_mode() + self._in_graph_mode = context.in_graph_mode() with ops.control_dependencies(None): with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name: @@ -241,7 +265,7 @@ class ResourceVariable(variables.Variable): # Use attr_scope and device(None) to simulate the behavior of # colocate_with when the variable we want to colocate with doesn't # yet exist. - if in_graph_mode: + if self._in_graph_mode: attr = attr_value_pb2.AttrValue( list=attr_value_pb2.AttrValue.ListValue( s=[compat.as_bytes("loc:@%s" % handle_name)])) @@ -249,26 +273,28 @@ class ResourceVariable(variables.Variable): with ops.name_scope("Initializer"), ops.device(None): initial_value = ops.convert_to_tensor( initial_value(), name="initial_value", dtype=dtype) - self._handle = gen_resource_variable_ops.var_handle_op( + self._handle = _eager_safe_variable_handle( shape=initial_value.get_shape(), dtype=initial_value.dtype.base_dtype, shared_name=handle_name, name=name) - self._handle_device = (self._handle.device if in_graph_mode else - context.get_default_context().device_name) + self._handle_device = ( + self._handle.device if self._in_graph_mode else + context.get_default_context().device_name) else: initial_value = initial_value() with ops.name_scope("Initializer"): initial_value = ops.convert_to_tensor( initial_value, name="initial_value", dtype=dtype) - self._handle = gen_resource_variable_ops.var_handle_op( + self._handle = _eager_safe_variable_handle( shape=initial_value.get_shape(), dtype=initial_value.dtype.base_dtype, shared_name=handle_name, name=name, container="") - self._handle_device = (self._handle.device if in_graph_mode else - context.get_default_context().device_name) + self._handle_device = ( + self._handle.device if self._in_graph_mode else + context.get_default_context().device_name) # pylint: enable=protected-access # Or get the initial value from a Tensor or Python object. @@ -277,7 +303,7 @@ class ResourceVariable(variables.Variable): initial_value = ops.convert_to_tensor( initial_value, name="initial_value", dtype=dtype) # pylint: disable=protected-access - if (in_graph_mode and initial_value is not None and + if (self._in_graph_mode and initial_value is not None and initial_value.op._get_control_flow_context() is not None): raise ValueError( "Initializer for variable %s is from inside a control-flow " @@ -285,21 +311,21 @@ class ResourceVariable(variables.Variable): "variable inside a loop or conditional, use a lambda as the " "initializer." % name) # pylint: enable=protected-access - self._handle = gen_resource_variable_ops.var_handle_op( + self._handle = _eager_safe_variable_handle( shape=initial_value.get_shape(), dtype=initial_value.dtype.base_dtype, shared_name=handle_name, name=name, container="") - self._handle_device = (self._handle.device if in_graph_mode else + self._handle_device = (self._handle.device if self._in_graph_mode else context.get_default_context().device_name) - self._initial_value = initial_value if in_graph_mode else None + self._initial_value = initial_value if self._in_graph_mode else None self._handle_name = handle_name + ":0" self._dtype = initial_value.dtype.base_dtype self._constraint = constraint - if in_graph_mode: + if self._in_graph_mode: with ops.name_scope("IsInitialized"): self._is_initialized_op = ( gen_resource_variable_ops.var_is_initialized_op(self._handle)) @@ -399,10 +425,11 @@ class ResourceVariable(variables.Variable): @property def shape(self): """The shape of this variable.""" - if context.in_graph_mode(): + if self._in_graph_mode: return tensor_shape.TensorShape(self._handle.op.get_attr("shape")) return tensor_shape.TensorShape( - gen_resource_variable_ops.variable_shape(self._handle).numpy()) + tensor_util.constant_value( + gen_resource_variable_ops.variable_shape(self._handle))) @property def create(self): diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py index bf48c75997..9f2f9b7479 100644 --- a/tensorflow/python/training/training_util.py +++ b/tensorflow/python/training/training_util.py @@ -157,6 +157,7 @@ def assert_global_step(global_step_tensor): raise TypeError('Existing "global_step" does not have integer type: %s' % global_step_tensor.dtype) - if global_step_tensor.get_shape().ndims != 0: + if (global_step_tensor.get_shape().ndims != 0 and + global_step_tensor.get_shape().is_fully_defined()): raise TypeError('Existing "global_step" is not scalar: %s' % global_step_tensor.get_shape()) |