From 574a85178942418f5531c215b74729b38b4499d2 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Tue, 12 Jun 2018 21:20:52 -0700 Subject: Add a `run_metadata` keyword arg for `Session._make_callable_from_options()`. All callables returned from this private API now accept a "run_metadata" keyword argument whose behavior matches the `run_metadata` argument accepted by `Session.run()`. PiperOrigin-RevId: 200331667 --- tensorflow/python/client/session.py | 20 ++++++++++++++++---- tensorflow/python/client/session_test.py | 14 ++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 648e35cdf2..35aa37ac6d 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -1369,12 +1369,24 @@ class BaseSession(SessionInterface): finally: tf_session.TF_DeleteBuffer(options_ptr) - def __call__(self, *args): + def __call__(self, *args, **kwargs): # TODO(b/74355905): Support argument and return value nested structures, # and tensor-like objects such as SparseTensors. - with errors.raise_exception_on_not_ok_status() as status: - return tf_session.TF_SessionRunCallable( - self._session._session, self._handle, args, status, None) + run_metadata = kwargs.get('run_metadata', None) + try: + run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None + # TODO(mrry): Switch to raising an exception from the SWIG wrapper. + with errors.raise_exception_on_not_ok_status() as status: + ret = tf_session.TF_SessionRunCallable( + self._session._session, self._handle, args, status, + run_metadata_ptr) + if run_metadata: + proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) + run_metadata.ParseFromString(compat.as_bytes(proto_data)) + finally: + if run_metadata_ptr: + tf_session.TF_DeleteBuffer(run_metadata_ptr) + return ret def __del__(self): # NOTE(mrry): It is possible that `self._session.__del__()` could be diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 482497078c..e49d067105 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -1364,6 +1364,20 @@ class SessionTest(test_util.TensorFlowTestCase): for _ in range(5): self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32))) + def testOptimizedMakeCallableWithRunMetadata(self): + with session.Session() as sess: + ph = array_ops.placeholder(dtypes.float32) + a = math_ops.add(ph, 1.0) + callable_opts = config_pb2.CallableOptions() + callable_opts.feed.append(ph.name) + callable_opts.fetch.append(a.name) + callable_opts.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE + callable_fn = sess._make_callable_from_options(callable_opts) + run_metadata = config_pb2.RunMetadata() + self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32), + run_metadata=run_metadata)) + self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) + def testFeedError(self): with session.Session() as sess: feed_t = array_ops.placeholder(dtype=dtypes.float32) -- cgit v1.2.3