diff options
Diffstat (limited to 'tensorflow/python/eager/pywrap_tfe_src.cc')
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe_src.cc | 89 |
1 files changed, 61 insertions, 28 deletions
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 57b4dab51c..4d28e98961 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1173,14 +1173,14 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { if (EagerTensor_CheckExact(tensor)) { TFE_TensorHandle* t = EagerTensor_Handle(tensor); tensorflow::int64 id = EagerTensor_id(tensor); - const tensorflow::Tensor* tensor = nullptr; - const tensorflow::Status status = t->handle->Tensor(&tensor); + tensorflow::TensorShape tensor_shape; + const tensorflow::Status status = t->handle->Shape(&tensor_shape); + if (MaybeRaiseExceptionFromStatus(status, nullptr)) { return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensorflow::TensorShape({})}; } else { - return tensorflow::eager::TapeTensor{id, t->handle->dtype, - tensor->shape()}; + return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensor_shape}; } } tensorflow::int64 id = FastTensorId(tensor); @@ -1898,14 +1898,39 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, void MaybeWatchVariable(PyObject* input) { DCHECK(CheckResourceVariable(input)); - DCHECK(PyObject_HasAttrString(input, "trainable")); + DCHECK(PyObject_HasAttrString(input, "_trainable")); tensorflow::Safe_PyObjectPtr trainable( - PyObject_GetAttrString(input, "trainable")); + PyObject_GetAttrString(input, "_trainable")); if (trainable.get() == Py_False) return; TFE_Py_TapeSetWatchVariable(input); } +bool CastTensor(const FastPathOpExecInfo& op_exec_info, + const TF_DataType& desired_dtype, + tensorflow::Safe_TFE_TensorHandlePtr* handle, + TF_Status* status) { + TF_DataType input_dtype = TFE_TensorHandleDataType(handle->get()); + TF_DataType output_dtype = input_dtype; + + if (desired_dtype >= 0 && desired_dtype != input_dtype) { + *handle = tensorflow::make_safe( + tensorflow::EagerCast(op_exec_info.ctx, handle->get(), input_dtype, + static_cast<TF_DataType>(desired_dtype), status)); + if (!status->status.ok()) return false; + output_dtype = desired_dtype; + } + + if (output_dtype != TF_INT32) { + // Note that this is a shallow copy and will share the underlying buffer + // if copying to the same device. + *handle = tensorflow::make_safe(TFE_TensorHandleCopyToDevice( + handle->get(), op_exec_info.ctx, op_exec_info.device_name, status)); + if (!status->status.ok()) return false; + } + return true; +} + bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info, PyObject* input, tensorflow::Safe_PyObjectPtr* output, TF_Status* status) { @@ -1938,9 +1963,31 @@ bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info, TFE_Execute(op, &output_handle, &num_retvals, status); if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; - // Always create the py object (and correctly DECREF it) from the returned - // value, else the data will leak. - output->reset(EagerTensorFromHandle(output_handle)); + if (!PyObject_HasAttrString(input, "_read_dtype")) { + // Always create the py object (and correctly DECREF it) from the returned + // value, else the data will leak. + output->reset(EagerTensorFromHandle(output_handle)); + } else { + // This is a _MixedPrecisionVariable which potentially does casting when + // being read. + tensorflow::Safe_PyObjectPtr read_dtype( + PyObject_GetAttrString(input, "_read_dtype")); + int desired_dtype = -1; + if (!ParseTypeValue("_read_dtype", read_dtype.get(), status, + &desired_dtype)) { + return false; + } + + auto safe_output_handle = tensorflow::make_safe(output_handle); + // Retires output_handle in the future. + output_handle = nullptr; + if (!CastTensor(parent_op_exec_info, + static_cast<TF_DataType>(desired_dtype), + &safe_output_handle, status)) { + return false; + } + output->reset(EagerTensorFromHandle(safe_output_handle.release())); + } // TODO(nareshmodi): Should we run post exec callbacks here? if (parent_op_exec_info.run_gradient_callback) { @@ -2010,27 +2057,13 @@ bool ConvertToTensor( } } - TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get()); - if (desired_dtype >= 0 && desired_dtype != handle_dtype) { - handle = tensorflow::make_safe( - tensorflow::EagerCast(op_exec_info.ctx, handle.get(), handle_dtype, - static_cast<TF_DataType>(desired_dtype), status)); - if (!status->status.ok()) return false; - - handle_dtype = TFE_TensorHandleDataType(handle.get()); - } - - if (handle_dtype != TF_INT32) { - // Note that this is a shallow copy and will share the underlying buffer - // if copying to the same device. - handle = tensorflow::make_safe(TFE_TensorHandleCopyToDevice( - handle.get(), op_exec_info.ctx, op_exec_info.device_name, status)); - if (!status->status.ok()) return false; + if (!CastTensor(op_exec_info, static_cast<TF_DataType>(desired_dtype), + &handle, status)) { + return false; } - + TF_DataType output_dtype = TFE_TensorHandleDataType(handle.get()); output_handle->reset(EagerTensorFromHandle(handle.release())); - - dtype_setter(handle_dtype); + dtype_setter(output_dtype); return true; } |