diff options
Diffstat (limited to 'tensorflow/contrib/rate/rate_test.py')
-rw-r--r-- | tensorflow/contrib/rate/rate_test.py | 97 |
1 files changed, 97 insertions, 0 deletions
diff --git a/tensorflow/contrib/rate/rate_test.py b/tensorflow/contrib/rate/rate_test.py new file mode 100644 index 0000000000..08908104f4 --- /dev/null +++ b/tensorflow/contrib/rate/rate_test.py @@ -0,0 +1,97 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Rate.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.rate import rate +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class RateTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testBuildRate(self): + m = rate.Rate() + m.build( + constant_op.constant([1], dtype=dtypes.float32), + constant_op.constant([2], dtype=dtypes.float32)) + old_numer = m.numer + m( + constant_op.constant([2], dtype=dtypes.float32), + constant_op.constant([2], dtype=dtypes.float32)) + self.assertTrue(old_numer is m.numer) + + @test_util.run_in_graph_and_eager_modes() + def testBasic(self): + with self.test_session(): + r_ = rate.Rate() + a = r_(array_ops.ones([1]), denominator=array_ops.ones([1])) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(variables.local_variables_initializer()) + self.assertEqual([[1]], self.evaluate(a)) + b = r_(constant_op.constant([2]), denominator=constant_op.constant([2])) + self.assertEqual([[1]], self.evaluate(b)) + c = r_(constant_op.constant([4]), denominator=constant_op.constant([3])) + self.assertEqual([[2]], self.evaluate(c)) + d = r_(constant_op.constant([16]), denominator=constant_op.constant([3])) + self.assertEqual([[0]], self.evaluate(d)) # divide by 0 + + def testNamesWithSpaces(self): + m1 = rate.Rate(name="has space") + m1(array_ops.ones([1]), array_ops.ones([1])) + self.assertEqual(m1.name, "has space") + self.assertEqual(m1.prev_values.name, "has_space_1/prev_values:0") + + @test_util.run_in_graph_and_eager_modes() + def testWhileLoop(self): + with self.test_session(): + r_ = rate.Rate() + + def body(value, denom, i, ret_rate): + i += 1 + ret_rate = r_(value, denom) + with ops.control_dependencies([ret_rate]): + value = math_ops.add(value, 2) + denom = math_ops.add(denom, 1) + return [value, denom, i, ret_rate] + + def condition(v, d, i, r): + del v, d, r # unused vars by condition + return math_ops.less(i, 100) + + i = constant_op.constant(0) + value = constant_op.constant([1], dtype=dtypes.float64) + denom = constant_op.constant([1], dtype=dtypes.float64) + ret_rate = r_(value, denom) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(variables.local_variables_initializer()) + loop = control_flow_ops.while_loop(condition, body, + [value, denom, i, ret_rate]) + self.assertEqual([[2]], self.evaluate(loop[3])) + + +if __name__ == "__main__": + test.main() |