aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-06-12 21:20:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-12 21:23:31 -0700
commit574a85178942418f5531c215b74729b38b4499d2 (patch)
tree689a6d87103216005a55c8179c89098728d0b7a1 /tensorflow/python/client
parentb2db6e8cbaddbdcc3bbdb05f376319fe6d5038cf (diff)
Add a `run_metadata` keyword arg for `Session._make_callable_from_options()`.
All callables returned from this private API now accept a "run_metadata" keyword argument whose behavior matches the `run_metadata` argument accepted by `Session.run()`. PiperOrigin-RevId: 200331667
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r--tensorflow/python/client/session.py20
-rw-r--r--tensorflow/python/client/session_test.py14
2 files changed, 30 insertions, 4 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 648e35cdf2..35aa37ac6d 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -1369,12 +1369,24 @@ class BaseSession(SessionInterface):
finally:
tf_session.TF_DeleteBuffer(options_ptr)
- def __call__(self, *args):
+ def __call__(self, *args, **kwargs):
# TODO(b/74355905): Support argument and return value nested structures,
# and tensor-like objects such as SparseTensors.
- with errors.raise_exception_on_not_ok_status() as status:
- return tf_session.TF_SessionRunCallable(
- self._session._session, self._handle, args, status, None)
+ run_metadata = kwargs.get('run_metadata', None)
+ try:
+ run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
+ # TODO(mrry): Switch to raising an exception from the SWIG wrapper.
+ with errors.raise_exception_on_not_ok_status() as status:
+ ret = tf_session.TF_SessionRunCallable(
+ self._session._session, self._handle, args, status,
+ run_metadata_ptr)
+ if run_metadata:
+ proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
+ run_metadata.ParseFromString(compat.as_bytes(proto_data))
+ finally:
+ if run_metadata_ptr:
+ tf_session.TF_DeleteBuffer(run_metadata_ptr)
+ return ret
def __del__(self):
# NOTE(mrry): It is possible that `self._session.__del__()` could be
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 482497078c..e49d067105 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -1364,6 +1364,20 @@ class SessionTest(test_util.TensorFlowTestCase):
for _ in range(5):
self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32)))
+ def testOptimizedMakeCallableWithRunMetadata(self):
+ with session.Session() as sess:
+ ph = array_ops.placeholder(dtypes.float32)
+ a = math_ops.add(ph, 1.0)
+ callable_opts = config_pb2.CallableOptions()
+ callable_opts.feed.append(ph.name)
+ callable_opts.fetch.append(a.name)
+ callable_opts.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE
+ callable_fn = sess._make_callable_from_options(callable_opts)
+ run_metadata = config_pb2.RunMetadata()
+ self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32),
+ run_metadata=run_metadata))
+ self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
+
def testFeedError(self):
with session.Session() as sess:
feed_t = array_ops.placeholder(dtype=dtypes.float32)