aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/compiler/jit_test.py
blob: 3e631b59094b182384c27031b3c780abd57a1bcb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for contrib.compiler.jit."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.contrib.compiler import jit
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import function
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import gradients
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
# pylint: enable=g-import-not-at-top


_REGISTERED_OPS = op_def_registry.get_registered_ops()


def enable_jit_nonstateful(node_def):
  try:
    return not _REGISTERED_OPS[node_def.op].is_stateful
  except KeyError:
    raise ValueError("Unregistered op being created: %s" % node_def)


class JITTest(test.TestCase):

  def compute(self, use_jit, compute_fn):
    random_seed.set_random_seed(1234)
    with self.session(graph=ops.Graph()) as sess:
      with jit.experimental_jit_scope(use_jit):
        r = compute_fn()
      sess.run(variables.global_variables_initializer())
      return (r, sess.run(r))

  def testJITCreateOpsLambda(self):
    """Test several ways of customizing the compilation attribute."""
    def create_ops():
      with variable_scope.variable_scope(
          "root",
          initializer=init_ops.random_uniform_initializer(
              -0.1, 0.1, seed=2)):
        inputs = random_ops.random_uniform((1,), seed=1)
        return inputs
    v_false_1_t, v_false_1 = self.compute(False, create_ops)
    _, v_false_2 = self.compute(False, create_ops)
    v_true_1_t, v_true_1 = self.compute(enable_jit_nonstateful, create_ops)
    _, v_true_2 = self.compute(enable_jit_nonstateful, create_ops)
    v_all_true_t, _ = self.compute(True, create_ops)
    self.assertFalse(v_false_1_t.op.get_attr("_XlaCompile"))
    v_true_1_t_sampler_op = v_true_1_t.graph.get_operation_by_name(
        "root/random_uniform/RandomUniform")
    v_all_true_t_sampler_op = v_all_true_t.graph.get_operation_by_name(
        "root/random_uniform/RandomUniform")

    self.assertFalse(v_true_1_t_sampler_op.get_attr("_XlaCompile"))
    self.assertTrue(v_all_true_t_sampler_op.get_attr("_XlaCompile"))

    self.assertTrue(v_true_1_t.op.get_attr("_XlaCompile"))
    self.assertTrue(v_all_true_t.op.get_attr("_XlaCompile"))

    # Additionally ensure that where no JIT compilation happens on the
    # random_uniform op, the output values are identical to the case
    # where no JIT compilation happens anywhere.
    self.assertAllClose(v_false_1, v_false_2)
    self.assertAllClose(v_true_1, v_true_2)
    self.assertAllClose(v_false_1, v_true_1)

  def testJITXlaScope(self):
    with self.session(graph=ops.Graph()):
      with jit.experimental_jit_scope(True):
        # XlaScope 0
        a1 = constant_op.constant(1)
      with jit.experimental_jit_scope(True):
        # XlaScope 1
        a2 = constant_op.constant(1)
        with jit.experimental_jit_scope(True):
          # XlaScope still 1, depth 1
          a3 = constant_op.constant(1)
          with jit.experimental_jit_scope(True):
            # XlaScope still 1, depth 2
            a4 = constant_op.constant(1)
          # XlaScope still 1, depth 1
          a5 = constant_op.constant(1)
      with jit.experimental_jit_scope(True):
        # XlaScope now 2, depth 0
        a6 = constant_op.constant(1)

    self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope"))
    self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope"))
    self.assertEqual(b"jit_scope_1", a3.op.get_attr("_XlaScope"))
    self.assertEqual(b"jit_scope_1", a4.op.get_attr("_XlaScope"))
    self.assertEqual(b"jit_scope_1", a5.op.get_attr("_XlaScope"))
    self.assertEqual(b"jit_scope_2", a6.op.get_attr("_XlaScope"))

  def testJITVariableSeed(self):
    """Test that the stateful initializer is not marked for compilation.

    XLA does not currently support seeded initialization and XLA initializers
    therefore return different values than non-XLA counterparts.  Here
    we ensure that if we can disable JIT compilation for the initializers and
    get the same variable values as if no JIT compilation happened.
    """
    def create_ops():
      with variable_scope.variable_scope(
          "root",
          initializer=init_ops.random_uniform_initializer(
              -0.1, 0.1, seed=2)):
        inputs = variable_scope.get_variable("var", (1,))
        return inputs
    _, v_false_1 = self.compute(False, create_ops)
    _, v_false_2 = self.compute(False, create_ops)
    _, v_true_1 = self.compute(enable_jit_nonstateful, create_ops)
    _, v_true_2 = self.compute(enable_jit_nonstateful, create_ops)
    self.assertAllClose(v_false_1, v_false_2)
    self.assertAllClose(v_true_1, v_true_2)
    self.assertAllClose(v_false_1, v_true_1)

  def testDefunNoJitScope(self):
    with self.session(graph=ops.Graph()):

      @function.Defun(compiled=True, noinline=True)
      def mulop(x1, x2):
        return x1 * x2
      x = constant_op.constant(1.0)
      r = mulop(x, x)

      # Ensure the forward function is compiled.
      graph_def = r.graph.as_graph_def()
      func_attrs = graph_def.library.function[0].attr
      self.assertTrue(func_attrs["_XlaCompile"].b)
      # No enclosing jit scope so function sets its own value for _XlaScope.
      self.assertEqual(b"function_mulop", func_attrs["_XlaScope"].s)

  def testDefunInheritsJitScope(self):
    with self.session(graph=ops.Graph()):
      with jit.experimental_jit_scope(True):
        @function.Defun(compiled=True, noinline=True)
        def mulop(x1, x2):
          return x1 * x2
        x = constant_op.constant(1.0)
        r = mulop(x, x)

      # Ensure the forward function is compiled.
      graph_def = r.graph.as_graph_def()
      func_attrs = graph_def.library.function[0].attr
      self.assertTrue(func_attrs["_XlaCompile"].b)
      # Ensure _XlaScope is inherited from enclosing context.
      self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s)


