aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/network_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/eager/python/network_test.py')
-rw-r--r--tensorflow/contrib/eager/python/network_test.py108
1 files changed, 12 insertions, 96 deletions
diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py
index 1127055c05..14adbafe57 100644
--- a/tensorflow/contrib/eager/python/network_test.py
+++ b/tensorflow/contrib/eager/python/network_test.py
@@ -410,103 +410,19 @@ class NetworkTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testWrappingInVariableScope(self):
- one = constant_op.constant([[1.]])
- # Naming happens in the order of first build rather than the order of
- # construction, but for clarity they're the same here and construction is
- # annotated.
- outside_net_before = MyNetwork() # name=my_network_1
- outside_net_before(one)
- captured_scope = variable_scope.get_variable_scope()
with variable_scope.variable_scope("outside_scope"):
- net1 = MyNetwork() # name=outside_scope/my_network_1
- net1(one)
- name_conflict1 = MyNetwork(name="name_conflict") # fine, unique so far
- name_conflict2 = MyNetwork(name="name_conflict") # error on build
- with variable_scope.variable_scope("inside_scope"):
- # No issue here since the name is unique within its scope.
- name_conflict3 = MyNetwork(name="name_conflict")
- net2 = MyNetwork() # name=outside_scope/my_network_3 to avoid the
- # variable_scope my_network_2 below.
- vs_name_conflict = MyNetwork(name="vs_name_conflict") # conflict below
- with variable_scope.variable_scope("intervening_scope"):
- with variable_scope.variable_scope(captured_scope):
- with variable_scope.variable_scope("outside_scope"):
- name_conflict4 = MyNetwork(name="name_conflict") # error on build
- with variable_scope.variable_scope("my_network_2"):
- pass
- with variable_scope.variable_scope("vs_name_conflict"):
- pass
- net3 = MyNetwork() # name=outside_scope/my_network_4
- name_conflict1(one)
- with self.assertRaisesRegexp(
- ValueError, "named 'name_conflict' already exists"):
- name_conflict2(one)
- name_conflict3(one)
- net2(one)
- with self.assertRaisesRegexp(
- ValueError, "or a variable_scope was created with this name"):
- vs_name_conflict(one)
- with self.assertRaisesRegexp(
- ValueError, "named 'name_conflict' already exists"):
- name_conflict4(one)
- self.assertEqual("outside_scope/name_conflict",
- name_conflict1.name)
- self.assertStartsWith(
- expected_start="outside_scope/name_conflict/dense_1/",
- actual=name_conflict1.variables[0].name)
- self.assertEqual("outside_scope/inside_scope/name_conflict",
- name_conflict3.name)
- self.assertStartsWith(
- expected_start="outside_scope/inside_scope/name_conflict/dense_1/",
- actual=name_conflict3.variables[0].name)
- self.assertEqual("outside_scope/my_network_1", net1.name)
- self.assertStartsWith(
- expected_start="outside_scope/my_network_1/dense_1/",
- actual=net1.trainable_weights[0].name)
- self.assertEqual("outside_scope/my_network_3", net2.name)
- self.assertStartsWith(
- expected_start="outside_scope/my_network_3/dense_1/",
- actual=net2.trainable_weights[0].name)
- net3(one)
- self.assertEqual("outside_scope/my_network_4", net3.name)
- self.assertStartsWith(
- expected_start="outside_scope/my_network_4/dense_1/",
- actual=net3.trainable_weights[0].name)
- outside_net_after = MyNetwork()
- outside_net_after(one)
- self.assertEqual("my_network_1", outside_net_before.name)
- self.assertStartsWith(
- expected_start="my_network_1/dense_1/",
- actual=outside_net_before.trainable_weights[0].name)
- self.assertEqual("my_network_2", outside_net_after.name)
- self.assertStartsWith(
- expected_start="my_network_2/dense_1/",
- actual=outside_net_after.trainable_weights[0].name)
-
- @test_util.run_in_graph_and_eager_modes()
- def testVariableScopeStripping(self):
- with variable_scope.variable_scope("scope1"):
- with variable_scope.variable_scope("scope2"):
- net = MyNetwork()
- net(constant_op.constant([[2.0]]))
- self.evaluate(net.variables[0].assign([[42.]]))
- self.assertEqual(net.name, "scope1/scope2/my_network_1")
- self.assertStartsWith(
- expected_start="scope1/scope2/my_network_1/dense_1/",
- actual=net.trainable_weights[0].name)
- save_path = net.save(self.get_temp_dir())
- self.assertIn("scope1_scope2_my_network_1", save_path)
- restore_net = MyNetwork()
- # Delayed restoration
- restore_net.restore(save_path)
- restore_net(constant_op.constant([[1.0]]))
- self.assertAllEqual([[42.]],
- self.evaluate(restore_net.variables[0]))
- self.evaluate(restore_net.variables[0].assign([[-1.]]))
- # Immediate restoration
- restore_net.restore(save_path)
- self.assertAllEqual([[42.]],
- self.evaluate(restore_net.variables[0]))
+ net = MyNetwork()
+ one = constant_op.constant([[1.]])
+ with self.assertRaisesRegexp(
+ ValueError,
+ ("Creating Networks inside named variable_scopes is currently not "
+ "supported")):
+ net(one)
+ # Alternatively, we could re-name the Network to match the variable_scope:
+ # self.assertEqual("outside_scope/my_network_1", net.name)
+ # self.assertStartsWith(
+ # expected_start="outside_scope/my_network_1/dense/",
+ # actual=net.trainable_weights[0].name)
@test_util.run_in_graph_and_eager_modes()
def testLayerNamesRespected(self):