diff options
Diffstat (limited to 'tensorflow/python/pywrap_tfe.i')
-rw-r--r-- | tensorflow/python/pywrap_tfe.i | 67 |
1 files changed, 28 insertions, 39 deletions
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index d1e2ab3e9c..128e46e6ce 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -15,24 +15,16 @@ limitations under the License. %ignore ""; -%rename("%s") TFE_Py_RegisterExceptionClass; -%rename("%s") TFE_Py_NumpyToTensorHandle; -%rename("%s") TFE_Py_SequenceToTensorHandle; -%rename("%s") TFE_Py_AllEqualInt64; %rename("%s") TFE_NewContext; %rename("%s") TFE_DeleteContext; %rename("%s") TFE_ContextListDevices; -%rename("%s") TFE_TensorHandleDataType; -%rename("%s") TFE_TensorHandleNumDims; -%rename("%s") TFE_DeleteTensorHandle; -%rename("%s") TFE_Py_Execute; %rename("%s") TFE_ContextAddFunctionDef; -%rename("%s") TFE_TensorHandleDim; -%rename("%s") TFE_TensorHandleDeviceName; -%rename("%s") TFE_TensorHandleCopyToDevice; %rename("%s") TFE_NewOp; -%rename("%s") TFE_Py_TensorHandleToNumpy; %rename("%s") TFE_OpGetAttrType; +%rename("%s") TFE_Py_InitEagerTensor; +%rename("%s") TFE_Py_RegisterExceptionClass; +%rename("%s") TFE_Py_Execute; +%rename("%s") TFE_Py_UID; %{ @@ -79,6 +71,18 @@ limitations under the License. $1 = TFE_GetPythonString($input); } +%typemap(in) (TFE_Context*) { + $1 = (TFE_Context*)PyCapsule_GetPointer($input, nullptr); + +} +%typemap(out) (TFE_Context*) { + if ($1 == nullptr) { + SWIG_fail; + } else { + $result = PyCapsule_New($1, nullptr, TFE_DeleteContextCapsule); + } +} + %include "tensorflow/c/eager/c_api.h" %typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp) { @@ -95,15 +99,13 @@ limitations under the License. if (!elem) { SWIG_fail; } - void* thp = nullptr; - int res = SWIG_ConvertPtr(elem, &thp, - $descriptor(TFE_TensorHandle*), 0 | 0); - if (!SWIG_IsOK(res)) { - SWIG_exception_fail(SWIG_ArgError(res), + if (EagerTensor_CheckExact(elem)) { + (*$1)[i] = EagerTensorHandle(elem); + } else { + SWIG_exception_fail(SWIG_TypeError, "provided list of inputs contains objects other " - "than 'TFE_TensorHandle*'"); + "than 'EagerTensor'"); } - (*$1)[i] = reinterpret_cast<TFE_TensorHandle*>(thp); } } } @@ -129,45 +131,32 @@ limitations under the License. } %typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status) { - if (TFE_Py_MaybeRaiseException($2)) { + if (MaybeRaiseExceptionFromTFStatus($2, nullptr)) { SWIG_fail; } else { int num_outputs = $1->size(); $result = PyList_New(num_outputs); for (int i = 0; i < num_outputs; ++i) { - PyList_SetItem($result, i, SWIG_NewPointerObj(SWIG_as_voidptr($1->at(i)), - $descriptor(TFE_TensorHandle*), - 0 | 0)); + PyObject *output; + output = EagerTensorFromHandle($1->at(i)); + PyList_SetItem($result, i, output); } } } -// Note that we need to use a typemap for TFE_TensorHandle* so that we can call -// SWIG_fail in case the value is nullptr. Otherwise SWIG will wrap the -// nullptr and return it to python as an opaque object, and python does not -// know that it needs to check if an Exception has been raised. -// TODO(agarwal): check if we can get rid of this typemap. -%typemap(out) (TFE_TensorHandle*) { - if ($1 == nullptr) { - SWIG_fail; - } else { - $result = SWIG_NewPointerObj(SWIG_as_voidptr($1), - $descriptor(TFE_TensorHandle*), 0 | 0); - } -} %include "tensorflow/python/eager/pywrap_tfe.h" -// Clear all typemaps127 +// Clear all typemaps. %typemap(out) TF_DataType; %typemap(out) int64_t; %typemap(out) TF_AttrType; %typemap(in, numinputs=0) TF_Status *out_status; %typemap(argout) unsigned char* is_list; -%typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp); +%typemap(in) (TFE_Context*); +%typemap(out) (TFE_Context*); %typemap(in) TFE_OutputTensorHandles* outputs (TFE_OutputTensorHandles temp); %typemap(in, numinputs=0) TF_Status *out_status; %typemap(freearg) (TF_Status* out_status); %typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status); -%typemap(out) (TFE_TensorHandle*); |