From 3acd57c2ffff6055b322ba08ba74fa1885fbba19 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Fri, 6 Oct 2017 09:37:33 -0700 Subject: Fuse TFE_NewOp and TFE_OpGetAttrType to avoid leaking memory. Removes TFE_NewOp and TFE_OpGetAttrType from pywrap_tensorflow, adds TFE_OpNameGetAttrType. PiperOrigin-RevId: 171302338 --- tensorflow/c/eager/c_api.cc | 14 ++++++++++++++ tensorflow/c/eager/c_api.h | 6 ++++++ tensorflow/python/eager/backprop.py | 4 ++-- tensorflow/python/pywrap_tfe.i | 3 +-- 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 74f2e4f342..514a4010bc 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -273,6 +273,20 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, return ret; } +TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx, + const char* op_or_function_name, + const char* attr_name, unsigned char* is_list, + TF_Status* status) { + TF_AttrType ret; + TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status); + if (!status->status.ok()) { + return TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType. + } + ret = TFE_OpGetAttrType(op, attr_name, is_list, status); + TFE_DeleteOp(op); + return ret; +} + void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) { op->attrs.Set(attr_name, value); } diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index a4f7d308fb..9bfa63711b 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -107,6 +107,12 @@ TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_St TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, unsigned char* is_list, TF_Status* status); +// Get an attribute type given an op name; a fusion of TFE_NewOp and +// TFE_OpGetAttrType for use from Python without the overhead of the individual +// calls and memory management of TFE_Op. +TF_CAPI_EXPORT extern TF_AttrType TFE_OpNameGetAttrType( + TFE_Context* ctx, const char* op_or_function_name, const char* attr_name, + unsigned char* is_list, TF_Status* status); TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value); diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 3c84cbbd6f..cca8e47044 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -49,8 +49,8 @@ def op_attr_type(op_type, attr_name): except KeyError: with errors.raise_exception_on_not_ok_status() as status: h = context.context()._handle # pylint: disable=protected-access - op = pywrap_tensorflow.TFE_NewOp(h, op_type, status) - attr_type = pywrap_tensorflow.TFE_OpGetAttrType(op, attr_name, status) + attr_type = pywrap_tensorflow.TFE_OpNameGetAttrType( + h, op_type, attr_name, status) _op_attr_type_cache[(op_type, attr_name)] = attr_type return attr_type diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 128e46e6ce..d5b7294c82 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -19,8 +19,7 @@ limitations under the License. %rename("%s") TFE_DeleteContext; %rename("%s") TFE_ContextListDevices; %rename("%s") TFE_ContextAddFunctionDef; -%rename("%s") TFE_NewOp; -%rename("%s") TFE_OpGetAttrType; +%rename("%s") TFE_OpNameGetAttrType; %rename("%s") TFE_Py_InitEagerTensor; %rename("%s") TFE_Py_RegisterExceptionClass; %rename("%s") TFE_Py_Execute; -- cgit v1.2.3