aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/client/tf_session.i26
-rw-r--r--tensorflow/python/eager/backprop.py11
-rw-r--r--tensorflow/python/eager/context.py56
-rw-r--r--tensorflow/python/eager/imperative_grad.py6
-rw-r--r--tensorflow/python/grappler/item.py4
-rw-r--r--tensorflow/python/lib/io/tf_record.py3
-rw-r--r--tensorflow/python/platform/base.i22
-rw-r--r--tensorflow/python/pywrap_tfe.i2
8 files changed, 54 insertions, 76 deletions
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index ee76e29c05..68768f2b4c 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/python/client/tf_session_helper.h"
-#include "tensorflow/python/lib/core/py_exception_registry.h"
// Helper function to convert a Python list of Tensors to a C++ vector of
// TF_Outputs.
@@ -353,27 +352,6 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
reinterpret_cast<const char*>($1.data), $1.length);
}
-// Typemaps to automatically raise a Python exception from bad output TF_Status.
-// TODO(b/77295559): expand this to all TF_Status* output params and deprecate
-// raise_exception_on_not_ok_status (currently it only affects the C API).
-%typemap(in, numinputs=0) TF_Status* status (TF_Status* status) {
- status = TF_NewStatus();
- $1 = status;
-}
-
-%typemap(argout) TF_Status* status {
- TF_Code code = TF_GetCode($1);
- if (code != TF_OK) {
- PyObject* exc = tensorflow::PyExceptionRegistry::Lookup(code);
- // Arguments to OpError.
- PyObject* exc_args = Py_BuildValue("sss", nullptr, nullptr, TF_Message($1));
- TF_DeleteStatus($1);
- SWIG_SetErrorObj(exc, exc_args);
- SWIG_fail;
- }
- TF_DeleteStatus($1);
-}
-
// Converts input Python list of wrapped TF_Outputs into a single array
%typemap(in) (const TF_Output* inputs, int num_inputs)
(std::vector<TF_Output> inputs) {
@@ -784,7 +762,3 @@ def TF_Reset(target, containers=None, config=None):
%include "tensorflow/python/client/tf_session_helper.h"
%unignoreall
-
-// Clear "TF_Status* status" typemap so it doesn't affect other modules and
-// unexpectedly remove the TF_Status* argument from wrappers.
-%clear TF_Status* status;
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 209b012621..92774d4d50 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -31,7 +31,6 @@ from tensorflow.python.eager import imperative_grad
from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
@@ -50,12 +49,10 @@ def op_attr_type(op_type, attr_name):
try:
return _op_attr_type_cache[(op_type, attr_name)]
except KeyError:
- with errors.raise_exception_on_not_ok_status() as status:
- h = context.context()._handle # pylint: disable=protected-access
- attr_type = pywrap_tensorflow.TFE_OpNameGetAttrType(
- h, op_type, attr_name, status)
- _op_attr_type_cache[(op_type, attr_name)] = attr_type
- return attr_type
+ h = context.context()._handle # pylint: disable=protected-access
+ attr_type = pywrap_tensorflow.TFE_OpNameGetAttrType(h, op_type, attr_name)
+ _op_attr_type_cache[(op_type, attr_name)] = attr_type
+ return attr_type
def make_attr(attr_type, value):
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 99ec895b54..9e146f021e 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -28,7 +28,6 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import device as pydev
-from tensorflow.python.framework import errors
from tensorflow.python.util import compat
from tensorflow.python.util import is_in_graph_mode
from tensorflow.python.util import tf_contextlib
@@ -224,24 +223,21 @@ class Context(object):
assert self._context_devices is None
opts = pywrap_tensorflow.TFE_NewContextOptions()
try:
- with errors.raise_exception_on_not_ok_status() as status:
- if self._config is not None:
- config_str = self._config.SerializeToString()
- pywrap_tensorflow.TFE_ContextOptionsSetConfig(
- opts, config_str, len(config_str), status)
- if self._device_policy is not None:
- pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy(
- opts, self._device_policy)
- if self._execution_mode == ASYNC:
- pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True)
- self._context_handle = pywrap_tensorflow.TFE_NewContext(opts, status)
+ if self._config is not None:
+ config_str = self._config.SerializeToString()
+ pywrap_tensorflow.TFE_ContextOptionsSetConfig(opts, config_str)
+ if self._device_policy is not None:
+ pywrap_tensorflow.TFE_ContextOptionsSetDevicePlacementPolicy(
+ opts, self._device_policy)
+ if self._execution_mode == ASYNC:
+ pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True)
+ self._context_handle = pywrap_tensorflow.TFE_NewContext(opts)
finally:
pywrap_tensorflow.TFE_DeleteContextOptions(opts)
# Store list of devices
self._context_devices = []
- with errors.raise_exception_on_not_ok_status() as status:
- device_list = pywrap_tensorflow.TFE_ContextListDevices(
- self._context_handle, status)
+ device_list = pywrap_tensorflow.TFE_ContextListDevices(
+ self._context_handle)
try:
self._num_gpus = 0
for i in range(pywrap_tensorflow.TF_DeviceListCount(device_list)):
@@ -412,9 +408,7 @@ class Context(object):
if mode is None:
mode = SYNC
self._eager_context.execution_mode = mode
- with errors.raise_exception_on_not_ok_status() as status:
- pywrap_tensorflow.TFE_ContextSetAsyncForThread(self._handle,
- mode == ASYNC, status)
+ pywrap_tensorflow.TFE_ContextSetAsyncForThread(self._handle, mode == ASYNC)
@tf_contextlib.contextmanager
def execution_mode(self, mode):
@@ -428,8 +422,7 @@ class Context(object):
def async_wait(self):
"""Waits for ops dispatched in ASYNC mode to finish."""
- with errors.raise_exception_on_not_ok_status() as status:
- pywrap_tensorflow.TFE_ContextAsyncWait(self._handle, status)
+ pywrap_tensorflow.TFE_ContextAsyncWait(self._handle)
def async_clear_error(self):
"""Clears errors raised during ASYNC execution."""
@@ -449,11 +442,9 @@ class Context(object):
Args:
fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
"""
- with errors.raise_exception_on_not_ok_status() as status:
- pywrap_tensorflow.TFE_ContextAddFunction(
- self._handle, # pylint: disable=protected-access
- fn,
- status)
+ pywrap_tensorflow.TFE_ContextAddFunction(
+ self._handle, # pylint: disable=protected-access
+ fn)
def add_function_def(self, fdef):
"""Add a function definition to the context.
@@ -465,12 +456,10 @@ class Context(object):
fdef: A FunctionDef protocol buffer message.
"""
fdef_string = fdef.SerializeToString()
- with errors.raise_exception_on_not_ok_status() as status:
- pywrap_tensorflow.TFE_ContextAddFunctionDef(
- self._handle, # pylint: disable=protected-access
- fdef_string,
- len(fdef_string),
- status)
+ pywrap_tensorflow.TFE_ContextAddFunctionDef(
+ self._handle, # pylint: disable=protected-access
+ fdef_string,
+ len(fdef_string))
def add_post_execution_callback(self, callback):
"""Add a post-execution callback to the context.
@@ -545,9 +534,8 @@ class Context(object):
if not self._context_handle:
return None
with c_api_util.tf_buffer() as buffer_:
- with errors.raise_exception_on_not_ok_status() as status:
- pywrap_tensorflow.TFE_ContextExportRunMetadata(
- self._context_handle, buffer_, status)
+ pywrap_tensorflow.TFE_ContextExportRunMetadata(
+ self._context_handle, buffer_)
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
run_metadata = config_pb2.RunMetadata()
run_metadata.ParseFromString(compat.as_bytes(proto_data))
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index 837cad974a..000152855d 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -21,7 +21,6 @@ from __future__ import print_function
import collections
from tensorflow.python import pywrap_tensorflow
-from tensorflow.python.framework import errors
VSpace = collections.namedtuple(
@@ -60,6 +59,5 @@ def imperative_grad(
or if only non-differentiable functions of the source were used in the
computation of target.
"""
- with errors.raise_exception_on_not_ok_status() as status:
- return pywrap_tensorflow.TFE_Py_TapeGradient(
- tape._tape, vspace, target, sources, output_gradients, status) # pylint: disable=protected-access
+ return pywrap_tensorflow.TFE_Py_TapeGradient(
+ tape._tape, vspace, target, sources, output_gradients) # pylint: disable=protected-access
diff --git a/tensorflow/python/grappler/item.py b/tensorflow/python/grappler/item.py
index 4a083849bd..1748efdd13 100644
--- a/tensorflow/python/grappler/item.py
+++ b/tensorflow/python/grappler/item.py
@@ -51,9 +51,7 @@ class Item(object):
self._BuildTFItem()
def IdentifyImportantOps(self, sort_topologically=False):
- with errors.raise_exception_on_not_ok_status() as status:
- return tf_item.TF_IdentifyImportantOps(self.tf_item, sort_topologically,
- status)
+ return tf_item.TF_IdentifyImportantOps(self.tf_item, sort_topologically)
def GetOpProperties(self):
ret_from_swig = tf_item.TF_GetOpProperties(self.tf_item)
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
index 6fcf9c91d8..bf2d6f68b5 100644
--- a/tensorflow/python/lib/io/tf_record.py
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -78,8 +78,7 @@ def tf_record_iterator(path, options=None):
try:
while True:
try:
- with errors.raise_exception_on_not_ok_status() as status:
- reader.GetNext(status)
+ reader.GetNext()
except errors.OutOfRangeError:
break
yield reader.record()
diff --git a/tensorflow/python/platform/base.i b/tensorflow/python/platform/base.i
index dbefca2be9..478dd46f7e 100644
--- a/tensorflow/python/platform/base.i
+++ b/tensorflow/python/platform/base.i
@@ -229,3 +229,25 @@ _COPY_TYPEMAPS(unsigned int, mode_t);
%define final %enddef
%define override %enddef
#endif
+
+// Typemaps to automatically raise a Python exception from bad output TF_Status.
+// TODO(b/77295559): expand this to all TF_Status* output params and deprecate
+// raise_exception_on_not_ok_status (currently it only affects the C API).
+%typemap(in, numinputs=0) TF_Status* status (TF_Status* status) {
+ $1 = TF_NewStatus();
+}
+
+%typemap(freearg) (TF_Status* status) {
+ TF_DeleteStatus($1);
+}
+
+%typemap(argout) TF_Status* status {
+ TF_Code code = TF_GetCode($1);
+ if (code != TF_OK) {
+ PyObject* exc = tensorflow::PyExceptionRegistry::Lookup(code);
+ // Arguments to OpError.
+ PyObject* exc_args = Py_BuildValue("sss", nullptr, nullptr, TF_Message($1));
+ SWIG_SetErrorObj(exc, exc_args);
+ SWIG_fail;
+ }
+}
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 39fabb9c1b..7acb8eeb1a 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+%include "tensorflow/python/platform/base.i"
+
%ignore "";
%rename("%s") TFE_NewContext;