aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/resource_variable_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/resource_variable_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py66
1 files changed, 66 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 0fb0b8895c..c739cd2c0d 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -246,6 +246,15 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[2]])
+ def testUseResource(self):
+ v = variables.Variable(1.0, use_resource=True)
+ self.assertTrue(isinstance(v, resource_variable_ops.ResourceVariable))
+
+ def testEagerNoUseResource(self):
+ with context.eager_mode():
+ v = variables.Variable(1.0)
+ self.assertTrue(isinstance(v, resource_variable_ops.ResourceVariable))
+
@test_util.run_in_graph_and_eager_modes
def testScatterMin(self):
with ops.device("cpu:0"):
@@ -852,5 +861,62 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
state_ops.scatter_update(v, [0, 1], [0, 1, 2])
+class _MixedPrecisionVariableTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_dense_var_to_tensor_read_dtype_same_as_var_dtype(self):
+ # read_dtype is same as dtype
+ v = resource_variable_ops.ResourceVariable(1.0, dtype=dtypes.float32)
+ v = resource_variable_ops._MixedPrecisionVariable(v, dtypes.float32)
+ if not context.executing_eagerly():
+ v.initializer.run()
+
+ # dtype is not read_dtype, return NotImplemented
+ self.assertEqual(
+ NotImplemented, v._dense_var_to_tensor(dtype=dtypes.float16))
+ self.assertEqual(NotImplemented,
+ v._dense_var_to_tensor(dtype=dtypes.float16, as_ref=True))
+
+ # as_ref is False
+ t = v._dense_var_to_tensor(as_ref=False)
+ self.assertTrue(isinstance(t, ops.Tensor))
+ self.assertEqual(t.dtype, dtypes.float32)
+ self.assertEqual(self.evaluate(t), 1.0)
+
+ t = v._dense_var_to_tensor(dtype=dtypes.float32, as_ref=False)
+ self.assertTrue(isinstance(t, ops.Tensor))
+ self.assertEqual(t.dtype, dtypes.float32)
+ self.assertEqual(self.evaluate(t), 1.0)
+
+ # as_ref is True
+ self.assertEqual(NotImplemented, v._dense_var_to_tensor(as_ref=True))
+ self.assertEqual(NotImplemented,
+ v._dense_var_to_tensor(dtype=dtypes.float32, as_ref=True))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_dense_var_to_tensor_read_dtype_different_from_var_dtype(self):
+ # read_dtype is different from dtype
+ v = resource_variable_ops.ResourceVariable(1.0, dtype=dtypes.float32)
+ v = resource_variable_ops._MixedPrecisionVariable(v, dtypes.float16)
+ if not context.executing_eagerly():
+ v.initializer.run()
+
+ # as_ref is False
+ t = v._dense_var_to_tensor(as_ref=False)
+ self.assertTrue(isinstance(t, ops.Tensor))
+ self.assertEqual(t.dtype, dtypes.float16)
+ self.assertEqual(self.evaluate(t), 1.0)
+
+ t = v._dense_var_to_tensor(dtype=dtypes.float16, as_ref=False)
+ self.assertTrue(isinstance(t, ops.Tensor))
+ self.assertEqual(t.dtype, dtypes.float16)
+ self.assertEqual(self.evaluate(t), 1.0)
+
+ # as_ref is True
+ self.assertEqual(NotImplemented, v._dense_var_to_tensor(as_ref=True))
+ self.assertEqual(NotImplemented,
+ v._dense_var_to_tensor(dtype=dtypes.float16, as_ref=True))
+
+
if __name__ == "__main__":
test.main()