aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/callbacks_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/callbacks_test.py')
-rw-r--r--tensorflow/python/keras/callbacks_test.py208
1 files changed, 193 insertions, 15 deletions
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index 92d66c95f6..7d830078ce 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -27,6 +27,7 @@ import unittest
import numpy as np
+from tensorflow.core.framework import summary_pb2
from tensorflow.python import keras
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
@@ -812,21 +813,6 @@ class KerasCallbacksTest(test.TestCase):
for cb in cbs:
cb.on_train_end()
- # fit generator with validation data generator should raise ValueError if
- # histogram_freq > 0
- cbs = callbacks_factory(histogram_freq=1)
- with self.assertRaises(ValueError):
- model.fit_generator(
- data_generator(True),
- len(x_train),
- epochs=2,
- validation_data=data_generator(False),
- validation_steps=1,
- callbacks=cbs)
-
- for cb in cbs:
- cb.on_train_end()
-
# Make sure file writer cache is clear to avoid failures during cleanup.
writer_cache.FileWriterCache.clear()
@@ -901,6 +887,130 @@ class KerasCallbacksTest(test.TestCase):
callbacks=callbacks_factory(histogram_freq=1))
assert os.path.isdir(filepath)
+ def test_Tensorboard_histogram_summaries_in_test_function(self):
+
+ class FileWriterStub(object):
+
+ def __init__(self, logdir, graph=None):
+ self.logdir = logdir
+ self.graph = graph
+ self.steps_seen = []
+
+ def add_summary(self, summary, global_step):
+ summary_obj = summary_pb2.Summary()
+
+ # ensure a valid Summary proto is being sent
+ if isinstance(summary, bytes):
+ summary_obj.ParseFromString(summary)
+ else:
+ assert isinstance(summary, summary_pb2.Summary)
+ summary_obj = summary
+
+ # keep track of steps seen for the merged_summary op,
+ # which contains the histogram summaries
+ if len(summary_obj.value) > 1:
+ self.steps_seen.append(global_step)
+
+ def flush(self):
+ pass
+
+ def close(self):
+ pass
+
+ np.random.seed(1337)
+ tmpdir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, tmpdir)
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=TRAIN_SAMPLES,
+ test_samples=TEST_SAMPLES,
+ input_shape=(INPUT_DIM,),
+ num_classes=NUM_CLASSES)
+ y_test = keras.utils.to_categorical(y_test)
+ y_train = keras.utils.to_categorical(y_train)
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.Dense(
+ NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
+ # non_trainable_weights: moving_variance, moving_mean
+ model.add(keras.layers.BatchNormalization())
+ model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
+ model.compile(
+ loss='categorical_crossentropy',
+ optimizer='sgd',
+ metrics=['accuracy'])
+ tsb = keras.callbacks.TensorBoard(
+ log_dir=tmpdir,
+ histogram_freq=1,
+ write_images=True,
+ write_grads=True,
+ batch_size=5)
+ tsb._writer_class = FileWriterStub
+ cbks = [tsb]
+
+ # fit with validation data
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=BATCH_SIZE,
+ validation_data=(x_test, y_test),
+ callbacks=cbks,
+ epochs=3,
+ verbose=0)
+
+ self.assertAllEqual(tsb.writer.steps_seen, [0, 0.5, 1, 1.5, 2, 2.5])
+
+ def test_Tensorboard_histogram_summaries_with_generator(self):
+ np.random.seed(1337)
+ tmpdir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, tmpdir)
+
+ def generator():
+ x = np.random.randn(10, 100).astype(np.float32)
+ y = np.random.randn(10, 10).astype(np.float32)
+ while True:
+ yield x, y
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(10, input_dim=100, activation='relu'))
+ model.add(keras.layers.Dense(10, activation='softmax'))
+ model.compile(
+ loss='categorical_crossentropy',
+ optimizer='sgd',
+ metrics=['accuracy'])
+ tsb = keras.callbacks.TensorBoard(
+ log_dir=tmpdir,
+ histogram_freq=1,
+ write_images=True,
+ write_grads=True,
+ batch_size=5)
+ cbks = [tsb]
+
+ # fit with validation generator
+ model.fit_generator(
+ generator(),
+ steps_per_epoch=2,
+ epochs=2,
+ validation_data=generator(),
+ validation_steps=2,
+ callbacks=cbks,
+ verbose=0)
+
+ with self.assertRaises(ValueError):
+ # fit with validation generator but no
+ # validation_steps
+ model.fit_generator(
+ generator(),
+ steps_per_epoch=2,
+ epochs=2,
+ validation_data=generator(),
+ callbacks=cbks,
+ verbose=0)
+
+ self.assertTrue(os.path.exists(tmpdir))
+
@unittest.skipIf(
os.name == 'nt',
'use_multiprocessing=True does not work on windows properly.')
@@ -986,6 +1096,74 @@ class KerasCallbacksTest(test.TestCase):
assert os.path.exists(temp_dir)
+ def test_Tensorboard_batch_logging(self):
+
+ class FileWriterStub(object):
+
+ def __init__(self, logdir, graph=None):
+ self.logdir = logdir
+ self.graph = graph
+ self.batches_logged = []
+ self.summary_values = []
+ self.summary_tags = []
+
+ def add_summary(self, summary, step):
+ self.summary_values.append(summary.value[0].simple_value)
+ self.summary_tags.append(summary.value[0].tag)
+ self.batches_logged.append(step)
+
+ def flush(self):
+ pass
+
+ def close(self):
+ pass
+
+ logdir = 'fake_dir'
+
+ # log every batch
+ tb_cbk = keras.callbacks.TensorBoard(logdir)
+ tb_cbk.writer = FileWriterStub(logdir)
+
+ for batch in range(5):
+ tb_cbk.on_batch_end(batch, {'acc': np.float32(batch)})
+ self.assertEqual(tb_cbk.writer.batches_logged, [0, 1, 2, 3, 4])
+ self.assertEqual(tb_cbk.writer.summary_values, [0., 1., 2., 3., 4.])
+ self.assertEqual(tb_cbk.writer.summary_tags, ['batch_acc'] * 5)
+
+ def test_Tensorboard_epoch_and_batch_logging(self):
+
+ class FileWriterStub(object):
+
+ def __init__(self, logdir, graph=None):
+ self.logdir = logdir
+ self.graph = graph
+
+ def add_summary(self, summary, step):
+ if 'batch_' in summary.value[0].tag:
+ self.batch_summary = (step, summary)
+ elif 'epoch_' in summary.value[0].tag:
+ self.epoch_summary = (step, summary)
+
+ def flush(self):
+ pass
+
+ def close(self):
+ pass
+
+ logdir = 'fake_dir'
+
+ tb_cbk = keras.callbacks.TensorBoard(logdir)
+ tb_cbk.writer = FileWriterStub(logdir)
+
+ tb_cbk.on_batch_end(0, {'acc': np.float32(5.0)})
+ tb_cbk.on_epoch_end(0, {'acc': np.float32(10.0)})
+ batch_step, batch_summary = tb_cbk.writer.batch_summary
+ self.assertEqual(batch_step, 0)
+ self.assertEqual(batch_summary.value[0].simple_value, 5.0)
+ epoch_step, epoch_summary = tb_cbk.writer.epoch_summary
+ self.assertEqual(epoch_step, 0)
+ self.assertEqual(epoch_summary.value[0].simple_value, 10.0)
+
def test_RemoteMonitorWithJsonPayload(self):
if requests is None:
self.skipTest('`requests` required to run this test')