aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rate/rate_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rate/rate_test.py')
-rw-r--r--tensorflow/contrib/rate/rate_test.py97
1 files changed, 97 insertions, 0 deletions
diff --git a/tensorflow/contrib/rate/rate_test.py b/tensorflow/contrib/rate/rate_test.py
new file mode 100644
index 0000000000..08908104f4
--- /dev/null
+++ b/tensorflow/contrib/rate/rate_test.py
@@ -0,0 +1,97 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Rate."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.rate import rate
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class RateTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testBuildRate(self):
+ m = rate.Rate()
+ m.build(
+ constant_op.constant([1], dtype=dtypes.float32),
+ constant_op.constant([2], dtype=dtypes.float32))
+ old_numer = m.numer
+ m(
+ constant_op.constant([2], dtype=dtypes.float32),
+ constant_op.constant([2], dtype=dtypes.float32))
+ self.assertTrue(old_numer is m.numer)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testBasic(self):
+ with self.test_session():
+ r_ = rate.Rate()
+ a = r_(array_ops.ones([1]), denominator=array_ops.ones([1]))
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(variables.local_variables_initializer())
+ self.assertEqual([[1]], self.evaluate(a))
+ b = r_(constant_op.constant([2]), denominator=constant_op.constant([2]))
+ self.assertEqual([[1]], self.evaluate(b))
+ c = r_(constant_op.constant([4]), denominator=constant_op.constant([3]))
+ self.assertEqual([[2]], self.evaluate(c))
+ d = r_(constant_op.constant([16]), denominator=constant_op.constant([3]))
+ self.assertEqual([[0]], self.evaluate(d)) # divide by 0
+
+ def testNamesWithSpaces(self):
+ m1 = rate.Rate(name="has space")
+ m1(array_ops.ones([1]), array_ops.ones([1]))
+ self.assertEqual(m1.name, "has space")
+ self.assertEqual(m1.prev_values.name, "has_space_1/prev_values:0")
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testWhileLoop(self):
+ with self.test_session():
+ r_ = rate.Rate()
+
+ def body(value, denom, i, ret_rate):
+ i += 1
+ ret_rate = r_(value, denom)
+ with ops.control_dependencies([ret_rate]):
+ value = math_ops.add(value, 2)
+ denom = math_ops.add(denom, 1)
+ return [value, denom, i, ret_rate]
+
+ def condition(v, d, i, r):
+ del v, d, r # unused vars by condition
+ return math_ops.less(i, 100)
+
+ i = constant_op.constant(0)
+ value = constant_op.constant([1], dtype=dtypes.float64)
+ denom = constant_op.constant([1], dtype=dtypes.float64)
+ ret_rate = r_(value, denom)
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(variables.local_variables_initializer())
+ loop = control_flow_ops.while_loop(condition, body,
+ [value, denom, i, ret_rate])
+ self.assertEqual([[2]], self.evaluate(loop[3]))
+
+
+if __name__ == "__main__":
+ test.main()