aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/base_test.py
blob: 509ad5a7afb28143b821d133d1fb4a86001e6e9f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.layers.base."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy

import numpy as np

from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.layers import base as base_layers
from tensorflow.python.layers import core as core_layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test


class BaseLayerTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes()
  def testLayerProperties(self):
    layer = base_layers.Layer(name='my_layer')
    self.assertEqual(layer.variables, [])
    self.assertEqual(layer.trainable_variables, [])
    self.assertEqual(layer.non_trainable_variables, [])
    if context.in_graph_mode():
      # updates, losses only supported in GRAPH mode
      self.assertEqual(layer.updates, [])
      self.assertEqual(layer.losses, [])
    self.assertEqual(layer.built, False)
    layer = base_layers.Layer(name='my_layer', trainable=False)
    self.assertEqual(layer.trainable, False)

  @test_util.run_in_graph_and_eager_modes()
  def testAddWeight(self):
    layer = base_layers.Layer(name='my_layer')

    # Test basic variable creation.
    variable = layer.add_variable(
        'my_var', [2, 2], initializer=init_ops.zeros_initializer())
    self.assertEqual(variable.name, 'my_layer/my_var:0')
    self.assertEqual(layer.variables, [variable])
    self.assertEqual(layer.trainable_variables, [variable])
    self.assertEqual(layer.non_trainable_variables, [])
    if context.in_graph_mode():
      self.assertEqual(
          layer.variables,
          ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))

    # Test non-trainable variable creation.
    # layer.add_variable should work even outside `build` and `call`.
    variable_2 = layer.add_variable(
        'non_trainable_var', [2, 2],
        initializer=init_ops.zeros_initializer(),
        trainable=False)
    self.assertEqual(layer.variables, [variable, variable_2])
    self.assertEqual(layer.trainable_variables, [variable])
    self.assertEqual(layer.non_trainable_variables, [variable_2])
    if context.in_graph_mode():
      self.assertEqual(
          len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)

      # regularizers only supported in GRAPH mode.
      regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
      variable = layer.add_variable(
          'reg_var', [2, 2],
          initializer=init_ops.zeros_initializer(),
          regularizer=regularizer)
      self.assertEqual(len(layer.losses), 1)

  def testGetVariable(self):
    with self.test_session():

      class MyLayer(base_layers.Layer):

        def build(self, input_shape):
          self.my_var = self.add_variable(
              'my_var', [2, 2], initializer=init_ops.zeros_initializer())

        def call(self, inputs):
          return inputs * 2

      layer = MyLayer(name='my_layer')
      inputs = random_ops.random_uniform((5,), seed=1)
      layer.apply(inputs)
      layer.apply(inputs)
      self.assertEqual([v.name for v in layer.variables],
                       ['my_layer/my_var:0'])

      # Creating a layer with no scope leads to lazy construction of
      # the scope at apply() time.  It uses scope "<current scope>/base_name"
      lazy_layer = MyLayer(_reuse=True)
      with variable_scope.variable_scope('new_scope'):
        with variable_scope.variable_scope('my_layer'):
          variable_scope.get_variable('my_var', [2, 2])

        # Smoke test: it runs.
        lazy_layer.apply(inputs)
        # The variables were created outside of the Layer, and
        # reuse=True, so the Layer does not own them and they are not
        # stored in its collection.
        self.assertEqual(lazy_layer.variables, [])
        self.assertEqual(lazy_layer._scope.name, 'new_scope/my_layer')

      # Creating a layer with no scope leads to lazy construction of
      # the scope at apply() time. If 'scope' argument is passed to
      # apply(), it uses that scope when accessing variables.
      lazy_layer = MyLayer(_reuse=True)
      with variable_scope.variable_scope('new_scope') as new_scope:
        variable_scope.get_variable('my_var', [2, 2])

        # Smoke test: it runs.
        lazy_layer.apply(inputs, scope=new_scope)
        # The variables were created outside of the Layer, and
        # reuse=True, so the Layer does not own them and they are not
        # stored in its collection.
        self.assertEqual(lazy_layer.variables, [])
        self.assertEqual(lazy_layer._scope.name, 'new_scope')

      # Checking for graph equality is only done in GRAPH mode.
      with ops.Graph().as_default():
        inputs_ng = random_ops.random_uniform((5,), seed=1)
        with self.assertRaisesRegexp(ValueError, r'graph are not the same'):
          layer.apply(inputs_ng)

  @test_util.run_in_graph_and_eager_modes()
  def testCall(self):

    class MyLayer(base_layers.Layer):

      def call(self, inputs):
        return math_ops.square(inputs)

    layer = MyLayer(name='my_layer')
    inputs = random_ops.random_uniform((5,), seed=1)
    outputs = layer.apply(inputs)
    self.assertEqual(layer.built, True)
    if context.in_graph_mode():
      # op is only supported in GRAPH mode
      self.assertEqual(outputs.op.name, 'my_layer/Square')

  def testFirstCallCanCreateVariablesButSecondCanNotWhenBuildEmpty(self):
    # Note that this test is only run in Graph mode since with EAGER mode we can
    # still create a new variable on second call.

    class MyLayer(base_layers.Layer):

      def build(self, _):
        # Do not mark the layer as built.
        pass

      def call(self, inputs):
        self.my_var = self.add_variable('my_var', [2, 2])
        if self.built:
          # Skip creating on the first call; try to create after it's
          # built.  This is expected to fail.
          self.add_variable('this_will_break_on_second_call', [2, 2])
        return inputs + math_ops.square(self.my_var)

    layer = MyLayer(name='my_layer')
    inputs = random_ops.random_uniform((2,), seed=1)
    outputs = layer.apply(inputs)
    self.assertEqual(layer.built, True)
    self.assertEqual(outputs.op.name, 'my_layer/add')
    self.assertEqual([v.name
                      for v in layer.variables], ['my_layer/my_var:0'])
    with self.assertRaisesRegexp(ValueError,
                                 'my_layer/this_will_break_on_second_call'):
      layer.apply(inputs)
    # The list of variables hasn't changed.
    self.assertEqual([v.name
                      for v in layer.variables], ['my_layer/my_var:0'])

  @test_util.run_in_graph_and_eager_modes()
  def testDeepCopy(self):

    class MyLayer(base_layers.Layer):

      def call(self, inputs):
        return math_ops.square(inputs)

    layer = MyLayer(name='my_layer')
    layer._private_tensor = random_ops.random_uniform(())
    inputs = random_ops.random_uniform((5,), seed=1)
    outputs = layer.apply(inputs)
    self.assertEqual(layer.built, True)
    if context.in_graph_mode():
      # op only supported in GRAPH mode.
      self.assertEqual(outputs.op.name, 'my_layer/Square')

    layer_copy = copy.deepcopy(layer)
    self.assertEqual(layer_copy.name, layer.name)
    self.assertEqual(layer_copy._scope.name, layer._scope.name)
    self.assertEqual(layer_copy._graph, layer._graph)
    self.assertEqual(layer_copy._private_tensor, layer._private_tensor)

  @test_util.run_in_graph_and_eager_modes()
  def testScopeNaming(self):

    class PrivateLayer(base_layers.Layer):

      def call(self, inputs):
        return inputs

    inputs = random_ops.random_uniform((5,))
    default_layer = PrivateLayer()
    _ = default_layer.apply(inputs)
    self.assertEqual(default_layer._scope.name, 'private_layer')
    default_layer1 = PrivateLayer()
    default_layer1.apply(inputs)
    self.assertEqual(default_layer1._scope.name, 'private_layer_1')
    my_layer = PrivateLayer(name='my_layer')
    my_layer.apply(inputs)
    self.assertEqual(my_layer._scope.name, 'my_layer')
    my_layer1 = PrivateLayer(name='my_layer')
    my_layer1.apply(inputs)
    self.assertEqual(my_layer1._scope.name, 'my_layer_1')
    my_layer2 = PrivateLayer(name='my_layer')
    my_layer2.apply(inputs)
    self.assertEqual(my_layer2._scope.name, 'my_layer_2')
    # Name scope shouldn't affect names.
    with ops.name_scope('some_name_scope'):
      default_layer2 = PrivateLayer()
      default_layer2.apply(inputs)
      self.assertEqual(default_layer2._scope.name, 'private_layer_2')
      my_layer3 = PrivateLayer(name='my_layer')
      my_layer3.apply(inputs)
      self.assertEqual(my_layer3._scope.name, 'my_layer_3')
      other_layer = PrivateLayer(name='other_layer')
      other_layer.apply(inputs)
      self.assertEqual(other_layer._scope.name, 'other_layer')
    # Variable scope gets added to scope names.
    with variable_scope.variable_scope('var_scope'):
      default_layer_scoped = PrivateLayer()
      default_layer_scoped.apply(inputs)
      self.assertEqual(default_layer_scoped._scope.name,
                       'var_scope/private_layer')
      my_layer_scoped = PrivateLayer(name='my_layer')
      my_layer_scoped.apply(inputs)
      self.assertEqual(my_layer_scoped._scope.name, 'var_scope/my_layer')
      my_layer_scoped1 = PrivateLayer(name='my_layer')
      my_layer_scoped1.apply(inputs)
      self.assertEqual(my_layer_scoped1._scope.name, 'var_scope/my_layer_1')

  @test_util.run_in_graph_and_eager_modes()
  def testInputSpecNdimCheck(self):

    class CustomerLayer(base_layers.Layer):

      def __init__(self):
        super(CustomerLayer, self).__init__()
        self.input_spec = base_layers.InputSpec(ndim=2)

      def call(self, inputs):
        return inputs

    if context.in_graph_mode():
      layer = CustomerLayer()
      with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
        layer.apply(array_ops.placeholder('int32'))

    layer = CustomerLayer()
    with self.assertRaisesRegexp(ValueError, r'expected ndim=2'):
      layer.apply(constant_op.constant([1]))

    # Note that we re-create the layer since in Eager mode, input spec checks
    # only happen on first call.
    # Works
    layer = CustomerLayer()
    layer.apply(constant_op.constant([[1], [2]]))

  @test_util.run_in_graph_and_eager_modes()
  def testInputSpecMinNdimCheck(self):

    class CustomerLayer(base_layers.Layer):

      def __init__(self):
        super(CustomerLayer, self).__init__()
        self.input_spec = base_layers.InputSpec(min_ndim=2)

      def call(self, inputs):
        return inputs

    if context.in_graph_mode():
      layer = CustomerLayer()
      with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
        layer.apply(array_ops.placeholder('int32'))

    layer = CustomerLayer()
    with self.assertRaisesRegexp(ValueError, r'expected min_ndim=2'):
      layer.apply(constant_op.constant([1]))

    # Works
    layer = CustomerLayer()
    layer.apply(constant_op.constant([[1], [2]]))

    layer = CustomerLayer()
    layer.apply(constant_op.constant([[[1], [2]]]))

  @test_util.run_in_graph_and_eager_modes()
  def testInputSpecMaxNdimCheck(self):

    class CustomerLayer(base_layers.Layer):

      def __init__(self):
        super(CustomerLayer, self).__init__()
        self.input_spec = base_layers.InputSpec(max_ndim=2)

      def call(self, inputs):
        return inputs

    if context.in_graph_mode():
      layer = CustomerLayer()
      with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
        layer.apply(array_ops.placeholder('int32'))

    layer = CustomerLayer()
    with self.assertRaisesRegexp(ValueError, r'expected max_ndim=2'):
      layer.apply(constant_op.constant([[[1], [2]]]))

    # Works
    layer = CustomerLayer()
    layer.apply(constant_op.constant([1]))

    layer = CustomerLayer()
    layer.apply(constant_op.constant([[1], [2]]))

  @test_util.run_in_graph_and_eager_modes()
  def testInputSpecDtypeCheck(self):

    class CustomerLayer(base_layers.Layer):

      def __init__(self):
        super(CustomerLayer, self).__init__()
        self.input_spec = base_layers.InputSpec(dtype='float32')

      def call(self, inputs):
        return inputs

    layer = CustomerLayer()
    with self.assertRaisesRegexp(ValueError, r'expected dtype=float32'):
      layer.apply(constant_op.constant(1, dtype=dtypes.int32))

    # Works
    layer = CustomerLayer()
    layer.apply(constant_op.constant(1.0, dtype=dtypes.float32))

  @test_util.run_in_graph_and_eager_modes()
  def testInputSpecAxesCheck(self):

    class CustomerLayer(base_layers.Layer):

      def __init__(self):
        super(CustomerLayer, self).__init__()
        self.input_spec = base_layers.InputSpec(axes={-1: 2})

      def call(self, inputs):
        return inputs

    layer = CustomerLayer()
    with self.assertRaisesRegexp(ValueError, r'expected axis'):
      layer.apply(constant_op.constant([1, 2, 3]))

    # Works
    layer = CustomerLayer()
    layer.apply(constant_op.constant([1, 2]))
    layer = CustomerLayer()
    layer.apply(constant_op.constant([[1, 2], [3, 4], [5, 6]]))

  @test_util.run_in_graph_and_eager_modes()
  def testInputSpecShapeCheck(self):

    class CustomerLayer(base_layers.Layer):

      def __init__(self):
        super(CustomerLayer, self).__init__()
        self.input_spec = base_layers.InputSpec(shape=(None, 3))

      def call(self, inputs):
        return inputs

    layer = CustomerLayer()
    with self.assertRaisesRegexp(ValueError, r'expected shape'):
      layer.apply(constant_op.constant([[1, 2]]))

    # Works
    layer = CustomerLayer()
    layer.apply(constant_op.constant([[1, 2, 3], [4, 5, 6]]))

  @test_util.run_in_graph_and_eager_modes()
  def testNoInputSpec(self):

    class CustomerLayer(base_layers.Layer):

      def __init__(self):
        super(CustomerLayer, self).__init__()
        self.input_spec = None

      def call(self, inputs):
        return inputs

    layer = CustomerLayer()

    layer.apply(constant_op.constant(1))

    # Works
    if context.in_graph_mode():
      layer.apply(array_ops.placeholder('int32'))
      layer.apply(array_ops.placeholder('int32', shape=(2, 3)))

  def test_get_updates_for(self):
    a = base_layers.Input(shape=(2,))
    dense_layer = core_layers.Dense(1)
    dense_layer.add_update(0, inputs=a)
    dense_layer.add_update(1, inputs=None)

    self.assertEqual(dense_layer.get_updates_for(a), [0])
    self.assertEqual(dense_layer.get_updates_for(None), [1])

  def test_get_losses_for(self):
    a = base_layers.Input(shape=(2,))
    dense_layer = core_layers.Dense(1)
    dense_layer.add_loss(0, inputs=a)
    dense_layer.add_loss(1, inputs=None)

    self.assertEqual(dense_layer.get_losses_for(a), [0])
    self.assertEqual(dense_layer.get_losses_for(None), [1])

  def testTopologicalAttributes(self):
    # test layer attributes / methods related to cross-layer connectivity.
    a = base_layers.Input(shape=(32,), name='input_a')
    b = base_layers.Input(shape=(32,), name='input_b')

    # test input, output, input_shape, output_shape
    test_layer = core_layers.Dense(16, name='test_layer')
    a_test = test_layer(a)
    self.assertEqual(test_layer.input, a)
    self.assertEqual(test_layer.output, a_test)
    self.assertEqual(test_layer.input_shape, (None, 32))
    self.assertEqual(test_layer.output_shape, (None, 16))

    # test `get_*_at` methods
    dense = core_layers.Dense(16, name='dense_1')
    a_2 = dense(a)
    b_2 = dense(b)

    self.assertEqual(dense.get_input_at(0), a)
    self.assertEqual(dense.get_input_at(1), b)
    self.assertEqual(dense.get_output_at(0), a_2)
    self.assertEqual(dense.get_output_at(1), b_2)
    self.assertEqual(dense.get_input_shape_at(0), (None, 32))
    self.assertEqual(dense.get_input_shape_at(1), (None, 32))
    self.assertEqual(dense.get_output_shape_at(0), (None, 16))
    self.assertEqual(dense.get_output_shape_at(1), (None, 16))

    # Test invalid value for attribute retrieval.
    with self.assertRaises(ValueError):
      dense.get_input_at(2)
    with self.assertRaises(AttributeError):
      new_dense = core_layers.Dense(16)
      _ = new_dense.input
    with self.assertRaises(AttributeError):
      new_dense = core_layers.Dense(16)
      _ = new_dense.output
    with self.assertRaises(AttributeError):
      new_dense = core_layers.Dense(16)
      _ = new_dense.output_shape
    with self.assertRaises(AttributeError):
      new_dense = core_layers.Dense(16)
      _ = new_dense.input_shape
    with self.assertRaises(AttributeError):
      new_dense = core_layers.Dense(16)
      a = base_layers.Input(shape=(3, 32))
      a = base_layers.Input(shape=(5, 32))
      a_2 = dense(a)
      b_2 = dense(b)
      _ = new_dense.input_shape
    with self.assertRaises(AttributeError):
      new_dense = core_layers.Dense(16)
      a = base_layers.Input(shape=(3, 32))
      a = base_layers.Input(shape=(5, 32))
      a_2 = dense(a)
      b_2 = dense(b)
      _ = new_dense.output_shape

  def testTopologicalAttributesMultiOutputLayer(self):

    class PowersLayer(base_layers.Layer):

      def call(self, inputs):
        return [inputs**2, inputs**3]

    x = base_layers.Input(shape=(32,))
    test_layer = PowersLayer()
    p1, p2 = test_layer(x)  # pylint: disable=not-callable

    self.assertEqual(test_layer.input, x)
    self.assertEqual(test_layer.output, [p1, p2])
    self.assertEqual(test_layer.input_shape, (None, 32))
    self.assertEqual(test_layer.output_shape, [(None, 32), (None, 32)])

  def testTopologicalAttributesMultiInputLayer(self):

    class AddLayer(base_layers.Layer):

      def call(self, inputs):
        assert len(inputs) == 2
        return inputs[0] + inputs[1]

    a = base_layers.Input(shape=(32,))
    b = base_layers.Input(shape=(32,))
    test_layer = AddLayer()
    y = test_layer([a, b])  # pylint: disable=not-callable

    self.assertEqual(test_layer.input, [a, b])
    self.assertEqual(test_layer.output, y)
    self.assertEqual(test_layer.input_shape, [(None, 32), (None, 32)])
    self.assertEqual(test_layer.output_shape, (None, 32))

  @test_util.run_in_graph_and_eager_modes()
  def test_count_params(self):
    dense = core_layers.Dense(16)
    dense.build((None, 4))
    self.assertEqual(dense.count_params(), 16 * 4 + 16)

    dense = core_layers.Dense(16)
    with self.assertRaises(ValueError):
      dense.count_params()

  @test_util.run_in_graph_and_eager_modes()
  def testDictInputOutput(self):

    class DictLayer(base_layers.Layer):

      def call(self, inputs):
        return {'l' + key: inputs[key] for key in inputs}

    layer = DictLayer()
    if context.in_graph_mode():
      i1 = array_ops.placeholder('int32')
      i2 = array_ops.placeholder('float32')
      result = layer.apply({'abel': i1, 'ogits': i2})
      self.assertTrue(isinstance(result, dict))
      self.assertEqual(set(['label', 'logits']), set(result.keys()))
    else:
      i1 = constant_op.constant(3)
      i2 = constant_op.constant(4.0)
      result = layer.apply({'abel': i1, 'ogits': i2})
      self.assertTrue(isinstance(result, dict))
      self.assertEqual(set(['label', 'logits']), set(result.keys()))
      self.assertEqual(3, result['label'].numpy())
      self.assertEqual(4.0, result['logits'].numpy())

  def testActivityRegularizer(self):
    regularizer = math_ops.reduce_sum
    layer = base_layers.Layer(activity_regularizer=regularizer)
    x = array_ops.placeholder('int32')
    layer.apply(x)
    self.assertEqual(len(layer.get_losses_for(x)), 1)


