aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-03-30 14:56:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-30 14:58:45 -0700
commit97731cb122f53552bd15351e046a256f78cca444 (patch)
treedbec22e40bc41f690953b7aad5498516cf457908 /tensorflow/python/client
parent15c10899c9c0e1717251b380330cc248b2c76c9c (diff)
Raise exception in SWIG on bad TF_Status from C API.
This change provides an alternative mechanism to tf.raise_exception_on_not_ok_status(), which is inefficient and error-prone (people often use the status multiple times in the with block, but it's only checked when the context manager exits). Instead, it uses SWIG to automatically raise an exception when a C API method fails. Note that this removes the status argument from affected methods. For now, I've only applied this typemap to C API methods. It would be good to expand this to all uses of raise_exception_on_not_ok_status. PiperOrigin-RevId: 191121016
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r--tensorflow/python/client/session.py91
-rw-r--r--tensorflow/python/client/session_list_devices_test.py19
-rw-r--r--tensorflow/python/client/tf_session.i33
-rw-r--r--tensorflow/python/client/tf_session_helper.h12
4 files changed, 80 insertions, 75 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 5c9ed9ccaf..4c84d78f2e 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -27,7 +27,6 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow as tf_session
-from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import device
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -629,14 +628,12 @@ class BaseSession(SessionInterface):
self._session = None
opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
try:
- with errors.raise_exception_on_not_ok_status() as status:
- if self._created_with_new_api:
- # pylint: disable=protected-access
- self._session = tf_session.TF_NewSession(self._graph._c_graph, opts,
- status)
- # pylint: enable=protected-access
- else:
- self._session = tf_session.TF_NewDeprecatedSession(opts, status)
+ if self._created_with_new_api:
+ # pylint: disable=protected-access
+ self._session = tf_session.TF_NewSession(self._graph._c_graph, opts)
+ # pylint: enable=protected-access
+ else:
+ self._session = tf_session.TF_NewDeprecatedSession(opts)
finally:
tf_session.TF_DeleteSessionOptions(opts)
@@ -663,22 +660,20 @@ class BaseSession(SessionInterface):
Returns:
A list of devices in the session.
"""
- with errors.raise_exception_on_not_ok_status() as status:
- if self._created_with_new_api:
- raw_device_list = tf_session.TF_SessionListDevices(
- self._session, status)
- else:
- raw_device_list = tf_session.TF_DeprecatedSessionListDevices(
- self._session, status)
- device_list = []
- size = tf_session.TF_DeviceListCount(raw_device_list)
- for i in range(size):
- name = tf_session.TF_DeviceListName(raw_device_list, i, status)
- device_type = tf_session.TF_DeviceListType(raw_device_list, i, status)
- memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i, status)
- device_list.append(_DeviceAttributes(name, device_type, memory))
- tf_session.TF_DeleteDeviceList(raw_device_list)
- return device_list
+ if self._created_with_new_api:
+ raw_device_list = tf_session.TF_SessionListDevices(self._session)
+ else:
+ raw_device_list = tf_session.TF_DeprecatedSessionListDevices(
+ self._session)
+ device_list = []
+ size = tf_session.TF_DeviceListCount(raw_device_list)
+ for i in range(size):
+ name = tf_session.TF_DeviceListName(raw_device_list, i)
+ device_type = tf_session.TF_DeviceListType(raw_device_list, i)
+ memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i)
+ device_list.append(_DeviceAttributes(name, device_type, memory))
+ tf_session.TF_DeleteDeviceList(raw_device_list)
+ return device_list
def close(self):
"""Closes this session.
@@ -692,15 +687,13 @@ class BaseSession(SessionInterface):
if self._created_with_new_api:
if self._session and not self._closed:
self._closed = True
- with errors.raise_exception_on_not_ok_status() as status:
- tf_session.TF_CloseSession(self._session, status)
+ tf_session.TF_CloseSession(self._session)
else:
with self._extend_lock:
if self._opened and not self._closed:
self._closed = True
- with errors.raise_exception_on_not_ok_status() as status:
- tf_session.TF_CloseDeprecatedSession(self._session, status)
+ tf_session.TF_CloseDeprecatedSession(self._session)
def __del__(self):
# cleanly ignore all exceptions
@@ -710,11 +703,10 @@ class BaseSession(SessionInterface):
pass
if self._session is not None:
try:
- status = c_api_util.ScopedTFStatus()
if self._created_with_new_api:
- tf_session.TF_DeleteSession(self._session, status)
+ tf_session.TF_DeleteSession(self._session)
else:
- tf_session.TF_DeleteDeprecatedSession(self._session, status)
+ tf_session.TF_DeleteDeprecatedSession(self._session)
except AttributeError:
# At shutdown, `c_api_util` or `tf_session` may have been garbage
# collected, causing the above method calls to fail. In this case,
@@ -1031,11 +1023,11 @@ class BaseSession(SessionInterface):
# Set up a graph with feeds and fetches for partial run.
def _setup_fn(session, feed_list, fetch_list, target_list):
self._extend_graph()
- with errors.raise_exception_on_not_ok_status() as status:
- if self._created_with_new_api:
- return tf_session.TF_SessionPRunSetup_wrapper(
- session, feed_list, fetch_list, target_list, status)
- else:
+ if self._created_with_new_api:
+ return tf_session.TF_SessionPRunSetup_wrapper(
+ session, feed_list, fetch_list, target_list)
+ else:
+ with errors.raise_exception_on_not_ok_status() as status:
return tf_session.TF_PRunSetup(session, feed_list, fetch_list,
target_list, status)
@@ -1345,8 +1337,7 @@ class BaseSession(SessionInterface):
def _extend_graph(self):
if self._created_with_new_api:
with self._graph._lock: # pylint: disable=protected-access
- with errors.raise_exception_on_not_ok_status() as status:
- tf_session.ExtendSession(self._session, status)
+ tf_session.ExtendSession(self._session)
else:
# Ensure any changes to the graph are reflected in the runtime.
with self._extend_lock:
@@ -1412,22 +1403,22 @@ class BaseSession(SessionInterface):
def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list,
run_metadata):
- with errors.raise_exception_on_not_ok_status() as status:
- if self._created_with_new_api:
- return tf_session.TF_SessionRun_wrapper(
- self._session, options, feed_dict, fetch_list, target_list,
- run_metadata, status)
- else:
+ if self._created_with_new_api:
+ return tf_session.TF_SessionRun_wrapper(
+ self._session, options, feed_dict, fetch_list, target_list,
+ run_metadata)
+ else:
+ with errors.raise_exception_on_not_ok_status() as status:
return tf_session.TF_Run(
self._session, options, feed_dict, fetch_list, target_list,
status, run_metadata)
def _call_tf_sessionprun(self, handle, feed_dict, fetch_list):
- with errors.raise_exception_on_not_ok_status() as status:
- if self._created_with_new_api:
- return tf_session.TF_SessionPRun_wrapper(
- self._session, handle, feed_dict, fetch_list, status)
- else:
+ if self._created_with_new_api:
+ return tf_session.TF_SessionPRun_wrapper(
+ self._session, handle, feed_dict, fetch_list)
+ else:
+ with errors.raise_exception_on_not_ok_status() as status:
return tf_session.TF_PRun(
self._session, handle, feed_dict, fetch_list, status)
diff --git a/tensorflow/python/client/session_list_devices_test.py b/tensorflow/python/client/session_list_devices_test.py
index 5a7413c12e..38a3acb2dc 100644
--- a/tensorflow/python/client/session_list_devices_test.py
+++ b/tensorflow/python/client/session_list_devices_test.py
@@ -23,7 +23,6 @@ from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow as tf_session
from tensorflow.python.client import session
-from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@@ -42,21 +41,13 @@ class SessionListDevicesTestMethods(object):
def testInvalidDeviceNumber(self):
opts = tf_session.TF_NewSessionOptions()
- with errors.raise_exception_on_not_ok_status() as status:
- c_session = tf_session.TF_NewSession(
- ops.get_default_graph()._c_graph, opts, status)
- raw_device_list = tf_session.TF_SessionListDevices(
- c_session, status)
+ c_session = tf_session.TF_NewSession(ops.get_default_graph()._c_graph, opts)
+ raw_device_list = tf_session.TF_SessionListDevices(c_session)
size = tf_session.TF_DeviceListCount(raw_device_list)
- # Test that invalid device numbers return -1 rather than a Swig-wrapped
- # pointer.
- status_no_exception = c_api_util.ScopedTFStatus()
- memory = tf_session.TF_DeviceListMemoryBytes(
- raw_device_list, size, status_no_exception)
- self.assertEqual(memory, -1)
+ with self.assertRaises(errors.InvalidArgumentError):
+ tf_session.TF_DeviceListMemoryBytes(raw_device_list, size)
tf_session.TF_DeleteDeviceList(raw_device_list)
- with errors.raise_exception_on_not_ok_status() as status:
- tf_session.TF_CloseSession(c_session, status)
+ tf_session.TF_CloseSession(c_session)
def testListDevicesGrpcSession(self):
server = server_lib.Server.create_local_server()
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 77ce9195ee..5dcd0c192e 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -18,11 +18,12 @@ limitations under the License.
%{
#include "tensorflow/c/python_api.h"
-#include "tensorflow/python/client/tf_session_helper.h"
#include "tensorflow/core/framework/session_state.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"
+#include "tensorflow/python/client/tf_session_helper.h"
+#include "tensorflow/python/lib/core/py_exception_registry.h"
// Helper function to convert a Python list of Tensors to a C++ vector of
// TF_Outputs.
@@ -352,6 +353,27 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
reinterpret_cast<const char*>($1.data), $1.length);
}
+// Typemaps to automatically raise a Python exception from bad output TF_Status.
+// TODO(b/77295559): expand this to all TF_Status* output params and deprecate
+// raise_exception_on_not_ok_status (currently it only affects the C API).
+%typemap(in, numinputs=0) TF_Status* status (TF_Status* status) {
+ status = TF_NewStatus();
+ $1 = status;
+}
+
+%typemap(argout) TF_Status* status {
+ TF_Code code = TF_GetCode($1);
+ if (code != TF_OK) {
+ PyObject* exc = tensorflow::PyExceptionRegistry::Lookup(code);
+ // Arguments to OpError.
+ PyObject* exc_args = Py_BuildValue("sss", nullptr, nullptr, TF_Message($1));
+ TF_DeleteStatus($1);
+ SWIG_SetErrorObj(exc, exc_args);
+ SWIG_fail;
+ }
+ TF_DeleteStatus($1);
+}
+
// Converts input Python list of wrapped TF_Outputs into a single array
%typemap(in) (const TF_Output* inputs, int num_inputs)
(std::vector<TF_Output> inputs) {
@@ -499,9 +521,8 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
_TF_SetTarget(opts, target)
if config is not None:
from tensorflow.python.framework import errors
- with errors.raise_exception_on_not_ok_status() as status:
- config_str = config.SerializeToString()
- _TF_SetConfig(opts, config_str, status)
+ config_str = config.SerializeToString()
+ _TF_SetConfig(opts, config_str)
return opts
%}
@@ -758,3 +779,7 @@ def TF_Reset(target, containers=None, config=None):
%include "tensorflow/python/client/tf_session_helper.h"
%unignoreall
+
+// Clear "TF_Status* status" typemap so it doesn't affect other modules and
+// unexpectedly remove the TF_Status* argument from wrappers.
+%clear TF_Status* status;
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
index 603d03e315..5416d41376 100644
--- a/tensorflow/python/client/tf_session_helper.h
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -136,8 +136,7 @@ string EqualAttrValueWrapper(const string& actual, const string& expected);
//
// If shape is unknown, sets unknown_shape to true.
tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper(
- TF_Graph* graph, TF_Output output, TF_Status* out_status,
- bool* unknown_shape);
+ TF_Graph* graph, TF_Output output, TF_Status* status, bool* unknown_shape);
// Runs the graph associated with the session starting with the supplied inputs.
// On success, `py_outputs` is populated with a numpy ndarray for each output
@@ -149,7 +148,7 @@ void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options,
const std::vector<PyObject*>& input_ndarrays,
const std::vector<TF_Output>& outputs,
const std::vector<TF_Operation*>& targets,
- TF_Buffer* run_metadata, TF_Status* out_status,
+ TF_Buffer* run_metadata, TF_Status* status,
std::vector<PyObject*>* py_outputs);
// Set up the graph with the intended feeds (inputs) and fetches (output) for
@@ -165,8 +164,7 @@ void TF_SessionPRunSetup_wrapper(TF_Session* session,
const std::vector<TF_Output>& inputs,
const std::vector<TF_Output>& outputs,
const std::vector<TF_Operation*>& targets,
- const char** out_handle,
- TF_Status* out_status);
+ const char** out_handle, TF_Status* status);
// Continue to run the graph with additional feeds and fetches. The
// execution state is uniquely identified by the handle.
@@ -182,7 +180,7 @@ void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
const std::vector<TF_Output>& inputs,
const std::vector<PyObject*>& input_ndarrays,
const std::vector<TF_Output>& outputs,
- TF_Status* out_status,
+ TF_Status* status,
std::vector<PyObject*>* py_outputs);
// Retrieves the inputs of this operation.
@@ -204,7 +202,7 @@ TF_Function* TF_GraphToFunction_wrapper(
const std::vector<TF_Operation*>* opers,
const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs,
const NameVector& output_names, const TF_FunctionOptions* opts,
- const char* description, TF_Status* out_status);
+ const char* description, TF_Status* status);
// Set the shapes and types for the output's handle.
//