aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/numerics.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/numerics.py')
-rw-r--r--tensorflow/python/ops/numerics.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/tensorflow/python/ops/numerics.py b/tensorflow/python/ops/numerics.py
new file mode 100644
index 0000000000..93f5d5db20
--- /dev/null
+++ b/tensorflow/python/ops/numerics.py
@@ -0,0 +1,50 @@
+"""Connects all float and double tensors to CheckNumericsOp."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+
+
+def verify_tensor_all_finite(t, msg, name=None):
+ """Assert that the tensor does not contain any NaN's or Inf's.
+
+ Args:
+ t: Tensor to check.
+ msg: Message to log on failure.
+ name: A name for this operation (optional).
+
+ Returns:
+ Same tensor as `t`.
+ """
+ with ops.op_scope([t], name, "VerifyFinite") as name:
+ t = ops.convert_to_tensor(t, name="t")
+ with ops.device(t.device or t.graph.get_default_device()):
+ verify_input = array_ops.check_numerics(t, message=msg)
+ out = control_flow_ops.with_dependencies([verify_input], t)
+ return out
+
+
+def add_check_numerics_ops():
+ """Connect a check_numerics to every floating point tensor.
+
+ `check_numerics` operations themselves are added for each `float` or `double`
+ tensor in the graph. For all ops in the graph, the `check_numerics` op for
+ all of its (`float` or `double`) inputs is guaranteed to run before the
+ `check_numerics` op on any of its outputs.
+
+ Returns:
+ A `group` op depending on all `check_numerics` ops added.
+ """
+ check_op = []
+ # This code relies on the ordering of ops in get_operations().
+ # The consumer of a tensor always comes before that tensor's producer in
+ # this list. This is true because get_operations() returns ops in the order
+ # added, and ops can only be added once its inputs are added.
+ for op in ops.get_default_graph().get_operations():
+ for output in op.outputs:
+ if output.dtype in [types.float32, types.float64]:
+ message = op.name + ":" + str(output.value_index)
+ with ops.control_dependencies(check_op):
+ check_op = [array_ops.check_numerics(output, message=message)]
+ return control_flow_ops.group(*check_op)