aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/pywrap_tfe.i
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/pywrap_tfe.i')
-rw-r--r--tensorflow/python/pywrap_tfe.i67
1 files changed, 28 insertions, 39 deletions
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index d1e2ab3e9c..128e46e6ce 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -15,24 +15,16 @@ limitations under the License.
%ignore "";
-%rename("%s") TFE_Py_RegisterExceptionClass;
-%rename("%s") TFE_Py_NumpyToTensorHandle;
-%rename("%s") TFE_Py_SequenceToTensorHandle;
-%rename("%s") TFE_Py_AllEqualInt64;
%rename("%s") TFE_NewContext;
%rename("%s") TFE_DeleteContext;
%rename("%s") TFE_ContextListDevices;
-%rename("%s") TFE_TensorHandleDataType;
-%rename("%s") TFE_TensorHandleNumDims;
-%rename("%s") TFE_DeleteTensorHandle;
-%rename("%s") TFE_Py_Execute;
%rename("%s") TFE_ContextAddFunctionDef;
-%rename("%s") TFE_TensorHandleDim;
-%rename("%s") TFE_TensorHandleDeviceName;
-%rename("%s") TFE_TensorHandleCopyToDevice;
%rename("%s") TFE_NewOp;
-%rename("%s") TFE_Py_TensorHandleToNumpy;
%rename("%s") TFE_OpGetAttrType;
+%rename("%s") TFE_Py_InitEagerTensor;
+%rename("%s") TFE_Py_RegisterExceptionClass;
+%rename("%s") TFE_Py_Execute;
+%rename("%s") TFE_Py_UID;
%{
@@ -79,6 +71,18 @@ limitations under the License.
$1 = TFE_GetPythonString($input);
}
+%typemap(in) (TFE_Context*) {
+ $1 = (TFE_Context*)PyCapsule_GetPointer($input, nullptr);
+
+}
+%typemap(out) (TFE_Context*) {
+ if ($1 == nullptr) {
+ SWIG_fail;
+ } else {
+ $result = PyCapsule_New($1, nullptr, TFE_DeleteContextCapsule);
+ }
+}
+
%include "tensorflow/c/eager/c_api.h"
%typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp) {
@@ -95,15 +99,13 @@ limitations under the License.
if (!elem) {
SWIG_fail;
}
- void* thp = nullptr;
- int res = SWIG_ConvertPtr(elem, &thp,
- $descriptor(TFE_TensorHandle*), 0 | 0);
- if (!SWIG_IsOK(res)) {
- SWIG_exception_fail(SWIG_ArgError(res),
+ if (EagerTensor_CheckExact(elem)) {
+ (*$1)[i] = EagerTensorHandle(elem);
+ } else {
+ SWIG_exception_fail(SWIG_TypeError,
"provided list of inputs contains objects other "
- "than 'TFE_TensorHandle*'");
+ "than 'EagerTensor'");
}
- (*$1)[i] = reinterpret_cast<TFE_TensorHandle*>(thp);
}
}
}
@@ -129,45 +131,32 @@ limitations under the License.
}
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status) {
- if (TFE_Py_MaybeRaiseException($2)) {
+ if (MaybeRaiseExceptionFromTFStatus($2, nullptr)) {
SWIG_fail;
} else {
int num_outputs = $1->size();
$result = PyList_New(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
- PyList_SetItem($result, i, SWIG_NewPointerObj(SWIG_as_voidptr($1->at(i)),
- $descriptor(TFE_TensorHandle*),
- 0 | 0));
+ PyObject *output;
+ output = EagerTensorFromHandle($1->at(i));
+ PyList_SetItem($result, i, output);
}
}
}
-// Note that we need to use a typemap for TFE_TensorHandle* so that we can call
-// SWIG_fail in case the value is nullptr. Otherwise SWIG will wrap the
-// nullptr and return it to python as an opaque object, and python does not
-// know that it needs to check if an Exception has been raised.
-// TODO(agarwal): check if we can get rid of this typemap.
-%typemap(out) (TFE_TensorHandle*) {
- if ($1 == nullptr) {
- SWIG_fail;
- } else {
- $result = SWIG_NewPointerObj(SWIG_as_voidptr($1),
- $descriptor(TFE_TensorHandle*), 0 | 0);
- }
-}
%include "tensorflow/python/eager/pywrap_tfe.h"
-// Clear all typemaps127
+// Clear all typemaps.
%typemap(out) TF_DataType;
%typemap(out) int64_t;
%typemap(out) TF_AttrType;
%typemap(in, numinputs=0) TF_Status *out_status;
%typemap(argout) unsigned char* is_list;
-%typemap(in) TFE_InputTensorHandles* inputs (TFE_InputTensorHandles temp);
+%typemap(in) (TFE_Context*);
+%typemap(out) (TFE_Context*);
%typemap(in) TFE_OutputTensorHandles* outputs (TFE_OutputTensorHandles temp);
%typemap(in, numinputs=0) TF_Status *out_status;
%typemap(freearg) (TF_Status* out_status);
%typemap(argout) (TFE_OutputTensorHandles* outputs, TF_Status* out_status);
-%typemap(out) (TFE_TensorHandle*);