aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/summary
diff options
context:
space:
mode:
authorGravatar Justine Tunney <jart@google.com>2017-11-08 10:55:48 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:36 -0800
commit35cc51dc2a716c4b92429db60238e4f15fba1ed3 (patch)
tree397908ffa876253ea4230a0c13e83775841b0201 /tensorflow/contrib/summary
parent4a618e411af3f808eb0f65ce4f7151450f1f16a5 (diff)
Add database writer ops to contrib/summary
PiperOrigin-RevId: 175030602
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r--tensorflow/contrib/summary/BUILD6
-rw-r--r--tensorflow/contrib/summary/summary.py2
-rw-r--r--tensorflow/contrib/summary/summary_ops.py125
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py110
4 files changed, 232 insertions, 11 deletions
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD
index da23f1c380..3c60d2bb56 100644
--- a/tensorflow/contrib/summary/BUILD
+++ b/tensorflow/contrib/summary/BUILD
@@ -26,12 +26,18 @@ py_test(
deps = [
":summary_ops",
":summary_test_util",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:errors",
+ "//tensorflow/python:framework",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:ops",
"//tensorflow/python:platform",
+ "//tensorflow/python:state_ops",
"//tensorflow/python:training",
"//tensorflow/python/eager:function",
"//tensorflow/python/eager:test",
+ "@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py
index ca82ea094c..813e8b2b09 100644
--- a/tensorflow/contrib/summary/summary.py
+++ b/tensorflow/contrib/summary/summary.py
@@ -28,11 +28,13 @@ from __future__ import print_function
from tensorflow.contrib.summary.summary_ops import all_summary_ops
from tensorflow.contrib.summary.summary_ops import always_record_summaries
from tensorflow.contrib.summary.summary_ops import audio
+from tensorflow.contrib.summary.summary_ops import create_summary_db_writer
from tensorflow.contrib.summary.summary_ops import create_summary_file_writer
from tensorflow.contrib.summary.summary_ops import eval_dir
from tensorflow.contrib.summary.summary_ops import generic
from tensorflow.contrib.summary.summary_ops import histogram
from tensorflow.contrib.summary.summary_ops import image
+from tensorflow.contrib.summary.summary_ops import import_event
from tensorflow.contrib.summary.summary_ops import never_record_summaries
from tensorflow.contrib.summary.summary_ops import record_summaries_every_n_global_steps
from tensorflow.contrib.summary.summary_ops import scalar
diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py
index 9238671c4a..f6be99f6ae 100644
--- a/tensorflow/contrib/summary/summary_ops.py
+++ b/tensorflow/contrib/summary/summary_ops.py
@@ -19,7 +19,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import getpass
import os
+import re
+import time
+
+import six
from tensorflow.contrib.summary import gen_summary_ops
from tensorflow.python.eager import context
@@ -42,6 +47,10 @@ _SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries"
_SUMMARY_COLLECTION_NAME = "_SUMMARY_V2"
_SUMMARY_WRITER_INIT_COLLECTION_NAME = "_SUMMARY_WRITER_V2"
+_EXPERIMENT_NAME_PATTERNS = re.compile(r"^[^\x00-\x1F<>]{0,256}$")
+_RUN_NAME_PATTERNS = re.compile(r"^[^\x00-\x1F<>]{0,512}$")
+_USER_NAME_PATTERNS = re.compile(r"^[a-z]([-a-z0-9]{0,29}[a-z0-9])?$", re.I)
+
def should_record_summaries():
"""Returns boolean Tensor which is true if summaries should be recorded."""
@@ -132,7 +141,8 @@ def create_summary_file_writer(logdir,
flush once the queue gets bigger than this.
flush_millis: the largest interval between flushes.
filename_suffix: optional suffix for the event file name.
- name: name for the summary writer.
+ name: Shared name for this SummaryWriter resource stored to default
+ Graph.
Returns:
Either a summary writer or an empty object which can be used as a
@@ -147,14 +157,81 @@ def create_summary_file_writer(logdir,
flush_millis = constant_op.constant(2 * 60 * 1000)
if filename_suffix is None:
filename_suffix = constant_op.constant("")
- resource = gen_summary_ops.summary_writer(shared_name=name)
- # TODO(apassos) ensure the initialization op runs when in graph mode;
- # consider calling session.run here.
- ops.add_to_collection(
- _SUMMARY_WRITER_INIT_COLLECTION_NAME,
- gen_summary_ops.create_summary_file_writer(
- resource, logdir, max_queue, flush_millis, filename_suffix))
- return SummaryWriter(resource)
+ return _make_summary_writer(
+ name,
+ gen_summary_ops.create_summary_file_writer,
+ logdir=logdir,
+ max_queue=max_queue,
+ flush_millis=flush_millis,
+ filename_suffix=filename_suffix)
+
+
+def create_summary_db_writer(db_uri,
+ experiment_name=None,
+ run_name=None,
+ user_name=None,
+ name=None):
+ """Creates a summary database writer in the current context.
+
+ This can be used to write tensors from the execution graph directly
+ to a database. Only SQLite is supported right now. This function
+ will create the schema if it doesn't exist. Entries in the Users,
+ Experiments, and Runs tables will be created automatically if they
+ don't already exist.
+
+ Args:
+ db_uri: For example "file:/tmp/foo.sqlite".
+ experiment_name: Defaults to YYYY-MM-DD in local time if None.
+ Empty string means the Run will not be associated with an
+ Experiment. Can't contain ASCII control characters or <>. Case
+ sensitive.
+ run_name: Defaults to HH:MM:SS in local time if None. Empty string
+ means a Tag will not be associated with any Run. Can't contain
+ ASCII control characters or <>. Case sensitive.
+ user_name: Defaults to system username if None. Empty means the
+ Experiment will not be associated with a User. Must be valid as
+ both a DNS label and Linux username.
+ name: Shared name for this SummaryWriter resource stored to default
+ Graph.
+
+ Returns:
+ A new SummaryWriter instance.
+ """
+ with ops.device("cpu:0"):
+ if experiment_name is None:
+ experiment_name = time.strftime("%Y-%m-%d", time.localtime(time.time()))
+ if run_name is None:
+ run_name = time.strftime("%H:%M:%S", time.localtime(time.time()))
+ if user_name is None:
+ user_name = getpass.getuser()
+ experiment_name = _cleanse_string(
+ "experiment_name", _EXPERIMENT_NAME_PATTERNS, experiment_name)
+ run_name = _cleanse_string("run_name", _RUN_NAME_PATTERNS, run_name)
+ user_name = _cleanse_string("user_name", _USER_NAME_PATTERNS, user_name)
+ return _make_summary_writer(
+ name,
+ gen_summary_ops.create_summary_db_writer,
+ db_uri=db_uri,
+ experiment_name=experiment_name,
+ run_name=run_name,
+ user_name=user_name)
+
+
+def _make_summary_writer(name, factory, **kwargs):
+ resource = gen_summary_ops.summary_writer(shared_name=name)
+ # TODO(apassos): Consider doing this instead.
+ # node = factory(resource, **kwargs)
+ # if not context.in_eager_mode():
+ # ops.get_default_session().run(node)
+ ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME,
+ factory(resource, **kwargs))
+ return SummaryWriter(resource)
+
+
+def _cleanse_string(name, pattern, value):
+ if isinstance(value, six.string_types) and pattern.search(value) is None:
+ raise ValueError("%s (%s) must match %s" % (name, value, pattern.pattern))
+ return ops.convert_to_tensor(value, dtypes.string)
def _nothing():
@@ -206,16 +283,22 @@ def summary_writer_function(name, tensor, function, family=None):
return op
-def generic(name, tensor, metadata, family=None, global_step=None):
+def generic(name, tensor, metadata=None, family=None, global_step=None):
"""Writes a tensor summary if possible."""
if global_step is None:
global_step = training_util.get_global_step()
def function(tag, scope):
+ if metadata is None:
+ serialized_metadata = constant_op.constant("")
+ elif hasattr(metadata, "SerializeToString"):
+ serialized_metadata = constant_op.constant(metadata.SerializeToString())
+ else:
+ serialized_metadata = metadata
# Note the identity to move the tensor to the CPU.
return gen_summary_ops.write_summary(
context.context().summary_writer_resource,
global_step, array_ops.identity(tensor),
- tag, metadata, name=scope)
+ tag, serialized_metadata, name=scope)
return summary_writer_function(name, tensor, function, family=family)
@@ -284,6 +367,26 @@ def audio(name, tensor, sample_rate, max_outputs, family=None,
return summary_writer_function(name, tensor, function, family=family)
+def import_event(tensor, name=None):
+ """Writes a tf.Event binary proto.
+
+ When using create_summary_db_writer(), this can be used alongside
+ tf.TFRecordReader to load event logs into the database. Please note
+ that this is lower level than the other summary functions and will
+ ignore any conditions set by methods like should_record_summaries().
+
+ Args:
+ tensor: A `Tensor` of type `string` containing a serialized `Event`
+ proto.
+ name: A name for the operation (optional).
+
+ Returns:
+ The created Operation.
+ """
+ return gen_summary_ops.import_event(
+ context.context().summary_writer_resource, tensor, name=name)
+
+
def eval_dir(model_dir, name=None):
"""Construct a logdir for an eval summary writer."""
return os.path.join(model_dir, "eval" if not name else "eval_" + name)
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index 466e194096..6e1a746815 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -17,14 +17,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+import os
import tempfile
+import six
+import sqlite3
+
from tensorflow.contrib.summary import summary_ops
from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.eager import function
from tensorflow.python.eager import test
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.platform import gfile
from tensorflow.python.training import training_util
@@ -99,5 +107,107 @@ class TargetTest(test_util.TensorFlowTestCase):
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].tag, 'scalar')
+
+class DbTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite')
+ if os.path.exists(self.db_path):
+ os.unlink(self.db_path)
+ self.db = sqlite3.connect(self.db_path)
+ self.create_summary_db_writer = functools.partial(
+ summary_ops.create_summary_db_writer,
+ db_uri=self.db_path,
+ experiment_name='experiment',
+ run_name='run',
+ user_name='user')
+
+ def tearDown(self):
+ self.db.close()
+
+ def testIntegerSummaries(self):
+ step = training_util.create_global_step()
+
+ def adder(x, y):
+ state_ops.assign_add(step, 1)
+ summary_ops.generic('x', x)
+ summary_ops.generic('y', y)
+ sum_ = x + y
+ summary_ops.generic('sum', sum_)
+ return sum_
+
+ with summary_ops.always_record_summaries():
+ with self.create_summary_db_writer().as_default():
+ self.assertEqual(5, adder(int64(2), int64(3)).numpy())
+
+ six.assertCountEqual(self, [1, 1, 1],
+ get_all(self.db, 'SELECT step FROM Tensors'))
+ six.assertCountEqual(self, ['x', 'y', 'sum'],
+ get_all(self.db, 'SELECT tag_name FROM Tags'))
+ x_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "x"')
+ y_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "y"')
+ sum_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "sum"')
+
+ with summary_ops.always_record_summaries():
+ with self.create_summary_db_writer().as_default():
+ self.assertEqual(9, adder(int64(4), int64(5)).numpy())
+
+ six.assertCountEqual(self, [1, 1, 1, 2, 2, 2],
+ get_all(self.db, 'SELECT step FROM Tensors'))
+ six.assertCountEqual(self, [x_id, y_id, sum_id],
+ get_all(self.db, 'SELECT tag_id FROM Tags'))
+ self.assertEqual(2, get_tensor(self.db, x_id, 1))
+ self.assertEqual(3, get_tensor(self.db, y_id, 1))
+ self.assertEqual(5, get_tensor(self.db, sum_id, 1))
+ self.assertEqual(4, get_tensor(self.db, x_id, 2))
+ self.assertEqual(5, get_tensor(self.db, y_id, 2))
+ self.assertEqual(9, get_tensor(self.db, sum_id, 2))
+ six.assertCountEqual(
+ self, ['experiment'],
+ get_all(self.db, 'SELECT experiment_name FROM Experiments'))
+ six.assertCountEqual(self, ['run'],
+ get_all(self.db, 'SELECT run_name FROM Runs'))
+ six.assertCountEqual(self, ['user'],
+ get_all(self.db, 'SELECT user_name FROM Users'))
+
+ def testBadExperimentName(self):
+ with self.assertRaises(ValueError):
+ self.create_summary_db_writer(experiment_name='\0')
+
+ def testBadRunName(self):
+ with self.assertRaises(ValueError):
+ self.create_summary_db_writer(run_name='\0')
+
+ def testBadUserName(self):
+ with self.assertRaises(ValueError):
+ self.create_summary_db_writer(user_name='-hi')
+ with self.assertRaises(ValueError):
+ self.create_summary_db_writer(user_name='hi-')
+ with self.assertRaises(ValueError):
+ self.create_summary_db_writer(user_name='@')
+
+
+def get_one(db, q, *p):
+ return db.execute(q, p).fetchone()[0]
+
+
+def get_all(db, q, *p):
+ return unroll(db.execute(q, p).fetchall())
+
+
+def get_tensor(db, tag_id, step):
+ return get_one(
+ db, 'SELECT tensor FROM Tensors WHERE tag_id = ? AND step = ?', tag_id,
+ step)
+
+
+def int64(x):
+ return array_ops.constant(x, dtypes.int64)
+
+
+def unroll(list_of_tuples):
+ return sum(list_of_tuples, ())
+
+
if __name__ == '__main__':
test.main()