aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/pywrap_tfe_src.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/pywrap_tfe_src.cc')
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc89
1 files changed, 61 insertions, 28 deletions
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 57b4dab51c..4d28e98961 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1173,14 +1173,14 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
TFE_TensorHandle* t = EagerTensor_Handle(tensor);
tensorflow::int64 id = EagerTensor_id(tensor);
- const tensorflow::Tensor* tensor = nullptr;
- const tensorflow::Status status = t->handle->Tensor(&tensor);
+ tensorflow::TensorShape tensor_shape;
+ const tensorflow::Status status = t->handle->Shape(&tensor_shape);
+
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
return tensorflow::eager::TapeTensor{id, t->handle->dtype,
tensorflow::TensorShape({})};
} else {
- return tensorflow::eager::TapeTensor{id, t->handle->dtype,
- tensor->shape()};
+ return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensor_shape};
}
}
tensorflow::int64 id = FastTensorId(tensor);
@@ -1898,14 +1898,39 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
void MaybeWatchVariable(PyObject* input) {
DCHECK(CheckResourceVariable(input));
- DCHECK(PyObject_HasAttrString(input, "trainable"));
+ DCHECK(PyObject_HasAttrString(input, "_trainable"));
tensorflow::Safe_PyObjectPtr trainable(
- PyObject_GetAttrString(input, "trainable"));
+ PyObject_GetAttrString(input, "_trainable"));
if (trainable.get() == Py_False) return;
TFE_Py_TapeSetWatchVariable(input);
}
+bool CastTensor(const FastPathOpExecInfo& op_exec_info,
+ const TF_DataType& desired_dtype,
+ tensorflow::Safe_TFE_TensorHandlePtr* handle,
+ TF_Status* status) {
+ TF_DataType input_dtype = TFE_TensorHandleDataType(handle->get());
+ TF_DataType output_dtype = input_dtype;
+
+ if (desired_dtype >= 0 && desired_dtype != input_dtype) {
+ *handle = tensorflow::make_safe(
+ tensorflow::EagerCast(op_exec_info.ctx, handle->get(), input_dtype,
+ static_cast<TF_DataType>(desired_dtype), status));
+ if (!status->status.ok()) return false;
+ output_dtype = desired_dtype;
+ }
+
+ if (output_dtype != TF_INT32) {
+ // Note that this is a shallow copy and will share the underlying buffer
+ // if copying to the same device.
+ *handle = tensorflow::make_safe(TFE_TensorHandleCopyToDevice(
+ handle->get(), op_exec_info.ctx, op_exec_info.device_name, status));
+ if (!status->status.ok()) return false;
+ }
+ return true;
+}
+
bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
PyObject* input, tensorflow::Safe_PyObjectPtr* output,
TF_Status* status) {
@@ -1938,9 +1963,31 @@ bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
TFE_Execute(op, &output_handle, &num_retvals, status);
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
- // Always create the py object (and correctly DECREF it) from the returned
- // value, else the data will leak.
- output->reset(EagerTensorFromHandle(output_handle));
+ if (!PyObject_HasAttrString(input, "_read_dtype")) {
+ // Always create the py object (and correctly DECREF it) from the returned
+ // value, else the data will leak.
+ output->reset(EagerTensorFromHandle(output_handle));
+ } else {
+ // This is a _MixedPrecisionVariable which potentially does casting when
+ // being read.
+ tensorflow::Safe_PyObjectPtr read_dtype(
+ PyObject_GetAttrString(input, "_read_dtype"));
+ int desired_dtype = -1;
+ if (!ParseTypeValue("_read_dtype", read_dtype.get(), status,
+ &desired_dtype)) {
+ return false;
+ }
+
+ auto safe_output_handle = tensorflow::make_safe(output_handle);
+ // Retires output_handle in the future.
+ output_handle = nullptr;
+ if (!CastTensor(parent_op_exec_info,
+ static_cast<TF_DataType>(desired_dtype),
+ &safe_output_handle, status)) {
+ return false;
+ }
+ output->reset(EagerTensorFromHandle(safe_output_handle.release()));
+ }
// TODO(nareshmodi): Should we run post exec callbacks here?
if (parent_op_exec_info.run_gradient_callback) {
@@ -2010,27 +2057,13 @@ bool ConvertToTensor(
}
}
- TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
- if (desired_dtype >= 0 && desired_dtype != handle_dtype) {
- handle = tensorflow::make_safe(
- tensorflow::EagerCast(op_exec_info.ctx, handle.get(), handle_dtype,
- static_cast<TF_DataType>(desired_dtype), status));
- if (!status->status.ok()) return false;
-
- handle_dtype = TFE_TensorHandleDataType(handle.get());
- }
-
- if (handle_dtype != TF_INT32) {
- // Note that this is a shallow copy and will share the underlying buffer
- // if copying to the same device.
- handle = tensorflow::make_safe(TFE_TensorHandleCopyToDevice(
- handle.get(), op_exec_info.ctx, op_exec_info.device_name, status));
- if (!status->status.ok()) return false;
+ if (!CastTensor(op_exec_info, static_cast<TF_DataType>(desired_dtype),
+ &handle, status)) {
+ return false;
}
-
+ TF_DataType output_dtype = TFE_TensorHandleDataType(handle.get());
output_handle->reset(EagerTensorFromHandle(handle.release()));
-
- dtype_setter(handle_dtype);
+ dtype_setter(output_dtype);
return true;
}