aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-29 21:20:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-29 21:46:49 -0700
commit9db93a8db8fafee8092b30ae33a26248867e9b26 (patch)
tree697ab050db8bb8e07c5002e4cbd3e21ff4a060f6 /tensorflow/python/client
parentbf170839d2a8be1b16e0a6c6a74ac2f0dc427f96 (diff)
Merged commit includes the following changes:
191029891 by xiejw: Fix python script file. -- 191029336 by isaprykin: Add .configure method to the whole hierarchy of DistributionStrategies. -- 191026971 by blamb: Updates get_started nav and renames the beginner guides to include eager. -- 191025863 by mrry: Add private Python API for accessing the C++ Session::*Callable API. -- 191025795 by mikecase: Internal Change. -- 191024780 by isaprykin: Internal change. -- PiperOrigin-RevId: 191029891
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r--tensorflow/python/client/session.py57
-rw-r--r--tensorflow/python/client/session_test.py12
-rw-r--r--tensorflow/python/client/tf_session.i30
-rw-r--r--tensorflow/python/client/tf_session_helper.cc150
-rw-r--r--tensorflow/python/client/tf_session_helper.h25
5 files changed, 274 insertions, 0 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index da5dc6f599..5c9ed9ccaf 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -1431,6 +1431,63 @@ class BaseSession(SessionInterface):
return tf_session.TF_PRun(
self._session, handle, feed_dict, fetch_list, status)
+ # pylint: disable=protected-access
+ class _Callable(object):
+ """Experimental wrapper for the C++ `Session::MakeCallable()` API."""
+
+ def __init__(self, session, callable_options):
+ self._session = session
+ self._handle = None
+ options_ptr = tf_session.TF_NewBufferFromString(
+ compat.as_bytes(callable_options.SerializeToString()))
+ try:
+ with errors.raise_exception_on_not_ok_status() as status:
+ if session._created_with_new_api:
+ self._handle = tf_session.TF_SessionMakeCallable(
+ session._session, options_ptr, status)
+ else:
+ self._handle = tf_session.TF_DeprecatedSessionMakeCallable(
+ session._session, options_ptr, status)
+ finally:
+ tf_session.TF_DeleteBuffer(options_ptr)
+
+ def __call__(self, *args):
+ # TODO(b/74355905): Support argument and return value nested structures,
+ # and tensor-like objects such as SparseTensors.
+ with errors.raise_exception_on_not_ok_status() as status:
+ if self._session._created_with_new_api:
+ return tf_session.TF_SessionRunCallable(
+ self._session._session, self._handle, args, status, None)
+ else:
+ return tf_session.TF_DeprecatedSessionRunCallable(
+ self._session._session, self._handle, args, status, None)
+
+ def __del__(self):
+ if self._handle is not None:
+ with errors.raise_exception_on_not_ok_status() as status:
+ if self._session._created_with_new_api:
+ tf_session.TF_SessionReleaseCallable(
+ self._session._session, self._handle, status)
+ else:
+ tf_session.TF_DeprecatedSessionReleaseCallable(
+ self._session._session, self._handle, status)
+ # pylint: enable=protected-access
+
+ # TODO(b/74355905): Reimplement `Session.make_callable()` using this method
+ # where possible.
+ def _make_callable_from_options(self, callable_options):
+ """Returns a handle to a "callable" with the given options.
+
+ Args:
+ callable_options: A `CallableOptions` protocol buffer message describing
+ the computation that will be performed by the callable.
+
+ Returns:
+ A handle to the new callable.
+ """
+ self._extend_graph()
+ return BaseSession._Callable(self, callable_options)
+
@tf_export('Session')
class Session(BaseSession):
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 6e2640efd1..92497272c6 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -1371,6 +1371,18 @@ class SessionTest(test_util.TensorFlowTestCase):
run_metadata=run_metadata))
self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
+ def testOptimizedMakeCallable(self):
+ with session.Session() as sess:
+ ph = array_ops.placeholder(dtypes.float32)
+ a = math_ops.add(ph, 1.0)
+ callable_opts = config_pb2.CallableOptions()
+ callable_opts.feed.append(ph.name)
+ callable_opts.fetch.append(a.name)
+ for _ in range(3):
+ callable_fn = sess._make_callable_from_options(callable_opts)
+ for _ in range(5):
+ self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32)))
+
def testFeedError(self):
with session.Session() as sess:
feed_t = array_ops.placeholder(dtype=dtypes.float32)
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 70a3d032f4..77ce9195ee 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -419,6 +419,25 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
$result = new_result;
}
+%typemap(in, numinputs=0) int64_t* out_handle (int64_t out_handle) {
+ $1 = &out_handle;
+}
+
+%typemap(argout) int64_t* out_handle {
+ $result = PyLong_FromLongLong(*$1);
+}
+
+%typemap(in) int64_t handle {
+ if (!PyLong_Check($input)) {
+ SWIG_exception_fail(
+ SWIG_TypeError,
+ tensorflow::strings::Printf(
+ "Expected a python long for conversion to callable handle but got %s",
+ Py_TYPE($input)->tp_name).c_str());
+ }
+ $1 = PyLong_AsLongLong($input);
+}
+
// TODO(skyewm): SWIG emits a warning for the const char* in TF_WhileParams,
// skip for now
%ignore TF_WhileParams;
@@ -452,6 +471,17 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
// See comment for "%noexception TF_SessionRun_wrapper;"
%noexception TF_SessionPRun_wrapper;
+%unignore TF_DeprecatedSessionMakeCallable;
+%unignore TF_SessionMakeCallable;
+%unignore TF_DeprecatedSessionRunCallable;
+%unignore TF_SessionRunCallable;
+%unignore TF_DeprecatedSessionReleaseCallable;
+%unignore TF_SessionReleaseCallable;
+
+// See comment for "%noexception TF_SessionRun_wrapper;"
+%noexception TF_DeprecatedSessionRunCallable;
+%noexception TF_SessionRunCallable;
+
%rename("_TF_SetTarget") TF_SetTarget;
%rename("_TF_SetConfig") TF_SetConfig;
%rename("_TF_NewSessionOptions") TF_NewSessionOptions;
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index a8ab91749a..ca57abd712 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -155,6 +155,156 @@ void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options,
ClearDecrefCache();
}
+namespace {
+void MakeCallableHelper(tensorflow::Session* session,
+ const TF_Buffer* callable_options, int64_t* out_handle,
+ TF_Status* out_status) {
+ tensorflow::CallableOptions callable_options_proto;
+ if (callable_options != nullptr &&
+ !callable_options_proto.ParseFromArray(callable_options->data,
+ callable_options->length)) {
+ Set_TF_Status_from_Status(
+ out_status,
+ errors::InvalidArgument("Unparseable CallableOptions proto"));
+ return;
+ }
+ tensorflow::Session::CallableHandle handle;
+ Status s = session->MakeCallable(callable_options_proto, &handle);
+ if (!s.ok()) {
+ Set_TF_Status_from_Status(out_status, s);
+ return;
+ }
+ *out_handle = handle;
+}
+} // namespace
+
+void TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession* session,
+ const TF_Buffer* callable_options,
+ int64_t* out_handle,
+ TF_Status* out_status) {
+ MakeCallableHelper(session->session, callable_options, out_handle,
+ out_status);
+}
+void TF_SessionMakeCallable(TF_Session* session,
+ const TF_Buffer* callable_options,
+ int64_t* out_handle, TF_Status* out_status) {
+ MakeCallableHelper(session->session, callable_options, out_handle,
+ out_status);
+}
+
+namespace {
+void RunCallableHelper(tensorflow::Session* session, int64_t handle,
+ PyObject* feed_values, TF_Status* out_status,
+ PyObjectVector* out_values, TF_Buffer* run_metadata) {
+ // Convert feed values to a vector of tensorflow::Tensor objects.
+ std::vector<Tensor> input_tensors;
+ Status s;
+ {
+ feed_values =
+ PySequence_Fast(feed_values, "feed_values must be a sequence");
+ if (feed_values == nullptr) return;
+ Safe_PyObjectPtr feed_values_holder(make_safe(feed_values));
+ Py_ssize_t len = PySequence_Fast_GET_SIZE(feed_values);
+ input_tensors.reserve(len);
+ for (Py_ssize_t i = 0; i < len; ++i) {
+ PyObject* elem = PySequence_Fast_GET_ITEM(feed_values, i);
+ if (!elem) {
+ Set_TF_Status_from_Status(
+ out_status, errors::Internal("Could not get feed value ", i));
+ return;
+ }
+ Tensor t;
+ s = NdarrayToTensor(elem, &t);
+ if (!s.ok()) {
+ Set_TF_Status_from_Status(out_status, s);
+ return;
+ }
+ input_tensors.push_back(std::move(t));
+ }
+ }
+
+ // Allocate a RunMetadata protobuf object to receive the metadata,
+ // if the caller is expecting any.
+ std::unique_ptr<RunMetadata> run_metadata_proto;
+ if (run_metadata != nullptr) {
+ run_metadata_proto.reset(new RunMetadata);
+ }
+
+ // Run the callable.
+ std::vector<Tensor> output_tensors;
+ Py_BEGIN_ALLOW_THREADS;
+ s = session->RunCallable(handle, input_tensors, &output_tensors,
+ run_metadata_proto.get());
+ Py_END_ALLOW_THREADS;
+
+ if (!s.ok()) {
+ Set_TF_Status_from_Status(out_status, s);
+ return;
+ }
+
+ // If requested, serialize the RunMetadata to pass it back to the caller.
+ if (run_metadata != nullptr) {
+ s = MessageToBuffer(*run_metadata_proto, run_metadata);
+ if (!s.ok()) {
+ Set_TF_Status_from_Status(out_status, s);
+ return;
+ }
+ }
+
+ // Convert results to NumPy arrays. Since this can fail, stage the
+ // results via a safe container that takes care of decreasing the
+ // reference count on failure.
+ std::vector<Safe_PyObjectPtr> py_outputs_safe;
+ py_outputs_safe.reserve(output_tensors.size());
+ for (const Tensor& output : output_tensors) {
+ PyObject* py_array;
+ s = TensorToNdarray(output, &py_array);
+ if (!s.ok()) {
+ Set_TF_Status_from_Status(out_status, s);
+ return;
+ }
+ py_outputs_safe.push_back(make_safe(py_array));
+ }
+
+ // If we reach this point, we have successfully built a list of objects
+ // so we can release them from the safe container.
+ out_values->reserve(py_outputs_safe.size());
+ for (auto& output : py_outputs_safe) {
+ out_values->push_back(output.release());
+ }
+}
+} // namespace
+
+void TF_DeprecatedSessionRunCallable(TF_DeprecatedSession* session,
+ int64_t handle, PyObject* feed_values,
+ TF_Status* out_status,
+ PyObjectVector* out_values,
+ TF_Buffer* run_metadata) {
+ RunCallableHelper(session->session, handle, feed_values, out_status,
+ out_values, run_metadata);
+ ClearDecrefCache();
+}
+void TF_SessionRunCallable(TF_Session* session, int64_t handle,
+ PyObject* feed_values, TF_Status* out_status,
+ PyObjectVector* out_values,
+ TF_Buffer* run_metadata) {
+ RunCallableHelper(session->session, handle, feed_values, out_status,
+ out_values, run_metadata);
+ ClearDecrefCache();
+}
+
+void TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession* session,
+ int64_t handle,
+ TF_Status* out_status) {
+ Set_TF_Status_from_Status(out_status,
+ session->session->ReleaseCallable(handle));
+}
+void TF_SessionReleaseCallable(TF_Session* session, int64_t handle,
+ TF_Status* out_status) {
+ Set_TF_Status_from_Status(out_status,
+ session->session->ReleaseCallable(handle));
+}
+
// Wrapper for TF_PRunSetup that converts the arguments to appropriate types.
// If *out_status is OK, the caller becomes the owner of *out_handle.
void TF_PRunSetup_wrapper(TF_DeprecatedSession* session,
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
index 83318dc178..603d03e315 100644
--- a/tensorflow/python/client/tf_session_helper.h
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -59,6 +59,31 @@ void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options,
const NameVector& target_nodes, TF_Status* out_status,
PyObjectVector* out_values, TF_Buffer* run_outputs);
+// Python wrappers for the `Session::MakeCallable()` API.
+void TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession* session,
+ const TF_Buffer* callable_options,
+ int64_t* out_handle,
+ TF_Status* out_status);
+void TF_SessionMakeCallable(TF_Session* session,
+ const TF_Buffer* callable_options,
+ int64_t* out_handle, TF_Status* out_status);
+
+// Python wrappers for the `Session::RunCallable()` API.
+void TF_DeprecatedSessionRunCallable(TF_DeprecatedSession* session,
+ int64_t handle, PyObject* feed_values,
+ TF_Status* out_status,
+ PyObjectVector* out_values,
+ TF_Buffer* run_metadata);
+void TF_SessionRunCallable(TF_Session* session, int64_t handle,
+ PyObject* feed_values, TF_Status* out_status,
+ PyObjectVector* out_values, TF_Buffer* run_metadata);
+
+// Python wrappers for the `Session::ReleaseCallable()` API.
+void TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession* session,
+ int64_t handle, TF_Status* out_status);
+void TF_SessionReleaseCallable(TF_Session* session, int64_t handle,
+ TF_Status* out_status);
+
// Set up the graph with the intended feeds and fetches for partial run.
// *out_handle is owned by the caller.
//