diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-29 21:20:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-29 21:46:49 -0700 |
commit | 9db93a8db8fafee8092b30ae33a26248867e9b26 (patch) | |
tree | 697ab050db8bb8e07c5002e4cbd3e21ff4a060f6 /tensorflow/python/client | |
parent | bf170839d2a8be1b16e0a6c6a74ac2f0dc427f96 (diff) |
Merged commit includes the following changes:
191029891 by xiejw:
Fix python script file.
--
191029336 by isaprykin:
Add .configure method to the whole hierarchy of DistributionStrategies.
--
191026971 by blamb:
Updates get_started nav and renames the beginner guides to include eager.
--
191025863 by mrry:
Add private Python API for accessing the C++ Session::*Callable API.
--
191025795 by mikecase:
Internal Change.
--
191024780 by isaprykin:
Internal change.
--
PiperOrigin-RevId: 191029891
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r-- | tensorflow/python/client/session.py | 57 | ||||
-rw-r--r-- | tensorflow/python/client/session_test.py | 12 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session.i | 30 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session_helper.cc | 150 | ||||
-rw-r--r-- | tensorflow/python/client/tf_session_helper.h | 25 |
5 files changed, 274 insertions, 0 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index da5dc6f599..5c9ed9ccaf 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -1431,6 +1431,63 @@ class BaseSession(SessionInterface): return tf_session.TF_PRun( self._session, handle, feed_dict, fetch_list, status) + # pylint: disable=protected-access + class _Callable(object): + """Experimental wrapper for the C++ `Session::MakeCallable()` API.""" + + def __init__(self, session, callable_options): + self._session = session + self._handle = None + options_ptr = tf_session.TF_NewBufferFromString( + compat.as_bytes(callable_options.SerializeToString())) + try: + with errors.raise_exception_on_not_ok_status() as status: + if session._created_with_new_api: + self._handle = tf_session.TF_SessionMakeCallable( + session._session, options_ptr, status) + else: + self._handle = tf_session.TF_DeprecatedSessionMakeCallable( + session._session, options_ptr, status) + finally: + tf_session.TF_DeleteBuffer(options_ptr) + + def __call__(self, *args): + # TODO(b/74355905): Support argument and return value nested structures, + # and tensor-like objects such as SparseTensors. + with errors.raise_exception_on_not_ok_status() as status: + if self._session._created_with_new_api: + return tf_session.TF_SessionRunCallable( + self._session._session, self._handle, args, status, None) + else: + return tf_session.TF_DeprecatedSessionRunCallable( + self._session._session, self._handle, args, status, None) + + def __del__(self): + if self._handle is not None: + with errors.raise_exception_on_not_ok_status() as status: + if self._session._created_with_new_api: + tf_session.TF_SessionReleaseCallable( + self._session._session, self._handle, status) + else: + tf_session.TF_DeprecatedSessionReleaseCallable( + self._session._session, self._handle, status) + # pylint: enable=protected-access + + # TODO(b/74355905): Reimplement `Session.make_callable()` using this method + # where possible. + def _make_callable_from_options(self, callable_options): + """Returns a handle to a "callable" with the given options. + + Args: + callable_options: A `CallableOptions` protocol buffer message describing + the computation that will be performed by the callable. + + Returns: + A handle to the new callable. + """ + self._extend_graph() + return BaseSession._Callable(self, callable_options) + @tf_export('Session') class Session(BaseSession): diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 6e2640efd1..92497272c6 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -1371,6 +1371,18 @@ class SessionTest(test_util.TensorFlowTestCase): run_metadata=run_metadata)) self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) + def testOptimizedMakeCallable(self): + with session.Session() as sess: + ph = array_ops.placeholder(dtypes.float32) + a = math_ops.add(ph, 1.0) + callable_opts = config_pb2.CallableOptions() + callable_opts.feed.append(ph.name) + callable_opts.fetch.append(a.name) + for _ in range(3): + callable_fn = sess._make_callable_from_options(callable_opts) + for _ in range(5): + self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32))) + def testFeedError(self): with session.Session() as sess: feed_t = array_ops.placeholder(dtype=dtypes.float32) diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 70a3d032f4..77ce9195ee 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -419,6 +419,25 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{ $result = new_result; } +%typemap(in, numinputs=0) int64_t* out_handle (int64_t out_handle) { + $1 = &out_handle; +} + +%typemap(argout) int64_t* out_handle { + $result = PyLong_FromLongLong(*$1); +} + +%typemap(in) int64_t handle { + if (!PyLong_Check($input)) { + SWIG_exception_fail( + SWIG_TypeError, + tensorflow::strings::Printf( + "Expected a python long for conversion to callable handle but got %s", + Py_TYPE($input)->tp_name).c_str()); + } + $1 = PyLong_AsLongLong($input); +} + // TODO(skyewm): SWIG emits a warning for the const char* in TF_WhileParams, // skip for now %ignore TF_WhileParams; @@ -452,6 +471,17 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{ // See comment for "%noexception TF_SessionRun_wrapper;" %noexception TF_SessionPRun_wrapper; +%unignore TF_DeprecatedSessionMakeCallable; +%unignore TF_SessionMakeCallable; +%unignore TF_DeprecatedSessionRunCallable; +%unignore TF_SessionRunCallable; +%unignore TF_DeprecatedSessionReleaseCallable; +%unignore TF_SessionReleaseCallable; + +// See comment for "%noexception TF_SessionRun_wrapper;" +%noexception TF_DeprecatedSessionRunCallable; +%noexception TF_SessionRunCallable; + %rename("_TF_SetTarget") TF_SetTarget; %rename("_TF_SetConfig") TF_SetConfig; %rename("_TF_NewSessionOptions") TF_NewSessionOptions; diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index a8ab91749a..ca57abd712 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -155,6 +155,156 @@ void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options, ClearDecrefCache(); } +namespace { +void MakeCallableHelper(tensorflow::Session* session, + const TF_Buffer* callable_options, int64_t* out_handle, + TF_Status* out_status) { + tensorflow::CallableOptions callable_options_proto; + if (callable_options != nullptr && + !callable_options_proto.ParseFromArray(callable_options->data, + callable_options->length)) { + Set_TF_Status_from_Status( + out_status, + errors::InvalidArgument("Unparseable CallableOptions proto")); + return; + } + tensorflow::Session::CallableHandle handle; + Status s = session->MakeCallable(callable_options_proto, &handle); + if (!s.ok()) { + Set_TF_Status_from_Status(out_status, s); + return; + } + *out_handle = handle; +} +} // namespace + +void TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession* session, + const TF_Buffer* callable_options, + int64_t* out_handle, + TF_Status* out_status) { + MakeCallableHelper(session->session, callable_options, out_handle, + out_status); +} +void TF_SessionMakeCallable(TF_Session* session, + const TF_Buffer* callable_options, + int64_t* out_handle, TF_Status* out_status) { + MakeCallableHelper(session->session, callable_options, out_handle, + out_status); +} + +namespace { +void RunCallableHelper(tensorflow::Session* session, int64_t handle, + PyObject* feed_values, TF_Status* out_status, + PyObjectVector* out_values, TF_Buffer* run_metadata) { + // Convert feed values to a vector of tensorflow::Tensor objects. + std::vector<Tensor> input_tensors; + Status s; + { + feed_values = + PySequence_Fast(feed_values, "feed_values must be a sequence"); + if (feed_values == nullptr) return; + Safe_PyObjectPtr feed_values_holder(make_safe(feed_values)); + Py_ssize_t len = PySequence_Fast_GET_SIZE(feed_values); + input_tensors.reserve(len); + for (Py_ssize_t i = 0; i < len; ++i) { + PyObject* elem = PySequence_Fast_GET_ITEM(feed_values, i); + if (!elem) { + Set_TF_Status_from_Status( + out_status, errors::Internal("Could not get feed value ", i)); + return; + } + Tensor t; + s = NdarrayToTensor(elem, &t); + if (!s.ok()) { + Set_TF_Status_from_Status(out_status, s); + return; + } + input_tensors.push_back(std::move(t)); + } + } + + // Allocate a RunMetadata protobuf object to receive the metadata, + // if the caller is expecting any. + std::unique_ptr<RunMetadata> run_metadata_proto; + if (run_metadata != nullptr) { + run_metadata_proto.reset(new RunMetadata); + } + + // Run the callable. + std::vector<Tensor> output_tensors; + Py_BEGIN_ALLOW_THREADS; + s = session->RunCallable(handle, input_tensors, &output_tensors, + run_metadata_proto.get()); + Py_END_ALLOW_THREADS; + + if (!s.ok()) { + Set_TF_Status_from_Status(out_status, s); + return; + } + + // If requested, serialize the RunMetadata to pass it back to the caller. + if (run_metadata != nullptr) { + s = MessageToBuffer(*run_metadata_proto, run_metadata); + if (!s.ok()) { + Set_TF_Status_from_Status(out_status, s); + return; + } + } + + // Convert results to NumPy arrays. Since this can fail, stage the + // results via a safe container that takes care of decreasing the + // reference count on failure. + std::vector<Safe_PyObjectPtr> py_outputs_safe; + py_outputs_safe.reserve(output_tensors.size()); + for (const Tensor& output : output_tensors) { + PyObject* py_array; + s = TensorToNdarray(output, &py_array); + if (!s.ok()) { + Set_TF_Status_from_Status(out_status, s); + return; + } + py_outputs_safe.push_back(make_safe(py_array)); + } + + // If we reach this point, we have successfully built a list of objects + // so we can release them from the safe container. + out_values->reserve(py_outputs_safe.size()); + for (auto& output : py_outputs_safe) { + out_values->push_back(output.release()); + } +} +} // namespace + +void TF_DeprecatedSessionRunCallable(TF_DeprecatedSession* session, + int64_t handle, PyObject* feed_values, + TF_Status* out_status, + PyObjectVector* out_values, + TF_Buffer* run_metadata) { + RunCallableHelper(session->session, handle, feed_values, out_status, + out_values, run_metadata); + ClearDecrefCache(); +} +void TF_SessionRunCallable(TF_Session* session, int64_t handle, + PyObject* feed_values, TF_Status* out_status, + PyObjectVector* out_values, + TF_Buffer* run_metadata) { + RunCallableHelper(session->session, handle, feed_values, out_status, + out_values, run_metadata); + ClearDecrefCache(); +} + +void TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession* session, + int64_t handle, + TF_Status* out_status) { + Set_TF_Status_from_Status(out_status, + session->session->ReleaseCallable(handle)); +} +void TF_SessionReleaseCallable(TF_Session* session, int64_t handle, + TF_Status* out_status) { + Set_TF_Status_from_Status(out_status, + session->session->ReleaseCallable(handle)); +} + // Wrapper for TF_PRunSetup that converts the arguments to appropriate types. // If *out_status is OK, the caller becomes the owner of *out_handle. void TF_PRunSetup_wrapper(TF_DeprecatedSession* session, diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index 83318dc178..603d03e315 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -59,6 +59,31 @@ void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options, const NameVector& target_nodes, TF_Status* out_status, PyObjectVector* out_values, TF_Buffer* run_outputs); +// Python wrappers for the `Session::MakeCallable()` API. +void TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession* session, + const TF_Buffer* callable_options, + int64_t* out_handle, + TF_Status* out_status); +void TF_SessionMakeCallable(TF_Session* session, + const TF_Buffer* callable_options, + int64_t* out_handle, TF_Status* out_status); + +// Python wrappers for the `Session::RunCallable()` API. +void TF_DeprecatedSessionRunCallable(TF_DeprecatedSession* session, + int64_t handle, PyObject* feed_values, + TF_Status* out_status, + PyObjectVector* out_values, + TF_Buffer* run_metadata); +void TF_SessionRunCallable(TF_Session* session, int64_t handle, + PyObject* feed_values, TF_Status* out_status, + PyObjectVector* out_values, TF_Buffer* run_metadata); + +// Python wrappers for the `Session::ReleaseCallable()` API. +void TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession* session, + int64_t handle, TF_Status* out_status); +void TF_SessionReleaseCallable(TF_Session* session, int64_t handle, + TF_Status* out_status); + // Set up the graph with the intended feeds and fetches for partial run. // *out_handle is owned by the caller. // |