aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/template_test.py
blob: 9dcdaa61ed2c0c12940817ccb311e27d1a19fa0c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for make_template."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import traceback

from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import training
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent


def variable_scoped_function(trainable=True):
  return variable_scope.get_variable(
      "dummy", shape=[1], trainable=trainable,
      initializer=init_ops.zeros_initializer())


def internally_variable_scoped_function(scope_name):
  with variable_scope.variable_scope(scope_name):
    return variable_scope.get_variable(
        "dummy", shape=[1], initializer=init_ops.zeros_initializer())


def function_with_create(trainable):
  """Creates a variable as a side effect using tf.Variable."""
  variables.Variable(0, trainable=trainable)
  return variable_scope.get_variable(
      "dummy", shape=[1], initializer=init_ops.zeros_initializer())


def function_with_side_create(trainable, name="side"):
  """Creates a variable as a side effect using tf.get_variable."""
  variable_scope.get_variable(name, shape=[1], trainable=trainable)
  return variable_scope.get_variable(
      "dummy", shape=[1], initializer=init_ops.zeros_initializer())


def variable_scoped_function_with_local_variable():
  variable_scope.get_local_variable(
      "local", shape=[1], initializer=init_ops.zeros_initializer())
  return variable_scope.get_variable(
      "dummy", shape=[1], initializer=init_ops.zeros_initializer())


