aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rate/rate_test.py
blob: 3dee16388182fd2a74aa4e7b42b29f08e7829195 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Rate."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.contrib.rate import rate
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test


class RateTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes()
  def testBuildRate(self):
    m = rate.Rate()
    m.build(
        constant_op.constant([1], dtype=dtypes.float32),
        constant_op.constant([2], dtype=dtypes.float32))
    old_numer = m.numer
    m(
        constant_op.constant([2], dtype=dtypes.float32),
        constant_op.constant([2], dtype=dtypes.float32))
    self.assertTrue(old_numer is m.numer)

  @test_util.run_in_graph_and_eager_modes()
  def testBasic(self):
    with self.cached_session():
      r_ = rate.Rate()
      a = r_(array_ops.ones([1]), denominator=array_ops.ones([1]))
      self.evaluate(variables.global_variables_initializer())
      self.evaluate(variables.local_variables_initializer())
      self.assertEqual([[1]], self.evaluate(a))
      b = r_(constant_op.constant([2]), denominator=constant_op.constant([2]))
      self.assertEqual([[1]], self.evaluate(b))
      c = r_(constant_op.constant([4]), denominator=constant_op.constant([3]))
      self.assertEqual([[2]], self.evaluate(c))
      d = r_(constant_op.constant([16]), denominator=constant_op.constant([3]))
      self.assertEqual([[0]], self.evaluate(d))  # divide by 0

  def testNamesWithSpaces(self):
    m1 = rate.Rate(name="has space")
    m1(array_ops.ones([1]), array_ops.ones([1]))
    self.assertEqual(m1.name, "has space")
    self.assertEqual(m1.prev_values.name, "has_space_1/prev_values:0")

  @test_util.run_in_graph_and_eager_modes()
  def testWhileLoop(self):
    with self.cached_session():
      r_ = rate.Rate()

      def body(value, denom, i, ret_rate):
        i += 1
        ret_rate = r_(value, denom)
        with ops.control_dependencies([ret_rate]):
          value = math_ops.add(value, 2)
          denom = math_ops.add(denom, 1)
        return [value, denom, i, ret_rate]

      def condition(v, d, i, r):
        del v, d, r  # unused vars by condition
        return math_ops.less(i, 100)

      i = constant_op.constant(0)
      value = constant_op.constant([1], dtype=dtypes.float64)
      denom = constant_op.constant([1], dtype=dtypes.float64)
      ret_rate = r_(value, denom)
      self.evaluate(variables.global_variables_initializer())
      self.evaluate(variables.local_variables_initializer())
      loop = control_flow_ops.while_loop(condition, body,
                                         [value, denom, i, ret_rate])
      self.assertEqual([[2]], self.evaluate(loop[3]))


if __name__ == "__main__":
  test.main()