aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/network_test.py
blob: 14adbafe5735bd2a3d3961402e8ef3e6a7be333b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gc

from tensorflow.contrib.eager.python import network
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.layers import core
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import training_util


# pylint: disable=not-callable
class MyNetwork(network.Network):

  def __init__(self, name=None):
    super(MyNetwork, self).__init__(name=name)
    self.l1 = self.track_layer(core.Dense(1, use_bias=False))

  def call(self, x):
    return self.l1(x)


class NetworkTest(test.TestCase):

  def _save_modify_load_network_built(self, net, global_step=None):
    checkpoint_directory = self.get_temp_dir()
    checkpoint_path = net.save(
        save_path=checkpoint_directory, global_step=global_step)
    input_value = constant_op.constant([[42.0]])
    original_output = self.evaluate(net(input_value))
    for var in net.variables:
      self.evaluate(var.assign(var + 1.))
    self.assertGreater(
        self.evaluate(net(input_value)),
        original_output)
    # Either the returned explicit checkpoint path or the directory should work.
    net.restore(save_path=checkpoint_directory)
    self.assertAllEqual(
        original_output,
        self.evaluate(net(input_value)))
    for var in net.variables:
      self.evaluate(var.assign(var + 2.))
    net.restore(save_path=checkpoint_path)
    self.assertAllEqual(
        original_output,
        self.evaluate(net(input_value)))

  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
  def testTrainableAttribute(self):
    net = network.Network()
    self.assertTrue(net.trainable)
    with self.assertRaises(AttributeError):
      net.trainable = False
    self.assertTrue(net.trainable)

  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
  def testNetworkCall(self):
    net = MyNetwork(name="abcd")
    net(constant_op.constant([[2.0]]))  # Force variables to be created.
    self.assertEqual(1, len(net.trainable_variables))
    self.evaluate(net.trainable_variables[0].assign([[17.0]]))
    # TODO(josh11b): Support passing Python values to networks.
    result = net(constant_op.constant([[2.0]]))
    self.assertEqual(34.0, self.evaluate(result))

  # TODO(allenl): This test creates garbage in some Python versions
  @test_util.run_in_graph_and_eager_modes()
  def testNetworkSaveRestoreAlreadyBuilt(self):
    net = MyNetwork(name="abcd")
    with self.assertRaisesRegexp(
        ValueError, "Attempt to save the Network before it was first called"):
      net.save(self.get_temp_dir())
    net(constant_op.constant([[2.0]]))
    self.evaluate(net.trainable_variables[0].assign([[17.0]]))
    self._save_modify_load_network_built(net, global_step=None)
    self._save_modify_load_network_built(net, global_step=10)

  # TODO(allenl): This test creates garbage in some Python versions
  @test_util.run_in_graph_and_eager_modes()
  def testSaveRestoreDefaultGlobalStep(self):
    net = MyNetwork(name="abcd")
    net(constant_op.constant([[2.0]]))
    self.evaluate(net.variables[0].assign([[3.]]))
    default_global_step = training_util.get_or_create_global_step()
    self.evaluate(default_global_step.assign(4242))
    save_path = net.save(self.get_temp_dir())
    self.assertIn("abcd-4242", save_path)

  # TODO(allenl): This test creates garbage in some Python versions
  @test_util.run_in_graph_and_eager_modes()
  def testNetworkSaveAndRestoreIntoUnbuilt(self):
    save_dir = self.get_temp_dir()
    net1 = MyNetwork()
    test_input = constant_op.constant([[2.0]])
    net1(test_input)
    self.evaluate(net1.trainable_variables[0].assign([[17.0]]))
    save_path = net1.save(save_dir)
    # With a pre-build restore we should have the same value.
    net2 = MyNetwork()
    net2.restore(save_path)
    self.assertAllEqual(self.evaluate(net1(test_input)),
                        self.evaluate(net2(test_input)))
    self.assertIsNot(net1.variables[0], net2.variables[0])
    self.assertAllEqual(self.evaluate(net1.variables[0]),
                        self.evaluate(net2.variables[0]))

  @test_util.run_in_graph_and_eager_modes()
  def testLoadIntoUnbuiltSharedLayer(self):

    class Owner(network.Network):

      def __init__(self, name=None):
        super(Owner, self).__init__(name=name)
        self.first = self.track_layer(core.Dense(
            1, name="first_layer", use_bias=False))

      def call(self, x):
        return self.first(x)

    first_owner = Owner()

    class User(network.Network):

      def __init__(self, use_layer, name=None):
        super(User, self).__init__(name=name)
        self.first = self.track_layer(use_layer)
        self.second = self.track_layer(core.Dense(
            1, name="second_layer", use_bias=False))

      def call(self, x):
        return self.second(self.first(x))

    class LikeUserButNotSharing(network.Network):

      def __init__(self, name=None):
        super(LikeUserButNotSharing, self).__init__(name=name)
        self.first = self.track_layer(core.Dense(
            1, name="first_layer", use_bias=False))
        self.second = self.track_layer(core.Dense(
            1, name="second_layer", use_bias=False))

      def call(self, x):
        return self.second(self.first(x))

    checkpoint_creator = LikeUserButNotSharing(name="checkpoint_creator")
    one = constant_op.constant([[1.0]])
    checkpoint_creator(one)
    self.assertEqual(2, len(checkpoint_creator.variables))
    self.evaluate(checkpoint_creator.variables[0].assign([[5.]]))
    self.evaluate(checkpoint_creator.variables[1].assign([[6.]]))
    # Re-map the variable names so that with default restore mapping we'll
    # attempt to restore into the unbuilt Layer.
    name_mapping = {
        "checkpoint_creator/first_layer/kernel": "owner_1/first_layer/kernel",
        "checkpoint_creator/second_layer/kernel": "second_layer/kernel",
    }
    save_path = checkpoint_creator.save(
        self.get_temp_dir(),
        map_func=lambda full_name: name_mapping[full_name])
    load_into = User(use_layer=first_owner.first)
    load_into.restore(save_path)
    self.assertEqual(0, len(first_owner.variables))
    self.assertAllEqual(self.evaluate(checkpoint_creator(one)),
                        self.evaluate(load_into(one)))
    self.assertEqual(1, len(first_owner.variables))
    self.assertAllEqual([[5.]], self.evaluate(load_into.variables[0]))
    self.assertAllEqual([[6.]], self.evaluate(load_into.variables[1]))
    first_owner(one)
    self.assertAllEqual([[5.]], self.evaluate(first_owner.variables[0]))

    # Try again with a garbage collected parent.
    first_owner = Owner()
    load_into = User(use_layer=first_owner.first)
    del first_owner
    gc.collect()
    def _restore_map_func(original_name):
      if original_name.startswith("owner_1"):
        return original_name.replace("owner_1", "owner_2")
      else:
        return "user_2/" + original_name
    with self.assertRaisesRegexp(ValueError, "garbage collected"):
      load_into.restore(save_path, map_func=_restore_map_func)

  @test_util.run_in_graph_and_eager_modes()
  def testRestoreIntoSubNetwork(self):

    class Parent(network.Network):

      def __init__(self, name=None):
        super(Parent, self).__init__(name=name)
        self.first = self.track_layer(MyNetwork())
        self.second = self.track_layer(MyNetwork())

      def call(self, x):
        return self.first(self.second(x))

    one = constant_op.constant([[3.]])
    whole_model_saver = Parent()
    whole_model_saver(one)
    self.evaluate(whole_model_saver.variables[0].assign([[15.]]))
    self.evaluate(whole_model_saver.variables[1].assign([[16.]]))
    whole_model_checkpoint = whole_model_saver.save(self.get_temp_dir())

    save_from = MyNetwork()
    save_from(one)
    self.evaluate(save_from.variables[0].assign([[5.]]))
    checkpoint = save_from.save(self.get_temp_dir())
    save_into_parent = Parent()
    save_into_parent.restore(whole_model_checkpoint)
    save_into_parent.first.restore(checkpoint)
    save_into_parent.first.restore(checkpoint)  # deferred loading multiple
                                                # times is fine
    save_into_parent(one)  # deferred loading
    self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[0]))
    self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1]))

    # Try again with the opposite ordering, and we should get different results
    # (deferred restoration should happen the same way non-deferred happens,
    # with later restorations overwriting older ones).
    save_into_parent = Parent()
    save_into_parent.first.restore(checkpoint)  # deferred loading multiple
                                                # times is fine
    save_into_parent.restore(whole_model_checkpoint)
    save_into_parent(one)  # deferred loading
    # We've overwritten the sub-Network restore.
    self.assertAllEqual([[15.]], self.evaluate(save_into_parent.variables[0]))
    self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1]))

    self.evaluate(save_into_parent.variables[0].assign([[3.]]))
    self.evaluate(save_into_parent.variables[1].assign([[4.]]))
    save_into_parent.second.restore(checkpoint)
    self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[1]))
    with self.assertRaisesRegexp(errors_impl.NotFoundError,
                                 "not found in checkpoint"):
      # The checkpoint is incompatible.
      save_into_parent.restore(checkpoint)

  @test_util.run_in_graph_and_eager_modes()
  def testCustomMapCollisionErrors(self):

    class Parent(network.Network):

      def __init__(self, name=None):
        super(Parent, self).__init__(name=name)
        self.first = self.track_layer(MyNetwork())
        self.second = self.track_layer(MyNetwork())

      def call(self, x):
        return self.first(self.second(x))

    make_checkpoint = Parent()
    one = constant_op.constant([[1.]])
    make_checkpoint(one)
    self.evaluate(make_checkpoint.variables[0].assign([[2.]]))
    self.evaluate(make_checkpoint.variables[1].assign([[3.]]))
    with self.assertRaisesRegexp(
        ValueError,
        "The map_func passed to Network.save for the Network 'parent_1' "
        "resulted in two variables named 'foo'"):
      make_checkpoint.save(self.get_temp_dir(), map_func=lambda n: "foo")
    checkpoint = make_checkpoint.first.save(
        self.get_temp_dir(), map_func=lambda n: "foo")
    loader = Parent()
    loader.restore(checkpoint, map_func=lambda n: "foo")
    with self.assertRaisesRegexp(
        ValueError,
        ("The map_func passed to Network.restore for the Network"
         " 'parent_2' resulted in two variables named 'foo'")):
      loader(one)
    loader = Parent()
    loader(one)
    with self.assertRaisesRegexp(
        ValueError,
        ("The map_func passed to Network.restore for the Network"
         " 'parent_3' resulted in two variables named 'foo'")):
      loader.restore(checkpoint, map_func=lambda n: "foo")

  @test_util.run_in_graph_and_eager_modes()
  def testDefaultMapCollisionErrors(self):

    one = constant_op.constant([[1.]])
    first = core.Dense(1, name="dense_1", use_bias=False)
    first(one)

    class Parent(network.Network):

      def __init__(self, name=None):
        super(Parent, self).__init__(name=name)
        self.first = self.track_layer(first)
        self.second = self.track_layer(core.Dense(1, use_bias=False))

      def call(self, x):
        return self.first(self.second(x))

    make_checkpoint = Parent()
    one = constant_op.constant([[1.]])
    make_checkpoint(one)
    self.evaluate(make_checkpoint.variables[0].assign([[2.]]))
    self.evaluate(make_checkpoint.variables[1].assign([[3.]]))
    with self.assertRaisesRegexp(
        ValueError,
        ("The default checkpoint variable name mapping strategy for Network "
         "'parent_1' resulted in a naming conflict.")):
      make_checkpoint.save(self.get_temp_dir())

    class Compatible(network.Network):

      def __init__(self, name=None):
        super(Compatible, self).__init__(name=name)
        self.first = self.track_layer(core.Dense(1, use_bias=False))

      def call(self, x):
        return self.first(x)

    successful_checkpoint = Compatible()
    successful_checkpoint(one)
    self.evaluate(successful_checkpoint.variables[0].assign([[-1.]]))
    checkpoint_path = successful_checkpoint.save(self.get_temp_dir())
    load_checkpoint = Parent()
    load_checkpoint(one)
    with self.assertRaisesRegexp(
        ValueError,
        ("The default checkpoint variable name mapping strategy for Network "
         "'parent_2' resulted in a naming conflict.")):
      load_checkpoint.restore(checkpoint_path)

  def testNoReferenceCyclesAfterCall(self):

    class ChildNetwork(network.Network):

      def __init__(self, name=None):
        super(ChildNetwork, self).__init__(name=name)

      def call(self, x):
        return x * 2.

    class ParentNetwork(network.Network):

      def __init__(self, name=None):
        super(ParentNetwork, self).__init__(name=name)
        self.l1 = self.track_layer(ChildNetwork())

      def call(self, x):
        return self.l1(x)

    one = constant_op.constant([[1.0]])
    gc.disable()
    gc.collect()
    previous_gc_debug_flags = gc.get_debug()
    gc.set_debug(gc.DEBUG_SAVEALL)
    preexisting = len(gc.garbage)
    net = ParentNetwork()
    net(one)
    del net
    gc.collect()
    # There should be no additional garbage requiring collection.
    self.assertEqual(preexisting, len(gc.garbage))
    gc.set_debug(previous_gc_debug_flags)
    gc.enable()

  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
  def testAnonymousNoNameInitially(self):
    net = MyNetwork()
    with self.assertRaisesRegexp(ValueError, "does not yet have a final name"):
      net.name  # pylint: disable=pointless-statement

  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
  def testExplicitHasNameInitially(self):
    net = MyNetwork(name="abcd")
    self.assertEqual("abcd", net.name)

  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
  def testUsingResourceVariables(self):
    net = MyNetwork()
    net(constant_op.constant([[0.]]))
    self.assertIsInstance(net.trainable_weights[0],
                          resource_variable_ops.ResourceVariable)

  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
  def testDuplicateNameError(self):
    one = constant_op.constant([[1.]])
    net = MyNetwork(name="foo")
    net(one)
    with self.assertRaisesRegexp(
        ValueError, "named 'foo' already exists"):
      net1 = MyNetwork(name="foo")
      net1(one)

  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
  def testWrappingInVariableScope(self):
    with variable_scope.variable_scope("outside_scope"):
      net = MyNetwork()
      one = constant_op.constant([[1.]])
      with self.assertRaisesRegexp(
          ValueError,
          ("Creating Networks inside named variable_scopes is currently not "
           "supported")):
        net(one)
      # Alternatively, we could re-name the Network to match the variable_scope:
      # self.assertEqual("outside_scope/my_network_1", net.name)
      # self.assertStartsWith(
      #     expected_start="outside_scope/my_network_1/dense/",
      #     actual=net.trainable_weights[0].name)

  @test_util.run_in_graph_and_eager_modes()
  def testLayerNamesRespected(self):
    class ParentNetwork(network.Network):

      def __init__(self):
        super(ParentNetwork, self).__init__()
        self.first = self.track_layer(
            core.Dense(1, use_bias=False, name="explicit_name"))

      def call(self, x):
        return self.first(x)

    one = constant_op.constant([[1.]])
    net = ParentNetwork()
    net(one)
    self.assertStartsWith(expected_start="parent_network_1/explicit_name/",
                          actual=net.trainable_weights[0].name)
    self.assertEqual("explicit_name", net.first.name)

  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
  def testWrappingInAnonymousVariableScope(self):
    # Named outside variable_scopes are not supported at the moment. However,
    # blank-named top level variable scopes do not change variable names, and so
    # can be used to set the properties of Network variables.
    was_called = [False]
    def _custom_getter(getter, *args, **kwargs):
      was_called[0] = True
      return getter(*args, **kwargs)
    with variable_scope.variable_scope("", custom_getter=_custom_getter):
      net = MyNetwork()
      one = constant_op.constant([[1.]])
      net(one)
    self.assertTrue(was_called[0])

  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
  def testReasonableSlashError(self):
    with self.assertRaisesRegexp(
        ValueError, "not allowed in Network names"):
      MyNetwork(name="slash/slash")

  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
  def testNoVariableScopeNames(self):
    with self.assertRaisesRegexp(
        ValueError, "VariableScopes are not valid Network names"):
      with variable_scope.variable_scope("some_scope") as vs:
        MyNetwork(name=vs)

  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
  def testVariableScopeNameCollision(self):
    with variable_scope.variable_scope("abcd"):
      pass
    with self.assertRaisesRegexp(
        ValueError, "or a variable_scope was created with this name"):
      net = MyNetwork(name="abcd")
      one = constant_op.constant([[1.]])
      net(one)

  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
  def testNetworkVariablesDoNotInterfere(self):
    core.Dense(1, use_bias=True)  # Should not interfere with naming.
    net1 = MyNetwork()
    net2 = MyNetwork()
    one = constant_op.constant([[1.]])
    net1(one)
    net2(one)
    # Layer names typically are globally unique rather than being unique within
    # the scope of their first use. However, within a Network they must be named
    # locally so that previous Layer consutrciton does not interfere with
    # variable naming (e.g. add a Layer construction before the Network,
    # suddenly your previously saved checkpoint is incompatible).
    self.assertEqual("dense_1", net1.l1.name)
    self.assertEqual("dense_1", net2.l1.name)
    self.evaluate(net1.trainable_weights[0].assign([[1.]]))
    self.evaluate(net2.trainable_weights[0].assign([[2.]]))
    self.assertEqual(2., self.evaluate(net2.trainable_weights[0]))
    self.assertEqual(1., self.evaluate(net1.trainable_weights[0]))
    self.assertStartsWith(expected_start="my_network_1/dense_1/",
                          actual=net1.trainable_weights[0].name)
    self.assertStartsWith(expected_start="my_network_2/dense_1/",
                          actual=net2.trainable_weights[0].name)

  @test_util.run_in_graph_and_eager_modes()
  def testNestableAnonymous(self):

    # The case where no explicit names are specified. We make up unique names,
    # and these should match the variable names.
    class ParentNetwork(network.Network):

      def __init__(self):
        super(ParentNetwork, self).__init__()
        self.first = self.track_layer(MyNetwork())
        self.second = self.track_layer(MyNetwork())

      def call(self, x):
        return self.second(self.first(x))

    one = constant_op.constant([[1.]])
    net = ParentNetwork()
    net(one)
    self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense",
                          actual=net.trainable_weights[0].name)
    self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense",
                          actual=net.first.trainable_weights[0].name)
    self.assertStartsWith(expected_start="parent_network_1/my_network_2/dense",
                          actual=net.trainable_weights[1].name)
    self.assertStartsWith(expected_start="parent_network_1/my_network_2/dense",
                          actual=net.second.trainable_weights[0].name)
    self.assertEqual("parent_network_1", net.name)
    self.assertEqual("my_network_1", net.first.name)
    self.assertEqual("my_network_2", net.second.name)

    net2 = ParentNetwork()
    net2(one)
    self.assertStartsWith(expected_start="parent_network_2/my_network_1/dense",
                          actual=net2.trainable_weights[0].name)
    self.assertStartsWith(expected_start="parent_network_2/my_network_1/dense",
                          actual=net2.first.trainable_weights[0].name)
    self.assertStartsWith(expected_start="parent_network_2/my_network_2/dense",
                          actual=net2.trainable_weights[1].name)
    self.assertStartsWith(expected_start="parent_network_2/my_network_2/dense",
                          actual=net2.second.trainable_weights[0].name)
    self.assertEqual("parent_network_2", net2.name)
    self.assertEqual("my_network_1", net2.first.name)
    self.assertEqual("my_network_2", net2.second.name)

  @test_util.run_in_graph_and_eager_modes()
  def testNestableExplicit(self):

    # We have explicit network names and everything is globally unique.
    class ParentNetwork(network.Network):

      def __init__(self):
        super(ParentNetwork, self).__init__(name="unique_parent_name")
        self.first = self.track_layer(
            MyNetwork(name="first_unique_child_name"))
        self.second = self.track_layer(
            MyNetwork(name="second_unique_child_name"))

      def call(self, x):
        return self.second(self.first(x))

    one = constant_op.constant([[1.]])
    net = ParentNetwork()
    net(one)
    self.assertStartsWith(
        expected_start="unique_parent_name/first_unique_child_name/dense",
        actual=net.trainable_weights[0].name)
    self.assertStartsWith(
        expected_start="unique_parent_name/second_unique_child_name/dense",
        actual=net.trainable_weights[1].name)
    self.assertEqual("unique_parent_name", net.name)
    self.assertEqual("first_unique_child_name", net.first.name)
    self.assertEqual("second_unique_child_name", net.second.name)

  @test_util.run_in_graph_and_eager_modes()
  def testLayerNetworkNameInteractions(self):

    # Same base name as core.Dense; Networks and non-Network Layers with the
    # same base name should use the same numbering system.
    class Dense(network.Network):

      def __init__(self):
        super(Dense, self).__init__()
        self.first = self.track_layer(core.Dense(1, use_bias=False))

      def call(self, x):
        return self.first(x)

    class MixedLayerNetwork(network.Network):

      def __init__(self):
        super(MixedLayerNetwork, self).__init__()
        self.first = self.track_layer(core.Dense(1, use_bias=False))
        self.second = self.track_layer(core.Dense(1, use_bias=False))
        self.third = self.track_layer(Dense())
        self.fourth = self.track_layer(core.Dense(1, use_bias=False))
        self.fifth = self.track_layer(core.Dense(1, use_bias=False))

      def call(self, x):
        return self.fifth(self.fourth(self.third(self.second(self.first(x)))))

    one = constant_op.constant([[1.]])
    net = MixedLayerNetwork()
    net(one)
    self.assertEqual("dense_1", net.first.name)
    self.assertEqual("dense_2", net.second.name)
    self.assertEqual("dense_3", net.third.name)
    self.assertEqual("dense_4", net.fourth.name)
    self.assertEqual("dense_5", net.fifth.name)
    # Note that this is _not_ the default naming behavior for Layers. Layers
    # which are added to Networks follow Network variable naming conventions
    # (i.e. variable names = network name unless variable sharing). Nested
    # Layers revert to Layer behavior.
    self.assertStartsWith(expected_start="mixed_layer_network_1/dense_1/",
                          actual=net.trainable_weights[0].name)
    self.assertStartsWith(expected_start="mixed_layer_network_1/dense_2/",
                          actual=net.trainable_weights[1].name)
    self.assertStartsWith(expected_start="mixed_layer_network_1/dense_3/",
                          actual=net.trainable_weights[2].name)
    self.assertStartsWith(expected_start="mixed_layer_network_1/dense_4/",
                          actual=net.trainable_weights[3].name)
    self.assertStartsWith(expected_start="mixed_layer_network_1/dense_5/",
                          actual=net.trainable_weights[4].name)
    self.assertEqual("mixed_layer_network_1", net.name)

  @test_util.run_in_graph_and_eager_modes()
  def testNestableExplicitCollisions(self):

    # We have explicit network names and they are unique within the layer
    # they're added to.
    class ParentNetwork(network.Network):

      def __init__(self):
        super(ParentNetwork, self).__init__(name="nonunique_name")
        self.first = self.track_layer(
            MyNetwork(name="nonunique_name"))
        self.second = self.track_layer(
            MyNetwork(name="second_unique_child_name"))

      def call(self, x):
        return self.second(self.first(x))

    one = constant_op.constant([[1.]])
    net = ParentNetwork()
    net(one)
    self.assertStartsWith(
        expected_start="nonunique_name/nonunique_name/dense",
        actual=net.trainable_weights[0].name)
    self.assertStartsWith(
        expected_start="nonunique_name/second_unique_child_name/dense",
        actual=net.trainable_weights[1].name)
    self.assertEqual("nonunique_name", net.name)
    self.assertEqual("nonunique_name", net.first.name)
    self.assertEqual("second_unique_child_name", net.second.name)

  @test_util.run_in_graph_and_eager_modes()
  def testNestableExplicitWithAnonymousParent(self):

    # A parent network is instantiated multiple times with explicitly named
    # children. We shouldn't throw any name errors.
    class ParentNetwork(network.Network):

      def __init__(self):
        super(ParentNetwork, self).__init__()
        self.first = self.track_layer(
            MyNetwork(name="first_unique_child_name"))
        self.second = self.track_layer(
            MyNetwork(name="second_unique_child_name"))

      def call(self, x):
        return self.second(self.first(x))

    one = constant_op.constant([[1.]])
    net = ParentNetwork()
    net(one)
    self.assertStartsWith(
        expected_start="parent_network_1/first_unique_child_name/dense_1/",
        actual=net.trainable_weights[0].name)
    self.assertStartsWith(
        expected_start="parent_network_1/second_unique_child_name/dense_1/",
        actual=net.trainable_weights[1].name)
    self.assertEqual("parent_network_1", net.name)
    self.assertEqual("first_unique_child_name", net.first.name)
    self.assertEqual("second_unique_child_name", net.second.name)

    net2 = ParentNetwork()
    net2(one)
    self.assertStartsWith(
        expected_start="parent_network_2/first_unique_child_name/dense",
        actual=net2.trainable_weights[0].name)
    self.assertStartsWith(
        expected_start="parent_network_2/second_unique_child_name/dense",
        actual=net2.trainable_weights[1].name)
    self.assertEqual("parent_network_2", net2.name)
    self.assertEqual("first_unique_child_name", net2.first.name)
    self.assertEqual("second_unique_child_name", net2.second.name)

  @test_util.run_in_graph_and_eager_modes()
  def testNestableExplicitSameLayerCollisions(self):

    # We have explicit network names and they are _not_ unique within the layer
    # they're added to. Error.
    class ParentNetwork(network.Network):

      def __init__(self):
        super(ParentNetwork, self).__init__(name="unique_parent_name")
        self.first = self.track_layer(MyNetwork(name="nonunique_name"))
        self.second = self.track_layer(MyNetwork(name="nonunique_name"))

      def call(self, x):
        return self.second(self.first(x))

    with self.assertRaisesRegexp(ValueError, "nonunique_name"):
      ParentNetwork()

  @test_util.run_in_graph_and_eager_modes()
  def testAnonymousVariableSharing(self):

    # Two "owned" Networks
    class FirstParentNetwork(network.Network):

      def __init__(self):
        super(FirstParentNetwork, self).__init__()
        self.first = self.track_layer(MyNetwork())
        self.second = self.track_layer(MyNetwork())

      def call(self, x):
        return self.second(self.first(x))

    one = constant_op.constant([[1.]])
    net = FirstParentNetwork()
    net(one)

    # One Network shared with FirstParentNetwork, one owned Network. Same name,
    # but this is OK because only one is owned. This name collision is
    # avoidable; we could have looked at the base_name of the non-owned Network
    # and incremented our naming based on that.
    class SecondParentNetwork(network.Network):

      def __init__(self):
        super(SecondParentNetwork, self).__init__()
        self.first = self.track_layer(net.first)
        self.second = self.track_layer(MyNetwork())

      def call(self, x):
        return self.second(self.first(x))

    net2 = SecondParentNetwork()
    net2(one)

    self.assertStartsWith(
        expected_start="first_parent_network_1/my_network_1/dense_1/",
        actual=net2.trainable_weights[0].name)
    self.assertStartsWith(
        expected_start="second_parent_network_1/my_network_1/dense_1/",
        actual=net2.trainable_weights[1].name)
    self.assertEqual("second_parent_network_1", net2.name)
    self.assertTrue(net2.first is net.first)
    self.assertEqual("my_network_1", net2.first.name)
    self.assertEqual("my_network_1", net2.second.name)

    # No name collision; the owned Network is added first and has a different
    # name than the shared Network.
    class ThirdParentNetwork(network.Network):

      def __init__(self):
        super(ThirdParentNetwork, self).__init__()
        self.first = self.track_layer(MyNetwork())
        self.second = self.track_layer(net.second)

      def call(self, x):
        return self.second(self.first(x))

    net3 = ThirdParentNetwork()
    net3(one)

    self.assertStartsWith(
        expected_start="third_parent_network_1/my_network_1/dense",
        actual=net3.trainable_weights[0].name)
    self.assertStartsWith(
        expected_start="first_parent_network_1/my_network_2/dense",
        actual=net3.trainable_weights[1].name)
    self.assertEqual("third_parent_network_1", net3.name)
    self.assertTrue(net3.second is net.second)
    self.assertEqual("my_network_1", net3.first.name)
    self.assertEqual("my_network_2", net3.second.name)

    # "Unavoidable" same-name Layer. The owned name is added first (fixed), then
    # a shared Network is added with the same name.
    class FourthParentNetwork(network.Network):

      def __init__(self):
        super(FourthParentNetwork, self).__init__()
        self.first = self.track_layer(MyNetwork())
        self.second = self.track_layer(net.first)

      def call(self, x):
        return self.second(self.first(x))

    net4 = FourthParentNetwork()
    net4(one)

    self.assertStartsWith(
        expected_start="fourth_parent_network_1/my_network_1/dense_1/",
        actual=net4.trainable_weights[0].name)
    self.assertStartsWith(
        expected_start="first_parent_network_1/my_network_1/dense_1/",
        actual=net4.trainable_weights[1].name)
    self.assertEqual("fourth_parent_network_1", net4.name)
    self.assertTrue(net4.second is net.first)
    self.assertEqual("my_network_1", net4.first.name)
    self.assertEqual("my_network_1", net4.second.name)

  @test_util.run_in_graph_and_eager_modes()
  def testRecursiveLayerRenaming(self):
    core.Dense(1)  # Under default Layer naming, would change subsequent names.

    class NetworkWithLayerChildren(network.Network):

      def __init__(self):
        super(NetworkWithLayerChildren, self).__init__()
        self.first = self.track_layer(core.Dense(1, use_bias=False))
        self.second = self.track_layer(core.Dense(1, use_bias=False))

      def call(self, x):
        return self.second(self.first(x))

    class ParentNetwork(network.Network):

      def __init__(self):
        super(ParentNetwork, self).__init__()
        self.first = self.track_layer(NetworkWithLayerChildren())
        self.second = self.track_layer(NetworkWithLayerChildren())

      def call(self, x):
        return self.second(self.first(x))

    net = ParentNetwork()
    one = constant_op.constant([[1.]])
    net(one)

    self.assertStartsWith(
        expected_start=("parent_network_1/network_with_layer_children_1/"
                        "dense_1/"),
        actual=net.trainable_weights[0].name)
    self.assertStartsWith(
        expected_start=("parent_network_1/network_with_layer_children_1/"
                        "dense_2/"),
        actual=net.trainable_weights[1].name)
    self.assertStartsWith(
        expected_start=("parent_network_1/network_with_layer_children_2/"
                        "dense_1/"),
        actual=net.trainable_weights[2].name)
    self.assertStartsWith(
        expected_start=("parent_network_1/network_with_layer_children_2/"
                        "dense_2/"),
        actual=net.trainable_weights[3].name)
    self.assertEqual("parent_network_1", net.name)
    self.assertEqual("network_with_layer_children_1", net.first.name)
    self.assertEqual("network_with_layer_children_2", net.second.name)
    self.assertEqual("dense_1", net.first.first.name)
    self.assertEqual("dense_2", net.first.second.name)
    self.assertEqual("dense_1", net.second.first.name)
    self.assertEqual("dense_2", net.second.second.name)

  @test_util.run_in_graph_and_eager_modes()
  def testCallInDifferentOrderThanConstruct(self):
    shared_network = MyNetwork()

    class FirstNetwork(network.Network):

      def __init__(self):
        super(FirstNetwork, self).__init__()
        self.first = self.track_layer(shared_network)
        self.second = self.track_layer(MyNetwork())

      def call(self, x):
        return self.second(self.first(x))

    class SecondNetwork(network.Network):

      def __init__(self):
        super(SecondNetwork, self).__init__()
        self.first = self.track_layer(shared_network)
        self.second = self.track_layer(MyNetwork())

      def call(self, x):
        return self.second(self.first(x))

    net1 = FirstNetwork()
    net2 = SecondNetwork()

    one = constant_op.constant([[1.]])
    net2(one)
    net1(one)

    self.assertStartsWith(
        expected_start="first_network_1/my_network_1/dense_1/",
        actual=net1.trainable_weights[0].name)
    self.assertStartsWith(
        expected_start="first_network_1/my_network_2/dense_1/",
        actual=net1.trainable_weights[1].name)
    self.assertStartsWith(
        expected_start="first_network_1/my_network_1/dense_1/",
        actual=net2.trainable_weights[0].name)
    self.assertStartsWith(
        expected_start="second_network_1/my_network_1/dense_1/",
        actual=net2.trainable_weights[1].name)
    self.assertTrue(net1.trainable_weights[0] is net2.trainable_weights[0])
    self.assertEqual("first_network_1", net1.name)
    self.assertEqual("my_network_1", net1.first.name)
    self.assertEqual("my_network_2", net1.second.name)
    self.assertTrue(net2.first is net1.first)
    self.assertEqual("my_network_1", net2.second.name)

  @test_util.run_in_graph_and_eager_modes()
  def testLayerCallInDifferentOrderThanConstruct(self):
    # Same idea as testCallInDifferentOrderThanConstruct, but this time with a
    # non-Network Layer shared between two Networks rather than a
    # Network. Naming should follow the same rules.
    shared_layer = core.Dense(1, use_bias=False)

    class FirstNetwork(network.Network):

      def __init__(self):
        super(FirstNetwork, self).__init__()
        self.first = self.track_layer(shared_layer)
        self.second = self.track_layer(core.Dense(1, use_bias=False))

      def call(self, x):
        return self.second(self.first(x))

    class SecondNetwork(network.Network):

      def __init__(self):
        super(SecondNetwork, self).__init__()
        self.first = self.track_layer(shared_layer)
        self.second = self.track_layer(core.Dense(1, use_bias=False))

      def call(self, x):
        return self.second(self.first(x))

    net1 = FirstNetwork()
    net2 = SecondNetwork()

    one = constant_op.constant([[1.]])
    net2(one)
    net1(one)

    self.assertStartsWith(
        expected_start="first_network_1/dense_1/",
        actual=net1.trainable_weights[0].name)
    self.assertStartsWith(
        expected_start="first_network_1/dense_2/",
        actual=net1.trainable_weights[1].name)
    self.assertStartsWith(
        expected_start="first_network_1/dense_1/",
        actual=net2.trainable_weights[0].name)
    self.assertStartsWith(
        expected_start="second_network_1/dense_1/",
        actual=net2.trainable_weights[1].name)
    self.assertTrue(net1.trainable_weights[0] is net2.trainable_weights[0])
    self.assertEqual("first_network_1", net1.name)
    self.assertEqual("dense_1", net1.first.name)
    self.assertEqual("dense_2", net1.second.name)
    self.assertTrue(net2.first is net1.first)
    self.assertEqual("dense_1", net2.second.name)

  @test_util.run_in_graph_and_eager_modes()
  def testLayerAlreadyBuilt(self):
    one = constant_op.constant([[1.]])
    core.Dense(1, use_bias=False)  # pre-built layers use global naming
    one = constant_op.constant([[1.]])
    core.Dense(1, use_bias=False)(one)
    shared_layer = core.Dense(1, use_bias=False)
    shared_layer(one)

    class FirstNetwork(network.Network):

      def __init__(self):
        super(FirstNetwork, self).__init__()
        self.first = self.track_layer(shared_layer)
        self.second = self.track_layer(core.Dense(1, use_bias=False))

      def call(self, x):
        return self.second(self.first(x))

    net = FirstNetwork()
    net(one)

    self.assertStartsWith(
        expected_start="dense_1/",  # Pre-built layers have variable names which
                                    # do not match their layer names.
        actual=net.trainable_weights[0].name)
    self.assertStartsWith(
        expected_start="first_network_1/dense_1/",
        actual=net.trainable_weights[1].name)
    self.assertTrue(
        net.trainable_weights[0] is shared_layer.trainable_weights[0])
    self.assertEqual("first_network_1", net.name)
    self.assertEqual("dense_3", net.first.name)
    self.assertEqual("dense_1", net.second.name)