class TemplateTest(test.TestCase):

  def test_end_to_end(self):
    """This test shows a very simple line model with test_loss.

    The template is used to share parameters between a training and test model.
    """
    # y = 2x + 1
    training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7])
    test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17])

    random_seed.set_random_seed(1234)

    def test_line(x):
      m = variable_scope.get_variable(
          "w", shape=[], initializer=init_ops.truncated_normal_initializer())
      b = variable_scope.get_variable(
          "b", shape=[], initializer=init_ops.truncated_normal_initializer())
      return x * m + b

    line_template = template.make_template("line", test_line)

    train_prediction = line_template(training_input)
    test_prediction = line_template(test_input)

    train_loss = math_ops.reduce_mean(
        math_ops.square(train_prediction - training_output))
    test_loss = math_ops.reduce_mean(
        math_ops.square(test_prediction - test_output))

    optimizer = gradient_descent.GradientDescentOptimizer(0.1)
    train_op = optimizer.minimize(train_loss)

    with session.Session() as sess:
      sess.run(variables.global_variables_initializer())
      initial_test_loss = sess.run(test_loss)
      sess.run(train_op)
      final_test_loss = sess.run(test_loss)

    # Parameters are tied, so the loss should have gone down when we trained it.
    self.assertLess(final_test_loss, initial_test_loss)

  def test_end_to_end_eager(self):
    """This test shows a very simple line model with test_loss in eager mode.

    The template is used to share parameters between a training and test model.
    """
    with context.eager_mode():
      # y = 2x + 1
      training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7])
      test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17])

      random_seed.set_random_seed(1234)

      def test_line(x):
        m = variable_scope.get_variable(
            "w", shape=[], initializer=init_ops.truncated_normal_initializer())
        b = variable_scope.get_variable(
            "b", shape=[], initializer=init_ops.truncated_normal_initializer())
        return x * m + b

      line_template = template.make_template("line", test_line)

      def train_loss():
        train_prediction = line_template(training_input)
        return math_ops.reduce_mean(
            math_ops.square(train_prediction - training_output))

      def test_loss():
        test_prediction = line_template(test_input)
        return math_ops.reduce_mean(
            math_ops.square(test_prediction - test_output))

      optimizer = gradient_descent.GradientDescentOptimizer(0.1)
      initial_test_loss = test_loss()
      optimizer.minimize(train_loss)
      final_test_loss = test_loss()

      # Parameters are tied, so the loss should have gone down after training.
      self.assertLess(final_test_loss.numpy(), initial_test_loss.numpy())

  @test_util.run_in_graph_and_eager_modes
  def test_skip_stack_frames(self):
    first = traceback.format_stack()
    second = traceback.format_stack()
    result = template._skip_common_stack_elements(first, second)
    self.assertEqual(1, len(result))
    self.assertNotEqual(len(first), len(result))

  @test_util.run_in_graph_and_eager_modes
  def test_template_with_name(self):
    tmpl1 = template.make_template("s1", variable_scoped_function)
    tmpl2 = template.make_template("s1", variable_scoped_function)

    v1 = tmpl1()
    v2 = tmpl1()
    v3 = tmpl2()
    self.assertEqual(v1, v2)
    self.assertNotEqual(v1, v3)
    self.assertEqual("s1/dummy:0", v1.name)
    self.assertEqual("s1_1/dummy:0", v3.name)

  def test_same_unique_name_raise_error(self):
    tmpl1 = template.make_template(
        "_", variable_scoped_function, unique_name_="s1")
    tmpl1()
    tmpl2 = template.make_template(
        "_", variable_scoped_function, unique_name_="s1")
    with self.assertRaisesRegexp(
        ValueError, "Variable s1/dummy already exists, disallowed.*"):
      tmpl2()

  def test_unique_name_raise_error_in_eager(self):
    with context.eager_mode():
      with self.assertRaisesRegexp(
          ValueError,
          "unique_name_ cannot be used when eager exeuction is enabled."):
        template.make_template(
            "_", variable_scoped_function, unique_name_="s1")

  def test_unique_name_and_reuse(self):
    tmpl1 = template.make_template(
        "_", variable_scoped_function, unique_name_="s1")
    v1 = tmpl1()
    v2 = tmpl1()

    variable_scope.get_variable_scope().reuse_variables()
    tmpl2 = template.make_template(
        "_", variable_scoped_function, unique_name_="s1")
    v3 = tmpl2()

    self.assertEqual(v1, v2)
    self.assertEqual(v1, v3)
    self.assertEqual("s1/dummy:0", v1.name)

  @test_util.run_in_graph_and_eager_modes
  def test_template_in_scope(self):
    tmpl1 = template.make_template("s1", variable_scoped_function)
    tmpl2 = template.make_template("s1", variable_scoped_function)

    with variable_scope.variable_scope("scope"):
      v1 = tmpl1()
      v3 = tmpl2()

    # The template contract requires the following to ignore scope2.
    with variable_scope.variable_scope("scope2"):
      v2 = tmpl1()
    self.assertEqual(v1, v2)
    self.assertNotEqual(v1, v3)
    self.assertEqual("scope/s1/dummy:0", v1.name)
    self.assertEqual("scope/s1_1/dummy:0", v3.name)

  @test_util.run_in_graph_and_eager_modes
  def test_template_with_internal_reuse(self):
    tmpl1 = template.make_template("s1", internally_variable_scoped_function)
    tmpl2 = template.make_template("s1", internally_variable_scoped_function)

    v1 = tmpl1("test")
    v2 = tmpl1("test")
    v3 = tmpl2("test")
    self.assertEqual(v1, v2)
    self.assertNotEqual(v1, v3)
    self.assertEqual("s1/test/dummy:0", v1.name)
    self.assertEqual("s1_1/test/dummy:0", v3.name)

    with self.assertRaises(ValueError):
      tmpl1("not_test")

  @test_util.run_in_graph_and_eager_modes
  def test_template_without_name(self):
    with self.assertRaisesRegexp(
        ValueError, "name cannot be None."):
      template.make_template(None, variable_scoped_function)

  @test_util.run_in_graph_and_eager_modes
  def test_make_template(self):
    # Test both that we can call it with positional and keywords.
    tmpl1 = template.make_template(
        "s1", internally_variable_scoped_function, scope_name="test")
    tmpl2 = template.make_template(
        "s1", internally_variable_scoped_function, scope_name="test")

    v1 = tmpl1()
    v2 = tmpl1()
    v3 = tmpl2()
    self.assertEqual(v1, v2)
    self.assertNotEqual(v1, v3)
    self.assertEqual("s1/test/dummy:0", v1.name)
    self.assertEqual("s1_1/test/dummy:0", v3.name)

  def test_enforces_no_extra_trainable_variables(self):
    tmpl = template.make_template("s", function_with_create, trainable=True)

    tmpl()
    with self.assertRaises(ValueError):
      tmpl()

  @test_util.run_in_graph_and_eager_modes
  def test_enforces_no_extra_trainable_variables_eager(self):
    tmpl = template.make_template("s",
                                  function_with_side_create,
                                  trainable=True)

    tmpl(name="1")
    with self.assertRaises(ValueError):
      tmpl(name="2")

  def test_permits_extra_non_trainable_variables(self):
    tmpl = template.make_template("s", function_with_create, trainable=False)
    self.assertEqual(tmpl(), tmpl())

  def test_permits_extra_non_trainable_variables_eager(self):
    with context.eager_mode():
      tmpl = template.make_template("s",
                                    function_with_side_create,
                                    trainable=False)
      self.assertEqual(tmpl(name="1"), tmpl(name="2"))

  @test_util.run_in_graph_and_eager_modes
  def test_internal_variable_reuse(self):

    def nested():
      with variable_scope.variable_scope("nested") as vs:
        v1 = variable_scope.get_variable(
            "x", initializer=init_ops.zeros_initializer(), shape=[])
      with variable_scope.variable_scope(vs, reuse=True):
        v2 = variable_scope.get_variable("x")
      self.assertEqual(v1, v2)
      return v1

    tmpl1 = template.make_template("s1", nested)
    tmpl2 = template.make_template("s1", nested)

    v1 = tmpl1()
    v2 = tmpl1()
    v3 = tmpl2()
    self.assertEqual(v1, v2)
    self.assertNotEqual(v1, v3)
    self.assertEqual("s1/nested/x:0", v1.name)
    self.assertEqual("s1_1/nested/x:0", v3.name)

  @test_util.run_in_graph_and_eager_modes
  def test_nested_templates(self):

    def nested_template():
      nested1 = template.make_template("nested", variable_scoped_function)
      nested2 = template.make_template("nested", variable_scoped_function)
      v1 = nested1()
      v2 = nested2()

      # nested1 and nested2 should not share variables
      self.assertNotEqual(v1, v2)

      # Variables created by nested1 should be isolated from variables
      # created by nested2.
      self.assertEqual(nested1.variables, [v1])
      self.assertEqual(nested2.variables, [v2])
      self.assertEqual(nested1.trainable_variables, [v1])
      self.assertEqual(nested2.trainable_variables, [v2])
      self.assertEqual(len(nested1.non_trainable_variables), 0)
      self.assertEqual(len(nested2.non_trainable_variables), 0)
      return v1, v2

    tmpl1 = template.make_template("s1", nested_template)
    tmpl2 = template.make_template("s1", nested_template)

    v1, v2 = tmpl1()
    v3, v4 = tmpl1()
    v5, v6 = tmpl2()

    # The second invocation of tmpl1 should reuse the variables
    # created in the first invocation.
    self.assertEqual([v1, v2], [v3, v4])
    self.assertEqual(tmpl1.variables, [v1, v2])
    self.assertEqual(tmpl1.trainable_variables, [v1, v2])
    self.assertEqual(len(tmpl1.non_trainable_variables), 0)

    # tmpl1 and tmpl2 should not share variables.
    self.assertNotEqual([v1, v2], [v5, v6])
    self.assertSequenceEqual(tmpl2.variables, [v5, v6])
    self.assertSequenceEqual(tmpl2.trainable_variables, [v5, v6])
    self.assertEqual(len(tmpl2.non_trainable_variables), 0)
    self.assertEqual("s1/nested/dummy:0", v1.name)
    self.assertEqual("s1/nested_1/dummy:0", v2.name)
    self.assertEqual("s1_1/nested/dummy:0", v5.name)
    self.assertEqual("s1_1/nested_1/dummy:0", v6.name)

    self.assertEqual(2, len(tmpl1._checkpoint_dependencies))
    self.assertEqual("nested", tmpl1._checkpoint_dependencies[0].name)
    self.assertEqual("nested_1", tmpl1._checkpoint_dependencies[1].name)
    model = training.Model()
    model.template = tmpl1
    self.assertEqual(model.variables, [v1, v2])
    self.assertEqual(model.trainable_variables, [v1, v2])
    self.assertEqual(len(model.non_trainable_variables), 0)
    model.templates = [tmpl2]
    self.assertEqual(model.variables, [v1, v2, v5, v6])
    self.assertEqual(model.trainable_variables, [v1, v2, v5, v6])
    self.assertEqual(len(model.non_trainable_variables), 0)
    # Make sure losses, layers, and updates aren't broken by having a Template
    # in the mix, which does not expose any updates or losses.
    self.assertEqual([], model.layers)
    self.assertEqual([], model.updates)
    self.assertEqual([], model.losses)
    self.assertEqual([], model.templates.layers)
    self.assertEqual([], model.templates.updates)
    self.assertEqual([], model.templates.losses)

  @test_util.run_in_graph_and_eager_modes
  def test_nested_templates_with_defun(self):

    def variable_scoped_function_no_return_value(trainable=True):
      # defun cannot compile functions that return non-Tensor objects
      _ = variable_scope.get_variable(
          "dummy",
          shape=[1],
          trainable=trainable,
          initializer=init_ops.zeros_initializer())

    def nested_template():
      nested1 = template.make_template_internal(
          "nested",
          variable_scoped_function_no_return_value,
          create_graph_function_=True)
      nested2 = template.make_template_internal(
          "nested",
          variable_scoped_function_no_return_value,
          create_graph_function_=True)
      nested1()
      nested2()
      v1 = nested1.variables
      v2 = nested2.variables

      # nested1 and nested2 should not share variables
      self.assertNotEqual(v1, v2)

      # Variables created by nested1 should be isolated from variables
      # created by nested2.
      self.assertEqual(nested1.variables, v1)
      self.assertEqual(nested2.variables, v2)
      self.assertEqual(nested1.trainable_variables, v1)
      self.assertEqual(nested2.trainable_variables, v2)
      self.assertEqual(len(nested1.non_trainable_variables), 0)
      self.assertEqual(len(nested2.non_trainable_variables), 0)

    tmpl1 = template.make_template("s1", nested_template)
    tmpl2 = template.make_template("s1", nested_template)

    tmpl1()
    v1 = tmpl1.variables
    tmpl1()
    v2 = tmpl1.variables
    tmpl2()
    v3 = tmpl2.variables

    # The second invocation of tmpl1 should reuse the variables
    # created in the first invocation.
    self.assertSequenceEqual(v1, v2)

    # tmpl1 and tmpl2 should not share variables.
    self.assertNotEqual(v1, v3)
    self.assertEqual("s1/nested/dummy:0", v1[0].name)
    self.assertEqual("s1/nested_1/dummy:0", v1[1].name)
    self.assertEqual("s1_1/nested/dummy:0", v3[0].name)
    self.assertEqual("s1_1/nested_1/dummy:0", v3[1].name)

  def test_graph_function_no_name(self):
    with context.eager_mode():

      def f(_, y):
        return y + 1

      partial = functools.partial(f, 1.0)
      tmpl = template.make_template_internal(
          "a", partial, create_graph_function_=True)
      self.assertAllEqual(tmpl(ops.convert_to_tensor(1.0)), 2.0)

  @test_util.run_in_graph_and_eager_modes
  def test_immediate_scope_creation(self):
    # Create templates in scope a then call in scope b. make_template should
    # capture the scope the first time it is called, and make_immediate_template
    # should capture the scope at construction time.
    with variable_scope.variable_scope("ctor_scope"):
      # Create scope here:
      tmpl_immed = template.make_template("a", variable_scoped_function,
                                          True)
      # default: create scope at __call__
      tmpl_defer = template.make_template(
          "b", variable_scoped_function, False)
    with variable_scope.variable_scope("call_scope"):
      inner_imm_var = tmpl_immed()
      inner_defer_var = tmpl_defer()
    outer_imm_var = tmpl_immed()
    outer_defer_var = tmpl_defer()

    self.assertNotEqual(inner_imm_var, inner_defer_var)
    self.assertEqual(outer_imm_var, inner_imm_var)
    self.assertEqual(outer_defer_var, inner_defer_var)

    self.assertEqual("ctor_scope/a/dummy:0", inner_imm_var.name)
    self.assertEqual("call_scope/b/dummy:0", inner_defer_var.name)

  @test_util.run_in_graph_and_eager_modes
  def test_scope_access(self):
    # Ensure that we can access the scope inside the template, because the name
    # of that scope may be different from the name we pass to make_template, due
    # to having been made unique by variable_scope.
    with variable_scope.variable_scope("foo"):
      # Create two templates with the same name, ensure scopes are made unique.
      ta = template.make_template("bar", variable_scoped_function, True)
      tb = template.make_template("bar", variable_scoped_function, True)

    # Ensure we can get the scopes before either template is actually called.
    self.assertEqual(ta.variable_scope.name, "foo/bar")
    self.assertEqual(tb.variable_scope.name, "foo/bar_1")

    with variable_scope.variable_scope("foo_2"):
      # Create a template which defers scope creation.
      tc = template.make_template("blah", variable_scoped_function, False)

    # Before we call the template, the scope property will be set to None.
    self.assertEqual(tc.variable_scope, None)
    tc()

    # Template is called at the top level, so there is no preceding "foo_2".
    self.assertEqual(tc.variable_scope.name, "blah")

  @test_util.run_in_graph_and_eager_modes
  def test_custom_getter(self):
    # Custom getter that maintains call count and forwards to true getter
    custom_getter_count = [0]

    def custom_getter(getter, name, *args, **kwargs):
      custom_getter_count[0] += 1
      return getter(name, *args, **kwargs)

    # Test that custom getter is called both when variables are created and
    # subsequently accessed
    tmpl1 = template.make_template(
        "s1", variable_scoped_function, custom_getter_=custom_getter)
    self.assertEqual(custom_getter_count[0], 0)
    tmpl1()
    self.assertEqual(custom_getter_count[0], 1)
    tmpl1()
    self.assertEqual(custom_getter_count[0], 2)

    # Test that custom getter is called when the variable scope is created
    # during construction
    custom_getter_count[0] = 0
    tmpl2 = template.make_template(
        "s2",
        variable_scoped_function,
        custom_getter_=custom_getter,
        create_scope_now_=True)
    self.assertEqual(custom_getter_count[0], 0)
    tmpl2()
    self.assertEqual(custom_getter_count[0], 1)
    tmpl2()
    self.assertEqual(custom_getter_count[0], 2)

  @test_util.run_in_graph_and_eager_modes
  def test_fails_gracefully(self):
    for create_scope_now in [True, False]:
      def module_function_with_one_arg(inputs):
        w = variable_scope.get_variable(
            "w", shape=[1], initializer=init_ops.zeros_initializer())
        return inputs * w

      templatized_function = template.make_template(
          "f1", module_function_with_one_arg,
          create_scope_now_=create_scope_now)
      data = array_ops.zeros([1])
      try:
        # Try to connect with a kwarg which is unsupported.
        templatized_function(data, is_training=True)
      except TypeError:
        pass

      # The failed __call__ hasn't modified the inner state.
      self.assertFalse(templatized_function._variables_created)
      templatized_function(data)
      self.assertTrue(templatized_function._variables_created)

  @test_util.run_in_graph_and_eager_modes
  def test_name_scopes_for_variable_scopes(self):
    # Test that name scopes are not unnecessarily uniquified (but are
    # still uniquified when necessary).
    def linear_module(x, output_size):
      w = variable_scope.get_variable(
          "w", shape=[x.get_shape()[1], output_size],
          initializer=init_ops.zeros_initializer())
      b = variable_scope.get_variable(
          "b", shape=[output_size],
          initializer=init_ops.zeros_initializer())
      return (math_ops.matmul(x, w) + b), w

    def make_linear_module(output_size, name):
      return template.make_template(
          name,
          linear_module,
          output_size=output_size,
          create_scope_now_=True)

    inputs = array_ops.ones((3, 4))

    linear1 = make_linear_module(output_size=2, name="foo")
    outputs_a, w1 = linear1(inputs)
    outputs_b, _ = linear1(inputs)
    self.assertEquals("foo", linear1.variable_scope.name)
    self.assertEquals("foo/w:0", w1.name)
    if not context.executing_eagerly():
      self.assertEquals("foo/add:0", outputs_a.name,
                        "First application of template should get "
                        "same name scope as variables.")
      self.assertEquals("foo_1/add:0", outputs_b.name,
                        "Second application of template should get "
                        "a freshly uniquified name scope.")

    linear2 = make_linear_module(output_size=2, name="foo")
    outputs_c, w2 = linear2(inputs)
    outputs_d, _ = linear2(inputs)
    self.assertEquals("foo_1", linear2.variable_scope.name,
                      "New template gets a freshly uniquified variable scope "
                      "because 'foo' is already taken.")
    self.assertEquals("foo_1/w:0", w2.name)
    if not context.executing_eagerly():
      self.assertEquals("foo_1_1/add:0", outputs_c.name,
                        "First application of template would get "
                        "same name scope as variables, but 'foo_1' is already "
                        "a name scope.")
      self.assertEquals("foo_1_2/add:0", outputs_d.name,
                        "Second application of template should also get "
                        "a freshly uniquified name scope.")

  @test_util.run_in_graph_and_eager_modes
  def test_global_variables(self):
    # Make sure global_variables are created.
    with variable_scope.variable_scope("foo"):
      # Create two templates with the same name, ensure scopes are made unique.
      ta = template.make_template("bar", variable_scoped_function, True)
      if context.executing_eagerly():
        tb = template.make_template("s", function_with_side_create,
                                    trainable=False)
      else:
        tb = template.make_template("s", function_with_create, trainable=False)

    # Initially there are not variables created.
    self.assertEqual([], list(ta.global_variables))
    self.assertEqual([], list(tb.global_variables))
    # After calling there are variables created.
    ta()
    tb()
    # Ensure we can get the scopes before either template is actually called.
    self.assertEqual(1, len(ta.global_variables))
    self.assertEqual(2, len(tb.global_variables))

  @test_util.run_in_graph_and_eager_modes
  def test_trainable_variables(self):
    # Make sure trainable_variables are created.
    with variable_scope.variable_scope("foo2"):
      # Create two templates with the same name, ensure scopes are made unique.
      ta = template.make_template("bar", variable_scoped_function, True)
      tb = template.make_template("bar", variable_scoped_function, True)

    # Initially there are not variables created.
    self.assertEqual([], list(ta.trainable_variables))
    self.assertEqual([], list(tb.trainable_variables))
    # After calling there are variables created.
    ta()
    tb()
    # Ensure we can get the scopes before either template is actually called.
    self.assertEqual(1, len(ta.trainable_variables))
    self.assertEqual(1, len(tb.trainable_variables))
    # None non-trainable variable was created.
    self.assertEqual([], list(ta.non_trainable_variables))
    self.assertEqual([], list(tb.non_trainable_variables))
    # Ensure variables returns all the variables.
    self.assertEqual(1, len(ta.variables))
    self.assertEqual(1, len(tb.variables))

  @test_util.run_in_graph_and_eager_modes
  def test_non_trainable_variables(self):
    # Make sure non_trainable_variables are created.
    with variable_scope.variable_scope("foo2"):
      ta = template.make_template("a", variable_scoped_function,
                                  trainable=True)
      tb = template.make_template("b", variable_scoped_function,
                                  trainable=False)
    # Initially there are not variables created.
    self.assertEqual([], list(ta.variables))
    self.assertEqual([], list(tb.variables))
    # After calling there are variables created.
    ta()
    tb()
    # Check the trainable and non_trainable variables.
    self.assertEqual(1, len(ta.trainable_variables))
    self.assertEqual([], list(ta.non_trainable_variables))

    self.assertEqual([], list(tb.trainable_variables))
    self.assertEqual(1, len(tb.non_trainable_variables))
    # Ensure variables returns all the variables.
    self.assertEqual(1, len(ta.variables))
    self.assertEqual(1, len(tb.variables))

  # TODO(apassos) handle local variables in Eager
  def test_local_variables(self):
    # Make sure trainable_variables are created.
    with variable_scope.variable_scope("foo3"):
      # Create two templates with the same name, ensure scopes are made unique.
      ta = template.make_template("bar", variable_scoped_function, True)
      tb = template.make_template("bar",
                                  variable_scoped_function_with_local_variable)

    # Initially there are not variables created.
    self.assertEqual([], list(ta.local_variables))
    self.assertEqual([], list(tb.local_variables))
    # After calling there are variables created.
    ta()
    tb()
    # Ensure we can get the scopes before either template is actually called.
    self.assertEqual(0, len(ta.local_variables))
    self.assertEqual(1, len(tb.local_variables))

  @test_util.run_in_graph_and_eager_modes
  def test_make_template_with_defun(self):

    def variable_scoped_function_no_return_value(scope_name):
      # defun cannot compile functions that return non-Tensor objects
      with variable_scope.variable_scope(scope_name):
        _ = variable_scope.get_variable(
            "dummy", shape=[1], initializer=init_ops.zeros_initializer())

    tmpl = template.make_template_internal(
        "s1",
        variable_scoped_function_no_return_value,
        create_graph_function_=True,
        scope_name="test")

    # The first invocation of tmpl1 creates variables, the second should
    # be executed as a graph function.
    tmpl()
    v1 = tmpl.variables
    tmpl()
    v2 = tmpl.variables

    self.assertSequenceEqual(v1, v2)
    self.assertEqual("s1/test/dummy:0", v1[0].name)


if __name__ == "__main__":
  test.main()