aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/execution_callbacks.py
blob: 80ff4459d60a33d1a02f14acaafb8370a48fb6ca (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Execution Callbacks for Eager Mode."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools

import numpy as np

from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import execute
from tensorflow.python.platform import tf_logging as logging

_DEFAULT_CALLBACK_ACTION = "raise"
_VALID_CALLBACK_ACTIONS = (None, "ignore", "print", "raise", "warn")


# TODO(cais): Consider moving this exception class to errors_impl.py.
class InfOrNanError(Exception):
  """Exception for inf and/or nan being present in tensor."""

  def __init__(self,
               op_type,
               op_name,
               output_index,
               num_outputs,
               value):
    """Constructor of InfOrNanError.

    Args:
      op_type: Type name of the op that generated the tensor that generated the
        `inf`(s) or `nan`(s) (e.g., `Div`).
      op_name: Name of the op that generated the tensor with `inf`(s) or
        `nan`(s). This name is set by client and can be `None` if it is unset.
      output_index: The 0-based output index of the tensor that contains
        `inf`(s) or `nan`(s).
      num_outputs: Total number of outputs of the operation.
      value: The tensor value that contains `inf`(s) or `nan`(s).
    """
    self._op_type = op_type
    self._op_name = op_name
    self._output_index = output_index
    self._num_outputs = num_outputs
    self._value = value

    self._total_count = np.size(value)
    self._inf_count = np.count_nonzero(np.isinf(value))
    self._nan_count = np.count_nonzero(np.isnan(value))

    super(InfOrNanError, self).__init__(self._get_error_message())

  def _get_error_message(self):
    """Get the error message describing this InfOrNanError object."""
    name_str = (("'%s'" % self._op_name) if self._op_name is not None
                else str(self._op_name))
    msg = "Output %d of %d of TFE operation %s (name: %s) contains " % (
        self._output_index + 1, self._num_outputs, self._op_type, name_str)
    if self._inf_count and self._nan_count:
      msg += "%d inf(s) and %d nan(s) " % (self._inf_count, self._nan_count)
    elif self._inf_count:
      msg += "%d inf(s) " % self._inf_count
    else:
      msg += "%d nan(s) " % self._nan_count
    msg += "out of a total of %d element(s). Tensor value: %s" % (
        self._total_count, self._value)
    return msg

  @property
  def op_type(self):
    return self._op_type

  @property
  def op_name(self):
    return self._op_name

  @property
  def output_index(self):
    return self._output_index

  @property
  def num_outputs(self):
    return self._num_outputs

  @property
  def value(self):
    return self._value


def inf_nan_callback(op_type,
                     inputs,
                     attrs,
                     outputs,
                     op_name,
                     check_inf=True,
                     check_nan=True,
                     action=_DEFAULT_CALLBACK_ACTION):
  """An execution callback that checks for `inf`s and `nan`s in output tensors.

  This callback can be used with `tfe.add_execute_callback` to check for invalid
  numeric values. E.g.,
  ```python
  tfe.add_execute_callback(tfe.inf_nan_callback)
  ```

  Args:
    op_type: Name of the TFE operation type (e.g., `MatMul`).
    inputs: The `list` of input tensors to the operation, currently unused by
      this callback.
    attrs: Attributes of the TFE operation, as a tuple of alternating attribute
      names and attribute values.
    outputs: The `list` of output tensors from the operation, checked by this
      callback for `inf` and `nan` values.
    op_name: Name of the TFE operation. This name is set by client and can be
      `None` if it unset.
    check_inf: (`bool`) Whether this callback should check for `inf` values in
      the output tensor values.
    check_nan: (`bool`) Whether this callback should check for `nan` values in
      the output tensor values.
    action: (`str`) Action to be taken by the callback when `inf` or `nan`
      values are detected. Possible values {"raise", "warn", "print"}
      `"raise"`: Raise a `InfOrNanError`.
      `"warn"`: Log a warning using `tf.logging.warn`.
      `"print"`: Print a message to `sys.stdout`.

  Raises:
    InfOrNanError: iff `inf` or `nan` values are seen in any of `outputs` and
      `action` is `"raise"`.
    ValueError: iff the value of `action` is invalid.
  """
  del attrs, inputs  # Not used.

  ctx = context.context()

  for index, output in enumerate(outputs):
    if not output.dtype.is_numpy_compatible:
      continue

    numpy_dtype = output.dtype.as_numpy_dtype
    if (np.issubdtype(numpy_dtype, np.floating) or
        np.issubdtype(numpy_dtype, np.complex) or
        np.issubdtype(numpy_dtype, np.integer)):
      try:
        check_numerics_op_attrs = (
            "message", "Eager-mode inf/nan check",
            "T", outputs[0].dtype.as_datatype_enum)
        # TODO(cais): Consider moving this into execute.py.
        # pylint: disable=protected-access
        pywrap_tensorflow.TFE_Py_Execute(
            ctx._handle, output.device, "CheckNumerics", [output],
            check_numerics_op_attrs, 1)
        # pylint: enable=protected-access
      except core._NotOkStatusException:  # pylint: disable=protected-access
        value = output.numpy()
        inf_detected = np.any(np.isinf(value)) and check_inf
        nan_detected = np.any(np.isnan(value)) and check_nan
        if not inf_detected and not nan_detected:
          continue

        error = InfOrNanError(op_type, op_name, index, len(outputs), value)
        if action == "print":
          print("Warning: %s" % str(error))
        elif action == "warn":
          logging.warn(str(error))
        elif action == "raise":
          raise error
        else:
          raise ValueError(
              "Invalid action for inf_nan_callback: %s. Valid actions are: "
              "{print | warn | raise}" % action)


def inf_callback(op_type,
                 inputs,
                 attrs,
                 outputs,
                 op_name,
                 action=_DEFAULT_CALLBACK_ACTION):
  """A specialization of `inf_nan_callback` that checks for `inf`s only."""
  inf_nan_callback(
      op_type,
      inputs,
      attrs,
      outputs,
      op_name,
      check_inf=True,
      check_nan=False,
      action=action)


def nan_callback(op_type,
                 inputs,
                 attrs,
                 outputs,
                 op_name,
                 action=_DEFAULT_CALLBACK_ACTION):
  """A specialization of `inf_nan_callback` that checks for `nan`s only."""
  inf_nan_callback(
      op_type,
      inputs,
      attrs,
      outputs,
      op_name,
      check_inf=False,
      check_nan=True,
      action=action)


def add_execution_callback(callback):
  """Add an execution callback to the default eager context.

  An execution callback is invoked immediately after an eager operation or
  function has finished execution, providing access to the op's type, name
  input and output tensors. Multiple execution callbacks can be added, in
  which case the callbacks will be invoked in the order in which they are
  added. To clear all execution callbacks that have been added, use
  `clear_execution_callbacks()`.

  Example:
  ```python
  def print_even_callback(op_type, op_name, attrs, inputs, outputs):
    # A callback that prints only the even output values.
    if outputs[0].numpy() % 2 == 0:
      print("Even output from %s: %s" % (op_name or op_type,  outputs))
  tfe.add_execution_callback(print_even_callback)

  x = tf.pow(2.0, 3.0) - 3.0
  y = tf.multiply(x, tf.add(1.0, 5.0))
  # When the line above is run, you will see all intermediate outputs that are
  # even numbers printed to the console.

  tfe.clear_execution_callbacks()
  ```

  Args:
    callback: a callable of the signature
      `f(op_type, op_name, attrs, inputs, outputs)`.
      `op_type` is the type of the operation that was just executed (e.g.,
        `MatMul`).
      `op_name` is the name of the operation that was just executed. This
        name is set by the client who created the operation and can be `None` if
        it is unset.
      `attrs` contains the attributes of the operation as a `tuple` of
        alternating attribute name and attribute value.
      `inputs` is the `list` of input `Tensor`(s) to the op.
      `outputs` is the `list` of output `Tensor`(s) from the op.
       Return value(s) from the callback are ignored.
  """
  execute.execute = execute.execute_with_callbacks
  context.context().add_post_execution_callback(callback)


def clear_execution_callbacks():
  """Clear all execution callbacks from the default eager context."""
  context.context().clear_post_execution_callbacks()


def seterr(inf_or_nan=None):
  """Set how abnormal conditions are handled by the default eager context.

  Example:
  ```python
  tfe.seterr(inf_or_nan="raise")
  a = tf.constant(10.0)
  b = tf.constant(0.0)
  try:
    c = a / b  # <-- Raises InfOrNanError.
  except Exception as e:
    print("Caught Exception: %s" % e)

  tfe.seterr(inf_or_nan="ignore")
  c = a / b  # <-- Does NOT raise exception anymore.
  ```

  Args:
    inf_or_nan: Set action for infinity (`inf`) and NaN (`nan`) values.
      Possible values: `{"ignore", "print", "raise", "warn"}`.
      `"ignore"`: take no action when `inf` values appear.
      `"print"`: print a warning to `stdout`.
      `"raise"`: raise an `InfOrNanError`.
      `"warn"`: print a warning using `tf.logging.warn`.
      A value of `None` leads to no change in the action of the condition.

  Returns:
    A dictionary of old actions.

  Raises:
    ValueError: If the value of any keyword arguments is invalid.
  """
  if inf_or_nan not in _VALID_CALLBACK_ACTIONS:
    raise ValueError(
        "Invalid action value for inf_or_nan: %s. "
        "Valid actions are %s." % (inf_or_nan, _VALID_CALLBACK_ACTIONS))

  old_settings = {"inf_or_nan": "ignore"}
  default_context = context.context()

  carryover_callbacks = []
  for callback in default_context.post_execution_callbacks:
    # Check whether the callback is inf_nan_callback or a partial object of
    # inf_nan_callback.
    if (callback == inf_nan_callback or
        isinstance(callback, functools.partial) and
        callback.func == inf_nan_callback):
      if callback == inf_nan_callback:
        old_settings["inf_or_nan"] = _DEFAULT_CALLBACK_ACTION
      else:
        old_settings["inf_or_nan"] = callback.keywords.get(
            "action", _DEFAULT_CALLBACK_ACTION)
    elif inf_or_nan is not None:
      carryover_callbacks.append(callback)

  if inf_or_nan is not None:
    default_context.clear_post_execution_callbacks()
    for callback in carryover_callbacks:
      default_context.add_post_execution_callback(callback)
    if inf_or_nan != "ignore":
      default_context.add_post_execution_callback(
          functools.partial(inf_nan_callback, action=inf_or_nan))

  return old_settings