aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/summary
diff options
context:
space:
mode:
authorGravatar Justine Tunney <jart@google.com>2018-01-11 16:08:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-11 16:12:46 -0800
commitfebdd26ae594133d24f82544706b1e012a5cf1ea (patch)
treedd325008019ab10ce35f98368bf392ce4a118ec9 /tensorflow/contrib/summary
parentfc252eb976c98c95a625ea6e6a0486334d3c5b6e (diff)
Add reservoir sampling to DB summary writer
This thing is kind of cool. It's able to turn a 350mB event log into a 35mB SQLite file at 80mBps with one Macbook core. Best of all, this was accomplished using a normalized schema without the embedded protos. PiperOrigin-RevId: 181676380
Diffstat (limited to 'tensorflow/contrib/summary')
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py41
1 files changed, 34 insertions, 7 deletions
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index 4ef03434b7..dfaa4182bb 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -18,12 +18,14 @@ from __future__ import print_function
import tempfile
+import numpy as np
import six
from tensorflow.contrib.summary import summary_ops
from tensorflow.contrib.summary import summary_test_util
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
+from tensorflow.core.framework import types_pb2
from tensorflow.python.eager import function
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
@@ -37,6 +39,23 @@ from tensorflow.python.training import training_util
get_all = summary_test_util.get_all
get_one = summary_test_util.get_one
+_NUMPY_NUMERIC_TYPES = {
+ types_pb2.DT_HALF: np.float16,
+ types_pb2.DT_FLOAT: np.float32,
+ types_pb2.DT_DOUBLE: np.float64,
+ types_pb2.DT_INT8: np.int8,
+ types_pb2.DT_INT16: np.int16,
+ types_pb2.DT_INT32: np.int32,
+ types_pb2.DT_INT64: np.int64,
+ types_pb2.DT_UINT8: np.uint8,
+ types_pb2.DT_UINT16: np.uint16,
+ types_pb2.DT_UINT32: np.uint32,
+ types_pb2.DT_UINT64: np.uint64,
+ types_pb2.DT_COMPLEX64: np.complex64,
+ types_pb2.DT_COMPLEX128: np.complex128,
+ types_pb2.DT_BOOL: np.bool_,
+}
+
class TargetTest(test_util.TensorFlowTestCase):
@@ -154,8 +173,9 @@ class DbTest(summary_test_util.SummaryDbTest):
with writer.as_default():
self.assertEqual(5, adder(int64(2), int64(3)).numpy())
- six.assertCountEqual(self, [1, 1, 1],
- get_all(self.db, 'SELECT step FROM Tensors'))
+ six.assertCountEqual(
+ self, [1, 1, 1],
+ get_all(self.db, 'SELECT step FROM Tensors WHERE dtype IS NOT NULL'))
six.assertCountEqual(self, ['x', 'y', 'sum'],
get_all(self.db, 'SELECT tag_name FROM Tags'))
x_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "x"')
@@ -166,8 +186,9 @@ class DbTest(summary_test_util.SummaryDbTest):
with writer.as_default():
self.assertEqual(9, adder(int64(4), int64(5)).numpy())
- six.assertCountEqual(self, [1, 1, 1, 2, 2, 2],
- get_all(self.db, 'SELECT step FROM Tensors'))
+ six.assertCountEqual(
+ self, [1, 1, 1, 2, 2, 2],
+ get_all(self.db, 'SELECT step FROM Tensors WHERE dtype IS NOT NULL'))
six.assertCountEqual(self, [x_id, y_id, sum_id],
get_all(self.db, 'SELECT tag_id FROM Tags'))
self.assertEqual(2, get_tensor(self.db, x_id, 1))
@@ -212,9 +233,15 @@ class DbTest(summary_test_util.SummaryDbTest):
def get_tensor(db, tag_id, step):
- return get_one(
- db, 'SELECT tensor FROM Tensors WHERE tag_id = ? AND step = ?', tag_id,
- step)
+ cursor = db.execute(
+ 'SELECT dtype, shape, data FROM Tensors WHERE series = ? AND step = ?',
+ (tag_id, step))
+ dtype, shape, data = cursor.fetchone()
+ assert dtype in _NUMPY_NUMERIC_TYPES
+ buf = np.frombuffer(data, dtype=_NUMPY_NUMERIC_TYPES[dtype])
+ if not shape:
+ return buf[0]
+ return buf.reshape([int(i) for i in shape.split(',')])
def int64(x):