aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/pywrap_tensor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/pywrap_tensor.cc')
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc33
1 files changed, 29 insertions, 4 deletions
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index ea604647fa..15d2ccf9d2 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -154,6 +154,7 @@ TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
TFE_OpSetAttrType(op, "SrcT", src_type_enum);
TFE_OpSetAttrType(op, "DstT", dst_type_enum);
+ TFE_OpSetAttrBool(op, "Truncate", false);
TFE_TensorHandle* output = nullptr;
int num_outputs = 1;
TFE_Execute(op, &output, &num_outputs, out_status);
@@ -620,10 +621,6 @@ static PyType_Slot EagerTensor_Type_slots[] = {
{Py_tp_init, reinterpret_cast<void*>(EagerTensor_init)},
{0, nullptr},
};
-
-PyType_Spec EagerTensor_Type_spec = {"EagerTensor", sizeof(EagerTensor), 0,
- Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE,
- EagerTensor_Type_slots};
#else
// TODO(agarwal): support active_trace.
static PyTypeObject _EagerTensorType = {
@@ -754,6 +751,34 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
#if PY_MAJOR_VERSION >= 3
PyObject* bases = PyTuple_New(1);
PyTuple_SET_ITEM(bases, 0, base_class);
+
+ tensorflow::Safe_PyObjectPtr base_class_module(
+ PyObject_GetAttrString(base_class, "__module__"));
+ const char* module = nullptr;
+ if (PyErr_Occurred()) {
+ PyErr_Clear();
+ module = "__builtin__";
+ } else {
+ module = PyBytes_AsString(base_class_module.get());
+ if (module == nullptr) {
+ PyErr_Clear();
+ module = PyUnicode_AsUTF8(base_class_module.get());
+ if (module == nullptr) {
+ PyErr_Clear();
+ module = "__builtin__";
+ }
+ }
+ }
+
+ // NOTE: The c_str from this string needs to outlast the function, hence is
+ // static.
+ static tensorflow::string fully_qualified_name =
+ tensorflow::strings::StrCat(module, ".EagerTensor");
+
+ static PyType_Spec EagerTensor_Type_spec = {
+ fully_qualified_name.c_str(), sizeof(EagerTensor), 0,
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, EagerTensor_Type_slots};
+
EagerTensorType = reinterpret_cast<PyTypeObject*>(
PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases));
if (PyErr_Occurred()) {