diff options
Diffstat (limited to 'tensorflow/python/ops/numerics.py')
-rw-r--r-- | tensorflow/python/ops/numerics.py | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/tensorflow/python/ops/numerics.py b/tensorflow/python/ops/numerics.py new file mode 100644 index 0000000000..93f5d5db20 --- /dev/null +++ b/tensorflow/python/ops/numerics.py @@ -0,0 +1,50 @@ +"""Connects all float and double tensors to CheckNumericsOp.""" + +from tensorflow.python.framework import ops +from tensorflow.python.framework import types +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops + + +def verify_tensor_all_finite(t, msg, name=None): + """Assert that the tensor does not contain any NaN's or Inf's. + + Args: + t: Tensor to check. + msg: Message to log on failure. + name: A name for this operation (optional). + + Returns: + Same tensor as `t`. + """ + with ops.op_scope([t], name, "VerifyFinite") as name: + t = ops.convert_to_tensor(t, name="t") + with ops.device(t.device or t.graph.get_default_device()): + verify_input = array_ops.check_numerics(t, message=msg) + out = control_flow_ops.with_dependencies([verify_input], t) + return out + + +def add_check_numerics_ops(): + """Connect a check_numerics to every floating point tensor. + + `check_numerics` operations themselves are added for each `float` or `double` + tensor in the graph. For all ops in the graph, the `check_numerics` op for + all of its (`float` or `double`) inputs is guaranteed to run before the + `check_numerics` op on any of its outputs. + + Returns: + A `group` op depending on all `check_numerics` ops added. + """ + check_op = [] + # This code relies on the ordering of ops in get_operations(). + # The consumer of a tensor always comes before that tensor's producer in + # this list. This is true because get_operations() returns ops in the order + # added, and ops can only be added once its inputs are added. + for op in ops.get_default_graph().get_operations(): + for output in op.outputs: + if output.dtype in [types.float32, types.float64]: + message = op.name + ":" + str(output.value_index) + with ops.control_dependencies(check_op): + check_op = [array_ops.check_numerics(output, message=message)] + return control_flow_ops.group(*check_op) |