diff options
Diffstat (limited to 'tensorflow/contrib/eager/python/network_test.py')
-rw-r--r-- | tensorflow/contrib/eager/python/network_test.py | 108 |
1 files changed, 12 insertions, 96 deletions
diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index 1127055c05..14adbafe57 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -410,103 +410,19 @@ class NetworkTest(test.TestCase): @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testWrappingInVariableScope(self): - one = constant_op.constant([[1.]]) - # Naming happens in the order of first build rather than the order of - # construction, but for clarity they're the same here and construction is - # annotated. - outside_net_before = MyNetwork() # name=my_network_1 - outside_net_before(one) - captured_scope = variable_scope.get_variable_scope() with variable_scope.variable_scope("outside_scope"): - net1 = MyNetwork() # name=outside_scope/my_network_1 - net1(one) - name_conflict1 = MyNetwork(name="name_conflict") # fine, unique so far - name_conflict2 = MyNetwork(name="name_conflict") # error on build - with variable_scope.variable_scope("inside_scope"): - # No issue here since the name is unique within its scope. - name_conflict3 = MyNetwork(name="name_conflict") - net2 = MyNetwork() # name=outside_scope/my_network_3 to avoid the - # variable_scope my_network_2 below. - vs_name_conflict = MyNetwork(name="vs_name_conflict") # conflict below - with variable_scope.variable_scope("intervening_scope"): - with variable_scope.variable_scope(captured_scope): - with variable_scope.variable_scope("outside_scope"): - name_conflict4 = MyNetwork(name="name_conflict") # error on build - with variable_scope.variable_scope("my_network_2"): - pass - with variable_scope.variable_scope("vs_name_conflict"): - pass - net3 = MyNetwork() # name=outside_scope/my_network_4 - name_conflict1(one) - with self.assertRaisesRegexp( - ValueError, "named 'name_conflict' already exists"): - name_conflict2(one) - name_conflict3(one) - net2(one) - with self.assertRaisesRegexp( - ValueError, "or a variable_scope was created with this name"): - vs_name_conflict(one) - with self.assertRaisesRegexp( - ValueError, "named 'name_conflict' already exists"): - name_conflict4(one) - self.assertEqual("outside_scope/name_conflict", - name_conflict1.name) - self.assertStartsWith( - expected_start="outside_scope/name_conflict/dense_1/", - actual=name_conflict1.variables[0].name) - self.assertEqual("outside_scope/inside_scope/name_conflict", - name_conflict3.name) - self.assertStartsWith( - expected_start="outside_scope/inside_scope/name_conflict/dense_1/", - actual=name_conflict3.variables[0].name) - self.assertEqual("outside_scope/my_network_1", net1.name) - self.assertStartsWith( - expected_start="outside_scope/my_network_1/dense_1/", - actual=net1.trainable_weights[0].name) - self.assertEqual("outside_scope/my_network_3", net2.name) - self.assertStartsWith( - expected_start="outside_scope/my_network_3/dense_1/", - actual=net2.trainable_weights[0].name) - net3(one) - self.assertEqual("outside_scope/my_network_4", net3.name) - self.assertStartsWith( - expected_start="outside_scope/my_network_4/dense_1/", - actual=net3.trainable_weights[0].name) - outside_net_after = MyNetwork() - outside_net_after(one) - self.assertEqual("my_network_1", outside_net_before.name) - self.assertStartsWith( - expected_start="my_network_1/dense_1/", - actual=outside_net_before.trainable_weights[0].name) - self.assertEqual("my_network_2", outside_net_after.name) - self.assertStartsWith( - expected_start="my_network_2/dense_1/", - actual=outside_net_after.trainable_weights[0].name) - - @test_util.run_in_graph_and_eager_modes() - def testVariableScopeStripping(self): - with variable_scope.variable_scope("scope1"): - with variable_scope.variable_scope("scope2"): - net = MyNetwork() - net(constant_op.constant([[2.0]])) - self.evaluate(net.variables[0].assign([[42.]])) - self.assertEqual(net.name, "scope1/scope2/my_network_1") - self.assertStartsWith( - expected_start="scope1/scope2/my_network_1/dense_1/", - actual=net.trainable_weights[0].name) - save_path = net.save(self.get_temp_dir()) - self.assertIn("scope1_scope2_my_network_1", save_path) - restore_net = MyNetwork() - # Delayed restoration - restore_net.restore(save_path) - restore_net(constant_op.constant([[1.0]])) - self.assertAllEqual([[42.]], - self.evaluate(restore_net.variables[0])) - self.evaluate(restore_net.variables[0].assign([[-1.]])) - # Immediate restoration - restore_net.restore(save_path) - self.assertAllEqual([[42.]], - self.evaluate(restore_net.variables[0])) + net = MyNetwork() + one = constant_op.constant([[1.]]) + with self.assertRaisesRegexp( + ValueError, + ("Creating Networks inside named variable_scopes is currently not " + "supported")): + net(one) + # Alternatively, we could re-name the Network to match the variable_scope: + # self.assertEqual("outside_scope/my_network_1", net.name) + # self.assertStartsWith( + # expected_start="outside_scope/my_network_1/dense/", + # actual=net.trainable_weights[0].name) @test_util.run_in_graph_and_eager_modes() def testLayerNamesRespected(self): |