aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-01-09 07:37:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-09 07:41:22 -0800
commitc5e2b0e5a3039fe98b0f22154c567c2eb425fb22 (patch)
treed9184516ef5d8fe14ac4abbff97c1812bf235666 /tensorflow/python
parent14fa431da0b7fd69ccf7bf4a60172a5745c1773c (diff)
Exposes runmetadata from tfe in python.
PiperOrigin-RevId: 181317960
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/eager/context.py59
-rw-r--r--tensorflow/python/eager/core_test.py14
-rw-r--r--tensorflow/python/pywrap_tfe.i3
3 files changed, 76 insertions, 0 deletions
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 3173afc424..e1ab1e7bc6 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -24,9 +24,12 @@ import copy
import random
import threading
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import errors
+from tensorflow.python.util import compat
from tensorflow.python.util import tf_contextlib
GRAPH_MODE = 0
@@ -398,6 +401,36 @@ class Context(object):
"""Get the list of post-execution callbacks added to the context."""
return self._post_execution_callbacks
+ def enable_run_metadata(self):
+ """Enables tracing of op execution via RunMetadata.
+
+ To retrieve the accumulated metadata call context.export_run_metadata()
+ and to stop tracing call context.disable_run_metadata().
+ """
+ pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._context_handle)
+
+ def disable_run_metadata(self):
+ """Disables tracing of op execution via RunMetadata."""
+ pywrap_tensorflow.TFE_ContextDisableRunMetadata(self._context_handle)
+
+ def export_run_metadata(self):
+ """Returns a RunMetadata proto with accumulated information.
+
+ The returned protocol buffer contains information since the most recent call
+ to either enable_run_metadata or export_run_metadata.
+
+ Returns:
+ A RunMetadata protocol buffer.
+ """
+ with c_api_util.tf_buffer() as buffer_:
+ with errors.raise_exception_on_not_ok_status() as status:
+ pywrap_tensorflow.TFE_ContextExportRunMetadata(
+ self._context_handle, buffer_, status)
+ proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
+ run_metadata = config_pb2.RunMetadata()
+ run_metadata.ParseFromString(compat.as_bytes(proto_data))
+ return run_metadata
+
_context = None
_context_lock = threading.Lock()
@@ -516,3 +549,29 @@ def num_gpus():
The number of available GPU devices.
"""
return context().num_gpus()
+
+
+def enable_run_metadata():
+ """Enables tracing of op execution via RunMetadata.
+
+ To retrieve the accumulated metadata call context.export_run_metadata()
+ and to stop tracing call context.disable_run_metadata().
+ """
+ context().enable_run_metadata()
+
+
+def disable_run_metadata():
+ """Disables tracing of op execution via RunMetadata."""
+ context().disable_run_metadata()
+
+
+def export_run_metadata():
+ """Returns a RunMetadata proto with accumulated information.
+
+ The returned protocol buffer contains information since the most recent call
+ to either enable_run_metadata or export_run_metadata.
+
+ Returns:
+ A RunMetadata protocol buffer.
+ """
+ return context().export_run_metadata()
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index 02694b34fe..a70fa72804 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -84,6 +84,20 @@ class TFETest(test_util.TensorFlowTestCase):
self.assertTrue(has_cpu_device)
del ctx
+ def testRunMetadata(self):
+ context.enable_run_metadata()
+ t = constant_op.constant(1.0)
+ _ = t + t # Runs an operation which will be in the RunMetadata
+ run_metadata = context.export_run_metadata()
+ context.disable_run_metadata()
+ step_stats = run_metadata.step_stats
+ self.assertGreater(len(step_stats.dev_stats), 0)
+ cpu_stats = step_stats.dev_stats[0]
+ self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0',
+ cpu_stats.device)
+ self.assertEqual(len(cpu_stats.node_stats), 1)
+ self.assertEqual(cpu_stats.node_stats[0].node_name, 'Add')
+
def testContextStackContainsEagerMode(self):
# Eager execution has been enabled, and no other context
# switch has occurred, so `context_stack` should contain
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 42e4773df3..083931aa83 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -20,6 +20,9 @@ limitations under the License.
%rename("%s") TFE_ContextListDevices;
%rename("%s") TFE_ContextAddFunction;
%rename("%s") TFE_ContextAddFunctionDef;
+%rename("%s") TFE_ContextEnableRunMetadata;
+%rename("%s") TFE_ContextDisableRunMetadata;
+%rename("%s") TFE_ContextExportRunMetadata;
%rename("%s") TFE_ContextClearCaches;
%rename("%s") TFE_OpNameGetAttrType;
%rename("%s") TFE_Py_InitEagerTensor;