aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers/python/layers/summaries_test.py
blob: 2ec2af9d442ba0ca65abb8a710c4d270b6869f1c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for regularizers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.contrib.layers.python.layers import summaries as summaries_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test


class SummariesTest(test.TestCase):

  def test_summarize_scalar_tensor(self):
    with self.cached_session():
      scalar_var = variables.Variable(1)
      summary_op = summaries_lib.summarize_tensor(scalar_var)
      self.assertEquals(summary_op.op.type, 'ScalarSummary')

  def test_summarize_multidim_tensor(self):
    with self.cached_session():
      tensor_var = variables.Variable([1, 2, 3])
      summary_op = summaries_lib.summarize_tensor(tensor_var)
      self.assertEquals(summary_op.op.type, 'HistogramSummary')

  def test_summarize_activation(self):
    with self.cached_session():
      var = variables.Variable(1)
      op = array_ops.identity(var, name='SummaryTest')
      summary_op = summaries_lib.summarize_activation(op)

      self.assertEquals(summary_op.op.type, 'HistogramSummary')
      names = [op.op.name for op in ops.get_collection(ops.GraphKeys.SUMMARIES)]
      self.assertEquals(len(names), 1)
      self.assertIn(u'SummaryTest/activation', names)

  def test_summarize_activation_relu(self):
    with self.cached_session():
      var = variables.Variable(1)
      op = nn_ops.relu(var, name='SummaryTest')
      summary_op = summaries_lib.summarize_activation(op)

      self.assertEquals(summary_op.op.type, 'HistogramSummary')
      names = [op.op.name for op in ops.get_collection(ops.GraphKeys.SUMMARIES)]
      self.assertEquals(len(names), 2)
      self.assertIn(u'SummaryTest/zeros', names)
      self.assertIn(u'SummaryTest/activation', names)

  def test_summarize_activation_relu6(self):
    with self.cached_session():
      var = variables.Variable(1)
      op = nn_ops.relu6(var, name='SummaryTest')
      summary_op = summaries_lib.summarize_activation(op)

      self.assertEquals(summary_op.op.type, 'HistogramSummary')
      names = [op.op.name for op in ops.get_collection(ops.GraphKeys.SUMMARIES)]
      self.assertEquals(len(names), 3)
      self.assertIn(u'SummaryTest/zeros', names)
      self.assertIn(u'SummaryTest/sixes', names)
      self.assertIn(u'SummaryTest/activation', names)

  def test_summarize_collection_regex(self):
    with self.cached_session():
      var = variables.Variable(1)
      array_ops.identity(var, name='Test1')
      ops.add_to_collection('foo', array_ops.identity(var, name='Test2'))
      ops.add_to_collection('foo', array_ops.identity(var, name='Foobar'))
      ops.add_to_collection('foo', array_ops.identity(var, name='Test3'))
      summaries = summaries_lib.summarize_collection('foo', r'Test[123]')
      names = [op.op.name for op in summaries]
      self.assertEquals(len(names), 2)
      self.assertIn(u'Test2_summary', names)
      self.assertIn(u'Test3_summary', names)


if __name__ == '__main__':
  test.main()