diff options
author | 2017-11-08 10:55:48 -0800 | |
---|---|---|
committer | 2017-11-10 16:14:36 -0800 | |
commit | 35cc51dc2a716c4b92429db60238e4f15fba1ed3 (patch) | |
tree | 397908ffa876253ea4230a0c13e83775841b0201 /tensorflow/contrib/summary | |
parent | 4a618e411af3f808eb0f65ce4f7151450f1f16a5 (diff) |
Add database writer ops to contrib/summary
PiperOrigin-RevId: 175030602
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r-- | tensorflow/contrib/summary/BUILD | 6 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops.py | 125 | ||||
-rw-r--r-- | tensorflow/contrib/summary/summary_ops_test.py | 110 |
4 files changed, 232 insertions, 11 deletions
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index da23f1c380..3c60d2bb56 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -26,12 +26,18 @@ py_test( deps = [ ":summary_ops", ":summary_test_util", + "//tensorflow/python:array_ops", "//tensorflow/python:errors", + "//tensorflow/python:framework", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:ops", "//tensorflow/python:platform", + "//tensorflow/python:state_ops", "//tensorflow/python:training", "//tensorflow/python/eager:function", "//tensorflow/python/eager:test", + "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py index ca82ea094c..813e8b2b09 100644 --- a/tensorflow/contrib/summary/summary.py +++ b/tensorflow/contrib/summary/summary.py @@ -28,11 +28,13 @@ from __future__ import print_function from tensorflow.contrib.summary.summary_ops import all_summary_ops from tensorflow.contrib.summary.summary_ops import always_record_summaries from tensorflow.contrib.summary.summary_ops import audio +from tensorflow.contrib.summary.summary_ops import create_summary_db_writer from tensorflow.contrib.summary.summary_ops import create_summary_file_writer from tensorflow.contrib.summary.summary_ops import eval_dir from tensorflow.contrib.summary.summary_ops import generic from tensorflow.contrib.summary.summary_ops import histogram from tensorflow.contrib.summary.summary_ops import image +from tensorflow.contrib.summary.summary_ops import import_event from tensorflow.contrib.summary.summary_ops import never_record_summaries from tensorflow.contrib.summary.summary_ops import record_summaries_every_n_global_steps from tensorflow.contrib.summary.summary_ops import scalar diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index 9238671c4a..f6be99f6ae 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -19,7 +19,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import getpass import os +import re +import time + +import six from tensorflow.contrib.summary import gen_summary_ops from tensorflow.python.eager import context @@ -42,6 +47,10 @@ _SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries" _SUMMARY_COLLECTION_NAME = "_SUMMARY_V2" _SUMMARY_WRITER_INIT_COLLECTION_NAME = "_SUMMARY_WRITER_V2" +_EXPERIMENT_NAME_PATTERNS = re.compile(r"^[^\x00-\x1F<>]{0,256}$") +_RUN_NAME_PATTERNS = re.compile(r"^[^\x00-\x1F<>]{0,512}$") +_USER_NAME_PATTERNS = re.compile(r"^[a-z]([-a-z0-9]{0,29}[a-z0-9])?$", re.I) + def should_record_summaries(): """Returns boolean Tensor which is true if summaries should be recorded.""" @@ -132,7 +141,8 @@ def create_summary_file_writer(logdir, flush once the queue gets bigger than this. flush_millis: the largest interval between flushes. filename_suffix: optional suffix for the event file name. - name: name for the summary writer. + name: Shared name for this SummaryWriter resource stored to default + Graph. Returns: Either a summary writer or an empty object which can be used as a @@ -147,14 +157,81 @@ def create_summary_file_writer(logdir, flush_millis = constant_op.constant(2 * 60 * 1000) if filename_suffix is None: filename_suffix = constant_op.constant("") - resource = gen_summary_ops.summary_writer(shared_name=name) - # TODO(apassos) ensure the initialization op runs when in graph mode; - # consider calling session.run here. - ops.add_to_collection( - _SUMMARY_WRITER_INIT_COLLECTION_NAME, - gen_summary_ops.create_summary_file_writer( - resource, logdir, max_queue, flush_millis, filename_suffix)) - return SummaryWriter(resource) + return _make_summary_writer( + name, + gen_summary_ops.create_summary_file_writer, + logdir=logdir, + max_queue=max_queue, + flush_millis=flush_millis, + filename_suffix=filename_suffix) + + +def create_summary_db_writer(db_uri, + experiment_name=None, + run_name=None, + user_name=None, + name=None): + """Creates a summary database writer in the current context. + + This can be used to write tensors from the execution graph directly + to a database. Only SQLite is supported right now. This function + will create the schema if it doesn't exist. Entries in the Users, + Experiments, and Runs tables will be created automatically if they + don't already exist. + + Args: + db_uri: For example "file:/tmp/foo.sqlite". + experiment_name: Defaults to YYYY-MM-DD in local time if None. + Empty string means the Run will not be associated with an + Experiment. Can't contain ASCII control characters or <>. Case + sensitive. + run_name: Defaults to HH:MM:SS in local time if None. Empty string + means a Tag will not be associated with any Run. Can't contain + ASCII control characters or <>. Case sensitive. + user_name: Defaults to system username if None. Empty means the + Experiment will not be associated with a User. Must be valid as + both a DNS label and Linux username. + name: Shared name for this SummaryWriter resource stored to default + Graph. + + Returns: + A new SummaryWriter instance. + """ + with ops.device("cpu:0"): + if experiment_name is None: + experiment_name = time.strftime("%Y-%m-%d", time.localtime(time.time())) + if run_name is None: + run_name = time.strftime("%H:%M:%S", time.localtime(time.time())) + if user_name is None: + user_name = getpass.getuser() + experiment_name = _cleanse_string( + "experiment_name", _EXPERIMENT_NAME_PATTERNS, experiment_name) + run_name = _cleanse_string("run_name", _RUN_NAME_PATTERNS, run_name) + user_name = _cleanse_string("user_name", _USER_NAME_PATTERNS, user_name) + return _make_summary_writer( + name, + gen_summary_ops.create_summary_db_writer, + db_uri=db_uri, + experiment_name=experiment_name, + run_name=run_name, + user_name=user_name) + + +def _make_summary_writer(name, factory, **kwargs): + resource = gen_summary_ops.summary_writer(shared_name=name) + # TODO(apassos): Consider doing this instead. + # node = factory(resource, **kwargs) + # if not context.in_eager_mode(): + # ops.get_default_session().run(node) + ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, + factory(resource, **kwargs)) + return SummaryWriter(resource) + + +def _cleanse_string(name, pattern, value): + if isinstance(value, six.string_types) and pattern.search(value) is None: + raise ValueError("%s (%s) must match %s" % (name, value, pattern.pattern)) + return ops.convert_to_tensor(value, dtypes.string) def _nothing(): @@ -206,16 +283,22 @@ def summary_writer_function(name, tensor, function, family=None): return op -def generic(name, tensor, metadata, family=None, global_step=None): +def generic(name, tensor, metadata=None, family=None, global_step=None): """Writes a tensor summary if possible.""" if global_step is None: global_step = training_util.get_global_step() def function(tag, scope): + if metadata is None: + serialized_metadata = constant_op.constant("") + elif hasattr(metadata, "SerializeToString"): + serialized_metadata = constant_op.constant(metadata.SerializeToString()) + else: + serialized_metadata = metadata # Note the identity to move the tensor to the CPU. return gen_summary_ops.write_summary( context.context().summary_writer_resource, global_step, array_ops.identity(tensor), - tag, metadata, name=scope) + tag, serialized_metadata, name=scope) return summary_writer_function(name, tensor, function, family=family) @@ -284,6 +367,26 @@ def audio(name, tensor, sample_rate, max_outputs, family=None, return summary_writer_function(name, tensor, function, family=family) +def import_event(tensor, name=None): + """Writes a tf.Event binary proto. + + When using create_summary_db_writer(), this can be used alongside + tf.TFRecordReader to load event logs into the database. Please note + that this is lower level than the other summary functions and will + ignore any conditions set by methods like should_record_summaries(). + + Args: + tensor: A `Tensor` of type `string` containing a serialized `Event` + proto. + name: A name for the operation (optional). + + Returns: + The created Operation. + """ + return gen_summary_ops.import_event( + context.context().summary_writer_resource, tensor, name=name) + + def eval_dir(model_dir, name=None): """Construct a logdir for an eval summary writer.""" return os.path.join(model_dir, "eval" if not name else "eval_" + name) diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index 466e194096..6e1a746815 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -17,14 +17,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools +import os import tempfile +import six +import sqlite3 + from tensorflow.contrib.summary import summary_ops from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import function from tensorflow.python.eager import test +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import state_ops from tensorflow.python.platform import gfile from tensorflow.python.training import training_util @@ -99,5 +107,107 @@ class TargetTest(test_util.TensorFlowTestCase): self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'scalar') + +class DbTest(test_util.TensorFlowTestCase): + + def setUp(self): + self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite') + if os.path.exists(self.db_path): + os.unlink(self.db_path) + self.db = sqlite3.connect(self.db_path) + self.create_summary_db_writer = functools.partial( + summary_ops.create_summary_db_writer, + db_uri=self.db_path, + experiment_name='experiment', + run_name='run', + user_name='user') + + def tearDown(self): + self.db.close() + + def testIntegerSummaries(self): + step = training_util.create_global_step() + + def adder(x, y): + state_ops.assign_add(step, 1) + summary_ops.generic('x', x) + summary_ops.generic('y', y) + sum_ = x + y + summary_ops.generic('sum', sum_) + return sum_ + + with summary_ops.always_record_summaries(): + with self.create_summary_db_writer().as_default(): + self.assertEqual(5, adder(int64(2), int64(3)).numpy()) + + six.assertCountEqual(self, [1, 1, 1], + get_all(self.db, 'SELECT step FROM Tensors')) + six.assertCountEqual(self, ['x', 'y', 'sum'], + get_all(self.db, 'SELECT tag_name FROM Tags')) + x_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "x"') + y_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "y"') + sum_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "sum"') + + with summary_ops.always_record_summaries(): + with self.create_summary_db_writer().as_default(): + self.assertEqual(9, adder(int64(4), int64(5)).numpy()) + + six.assertCountEqual(self, [1, 1, 1, 2, 2, 2], + get_all(self.db, 'SELECT step FROM Tensors')) + six.assertCountEqual(self, [x_id, y_id, sum_id], + get_all(self.db, 'SELECT tag_id FROM Tags')) + self.assertEqual(2, get_tensor(self.db, x_id, 1)) + self.assertEqual(3, get_tensor(self.db, y_id, 1)) + self.assertEqual(5, get_tensor(self.db, sum_id, 1)) + self.assertEqual(4, get_tensor(self.db, x_id, 2)) + self.assertEqual(5, get_tensor(self.db, y_id, 2)) + self.assertEqual(9, get_tensor(self.db, sum_id, 2)) + six.assertCountEqual( + self, ['experiment'], + get_all(self.db, 'SELECT experiment_name FROM Experiments')) + six.assertCountEqual(self, ['run'], + get_all(self.db, 'SELECT run_name FROM Runs')) + six.assertCountEqual(self, ['user'], + get_all(self.db, 'SELECT user_name FROM Users')) + + def testBadExperimentName(self): + with self.assertRaises(ValueError): + self.create_summary_db_writer(experiment_name='\0') + + def testBadRunName(self): + with self.assertRaises(ValueError): + self.create_summary_db_writer(run_name='\0') + + def testBadUserName(self): + with self.assertRaises(ValueError): + self.create_summary_db_writer(user_name='-hi') + with self.assertRaises(ValueError): + self.create_summary_db_writer(user_name='hi-') + with self.assertRaises(ValueError): + self.create_summary_db_writer(user_name='@') + + +def get_one(db, q, *p): + return db.execute(q, p).fetchone()[0] + + +def get_all(db, q, *p): + return unroll(db.execute(q, p).fetchall()) + + +def get_tensor(db, tag_id, step): + return get_one( + db, 'SELECT tensor FROM Tensors WHERE tag_id = ? AND step = ?', tag_id, + step) + + +def int64(x): + return array_ops.constant(x, dtypes.int64) + + +def unroll(list_of_tuples): + return sum(list_of_tuples, ()) + + if __name__ == '__main__': test.main() |