class CompilationEnabledInGradientTest(test.TestCase):

  def testCompilationInGradient(self):
    with self.cached_session():
      x = constant_op.constant([[3.]])
      y_nc = math_ops.matmul(x, x, name="not_compiled")
      with jit.experimental_jit_scope():
        y_c = math_ops.matmul(y_nc, y_nc, name="compiled")
      x_grads = gradients.gradients([y_c], [x])[0]
      operations = x.graph.get_operations()
      c_grad_ops = [
          op for op in operations if "gradients/compiled" in op.name]
      nc_grad_ops = [
          op for op in operations if "gradients/not_compiled" in op.name]
      self.assertGreater(len(c_grad_ops), 0)
      self.assertGreater(len(nc_grad_ops), 0)
      for cg in c_grad_ops:
        self.assertTrue(cg.get_attr("_XlaCompile"))
      for ncg in nc_grad_ops:
        with self.assertRaisesRegexp(ValueError, "[Nn]o attr named"):
          ncg.get_attr("_XlaCompile")

      # d/dx (x ** 4) = 4 * (x ** 3)
      self.assertAllClose([[108]], x_grads.eval())

  def testCompilationGradientScopeNames(self):
    with self.session(graph=ops.Graph()):
      with jit.experimental_jit_scope():
        # XlaScope 0
        a1 = constant_op.constant([[1.]])
        a1t = math_ops.matmul(a1, a1)
      with jit.experimental_jit_scope():
        # XlaScope 1
        a2 = constant_op.constant([[1.]])
        a2t = math_ops.matmul(a2, a2)

      self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope"))
      self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope"))
      grad_a1 = gradients.gradients(a1t, a1, name="GA")[0]
      grad_a2 = gradients.gradients(a2t, a2, name="GB")[0]
      grad_a1 = grad_a1.op.inputs[0]
      grad_a2 = grad_a2.op.inputs[0]
      self.assertTrue(grad_a1.op.get_attr("_XlaCompile"))
      self.assertTrue(grad_a2.op.get_attr("_XlaCompile"))
      self.assertEqual(b"jit_scope_0", grad_a1.op.get_attr("_XlaScope"))
      self.assertEqual(b"jit_scope_1", grad_a2.op.get_attr("_XlaScope"))

  def testCompilationSeparateGradientScopeNames(self):
    with self.session(graph=ops.Graph()):
      with jit.experimental_jit_scope(True, separate_compiled_gradients=True):
        # XlaScope 0
        a1 = constant_op.constant([[1.]])
        a1t = math_ops.matmul(a1, a1)
      with jit.experimental_jit_scope(True, separate_compiled_gradients=True):
        # XlaScope 1
        a2 = constant_op.constant([[1.]])
        a2t = math_ops.matmul(a2, a2)

      self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope"))
      self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope"))
      grad_a1 = gradients.gradients(a1t, a1, name="GA")[0]
      grad_a2 = gradients.gradients(a2t, a2, name="GB")[0]
      grad_a1 = grad_a1.op.inputs[0]
      grad_a2 = grad_a2.op.inputs[0]
      self.assertTrue(grad_a1.op.get_attr("_XlaCompile"))
      self.assertTrue(grad_a2.op.get_attr("_XlaCompile"))
      self.assertEqual(b"jit_scope_0_grad_GA",
                       grad_a1.op.get_attr("_XlaScope"))
      self.assertEqual(b"jit_scope_1_grad_GB",
                       grad_a2.op.get_attr("_XlaScope"))

  def testPlaysNicelyWithDefun(self):
    with self.session(graph=ops.Graph()) as sess:
      with jit.experimental_jit_scope(True):
        @function.Defun(compiled=True, noinline=True)
        def mulop(x1, x2):
          return x1 * x2
        x = constant_op.constant(1.0)
        r = mulop(x, x)
        g_r = gradients.gradients(r, x, name="GA")[0]

      # Ensure the forward function is compiled.
      graph_def = r.graph.as_graph_def()
      func_attrs = graph_def.library.function[0].attr
      self.assertTrue(func_attrs["_XlaCompile"].b)
      self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s)

      # Ensure the gradient (SymbolicGradient) is compiled, with the same
      # _XlaScope as the function itself.
      grad_op = g_r.op.inputs[0].op
      self.assertTrue(grad_op.get_attr("_XlaCompile"))
      self.assertEqual(b"jit_scope_0", grad_op.get_attr("_XlaScope"))

      # Ensure the ops run: grad(x1*x1) = 2*x1
      self.assertAllClose([1.0, 1.0, 2.0], sess.run([x, r, g_r]))

  def testPlaysNicelyWithDefunSeparateGradientScope(self):
    with self.session(graph=ops.Graph()) as sess:
      with jit.experimental_jit_scope(True):

        @function.Defun(
            compiled=True, noinline=True, separate_compiled_gradients=True)
        def mulop(x1, x2):
          return x1 * x2

        x = constant_op.constant(1.0)
        r = mulop(x, x)
        g_r = gradients.gradients(r, x, name="GA")[0]

      # Ensure the forward function is compiled.
      graph_def = r.graph.as_graph_def()
      func_attrs = graph_def.library.function[0].attr
      self.assertTrue(func_attrs["_XlaCompile"].b)
      self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s)

      # Ensure the gradient (SymbolicGradient) is compiled, with a different
      # _XlaScope from the function itself.
      grad_op = g_r.op.inputs[0].op
      self.assertTrue(grad_op.get_attr("_XlaCompile"))
      self.assertEqual(b"jit_scope_0_grad_GA",
                       grad_op.get_attr("_XlaScope"))

      # Ensure the ops run: grad(x1*x1) = 2*x1
      self.assertAllClose([1.0, 1.0, 2.0], sess.run([x, r, g_r]))


if __name__ == "__main__":
  test.main()