diff options
author | Derek Murray <mrry@google.com> | 2018-06-12 21:20:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-12 21:23:31 -0700 |
commit | 574a85178942418f5531c215b74729b38b4499d2 (patch) | |
tree | 689a6d87103216005a55c8179c89098728d0b7a1 /tensorflow/python/client | |
parent | b2db6e8cbaddbdcc3bbdb05f376319fe6d5038cf (diff) |
Add a `run_metadata` keyword arg for `Session._make_callable_from_options()`.
All callables returned from this private API now accept a
"run_metadata" keyword argument whose behavior matches the
`run_metadata` argument accepted by `Session.run()`.
PiperOrigin-RevId: 200331667
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r-- | tensorflow/python/client/session.py | 20 | ||||
-rw-r--r-- | tensorflow/python/client/session_test.py | 14 |
2 files changed, 30 insertions, 4 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 648e35cdf2..35aa37ac6d 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -1369,12 +1369,24 @@ class BaseSession(SessionInterface): finally: tf_session.TF_DeleteBuffer(options_ptr) - def __call__(self, *args): + def __call__(self, *args, **kwargs): # TODO(b/74355905): Support argument and return value nested structures, # and tensor-like objects such as SparseTensors. - with errors.raise_exception_on_not_ok_status() as status: - return tf_session.TF_SessionRunCallable( - self._session._session, self._handle, args, status, None) + run_metadata = kwargs.get('run_metadata', None) + try: + run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None + # TODO(mrry): Switch to raising an exception from the SWIG wrapper. + with errors.raise_exception_on_not_ok_status() as status: + ret = tf_session.TF_SessionRunCallable( + self._session._session, self._handle, args, status, + run_metadata_ptr) + if run_metadata: + proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) + run_metadata.ParseFromString(compat.as_bytes(proto_data)) + finally: + if run_metadata_ptr: + tf_session.TF_DeleteBuffer(run_metadata_ptr) + return ret def __del__(self): # NOTE(mrry): It is possible that `self._session.__del__()` could be diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 482497078c..e49d067105 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -1364,6 +1364,20 @@ class SessionTest(test_util.TensorFlowTestCase): for _ in range(5): self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32))) + def testOptimizedMakeCallableWithRunMetadata(self): + with session.Session() as sess: + ph = array_ops.placeholder(dtypes.float32) + a = math_ops.add(ph, 1.0) + callable_opts = config_pb2.CallableOptions() + callable_opts.feed.append(ph.name) + callable_opts.fetch.append(a.name) + callable_opts.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE + callable_fn = sess._make_callable_from_options(callable_opts) + run_metadata = config_pb2.RunMetadata() + self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32), + run_metadata=run_metadata)) + self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) + def testFeedError(self): with session.Session() as sess: feed_t = array_ops.placeholder(dtype=dtypes.float32) |