aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/values_test.py')
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py64
1 files changed, 50 insertions, 14 deletions
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index c5b246e804..91a43d4999 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -32,6 +32,7 @@ from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
@@ -79,6 +80,30 @@ class DistributedValuesTest(test.TestCase):
with self.assertRaises(AssertionError):
v = values.DistributedValues({"/device:cpu:0": 42})
+ def testIsTensorLike(self):
+ with context.graph_mode(), \
+ ops.Graph().as_default(), \
+ ops.device("/device:CPU:0"):
+ one = constant_op.constant(1)
+ two = constant_op.constant(2)
+ v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two})
+ self.assertEqual(two, v.get("/device:GPU:0"))
+ self.assertEqual(one, v.get())
+ self.assertTrue(v.is_tensor_like)
+ self.assertTrue(tensor_util.is_tensor(v))
+
+ def testIsTensorLikeWithAConstant(self):
+ with context.graph_mode(), \
+ ops.Graph().as_default(), \
+ ops.device("/device:CPU:0"):
+ one = constant_op.constant(1)
+ two = 2.0
+ v = values.DistributedValues({"/device:CPU:0": one, "/device:GPU:0": two})
+ self.assertEqual(two, v.get("/device:GPU:0"))
+ self.assertEqual(one, v.get())
+ self.assertFalse(v.is_tensor_like)
+ self.assertFalse(tensor_util.is_tensor(v))
+
class DistributedDelegateTest(test.TestCase):
@@ -158,7 +183,8 @@ def _make_mirrored():
v.append(variable_scope.get_variable(
name=n, initializer=init, use_resource=True))
index[d] = v[-1]
- mirrored = values.MirroredVariable(index, v[0])
+ mirrored = values.MirroredVariable(index, v[0],
+ variable_scope.VariableAggregation.SUM)
return v, devices, mirrored
@@ -277,7 +303,8 @@ class RegroupAndSelectDeviceTest(test.TestCase):
v = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
index = {d: v}
- mirrored = values.MirroredVariable(index, v)
+ mirrored = values.MirroredVariable(index, v,
+ variable_scope.VariableAggregation.SUM)
result = values.regroup(index)
self.assertIs(mirrored, result)
@@ -581,7 +608,8 @@ class MirroredVariableTest(test.TestCase):
v = variable_scope.get_variable(
name="v", initializer=[1.], use_resource=True)
index = {"/job:foo/device:CPU:0": v}
- mirrored = values.MirroredVariable(index, v)
+ mirrored = values.MirroredVariable(index, v,
+ variable_scope.VariableAggregation.MEAN)
self.assertEquals(v.name, mirrored.name)
self.assertEquals(v.dtype, mirrored.dtype)
@@ -716,7 +744,9 @@ class MirroredVariableTest(test.TestCase):
with ops.device("/device:GPU:0"):
v = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
- mirrored = values.MirroredVariable({"/device:GPU:0": v}, v)
+ mirrored = values.MirroredVariable({
+ "/device:GPU:0": v
+ }, v, variable_scope.VariableAggregation.MEAN)
sess.run(variables_lib.global_variables_initializer())
sess.run({"complicated": mirrored})
@@ -746,24 +776,27 @@ class TowerLocalVariableTest(test.TestCase):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
- v, tower_local = _make_tower_local("sum")
+ v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
self.assertEquals(v[0].name, tower_local.name)
self.assertEquals(v[0].dtype, tower_local.dtype)
self.assertEquals(v[0].shape, tower_local.shape)
- self.assertEquals("sum", tower_local.reduce_method)
+ self.assertEquals(variable_scope.VariableAggregation.SUM,
+ tower_local.aggregation)
@test_util.run_in_graph_and_eager_modes(config=config)
def testVariableOnAnotherDevice(self):
v = variable_scope.get_variable(
name="v", initializer=[1.], use_resource=True)
index = {"/job:foo/device:CPU:0": v}
- tower_local = values.TowerLocalVariable(index, v, "mean")
+ tower_local = values.TowerLocalVariable(
+ index, v, variable_scope.VariableAggregation.MEAN)
self.assertEquals(v.name, tower_local.name)
self.assertEquals(v.dtype, tower_local.dtype)
self.assertEquals(v.shape, tower_local.shape)
- self.assertEquals("mean", tower_local.reduce_method)
+ self.assertEquals(variable_scope.VariableAggregation.MEAN,
+ tower_local.aggregation)
def _assign_tower_local(self, devices, v, new):
for d, var, n in zip(devices, v, new):
@@ -789,7 +822,7 @@ class TowerLocalVariableTest(test.TestCase):
self.skipTest("A GPU is not available for this test in eager mode.")
with self.test_session() as sess:
- v, tower_local = _make_tower_local("sum")
+ v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [3., 4.])
@@ -812,7 +845,8 @@ class TowerLocalVariableTest(test.TestCase):
self.skipTest("A GPU is not available for this test in eager mode.")
with self.test_session() as sess:
- v, tower_local = _make_tower_local("mean")
+ v, tower_local = _make_tower_local(
+ variable_scope.VariableAggregation.MEAN)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [3., 4.])
@@ -831,7 +865,8 @@ class TowerLocalVariableTest(test.TestCase):
def _save_tower_local_mean(self):
"""Save variables with mirroring, returns save_path."""
with self.test_session(graph=ops.Graph()) as sess:
- v, tower_local = _make_tower_local("mean")
+ v, tower_local = _make_tower_local(
+ variable_scope.VariableAggregation.MEAN)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [3., 4.])
@@ -893,7 +928,8 @@ class TowerLocalVariableTest(test.TestCase):
def _restore_tower_local_mean(self, save_path):
"""Restore to variables with mirroring in a fresh graph."""
with self.test_session(graph=ops.Graph()) as sess:
- v, tower_local = _make_tower_local("mean")
+ v, tower_local = _make_tower_local(
+ variable_scope.VariableAggregation.MEAN)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [7., 8.])
@@ -907,7 +943,7 @@ class TowerLocalVariableTest(test.TestCase):
def _restore_tower_local_sum(self, save_path):
"""Restore to variables with mirroring in a fresh graph."""
with self.test_session(graph=ops.Graph()) as sess:
- v, tower_local = _make_tower_local("sum")
+ v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [7., 8.])
@@ -968,7 +1004,7 @@ class TowerLocalVariableTest(test.TestCase):
def testTensorConversion(self):
with context.graph_mode():
- _, tower_local = _make_tower_local("sum")
+ _, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
converted = ops.internal_convert_to_tensor(tower_local, as_ref=False)
self.assertIsInstance(converted, ops.Tensor)
self.assertEqual(converted.dtype, tower_local.dtype)