aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py52
1 files changed, 31 insertions, 21 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index e064cfe37d..9a4cc0a897 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -40,7 +40,7 @@ from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import device_util
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
GPU_TEST = "test_gpu" in sys.argv[0]
@@ -164,7 +164,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
# This variable should be created only once across the threads because of
# special variable_creator functions used by `dist.call_for_each_tower`.
v = variable_scope.variable(1.0, name="foo")
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@@ -181,7 +181,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
v = variable_scope.variable(1.0)
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@@ -201,7 +201,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
vs = []
for i in range(5):
vs.append(variable_scope.variable(1.0, name="foo" + str(i)))
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return vs
dist = mirrored_strategy.MirroredStrategy(
@@ -223,7 +223,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return vs
dist = mirrored_strategy.MirroredStrategy(
@@ -245,7 +245,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn(device_id):
v = variable_scope.variable(1.0, name="foo_" + str(device_id))
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@@ -268,7 +268,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
layer2 = core.Dense(1)
layer2(features)
# This will pause the current thread, and execute the other thread.
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(
+ lambda _: _)
layer3 = core.Dense(1)
layer3(features)
return [(layer1.kernel, layer1.bias),
@@ -300,7 +301,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
with variable_scope.variable_scope("common"):
v1 = variable_scope.variable(1.0, name="var1")
# This will pause the current thread, and execute the other thread.
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(
+ lambda _: _)
v2 = variable_scope.variable(
1.0,
name="var2",
@@ -343,7 +345,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
with variable_scope.variable_scope("common"):
v1 = variable_scope.get_variable("var1", [1])
# This will pause the current thread, and execute the other thread.
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(
+ lambda _: _)
v2 = variable_scope.get_variable(
"var2", [1],
synchronization=variable_scope.VariableSynchronization.ON_READ,
@@ -453,7 +456,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
v = variable_scope.variable(1.0, name="foo")
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@@ -470,7 +473,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn(name):
v = variable_scope.variable(1.0, name=name)
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(lambda _: _)
return v
dist = mirrored_strategy.MirroredStrategy(
@@ -570,7 +573,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
with ops.name_scope("foo"):
a = constant_op.constant(1.0, name="a")
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(
+ lambda _: _)
b = constant_op.constant(1.0, name="b")
return a, b
@@ -591,7 +595,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
with ops.name_scope(None, "foo"):
a = constant_op.constant(1.0, name="a")
- distribute_lib.get_tower_context().merge_call(lambda _: _)
+ distribution_strategy_context.get_tower_context().merge_call(
+ lambda _: _)
b = constant_op.constant(2.0, name="b")
return a, b
@@ -619,7 +624,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
b = variable_scope.variable(1.0, name="b")
with ops.name_scope("foo"):
- c = distribute_lib.get_tower_context().merge_call(in_cross_tower)
+ c = distribution_strategy_context.get_tower_context().merge_call(
+ in_cross_tower)
return b, c
dist = mirrored_strategy.MirroredStrategy(
@@ -651,7 +657,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
def model_fn():
b = variable_scope.get_variable("b", [1])
with ops.name_scope("foo"):
- c = distribute_lib.get_tower_context().merge_call(in_cross_tower)
+ c = distribution_strategy_context.get_tower_context().merge_call(
+ in_cross_tower)
return b, c
dist = mirrored_strategy.MirroredStrategy(
@@ -833,8 +840,9 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertEquals(1.0, self.evaluate(mirrored_var))
def model_fn():
- value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
- mirrored_var.dtype)
+ value = math_ops.cast(
+ distribution_strategy_context.get_tower_context().tower_id,
+ mirrored_var.dtype)
return mirrored_var.assign(value)
self.evaluate(dist.unwrap(dist.call_for_each_tower(
@@ -898,8 +906,9 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertEquals(1.0, self.evaluate(mirrored_var))
def model_fn():
- value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
- mirrored_var.dtype)
+ value = math_ops.cast(
+ distribution_strategy_context.get_tower_context().tower_id,
+ mirrored_var.dtype)
return mirrored_var.assign_add(value)
self.evaluate(dist.unwrap(dist.call_for_each_tower(
@@ -963,8 +972,9 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertEquals(5.0, self.evaluate(mirrored_var))
def model_fn():
- value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
- mirrored_var.dtype)
+ value = math_ops.cast(
+ distribution_strategy_context.get_tower_context().tower_id,
+ mirrored_var.dtype)
return mirrored_var.assign_sub(value)
self.evaluate(dist.unwrap(dist.call_for_each_tower(