class SequentialTest(test.TestCase):

  @test_util.assert_no_garbage_created
  def testTwoLayers(self):
    # Create a sequential network with one layer.
    net = network.Sequential([core.Dense(1, use_bias=False)])

    # Set that layer's weights so it multiplies by 3
    l1 = net.get_layer(index=0)
    net(constant_op.constant([[2.0]]))  # Create l1's variables
    self.assertEqual(1, len(l1.trainable_variables))
    l1.trainable_variables[0].assign([[3.0]])
    self.assertEqual(21.0, net(constant_op.constant([[7.0]])).numpy())

    # Add a second layer to the network.
    l2 = core.Dense(1, use_bias=False)
    net.add(l2)

    # Set the second layer's weights so it multiplies by 11
    net(constant_op.constant([[2.0]]))  # Create l2's variables
    self.assertEqual(1, len(l2.trainable_variables))
    l2.trainable_variables[0].assign([[11.0]])
    self.assertEqual(231.0, net(constant_op.constant([[7.0]])).numpy())

  @test_util.assert_no_garbage_created
  def testFunctions(self):
    # Create a sequential network with one function.
    net = network.Sequential([nn_ops.relu])
    two = constant_op.constant(2.0)
    self.assertEqual(2.0, net(two).numpy())
    self.assertEqual(0.0, net(-two).numpy())
    # Add a second function.
    net.add(math_ops.negative)
    self.assertEqual(-2.0, net(two).numpy())

  @test_util.assert_no_garbage_created
  def testTrainingLayer(self):
    net = network.Sequential([core.Dropout(0.99999)])
    two = constant_op.constant(2.0)
    self.assertEqual(2.0, net(two).numpy())
    self.assertEqual(2.0, net(two, training=False).numpy())
    for _ in range(20):
      with_dropout = net(two, training=True).numpy()
      self.assertIn(with_dropout, [0.0, 2.0])
      if with_dropout == 0.0:
        return
    # Should only fail spuriously 1 in 10^100 runs.
    self.fail("Didn't see dropout happen after 20 tries.")

  @test_util.assert_no_garbage_created
  def testTrainingFunction(self):
    # Output depends on value of "training".
    def add_training(input_value, training=None):
      if training is None:
        return input_value
      elif training:
        return input_value + 1
      return input_value - 1

    # Passing a "training" argument to double would cause an error.
    def double(input_value):
      return 2 * input_value

    net = network.Sequential([add_training, double])
    two = constant_op.constant(2)
    self.assertEqual(4, net(two).numpy())
    self.assertEqual(2, net(two, training=False).numpy())
    self.assertEqual(6, net(two, training=True).numpy())


if __name__ == "__main__":
  test.main()