From c5e2b0e5a3039fe98b0f22154c567c2eb425fb22 Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Tue, 9 Jan 2018 07:37:55 -0800 Subject: Exposes runmetadata from tfe in python. PiperOrigin-RevId: 181317960 --- tensorflow/python/eager/context.py | 59 ++++++++++++++++++++++++++++++++++++ tensorflow/python/eager/core_test.py | 14 +++++++++ tensorflow/python/pywrap_tfe.i | 3 ++ 3 files changed, 76 insertions(+) (limited to 'tensorflow/python') diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 3173afc424..e1ab1e7bc6 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -24,9 +24,12 @@ import copy import random import threading +from tensorflow.core.protobuf import config_pb2 from tensorflow.python import pywrap_tensorflow +from tensorflow.python.framework import c_api_util from tensorflow.python.framework import device as pydev from tensorflow.python.framework import errors +from tensorflow.python.util import compat from tensorflow.python.util import tf_contextlib GRAPH_MODE = 0 @@ -398,6 +401,36 @@ class Context(object): """Get the list of post-execution callbacks added to the context.""" return self._post_execution_callbacks + def enable_run_metadata(self): + """Enables tracing of op execution via RunMetadata. + + To retrieve the accumulated metadata call context.export_run_metadata() + and to stop tracing call context.disable_run_metadata(). + """ + pywrap_tensorflow.TFE_ContextEnableRunMetadata(self._context_handle) + + def disable_run_metadata(self): + """Disables tracing of op execution via RunMetadata.""" + pywrap_tensorflow.TFE_ContextDisableRunMetadata(self._context_handle) + + def export_run_metadata(self): + """Returns a RunMetadata proto with accumulated information. + + The returned protocol buffer contains information since the most recent call + to either enable_run_metadata or export_run_metadata. + + Returns: + A RunMetadata protocol buffer. + """ + with c_api_util.tf_buffer() as buffer_: + with errors.raise_exception_on_not_ok_status() as status: + pywrap_tensorflow.TFE_ContextExportRunMetadata( + self._context_handle, buffer_, status) + proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) + run_metadata = config_pb2.RunMetadata() + run_metadata.ParseFromString(compat.as_bytes(proto_data)) + return run_metadata + _context = None _context_lock = threading.Lock() @@ -516,3 +549,29 @@ def num_gpus(): The number of available GPU devices. """ return context().num_gpus() + + +def enable_run_metadata(): + """Enables tracing of op execution via RunMetadata. + + To retrieve the accumulated metadata call context.export_run_metadata() + and to stop tracing call context.disable_run_metadata(). + """ + context().enable_run_metadata() + + +def disable_run_metadata(): + """Disables tracing of op execution via RunMetadata.""" + context().disable_run_metadata() + + +def export_run_metadata(): + """Returns a RunMetadata proto with accumulated information. + + The returned protocol buffer contains information since the most recent call + to either enable_run_metadata or export_run_metadata. + + Returns: + A RunMetadata protocol buffer. + """ + return context().export_run_metadata() diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index 02694b34fe..a70fa72804 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -84,6 +84,20 @@ class TFETest(test_util.TensorFlowTestCase): self.assertTrue(has_cpu_device) del ctx + def testRunMetadata(self): + context.enable_run_metadata() + t = constant_op.constant(1.0) + _ = t + t # Runs an operation which will be in the RunMetadata + run_metadata = context.export_run_metadata() + context.disable_run_metadata() + step_stats = run_metadata.step_stats + self.assertGreater(len(step_stats.dev_stats), 0) + cpu_stats = step_stats.dev_stats[0] + self.assertEqual('/job:localhost/replica:0/task:0/device:CPU:0', + cpu_stats.device) + self.assertEqual(len(cpu_stats.node_stats), 1) + self.assertEqual(cpu_stats.node_stats[0].node_name, 'Add') + def testContextStackContainsEagerMode(self): # Eager execution has been enabled, and no other context # switch has occurred, so `context_stack` should contain diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 42e4773df3..083931aa83 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -20,6 +20,9 @@ limitations under the License. %rename("%s") TFE_ContextListDevices; %rename("%s") TFE_ContextAddFunction; %rename("%s") TFE_ContextAddFunctionDef; +%rename("%s") TFE_ContextEnableRunMetadata; +%rename("%s") TFE_ContextDisableRunMetadata; +%rename("%s") TFE_ContextExportRunMetadata; %rename("%s") TFE_ContextClearCaches; %rename("%s") TFE_OpNameGetAttrType; %rename("%s") TFE_Py_InitEagerTensor; -- cgit v1.2.3