aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-09-06 11:03:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-06 11:07:11 -0700
commit86b94d8a17783773681b585df474bbda300b62f7 (patch)
tree18ebf42967bb59a68b885824d6231f1b85558dbe
parentc4cb861fd9b752f2d4ce9205489bc4325b8ead32 (diff)
contrib summaries work in eager-graph mode (with defun)
As a side effect fix issues related to using eager-defined variables in graph mode. PiperOrigin-RevId: 167744121
-rw-r--r--tensorflow/contrib/summary/BUILD5
-rw-r--r--tensorflow/contrib/summary/summary_ops.py98
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py27
-rw-r--r--tensorflow/core/kernels/summary_kernels.cc7
-rw-r--r--tensorflow/python/eager/function.py7
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py57
-rw-r--r--tensorflow/python/training/training_util.py3
7 files changed, 137 insertions, 67 deletions
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD
index bc30502264..527deab86a 100644
--- a/tensorflow/contrib/summary/BUILD
+++ b/tensorflow/contrib/summary/BUILD
@@ -22,10 +22,12 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":summary_ops",
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:lib",
"//tensorflow/python:platform",
"//tensorflow/python:training",
- "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:function",
"//tensorflow/python/eager:test",
],
)
@@ -38,6 +40,7 @@ py_library(
deps = [
":gen_summary_ops",
"//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:summary_op_util",
diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py
index 05e627adf1..ceaf83b70a 100644
--- a/tensorflow/contrib/summary/summary_ops.py
+++ b/tensorflow/contrib/summary/summary_ops.py
@@ -68,7 +68,8 @@ def never_record_summaries():
def create_summary_file_writer(logdir,
max_queue=None,
flush_secs=None,
- filename_suffix=None):
+ filename_suffix=None,
+ name=None):
"""Creates a summary file writer in the current context."""
if max_queue is None:
max_queue = constant_op.constant(10)
@@ -76,7 +77,7 @@ def create_summary_file_writer(logdir,
flush_secs = constant_op.constant(120)
if filename_suffix is None:
filename_suffix = constant_op.constant("")
- resource = gen_summary_ops.summary_writer()
+ resource = gen_summary_ops.summary_writer(shared_name=name)
gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue,
flush_secs, filename_suffix)
context.context().summary_writer_resource = resource
@@ -84,76 +85,87 @@ def create_summary_file_writer(logdir,
def _nothing():
"""Convenient else branch for when summaries do not record."""
- return
+ return False
-def generic(name, tensor, metadata, family=None):
- """Writes a tensor summary if possible."""
+def summary_writer_function(name, tensor, function, family=None):
+ """Helper function to write summaries.
+ Args:
+ name: name of the summary
+ tensor: main tensor to form the summary
+ function: function taking a tag and a scope which writes the summary
+ family: optional, the summary's family
+
+ Returns:
+ The result of writing the summary.
+ """
def record():
with summary_op_util.summary_scope(
name, family, values=[tensor]) as (tag, scope):
- gen_summary_ops.write_summary(context.context().summary_writer_resource,
- training_util.get_global_step(), tensor,
- tag, metadata, name=scope)
+ function(tag, scope)
+ return True
+
return control_flow_ops.cond(should_record_summaries(), record, _nothing)
+def generic(name, tensor, metadata, family=None):
+ """Writes a tensor summary if possible."""
+
+ def function(tag, scope):
+ gen_summary_ops.write_summary(context.context().summary_writer_resource,
+ training_util.get_global_step(), tensor,
+ tag, metadata, name=scope)
+ return summary_writer_function(name, tensor, function, family=family)
+
+
def scalar(name, tensor, family=None):
"""Writes a scalar summary if possible."""
- def record():
- with summary_op_util.summary_scope(
- name, family, values=[tensor]) as (tag, scope):
- gen_summary_ops.write_scalar_summary(
- context.context().summary_writer_resource,
- training_util.get_global_step(), tag, tensor, name=scope)
+ def function(tag, scope):
+ gen_summary_ops.write_scalar_summary(
+ context.context().summary_writer_resource,
+ training_util.get_global_step(), tag, tensor, name=scope)
- return control_flow_ops.cond(should_record_summaries(), record, _nothing)
+ return summary_writer_function(name, tensor, function, family=family)
def histogram(name, tensor, family=None):
"""Writes a histogram summary if possible."""
- def record():
- with summary_op_util.summary_scope(
- name, family, values=[tensor]) as (tag, scope):
- gen_summary_ops.write_histogram_summary(
- context.context().summary_writer_resource,
- training_util.get_global_step(), tag, tensor, name=scope)
+ def function(tag, scope):
+ gen_summary_ops.write_histogram_summary(
+ context.context().summary_writer_resource,
+ training_util.get_global_step(), tag, tensor, name=scope)
- return control_flow_ops.cond(should_record_summaries(), record, _nothing)
+ return summary_writer_function(name, tensor, function, family=family)
def image(name, tensor, bad_color=None, max_images=3, family=None):
"""Writes an image summary if possible."""
- def record():
+ def function(tag, scope):
if bad_color is None:
bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8)
- with summary_op_util.summary_scope(
- name, family, values=[tensor]) as (tag, scope):
- gen_summary_ops.write_image_summary(
- context.context().summary_writer_resource,
- training_util.get_global_step(), tag, tensor, bad_color_, max_images,
- name=scope)
+ gen_summary_ops.write_image_summary(
+ context.context().summary_writer_resource,
+ training_util.get_global_step(), tag, tensor, bad_color_, max_images,
+ name=scope)
- return control_flow_ops.cond(should_record_summaries(), record, _nothing)
+ return summary_writer_function(name, tensor, function, family=family)
def audio(name, tensor, sample_rate, max_outputs, family=None):
"""Writes an audio summary if possible."""
- def record():
- with summary_op_util.summary_scope(
- name, family, values=[tensor]) as (tag, scope):
- gen_summary_ops.write_audio_summary(
- context.context().summary_writer_resource,
- training_util.get_global_step(),
- tag,
- tensor,
- sample_rate=sample_rate,
- max_outputs=max_outputs,
- name=scope)
-
- return control_flow_ops.cond(should_record_summaries(), record, _nothing)
+ def function(tag, scope):
+ gen_summary_ops.write_audio_summary(
+ context.context().summary_writer_resource,
+ training_util.get_global_step(),
+ tag,
+ tensor,
+ sample_rate=sample_rate,
+ max_outputs=max_outputs,
+ name=scope)
+
+ return summary_writer_function(name, tensor, function, family=family)
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index 56c1a16f7f..4b1f60ce4e 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -17,11 +17,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
import tempfile
from tensorflow.contrib.summary import summary_ops
+from tensorflow.core.util import event_pb2
+from tensorflow.python.eager import function
from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
+from tensorflow.python.lib.io import tf_record
from tensorflow.python.platform import gfile
from tensorflow.python.training import training_util
@@ -36,7 +40,7 @@ class TargetTest(test_util.TensorFlowTestCase):
def testSummaryOps(self):
training_util.get_or_create_global_step()
logdir = tempfile.mkdtemp()
- summary_ops.create_summary_file_writer(logdir, max_queue=0)
+ summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t0')
summary_ops.always_record_summaries()
summary_ops.generic('tensor', 1, '')
summary_ops.scalar('scalar', 2.0)
@@ -47,6 +51,27 @@ class TargetTest(test_util.TensorFlowTestCase):
# test here that we're calling them correctly.
self.assertTrue(gfile.Exists(logdir))
+ def testDefunSummarys(self):
+ training_util.get_or_create_global_step()
+ logdir = tempfile.mkdtemp()
+ summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t1')
+ summary_ops.always_record_summaries()
+
+ @function.defun
+ def write():
+ summary_ops.scalar('scalar', 2.0)
+
+ write()
+
+ self.assertTrue(gfile.Exists(logdir))
+ files = gfile.ListDirectory(logdir)
+ self.assertEqual(len(files), 1)
+ records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
+ self.assertEqual(len(records), 2)
+ event = event_pb2.Event()
+ event.ParseFromString(records[1])
+ self.assertEqual(event.summary.value[0].simple_value, 2.0)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc
index d0eca0f1e7..cfa707de71 100644
--- a/tensorflow/core/kernels/summary_kernels.cc
+++ b/tensorflow/core/kernels/summary_kernels.cc
@@ -40,12 +40,7 @@ class CreateSummaryFileWriterOp : public OpKernel {
SummaryWriterInterface* s;
OP_REQUIRES_OK(ctx, CreateSummaryWriter(max_queue, flush_millis, logdir,
filename_suffix, ctx->env(), &s));
- Status status = CreateResource(ctx, HandleFromInput(ctx, 0), s);
- if (!status.ok()) {
- s->Unref();
- ctx->SetStatus(status);
- return;
- }
+ OP_REQUIRES_OK(ctx, CreateResource(ctx, HandleFromInput(ctx, 0), s));
}
};
REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU),
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 980b6c883f..227520eea8 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -373,6 +373,13 @@ def _defun_internal(name, func, args, kwds):
"""Defines and returns graph-mode version of func."""
with context.graph_mode():
tmp_graph = ops.Graph()
+ # Copy the graph collections to ensure summaries and other things work. This
+ # lets the function access (but not mutate) collections of the containing
+ # graph, such as the global step and the summary writer collections.
+ curr_graph = ops.get_default_graph()
+ for collection in curr_graph.collections:
+ tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
+ collection)
with tmp_graph.as_default():
func_inputs = _get_defun_inputs(args)
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index b9bb9f3917..fdc8a5843f 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -30,6 +30,7 @@ from tensorflow.python.eager import tensor_node
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_resource_variable_ops
@@ -41,6 +42,29 @@ from tensorflow.python.ops.gen_resource_variable_ops import *
from tensorflow.python.util import compat
+def _eager_safe_variable_handle(shape, dtype, shared_name, name,
+ container=None):
+ """Creates a variable handle with information to do shape inference."""
+ handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
+ shared_name=shared_name,
+ name=name,
+ container=container)
+ if context.in_graph_mode():
+ return handle
+ with context.graph_mode(), ops.Graph().as_default():
+ h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
+ shared_name=shared_name,
+ name=name,
+ container=container)
+
+ # Tensor._handle_data contains information for the shape-inference code to
+ # know the shape and dtype of the variable pointed to by a handle. Since
+ # shape inference doesn't run in eager mode we copy this data here for when
+ # the handle is captured by an eager mode function.
+ handle._handle_data = h._handle_data # pylint: disable=protected-access
+ return handle
+
+
class ResourceVariable(variables.Variable):
"""Variable based on resource handles.
@@ -231,7 +255,7 @@ class ResourceVariable(variables.Variable):
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
self._save_slice_info = None
- in_graph_mode = context.in_graph_mode()
+ self._in_graph_mode = context.in_graph_mode()
with ops.control_dependencies(None):
with ops.name_scope(name, "Variable", []
if init_from_fn else [initial_value]) as name:
@@ -241,7 +265,7 @@ class ResourceVariable(variables.Variable):
# Use attr_scope and device(None) to simulate the behavior of
# colocate_with when the variable we want to colocate with doesn't
# yet exist.
- if in_graph_mode:
+ if self._in_graph_mode:
attr = attr_value_pb2.AttrValue(
list=attr_value_pb2.AttrValue.ListValue(
s=[compat.as_bytes("loc:@%s" % handle_name)]))
@@ -249,26 +273,28 @@ class ResourceVariable(variables.Variable):
with ops.name_scope("Initializer"), ops.device(None):
initial_value = ops.convert_to_tensor(
initial_value(), name="initial_value", dtype=dtype)
- self._handle = gen_resource_variable_ops.var_handle_op(
+ self._handle = _eager_safe_variable_handle(
shape=initial_value.get_shape(),
dtype=initial_value.dtype.base_dtype,
shared_name=handle_name,
name=name)
- self._handle_device = (self._handle.device if in_graph_mode else
- context.get_default_context().device_name)
+ self._handle_device = (
+ self._handle.device if self._in_graph_mode else
+ context.get_default_context().device_name)
else:
initial_value = initial_value()
with ops.name_scope("Initializer"):
initial_value = ops.convert_to_tensor(
initial_value, name="initial_value", dtype=dtype)
- self._handle = gen_resource_variable_ops.var_handle_op(
+ self._handle = _eager_safe_variable_handle(
shape=initial_value.get_shape(),
dtype=initial_value.dtype.base_dtype,
shared_name=handle_name,
name=name,
container="")
- self._handle_device = (self._handle.device if in_graph_mode else
- context.get_default_context().device_name)
+ self._handle_device = (
+ self._handle.device if self._in_graph_mode else
+ context.get_default_context().device_name)
# pylint: enable=protected-access
# Or get the initial value from a Tensor or Python object.
@@ -277,7 +303,7 @@ class ResourceVariable(variables.Variable):
initial_value = ops.convert_to_tensor(
initial_value, name="initial_value", dtype=dtype)
# pylint: disable=protected-access
- if (in_graph_mode and initial_value is not None and
+ if (self._in_graph_mode and initial_value is not None and
initial_value.op._get_control_flow_context() is not None):
raise ValueError(
"Initializer for variable %s is from inside a control-flow "
@@ -285,21 +311,21 @@ class ResourceVariable(variables.Variable):
"variable inside a loop or conditional, use a lambda as the "
"initializer." % name)
# pylint: enable=protected-access
- self._handle = gen_resource_variable_ops.var_handle_op(
+ self._handle = _eager_safe_variable_handle(
shape=initial_value.get_shape(),
dtype=initial_value.dtype.base_dtype,
shared_name=handle_name,
name=name,
container="")
- self._handle_device = (self._handle.device if in_graph_mode else
+ self._handle_device = (self._handle.device if self._in_graph_mode else
context.get_default_context().device_name)
- self._initial_value = initial_value if in_graph_mode else None
+ self._initial_value = initial_value if self._in_graph_mode else None
self._handle_name = handle_name + ":0"
self._dtype = initial_value.dtype.base_dtype
self._constraint = constraint
- if in_graph_mode:
+ if self._in_graph_mode:
with ops.name_scope("IsInitialized"):
self._is_initialized_op = (
gen_resource_variable_ops.var_is_initialized_op(self._handle))
@@ -399,10 +425,11 @@ class ResourceVariable(variables.Variable):
@property
def shape(self):
"""The shape of this variable."""
- if context.in_graph_mode():
+ if self._in_graph_mode:
return tensor_shape.TensorShape(self._handle.op.get_attr("shape"))
return tensor_shape.TensorShape(
- gen_resource_variable_ops.variable_shape(self._handle).numpy())
+ tensor_util.constant_value(
+ gen_resource_variable_ops.variable_shape(self._handle)))
@property
def create(self):
diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py
index bf48c75997..9f2f9b7479 100644
--- a/tensorflow/python/training/training_util.py
+++ b/tensorflow/python/training/training_util.py
@@ -157,6 +157,7 @@ def assert_global_step(global_step_tensor):
raise TypeError('Existing "global_step" does not have integer type: %s' %
global_step_tensor.dtype)
- if global_step_tensor.get_shape().ndims != 0:
+ if (global_step_tensor.get_shape().ndims != 0 and
+ global_step_tensor.get_shape().is_fully_defined()):
raise TypeError('Existing "global_step" is not scalar: %s' %
global_step_tensor.get_shape())