diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/resource_variable_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/resource_variable_ops_test.py | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 0fb0b8895c..c739cd2c0d 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -246,6 +246,15 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[2]]) + def testUseResource(self): + v = variables.Variable(1.0, use_resource=True) + self.assertTrue(isinstance(v, resource_variable_ops.ResourceVariable)) + + def testEagerNoUseResource(self): + with context.eager_mode(): + v = variables.Variable(1.0) + self.assertTrue(isinstance(v, resource_variable_ops.ResourceVariable)) + @test_util.run_in_graph_and_eager_modes def testScatterMin(self): with ops.device("cpu:0"): @@ -852,5 +861,62 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): state_ops.scatter_update(v, [0, 1], [0, 1, 2]) +class _MixedPrecisionVariableTest(test_util.TensorFlowTestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_dense_var_to_tensor_read_dtype_same_as_var_dtype(self): + # read_dtype is same as dtype + v = resource_variable_ops.ResourceVariable(1.0, dtype=dtypes.float32) + v = resource_variable_ops._MixedPrecisionVariable(v, dtypes.float32) + if not context.executing_eagerly(): + v.initializer.run() + + # dtype is not read_dtype, return NotImplemented + self.assertEqual( + NotImplemented, v._dense_var_to_tensor(dtype=dtypes.float16)) + self.assertEqual(NotImplemented, + v._dense_var_to_tensor(dtype=dtypes.float16, as_ref=True)) + + # as_ref is False + t = v._dense_var_to_tensor(as_ref=False) + self.assertTrue(isinstance(t, ops.Tensor)) + self.assertEqual(t.dtype, dtypes.float32) + self.assertEqual(self.evaluate(t), 1.0) + + t = v._dense_var_to_tensor(dtype=dtypes.float32, as_ref=False) + self.assertTrue(isinstance(t, ops.Tensor)) + self.assertEqual(t.dtype, dtypes.float32) + self.assertEqual(self.evaluate(t), 1.0) + + # as_ref is True + self.assertEqual(NotImplemented, v._dense_var_to_tensor(as_ref=True)) + self.assertEqual(NotImplemented, + v._dense_var_to_tensor(dtype=dtypes.float32, as_ref=True)) + + @test_util.run_in_graph_and_eager_modes() + def test_dense_var_to_tensor_read_dtype_different_from_var_dtype(self): + # read_dtype is different from dtype + v = resource_variable_ops.ResourceVariable(1.0, dtype=dtypes.float32) + v = resource_variable_ops._MixedPrecisionVariable(v, dtypes.float16) + if not context.executing_eagerly(): + v.initializer.run() + + # as_ref is False + t = v._dense_var_to_tensor(as_ref=False) + self.assertTrue(isinstance(t, ops.Tensor)) + self.assertEqual(t.dtype, dtypes.float16) + self.assertEqual(self.evaluate(t), 1.0) + + t = v._dense_var_to_tensor(dtype=dtypes.float16, as_ref=False) + self.assertTrue(isinstance(t, ops.Tensor)) + self.assertEqual(t.dtype, dtypes.float16) + self.assertEqual(self.evaluate(t), 1.0) + + # as_ref is True + self.assertEqual(NotImplemented, v._dense_var_to_tensor(as_ref=True)) + self.assertEqual(NotImplemented, + v._dense_var_to_tensor(dtype=dtypes.float16, as_ref=True)) + + if __name__ == "__main__": test.main() |