class NetworkTest(test.TestCase):

  def testBasicNetwork(self):
    # minimum viable network
    x = base_layers.Input(shape=(32,))
    dense = core_layers.Dense(2)
    y = dense(x)
    network = base_layers.Network(x, y, name='dense_network')

    # test basic attributes
    self.assertEqual(network.name, 'dense_network')
    self.assertEqual(len(network.layers), 2)  # InputLayer + Dense
    self.assertEqual(network.layers[1], dense)
    self.assertEqual(network.weights, dense.weights)
    self.assertEqual(network.trainable_weights, dense.trainable_weights)
    self.assertEqual(network.non_trainable_weights, dense.non_trainable_weights)

    # test callability on Input
    x_2 = base_layers.Input(shape=(32,))
    y_2 = network(x_2)
    self.assertEqual(y_2.get_shape().as_list(), [None, 2])

    # test callability on regular tensor
    x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32))
    y_2 = network(x_2)
    self.assertEqual(y_2.get_shape().as_list(), [None, 2])

    # test network `trainable` attribute
    network.trainable = False
    self.assertEqual(network.weights, dense.weights)
    self.assertEqual(network.trainable_weights, [])
    self.assertEqual(network.non_trainable_weights,
                     dense.trainable_weights + dense.non_trainable_weights)

  def test_node_construction(self):
    # test graph topology construction basics
    a = base_layers.Input(shape=(32,), name='input_a')
    b = base_layers.Input(shape=(32,), name='input_b')

    self.assertEqual(a.get_shape().as_list(), [None, 32])
    a_layer, a_node_index, a_tensor_index = a._keras_history
    b_layer, _, _ = b._keras_history
    self.assertEqual(len(a_layer._inbound_nodes), 1)
    self.assertEqual(a_tensor_index, 0)
    node = a_layer._inbound_nodes[a_node_index]
    self.assertEqual(node.outbound_layer, a_layer)

    self.assertEqual(node.inbound_layers, [])
    self.assertEqual(node.input_tensors, [a])
    self.assertEqual(node.input_shapes, [(None, 32)])
    self.assertEqual(node.output_tensors, [a])
    self.assertEqual(node.output_shapes, [(None, 32)])

    dense = core_layers.Dense(16, name='dense_1')
    dense(a)
    dense(b)

    self.assertEqual(len(dense._inbound_nodes), 2)
    self.assertEqual(len(dense._outbound_nodes), 0)
    self.assertEqual(dense._inbound_nodes[0].inbound_layers, [a_layer])
    self.assertEqual(dense._inbound_nodes[0].outbound_layer, dense)
    self.assertEqual(dense._inbound_nodes[1].inbound_layers, [b_layer])
    self.assertEqual(dense._inbound_nodes[1].outbound_layer, dense)
    self.assertEqual(dense._inbound_nodes[0].input_tensors, [a])
    self.assertEqual(dense._inbound_nodes[1].input_tensors, [b])

    # Test config
    config_0 = dense._inbound_nodes[0].get_config()
    self.assertEqual(config_0['outbound_layer'], dense.name)

  def testMultiInputNetwork(self):
    a = base_layers.Input(shape=(32,), name='input_a')
    b = base_layers.Input(shape=(32,), name='input_b')

    class AddLayer(base_layers.Layer):

      def call(self, inputs):
        assert len(inputs) == 2
        return inputs[0] + inputs[1]

    c = AddLayer()([a, b])  # pylint: disable=not-callable
    network = base_layers.Network([a, b], c)
    self.assertEqual(len(network.layers), 3)  # 2 * InputLayer + AddLayer

    # Test callability.
    a2 = base_layers.Input(shape=(32,))
    b2 = base_layers.Input(shape=(32,))
    c2 = network([a2, b2])
    self.assertEqual(c2.get_shape().as_list(), [None, 32])

  def testMultiOutputNetwork(self):
    x = base_layers.Input(shape=(32,))
    y1 = core_layers.Dense(2)(x)
    y2 = core_layers.Dense(3)(x)
    network = base_layers.Network(x, [y1, y2])

    self.assertEqual(len(network.layers), 3)  # InputLayer + 2 * Dense

    # Test callability.
    x2 = base_layers.Input(shape=(32,))
    outputs = network(x2)

    self.assertEqual(type(outputs), list)
    self.assertEqual(len(outputs), 2)
    self.assertEqual(outputs[0].get_shape().as_list(), [None, 2])
    self.assertEqual(outputs[1].get_shape().as_list(), [None, 3])

  def testMultiInputMultiOutputNetworkSharedLayer(self):
    a = base_layers.Input(shape=(32,), name='input_a')
    b = base_layers.Input(shape=(32,), name='input_b')

    dense = core_layers.Dense(2)

    y1 = dense(a)
    y2 = dense(b)
    network = base_layers.Network([a, b], [y1, y2])
    self.assertEqual(len(network.layers), 3)  # 2 * InputLayer + Dense

    # Test callability.
    a2 = base_layers.Input(shape=(32,))
    b2 = base_layers.Input(shape=(32,))
    outputs = network([a2, b2])

    self.assertEqual(type(outputs), list)
    self.assertEqual(len(outputs), 2)
    self.assertEqual(outputs[0].get_shape().as_list(), [None, 2])
    self.assertEqual(outputs[1].get_shape().as_list(), [None, 2])

  def testCrossDataFlows(self):
    # Test the ability to have multi-output layers with outputs that get routed
    # to separate layers

    class PowersLayer(base_layers.Layer):

      def call(self, inputs):
        return [inputs**2, inputs**3]

    x = base_layers.Input(shape=(32,))
    p1, p2 = PowersLayer()(x)  # pylint: disable=not-callable
    y1 = core_layers.Dense(2)(p1)
    y2 = core_layers.Dense(3)(p2)
    network = base_layers.Network(x, [y1, y2])

    self.assertEqual(len(network.layers), 4)  # InputLayer + 2 * Dense + PLayer

    # Test callability.
    x2 = base_layers.Input(shape=(32,))
    outputs = network(x2)

    self.assertEqual(type(outputs), list)
    self.assertEqual(len(outputs), 2)
    self.assertEqual(outputs[0].get_shape().as_list(), [None, 2])
    self.assertEqual(outputs[1].get_shape().as_list(), [None, 3])

  def testNetworkAttributes(self):
    x = base_layers.Input(shape=(32,))
    z = core_layers.Dense(2, kernel_regularizer=lambda x: 0.01 * (x**2))(x)
    dense = core_layers.Dense(2, name='dense')
    dense.add_update(1)
    y = dense(z)
    net = base_layers.Network(x, y)

    # losses
    self.assertEqual(len(net.losses), 1)

    # updates
    self.assertEqual(len(net.updates), 1)

    # get_layer
    self.assertEqual(net.get_layer('dense'), dense)
    self.assertEqual(net.get_layer(index=2), dense)
    with self.assertRaises(ValueError):
      net.get_layer('dense_unknown')
    with self.assertRaises(ValueError):
      net.get_layer()
    with self.assertRaises(ValueError):
      net.get_layer(index=4)

    # input, output
    self.assertEqual(net.input, x)
    self.assertEqual(net.output, y)

    # input_shape, output_shape
    self.assertEqual(net.input_shape, (None, 32))
    self.assertEqual(net.output_shape, (None, 2))

    # get_*_at
    self.assertEqual(net.get_input_at(0), x)
    self.assertEqual(net.get_output_at(0), y)

    # _compute_output_shape
    self.assertEqual(net._compute_output_shape((3, 32)).as_list(), [3, 2])

  def testInvalidNetworks(self):
    # redundant inputs
    x = base_layers.Input(shape=(32,))
    y = core_layers.Dense(2)(x)
    with self.assertRaises(ValueError):
      base_layers.Network([x, x], y)

    # inputs that don't come from Input
    x = array_ops.placeholder(dtype='float32', shape=(None, 32))
    y = core_layers.Dense(2)(x)
    with self.assertRaises(ValueError):
      base_layers.Network(x, y)

    # inputs that don't come from Input but have a layer history
    x = base_layers.Input(shape=(32,))
    x = core_layers.Dense(32)(x)
    y = core_layers.Dense(2)(x)
    with self.assertRaises(ValueError):
      base_layers.Network(x, y)

    # outputs that don't come from layers
    x = base_layers.Input(shape=(32,))
    y = core_layers.Dense(2)(x)
    y = 2 * y
    with self.assertRaises(ValueError):
      base_layers.Network(x, y)

    # disconnected graphs
    x1 = base_layers.Input(shape=(32,))
    x2 = base_layers.Input(shape=(32,))
    y = core_layers.Dense(2)(x1)
    with self.assertRaises(ValueError):
      base_layers.Network(x2, y)

    # redundant layer names
    x = base_layers.Input(shape=(32,))
    z = core_layers.Dense(2, name='dense')(x)
    y = core_layers.Dense(2, name='dense')(z)
    with self.assertRaises(ValueError):
      base_layers.Network(x, y)

  def testInputTensorWrapping(self):
    x = array_ops.placeholder(dtype='float32', shape=(None, 32))
    x = base_layers.Input(tensor=x)
    y = core_layers.Dense(2)(x)
    base_layers.Network(x, y)

  def testExplicitBatchSize(self):
    x = base_layers.Input(shape=(32,), batch_size=3)
    y = core_layers.Dense(2)(x)
    self.assertEqual(y.get_shape().as_list(), [3, 2])

  def testNetworkRecursion(self):
    # test the ability of networks to be used as layers inside networks.
    a = base_layers.Input(shape=(32,))
    b = core_layers.Dense(2)(a)
    net = base_layers.Network(a, b)

    c = base_layers.Input(shape=(32,))
    d = net(c)

    recursive_net = base_layers.Network(c, d)
    self.assertEqual(len(recursive_net.layers), 2)
    self.assertEqual(recursive_net.layers[1], net)
    self.assertEqual(len(recursive_net.weights), 2)

    # test callability
    x = array_ops.placeholder(dtype='float32', shape=(None, 32))
    y = recursive_net(x)
    self.assertEqual(y.get_shape().as_list(), [None, 2])

  def testSparseInput(self):

    class SparseSoftmax(base_layers.Layer):

      def call(self, inputs):
        return sparse_ops.sparse_softmax(inputs)

    x = base_layers.Input(shape=(32,), sparse=True)
    y = SparseSoftmax()(x)  # pylint: disable=not-callable
    network = base_layers.Network(x, y)

    self.assertEqual(len(network.layers), 2)
    self.assertEqual(network.layers[0].sparse, True)

  @test_util.run_in_graph_and_eager_modes()
  def testMaskingSingleInput(self):

    class MaskedLayer(base_layers.Layer):

      def call(self, inputs, mask=None):
        if mask is not None:
          return inputs * mask
        return inputs

      def compute_mask(self, inputs, mask=None):
        return array_ops.ones_like(inputs)

    if context.in_graph_mode():
      x = base_layers.Input(shape=(32,))
      y = MaskedLayer()(x)  # pylint: disable=not-callable
      network = base_layers.Network(x, y)

      # test callability on Input
      x_2 = base_layers.Input(shape=(32,))
      y_2 = network(x_2)
      self.assertEqual(y_2.get_shape().as_list(), [None, 32])

      # test callability on regular tensor
      x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32))
      y_2 = network(x_2)
      self.assertEqual(y_2.get_shape().as_list(), [None, 32])
    else:
      a = constant_op.constant([2] * 32)
      mask = constant_op.constant([0, 1] * 16)
      a._keras_mask = mask
      b = MaskedLayer().apply(a)
      self.assertTrue(hasattr(b, '_keras_mask'))
      self.assertAllEqual(self.evaluate(array_ops.ones_like(mask)),
                          self.evaluate(getattr(b, '_keras_mask')))
      self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b))


