aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/logging_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/logging_ops.py')
-rw-r--r--tensorflow/python/ops/logging_ops.py58
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py
new file mode 100644
index 0000000000..0fad4a2dde
--- /dev/null
+++ b/tensorflow/python/ops/logging_ops.py
@@ -0,0 +1,58 @@
+"""Logging Operations."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import gen_logging_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.gen_logging_ops import *
+# pylint: enable=wildcard-import
+
+
+# Assert and Print are special symbols in python, so we must
+# use an upper-case version of them.
+def Assert(condition, data, summarize=None, name=None):
+ """Asserts that the given condition is true.
+
+ If `condition` evaluates to false, print the list of tensors in `data`.
+ `summarize` determines how many entries of the tensors to print.
+
+ Args:
+ condition: The condition to evaluate.
+ data: The tensors to print out when condition is false.
+ summarize: Print this many entries of each tensor.
+ name: A name for this operation (optional).
+ """
+ return gen_logging_ops._assert(condition, data, summarize, name)
+
+
+def Print(input_, data, message=None, first_n=None, summarize=None,
+ name=None):
+ """Prints a list of tensors.
+
+ This is an identity op with the side effect of printing `data` when
+ evaluating.
+
+ Args:
+ input_: A tensor passed through this op.
+ data: A list of tensors to print out when op is evaluated.
+ message: A string, prefix of the error message.
+ first_n: Only log `first_n` number of times. Negative numbers log always;
+ this is the default.
+ summarize: Only print this many entries of each tensor.
+ name: A name for the operation (optional).
+
+ Returns:
+ Same tensor as `input_`.
+ """
+ return gen_logging_ops._print(input_, data, message, first_n, summarize, name)
+
+
+@ops.RegisterGradient("Print")
+def _PrintGrad(op, *grad):
+ return list(grad) + [None] * (len(op.inputs) - 1)
+
+
+# NOTE(mrry): Assert and Print produce an empty output, which is
+# presumably never read.
+ops.RegisterShape("Assert")(common_shapes.unknown_shape)
+ops.RegisterShape("Print")(common_shapes.unknown_shape)