diff options
-rw-r--r-- | tensorflow/python/client/tf_session.i | 26 | ||||
-rw-r--r-- | tensorflow/python/eager/backprop.py | 11 | ||||
-rw-r--r-- | tensorflow/python/eager/context.py | 56 | ||||
-rw-r--r-- | tensorflow/python/eager/imperative_grad.py | 6 | ||||
-rw-r--r-- | tensorflow/python/grappler/item.py | 4 | ||||
-rw-r--r-- | tensorflow/python/lib/io/tf_record.py | 3 | ||||
-rw-r--r-- | tensorflow/python/platform/base.i | 22 | ||||
-rw-r--r-- | tensorflow/python/pywrap_tfe.i | 2 |
8 files changed, 54 insertions, 76 deletions
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index ee76e29c05..68768f2b4c 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" #include "tensorflow/python/client/tf_session_helper.h" -#include "tensorflow/python/lib/core/py_exception_registry.h" // Helper function to convert a Python list of Tensors to a C++ vector of // TF_Outputs. @@ -353,27 +352,6 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{ reinterpret_cast<const char*>($1.data), $1.length); } -// Typemaps to automatically raise a Python exception from bad output TF_Status. -// TODO(b/77295559): expand this to all TF_Status* output params and deprecate -// raise_exception_on_not_ok_status (currently it only affects the C API). -%typemap(in, numinputs=0) TF_Status* status (TF_Status* status) { - status = TF_NewStatus(); - $1 = status; -} - -%typemap(argout) TF_Status* status { - TF_Code code = TF_GetCode($1); - if (code != TF_OK) { - PyObject* exc = tensorflow::PyExceptionRegistry::Lookup(code); - // Arguments to OpError. - PyObject* exc_args = Py_BuildValue("sss", nullptr, nullptr, TF_Message($1)); - TF_DeleteStatus($1); - SWIG_SetErrorObj(exc, exc_args); - SWIG_fail; - } - TF_DeleteStatus($1); -} - // Converts input Python list of wrapped TF_Outputs into a single array %typemap(in) (const TF_Output* inputs, int num_inputs) (std::vector<TF_Output> inputs) { @@ -784,7 +762,3 @@ def TF_Reset(target, containers=None, config=None): %include "tensorflow/python/client/tf_session_helper.h" %unignoreall - -// Clear "TF_Status* status" typemap so it doesn't affect other modules and -// unexpectedly remove the TF_Status* argument from wrappers. -%clear TF_Status* status; diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 209b012621..92774d4d50 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -31,7 +31,6 @@ from tensorflow.python.eager import imperative_grad from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -50,12 +49,10 @@ def op_attr_type(op_type, attr_name): try: return _op_attr_type_cache[(op_type, attr_name)] except KeyError: - with errors.raise_exception_on_not_ok_status() as status: - h = context.context()._handle # pylint: disable=protected-access - attr_type = pywrap_tensorflow.TFE_OpNameGetAttrType( - h, op_type, attr_name, status) - _op_attr_type_cache[(op_type, attr_name)] = attr_type - return attr_type + h = context.context()._handle # pylint: disable=protected-access + attr_type = pywrap_tensorflow.TFE_OpNameGetAttrType(h, op_type, attr_name) + _op_attr_type_cache[(op_type, attr_name)] = attr_type + return attr_type def make_attr(attr_type, value): diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 99ec895b54..9e146f021e 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -28,7 +28,6 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import c_api_util from tensorflow.python.framework import device as pydev -from tensorflow.python.framework import errors from tensorflow.python.util import compat from tensorflow.python.util import is_in_graph_mode from tensorflow.python.util import tf_contextlib @@ -224,24 +223,21 @@ class Context(object): assert self._context_devices is None opts = pywrap_tensorflow.TFE_NewContextOptions() try: - with errors.raise_exception_on_not_ok_status() as status: - if self._config is not None: - config_str = self._config.SerializeToString() - pywrap_tensorflow.TFE_ContextOptionsSetConfig( - opts, config_str, len(config_str), status) - if self._device_policy is not None: - pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy( - opts, self._device_policy) - if self._execution_mode == ASYNC: - pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True) - self._context_handle = pywrap_tensorflow.TFE_NewContext(opts, status) + if self._config is not None: + config_str = self._config.SerializeToString() + pywrap_tensorflow.TFE_ContextOptionsSetConfig(opts, config_str) + if self._device_policy is not None: + pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy( + opts, self._device_policy) + if self._execution_mode == ASYNC: + pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True) + self._context_handle = pywrap_tensorflow.TFE_NewContext(opts) finally: pywrap_tensorflow.TFE_DeleteContextOptions(opts) # Store list of devices self._context_devices = [] - with errors.raise_exception_on_not_ok_status() as status: - device_list = pywrap_tensorflow.TFE_ContextListDevices( - self._context_handle, status) + device_list = pywrap_tensorflow.TFE_ContextListDevices( + self._context_handle) try: self._num_gpus = 0 for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)): @@ -412,9 +408,7 @@ class Context(object): if mode is None: mode = SYNC self._eager_context.execution_mode = mode - with errors.raise_exception_on_not_ok_status() as status: - pywrap_tensorflow.TFE_ContextSetAsyncForThread(self._handle, - mode == ASYNC, status) + pywrap_tensorflow.TFE_ContextSetAsyncForThread(self._handle, mode == ASYNC) @tf_contextlib.contextmanager def execution_mode(self, mode): @@ -428,8 +422,7 @@ class Context(object): def async_wait(self): """Waits for ops dispatched in ASYNC mode to finish.""" - with errors.raise_exception_on_not_ok_status() as status: - pywrap_tensorflow.TFE_ContextAsyncWait(self._handle, status) + pywrap_tensorflow.TFE_ContextAsyncWait(self._handle) def async_clear_error(self): """Clears errors raised during ASYNC execution.""" @@ -449,11 +442,9 @@ class Context(object): Args: fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper). """ - with errors.raise_exception_on_not_ok_status() as status: - pywrap_tensorflow.TFE_ContextAddFunction( - self._handle, # pylint: disable=protected-access - fn, - status) + pywrap_tensorflow.TFE_ContextAddFunction( + self._handle, # pylint: disable=protected-access + fn) def add_function_def(self, fdef): """Add a function definition to the context. @@ -465,12 +456,10 @@ class Context(object): fdef: A FunctionDef protocol buffer message. """ fdef_string = fdef.SerializeToString() - with errors.raise_exception_on_not_ok_status() as status: - pywrap_tensorflow.TFE_ContextAddFunctionDef( - self._handle, # pylint: disable=protected-access - fdef_string, - len(fdef_string), - status) + pywrap_tensorflow.TFE_ContextAddFunctionDef( + self._handle, # pylint: disable=protected-access + fdef_string, + len(fdef_string)) def add_post_execution_callback(self, callback): """Add a post-execution callback to the context. @@ -545,9 +534,8 @@ class Context(object): if not self._context_handle: return None with c_api_util.tf_buffer() as buffer_: - with errors.raise_exception_on_not_ok_status() as status: - pywrap_tensorflow.TFE_ContextExportRunMetadata( - self._context_handle, buffer_, status) + pywrap_tensorflow.TFE_ContextExportRunMetadata( + self._context_handle, buffer_) proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_) run_metadata = config_pb2.RunMetadata() run_metadata.ParseFromString(compat.as_bytes(proto_data)) diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py index 837cad974a..000152855d 100644 --- a/tensorflow/python/eager/imperative_grad.py +++ b/tensorflow/python/eager/imperative_grad.py @@ -21,7 +21,6 @@ from __future__ import print_function import collections from tensorflow.python import pywrap_tensorflow -from tensorflow.python.framework import errors VSpace = collections.namedtuple( @@ -60,6 +59,5 @@ def imperative_grad( or if only non-differentiable functions of the source were used in the computation of target. """ - with errors.raise_exception_on_not_ok_status() as status: - return pywrap_tensorflow.TFE_Py_TapeGradient( - tape._tape, vspace, target, sources, output_gradients, status) # pylint: disable=protected-access + return pywrap_tensorflow.TFE_Py_TapeGradient( + tape._tape, vspace, target, sources, output_gradients) # pylint: disable=protected-access diff --git a/tensorflow/python/grappler/item.py b/tensorflow/python/grappler/item.py index 4a083849bd..1748efdd13 100644 --- a/tensorflow/python/grappler/item.py +++ b/tensorflow/python/grappler/item.py @@ -51,9 +51,7 @@ class Item(object): self._BuildTFItem() def IdentifyImportantOps(self, sort_topologically=False): - with errors.raise_exception_on_not_ok_status() as status: - return tf_item.TF_IdentifyImportantOps(self.tf_item, sort_topologically, - status) + return tf_item.TF_IdentifyImportantOps(self.tf_item, sort_topologically) def GetOpProperties(self): ret_from_swig = tf_item.TF_GetOpProperties(self.tf_item) diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py index 6fcf9c91d8..bf2d6f68b5 100644 --- a/tensorflow/python/lib/io/tf_record.py +++ b/tensorflow/python/lib/io/tf_record.py @@ -78,8 +78,7 @@ def tf_record_iterator(path, options=None): try: while True: try: - with errors.raise_exception_on_not_ok_status() as status: - reader.GetNext(status) + reader.GetNext() except errors.OutOfRangeError: break yield reader.record() diff --git a/tensorflow/python/platform/base.i b/tensorflow/python/platform/base.i index dbefca2be9..478dd46f7e 100644 --- a/tensorflow/python/platform/base.i +++ b/tensorflow/python/platform/base.i @@ -229,3 +229,25 @@ _COPY_TYPEMAPS(unsigned int, mode_t); %define final %enddef %define override %enddef #endif + +// Typemaps to automatically raise a Python exception from bad output TF_Status. +// TODO(b/77295559): expand this to all TF_Status* output params and deprecate +// raise_exception_on_not_ok_status (currently it only affects the C API). +%typemap(in, numinputs=0) TF_Status* status (TF_Status* status) { + $1 = TF_NewStatus(); +} + +%typemap(freearg) (TF_Status* status) { + TF_DeleteStatus($1); +} + +%typemap(argout) TF_Status* status { + TF_Code code = TF_GetCode($1); + if (code != TF_OK) { + PyObject* exc = tensorflow::PyExceptionRegistry::Lookup(code); + // Arguments to OpError. + PyObject* exc_args = Py_BuildValue("sss", nullptr, nullptr, TF_Message($1)); + SWIG_SetErrorObj(exc, exc_args); + SWIG_fail; + } +} diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 39fabb9c1b..7acb8eeb1a 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +%include "tensorflow/python/platform/base.i" + %ignore ""; %rename("%s") TFE_NewContext; |