class DeferredModeTest(test.TestCase):

  def testDeferredTensorAttributes(self):
    x = base_layers._DeferredTensor(shape=(None, 2), dtype='float32', name='x')
    self.assertEqual(str(x),
                     'DeferredTensor(\'x\', shape=(?, 2), dtype=float32)')
    self.assertEqual(repr(x),
                     '<_DeferredTensor \'x\' shape=(?, 2) dtype=float32>')

  @test_util.run_in_graph_and_eager_modes()
  def testSimpleNetworkBuilding(self):
    inputs = base_layers.Input(shape=(32,))
    if context.in_eager_mode():
      self.assertIsInstance(inputs, base_layers._DeferredTensor)
      self.assertEqual(inputs.dtype.name, 'float32')
      self.assertEqual(inputs.shape.as_list(), [None, 32])

    x = core_layers.Dense(2)(inputs)
    if context.in_eager_mode():
      self.assertIsInstance(x, base_layers._DeferredTensor)
      self.assertEqual(x.dtype.name, 'float32')
      self.assertEqual(x.shape.as_list(), [None, 2])

    outputs = core_layers.Dense(4)(x)
    network = base_layers.Network(inputs, outputs)
    self.assertIsInstance(network, base_layers.Network)

    if context.in_eager_mode():
      # It should be possible to call such a network on EagerTensors.
      inputs = constant_op.constant(
          np.random.random((10, 32)).astype('float32'))
      outputs = network(inputs)
      self.assertEqual(outputs.shape.as_list(), [10, 4])

  @test_util.run_in_graph_and_eager_modes()
  def testMultiIONetworkbuilding(self):
    input_a = base_layers.Input(shape=(32,))
    input_b = base_layers.Input(shape=(16,))
    a = core_layers.Dense(16)(input_a)

    class AddLayer(base_layers.Layer):

      def call(self, inputs):
        return inputs[0] + inputs[1]

      def _compute_output_shape(self, input_shape):
        return input_shape[0]

    c = AddLayer()([a, input_b])  # pylint: disable=not-callable
    c = core_layers.Dense(2)(c)

    network = base_layers.Network([input_a, input_b], [a, c])
    if context.in_eager_mode():
      a_val = constant_op.constant(
          np.random.random((10, 32)).astype('float32'))
      b_val = constant_op.constant(
          np.random.random((10, 16)).astype('float32'))
      outputs = network([a_val, b_val])
      self.assertEqual(len(outputs), 2)
      self.assertEqual(outputs[0].shape.as_list(), [10, 16])
      self.assertEqual(outputs[1].shape.as_list(), [10, 2])

if __name__ == '__main__':
  test.main()