aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/eager/BUILD2
-rw-r--r--tensorflow/c/eager/c_api.cc272
-rw-r--r--tensorflow/c/eager/c_api_internal.h85
-rw-r--r--tensorflow/core/common_runtime/eager/BUILD23
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.cc107
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.h130
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc7
-rw-r--r--tensorflow/python/lib/core/py_func.cc2
8 files changed, 393 insertions, 235 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index bea5a121b3..d2d8d59323 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -31,6 +31,7 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
+ "//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
@@ -68,6 +69,7 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
+ "//tensorflow/core/common_runtime/eager:tensor_handle",
],
)
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 2402a6d044..59432f2ef8 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -161,29 +161,32 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
DCHECK(h);
- h->Unref();
+ if (h->handle) {
+ h->handle->Unref();
+ }
+ delete h;
}
TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
- return static_cast<TF_DataType>(h->dtype);
+ return static_cast<TF_DataType>(h->handle->dtype);
}
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
const tensorflow::Tensor* t = nullptr;
- status->status = h->Tensor(&t);
+ status->status = h->handle->Tensor(&t);
return t == nullptr ? 0 : t->dims();
}
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
const tensorflow::Tensor* t = nullptr;
- status->status = h->Tensor(&t);
+ status->status = h->handle->Tensor(&t);
return t == nullptr ? 0 : t->dim_size(dim_index);
}
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
tensorflow::Device* d = nullptr;
- status->status = h->OpDevice(&d);
+ status->status = h->handle->OpDevice(&d);
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
}
@@ -193,7 +196,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
tensorflow::Device* d = nullptr;
tensorflow::Device* op_device = nullptr;
const tensorflow::Tensor* t = nullptr;
- status->status = h->TensorAndDevice(&t, &d, &op_device);
+ status->status = h->handle->TensorAndDevice(&t, &d, &op_device);
if (!status->status.ok()) return nullptr;
if (!IsCPU(d)) {
TF_SetStatus(status, TF_UNIMPLEMENTED,
@@ -212,10 +215,10 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
namespace {
-tensorflow::Status TensorHandleCopyToDevice(TFE_TensorHandle* h,
+tensorflow::Status TensorHandleCopyToDevice(tensorflow::TensorHandle* h,
TFE_Context* ctx,
tensorflow::Device* dstd,
- TFE_TensorHandle** output) {
+ tensorflow::TensorHandle** output) {
const tensorflow::Tensor* src = nullptr;
tensorflow::Device* srcd = nullptr;
// TODO(agarwal): src_opd is unused. Perhaps allow TensorAndDevice to accept
@@ -232,7 +235,7 @@ tensorflow::Status TensorHandleCopyToDevice(TFE_TensorHandle* h,
const bool both_on_cpu = src_cpu && dst_cpu;
if (is_same_device || both_on_cpu) {
dstd = dst_cpu ? nullptr : dstd;
- *output = new TFE_TensorHandle(*src, dstd, dstd);
+ *output = new tensorflow::TensorHandle(*src, dstd, dstd);
return tensorflow::Status::OK();
}
if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT &&
@@ -249,7 +252,7 @@ tensorflow::Status TensorHandleCopyToDevice(TFE_TensorHandle* h,
tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape());
if (src->shape().num_elements() == 0) {
dstd = dst_cpu ? nullptr : dstd;
- *output = new TFE_TensorHandle(dst, dstd, dstd);
+ *output = new tensorflow::TensorHandle(dst, dstd, dstd);
return tensorflow::Status::OK();
}
tensorflow::DeviceContext* src_device_context = nullptr;
@@ -280,7 +283,7 @@ tensorflow::Status TensorHandleCopyToDevice(TFE_TensorHandle* h,
n.WaitForNotification();
if (status.ok()) {
dstd = dst_cpu ? nullptr : dstd;
- *output = new TFE_TensorHandle(dst, dstd, dstd);
+ *output = new tensorflow::TensorHandle(dst, dstd, dstd);
}
return status;
}
@@ -335,12 +338,12 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
tensorflow::Device* d = nullptr;
// TODO(agarwal): This call may block if h is not ready. Avoid this if
// possible.
- status->status = h->Device(&d);
+ status->status = h->handle->Device(&d);
if (!status->status.ok()) return;
if (!IsCPU(d)) op->device = d;
}
- h->Ref();
- op->inputs.push_back(h);
+ h->handle->Ref();
+ op->inputs.push_back(h->handle);
op->attrs.NumInputs(op->inputs.size());
}
@@ -506,6 +509,79 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
namespace {
+class CopyToDeviceNode : public tensorflow::EagerNode {
+ public:
+ CopyToDeviceNode(tensorflow::TensorHandle* src, tensorflow::Device* dstd,
+ TFE_Context* ctx)
+ : tensorflow::EagerNode(ctx->context.NextId()),
+ src_(src),
+ dstd_(dstd),
+ ctx_(ctx),
+ dst_(new tensorflow::TensorHandle(id, src_->dtype, &ctx->context)) {
+ src_->Ref();
+ dst_->Ref();
+ }
+
+ ~CopyToDeviceNode() override {
+ src_->Unref();
+ dst_->Unref();
+ }
+
+ tensorflow::Status Run() override {
+ tensorflow::TensorHandle* temp = nullptr;
+ TF_RETURN_IF_ERROR(TensorHandleCopyToDevice(src_, ctx_, dstd_, &temp));
+ const tensorflow::Tensor* tensor = nullptr;
+ tensorflow::Device* device = nullptr;
+ tensorflow::Device* op_device = nullptr;
+ tensorflow::Status status =
+ temp->TensorAndDevice(&tensor, &device, &op_device);
+ // `temp` is a ready handle. So the following call should return OK.
+ TF_DCHECK_OK(status) << status.error_message();
+ DCHECK(tensor);
+ dst_->SetTensorAndDevice(*tensor, device, op_device);
+ temp->Unref();
+ return tensorflow::Status::OK();
+ }
+
+ tensorflow::TensorHandle* dst() { return dst_; }
+
+ private:
+ tensorflow::TensorHandle* src_;
+ tensorflow::Device* dstd_;
+ TFE_Context* ctx_;
+ tensorflow::TensorHandle* dst_;
+};
+
+// TODO(apassos) move to TensorHandle
+tensorflow::TensorHandle* TFE_TensorHandleCopyToDevice_Internal(
+ tensorflow::TensorHandle* h, TFE_Context* ctx, const char* device_name,
+ TF_Status* status) {
+ status->status = ctx->context.GetStatus();
+ if (!status->status.ok()) {
+ return nullptr;
+ }
+ tensorflow::Device* dstd = ctx->context.HostCPU();
+ if (device_name != nullptr && strlen(device_name) > 0) {
+ status->status =
+ ctx->context.device_mgr()->LookupDevice(device_name, &dstd);
+ if (!status->status.ok()) return nullptr;
+ }
+ if (ctx->context.Async()) {
+ // Note that `h` may not be currently ready. However execution order will
+ // make sure that `h` is ready before the copy is actually done.
+ CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx);
+ tensorflow::TensorHandle* output = node->dst();
+ // Note that calling Add makes `node` accessible by the EagerExecutor
+ // thread. So further accesses need to be thread-safe.
+ ctx->context.ExecutorAdd(node);
+ return output;
+ } else {
+ tensorflow::TensorHandle* output = nullptr;
+ status->status = TensorHandleCopyToDevice(h, ctx, dstd, &output);
+ return output;
+ }
+}
+
tensorflow::Status ValidateInputTypeAndPlacement(
TFE_Context* ctx, tensorflow::Device* host_device,
tensorflow::Device* op_device, TFE_Op* op,
@@ -518,7 +594,7 @@ tensorflow::Status ValidateInputTypeAndPlacement(
for (int i = 0; i < op->inputs.size(); ++i) {
const tensorflow::Device* expected_device =
memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device;
- TFE_TensorHandle* handle = op->inputs[i];
+ tensorflow::TensorHandle* handle = op->inputs[i];
tensorflow::Device* handle_device = nullptr;
TF_RETURN_IF_ERROR(handle->Device(&handle_device));
const tensorflow::Device* actual_device =
@@ -560,8 +636,9 @@ tensorflow::Status ValidateInputTypeAndPlacement(
// We are only here if the policy is warn or silent copies, so we should
// trigger a copy.
TF_Status* s = TF_NewStatus();
- TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice(
- handle, ctx, expected_device->name().c_str(), s);
+ tensorflow::TensorHandle* copied_tensor =
+ TFE_TensorHandleCopyToDevice_Internal(
+ handle, ctx, expected_device->name().c_str(), s);
tensorflow::Status status = s->status;
TF_DeleteStatus(s);
if (!status.ok()) {
@@ -616,9 +693,10 @@ tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
tensorflow::Status Execute(
TFE_Context* ctx, tensorflow::Device* device,
- const tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>& op_inputs,
+ const tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4>&
+ op_inputs,
tensorflow::KernelAndDevice* kernel, tensorflow::NodeExecStats* maybe_stats,
- TFE_TensorHandle** retvals, int num_retvals) {
+ tensorflow::TensorHandle** retvals, int num_retvals) {
if (!ctx->context.SoftPlacement() && device == nullptr) {
device = ctx->context.HostCPU();
}
@@ -683,7 +761,7 @@ tensorflow::Status Execute(
d = nullptr;
}
if (retvals[i] == nullptr) {
- retvals[i] = new TFE_TensorHandle(outputs[i], d, op_device);
+ retvals[i] = new tensorflow::TensorHandle(outputs[i], d, op_device);
} else {
retvals[i]->SetTensorAndDevice(outputs[i], d, op_device);
}
@@ -711,9 +789,10 @@ class ExecuteNode : public tensorflow::EagerNode {
}
TFE_Context* ctx = op->ctx;
for (int i = 0; i < num_retvals; ++i) {
- TFE_TensorHandle* h = new TFE_TensorHandle(id, output_dtypes[i], ctx);
+ tensorflow::TensorHandle* h =
+ new tensorflow::TensorHandle(id, output_dtypes[i], &ctx->context);
h->Ref();
- retvals[i] = h;
+ retvals[i] = new TFE_TensorHandle(h);
retvals_[i] = h;
}
}
@@ -745,54 +824,12 @@ class ExecuteNode : public tensorflow::EagerNode {
private:
TFE_Context* ctx_;
tensorflow::Device* op_device_;
- tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4> inputs_;
+ tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs_;
tensorflow::KernelAndDevice* kernel_;
std::unique_ptr<tensorflow::NodeExecStats> maybe_stats_;
- tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals_;
+ tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> retvals_;
};
-class CopyToDeviceNode : public tensorflow::EagerNode {
- public:
- CopyToDeviceNode(TFE_TensorHandle* src, tensorflow::Device* dstd,
- TFE_Context* ctx)
- : tensorflow::EagerNode(ctx->context.NextId()),
- src_(src),
- dstd_(dstd),
- ctx_(ctx),
- dst_(new TFE_TensorHandle(id, src_->dtype, ctx)) {
- src_->Ref();
- dst_->Ref();
- }
-
- ~CopyToDeviceNode() override {
- src_->Unref();
- dst_->Unref();
- }
-
- tensorflow::Status Run() override {
- TFE_TensorHandle* temp = nullptr;
- TF_RETURN_IF_ERROR(TensorHandleCopyToDevice(src_, ctx_, dstd_, &temp));
- const tensorflow::Tensor* tensor = nullptr;
- tensorflow::Device* device = nullptr;
- tensorflow::Device* op_device = nullptr;
- tensorflow::Status status =
- temp->TensorAndDevice(&tensor, &device, &op_device);
- // `temp` is a ready handle. So the following call should return OK.
- TF_DCHECK_OK(status) << status.error_message();
- DCHECK(tensor);
- dst_->SetTensorAndDevice(*tensor, device, op_device);
- temp->Unref();
- return tensorflow::Status::OK();
- }
-
- TFE_TensorHandle* dst() { return dst_; }
-
- private:
- TFE_TensorHandle* src_;
- tensorflow::Device* dstd_;
- TFE_Context* ctx_;
- TFE_TensorHandle* dst_;
-};
#ifdef TENSORFLOW_EAGER_USE_XLA
// Synthesizes and returns a wrapper function over `op`, which must be a
@@ -1140,11 +1177,14 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
} else {
// Execute checks if retvals[i] is nullptr or not to figure if it needs to
// allocate it.
+ std::vector<tensorflow::TensorHandle*> handle_retvals(*num_retvals,
+ nullptr);
+ status->status =
+ Execute(op->ctx, op->device, op->inputs, kernel, maybe_stats.get(),
+ handle_retvals.data(), *num_retvals);
for (int i = 0; i < *num_retvals; ++i) {
- retvals[i] = nullptr;
+ retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
}
- status->status = Execute(op->ctx, op->device, op->inputs, kernel,
- maybe_stats.get(), retvals, *num_retvals);
}
}
@@ -1152,30 +1192,12 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TFE_Context* ctx,
const char* device_name,
TF_Status* status) {
- status->status = ctx->context.GetStatus();
- if (!status->status.ok()) {
- return nullptr;
- }
- tensorflow::Device* dstd = ctx->context.HostCPU();
- if (device_name != nullptr && strlen(device_name) > 0) {
- status->status =
- ctx->context.device_mgr()->LookupDevice(device_name, &dstd);
- if (!status->status.ok()) return nullptr;
- }
- if (ctx->context.Async()) {
- // Note that `h` may not be currently ready. However execution order will
- // make sure that `h` is ready before the copy is actually done.
- CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx);
- TFE_TensorHandle* output = node->dst();
- // Note that calling Add makes `node` accessible by the EagerExecutor
- // thread. So further accesses need to be thread-safe.
- ctx->context.ExecutorAdd(node);
- return output;
- } else {
- TFE_TensorHandle* output = nullptr;
- status->status = TensorHandleCopyToDevice(h, ctx, dstd, &output);
- return output;
+ tensorflow::TensorHandle* handle = TFE_TensorHandleCopyToDevice_Internal(
+ h->handle, ctx, device_name, status);
+ if (status->status.ok()) {
+ return new TFE_TensorHandle(handle);
}
+ return nullptr;
}
void TFE_ContextAddFunctionDef(TFE_Context* ctx,
@@ -1214,7 +1236,7 @@ const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
tensorflow::Device* d = nullptr;
tensorflow::Device* op_device = nullptr;
const tensorflow::Tensor* t = nullptr;
- status->status = h->TensorAndDevice(&t, &d, &op_device);
+ status->status = h->handle->TensorAndDevice(&t, &d, &op_device);
if (!status->status.ok()) return nullptr;
if (d != nullptr) {
status->status = tensorflow::errors::FailedPrecondition(
@@ -1306,70 +1328,8 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
} // namespace tensorflow
-
-bool TFE_TensorHandle::IsReady() {
- if (node_id == 0) return true;
- tensorflow::mutex_lock l(ctx_mutex_);
- return ctx_ == nullptr;
-}
-
-tensorflow::Status TFE_TensorHandle::WaitReady() {
- if (node_id == 0) return tensorflow::Status::OK();
- tensorflow::EagerExecutor* executor = nullptr;
- {
- tensorflow::mutex_lock l(ctx_mutex_);
- if (ctx_ == nullptr) return tensorflow::Status::OK();
- executor = ctx_->context.Executor();
- }
- return executor->WaitFor(node_id);
-}
-
-tensorflow::Status TFE_TensorHandle::Tensor(const tensorflow::Tensor** t) {
- TF_RETURN_IF_ERROR(WaitReady());
- DCHECK(IsReady());
- *t = &tensor_;
- return tensorflow::Status::OK();
-}
-
-tensorflow::Status TFE_TensorHandle::Device(tensorflow::Device** d) {
- TF_RETURN_IF_ERROR(WaitReady());
- DCHECK(IsReady());
- *d = device_;
- return tensorflow::Status::OK();
-}
-
-tensorflow::Status TFE_TensorHandle::OpDevice(tensorflow::Device** d) {
- TF_RETURN_IF_ERROR(WaitReady());
- DCHECK(IsReady());
- *d = op_device_;
- return tensorflow::Status::OK();
-}
-
-tensorflow::Status TFE_TensorHandle::TensorAndDevice(
- const tensorflow::Tensor** tensor, tensorflow::Device** device,
- tensorflow::Device** op_device) {
- TF_RETURN_IF_ERROR(WaitReady());
- DCHECK(IsReady());
- *tensor = &tensor_;
- *device = device_;
- *op_device = op_device_;
- return tensorflow::Status::OK();
-}
-
-void TFE_TensorHandle::SetTensorAndDevice(const tensorflow::Tensor& tensor,
- tensorflow::Device* device,
- tensorflow::Device* op_device) {
- tensorflow::mutex_lock l(ctx_mutex_);
- DCHECK(node_id > 0 && ctx_) << "SetTensorAndDevice should be only called "
- << "on non-ready handles.";
- ctx_ = nullptr;
- tensor_ = tensor;
- device_ = device;
- op_device_ = op_device;
-}
-
TFE_Op::~TFE_Op() {
- for (TFE_TensorHandle* h : inputs) {
+ for (tensorflow::TensorHandle* h : inputs) {
h->Unref();
}
}
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 5b29120b40..e6d2ab75ff 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/rendezvous.h"
@@ -67,84 +68,18 @@ struct TFE_Context {
tensorflow::EagerContext context;
};
-struct TFE_TensorHandle : public tensorflow::core::RefCounted {
- public:
+struct TFE_TensorHandle {
TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d,
tensorflow::Device* op_device)
- : dtype(t.dtype()),
- node_id(0),
- tensor_(t),
- device_(d),
- op_device_(op_device),
- ctx_(nullptr) {}
+ : handle(new tensorflow::TensorHandle(t, d, op_device)) {}
TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype,
- TFE_Context* ctx)
- : dtype(dtype),
- node_id(node_id),
- tensor_(dtype),
- device_(nullptr),
- op_device_(nullptr),
- ctx_(ctx) {
- DCHECK_GT(node_id, 0);
- }
-
- ~TFE_TensorHandle() override {}
-
- tensorflow::Status Tensor(const tensorflow::Tensor** t);
-
- tensorflow::Status Device(tensorflow::Device** d);
-
- tensorflow::Status OpDevice(tensorflow::Device** d);
-
- tensorflow::Status TensorAndDevice(const tensorflow::Tensor** tensor,
- tensorflow::Device** device,
- tensorflow::Device** op_device);
-
- // Note that this can be called at most once, and only on non-ready handles,
- // and makes them ready.
- void SetTensorAndDevice(const tensorflow::Tensor& tensor,
- tensorflow::Device* device,
- tensorflow::Device* op_device);
-
- // dtype for the handle. It must be the same as t.dtype() once the handle is
- // ready.
- const tensorflow::DataType dtype;
-
- private:
- // If the contents of the Tensor pointed to by this handle is yet to be
- // computed by a EagerNode, this function will block till that compuatation is
- // done and the handle is "ready".
- tensorflow::Status WaitReady();
-
- bool IsReady();
-
- // Id for the EagerNode that will compute the value pointed to by this handle.
- // If the value is 0, the handle is already ready, but not vice-versa.
- const tensorflow::uint64 node_id;
-
- tensorflow::Tensor tensor_;
-
- // TODO(ashankar): device_ == nullptr iff local CPU
- // This was expedient, but perhaps worth revisiting ('device_' should always
- // be a valid pointer?)
- // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
- // provided with the appropriate TFE_Context.
- //
- // TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a
- // TFE_TensorHandle does not outlive the TFE_Context from which it came?
- tensorflow::Device* device_;
-
- // Device in which the op producing this tensor was executed. Equals to
- // device_ for constant tensors.
- tensorflow::Device* op_device_;
-
- tensorflow::mutex ctx_mutex_;
-
- // `ctx` is only guaranteed to be set if the handle is not "ready". This is
- // typically true when the handle was produced during async execution.
- // `ctx` object is not owned and should outlive this handle.
- TFE_Context* ctx_ GUARDED_BY(ctx_mutex_);
+ tensorflow::EagerContext* ctx)
+ : handle(new tensorflow::TensorHandle(node_id, dtype, ctx)) {}
+
+ TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {}
+
+ tensorflow::TensorHandle* handle;
};
struct TFE_Op {
@@ -161,7 +96,7 @@ struct TFE_Op {
const tensorflow::string name;
tensorflow::AttrBuilder attrs;
const tensorflow::AttrTypeMap* attr_types;
- tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4> inputs;
+ tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs;
tensorflow::Device* device;
bool use_xla = false;
};
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index de10b10b7e..02fb83200a 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -55,6 +55,29 @@ tf_cuda_library(
)
tf_cuda_library(
+ name = "tensor_handle",
+ srcs = [
+ "tensor_handle.cc",
+ ],
+ hdrs = [
+ "tensor_handle.h",
+ ],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":context",
+ ":eager_executor",
+ ":kernel_and_device",
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:session_options",
+ ],
+)
+
+tf_cuda_library(
name = "kernel_and_device",
srcs = [
"kernel_and_device.cc",
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
new file mode 100644
index 0000000000..5bc1700627
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -0,0 +1,107 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+
+#include <algorithm>
+#include <cstddef>
+#include <map>
+#include <memory>
+#include <queue>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/common_runtime/eager/eager_executor.h"
+#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+bool TensorHandle::IsReady() {
+ if (node_id == 0) return true;
+ mutex_lock l(ctx_mutex_);
+ return ctx_ == nullptr;
+}
+
+Status TensorHandle::WaitReady() {
+ if (node_id == 0) return Status::OK();
+ EagerExecutor* executor = nullptr;
+ {
+ mutex_lock l(ctx_mutex_);
+ if (ctx_ == nullptr) return Status::OK();
+ executor = ctx_->Executor();
+ }
+ return executor->WaitFor(node_id);
+}
+
+Status TensorHandle::Tensor(const tensorflow::Tensor** t) {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ *t = &tensor_;
+ return Status::OK();
+}
+
+Status TensorHandle::Device(tensorflow::Device** d) {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ *d = device_;
+ return Status::OK();
+}
+
+Status TensorHandle::OpDevice(tensorflow::Device** d) {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ *d = op_device_;
+ return Status::OK();
+}
+
+Status TensorHandle::TensorAndDevice(const tensorflow::Tensor** tensor,
+ tensorflow::Device** device,
+ tensorflow::Device** op_device) {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ *tensor = &tensor_;
+ *device = device_;
+ *op_device = op_device_;
+ return Status::OK();
+}
+
+void TensorHandle::SetTensorAndDevice(const tensorflow::Tensor& tensor,
+ tensorflow::Device* device,
+ tensorflow::Device* op_device) {
+ mutex_lock l(ctx_mutex_);
+ DCHECK(node_id > 0 && ctx_) << "SetTensorAndDevice should be only called "
+ << "on non-ready handles.";
+ ctx_ = nullptr;
+ tensor_ = tensor;
+ device_ = device;
+ op_device_ = op_device;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
new file mode 100644
index 0000000000..97e67e4652
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -0,0 +1,130 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
+
+#include <algorithm>
+#include <cstddef>
+#include <map>
+#include <memory>
+#include <queue>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/common_runtime/eager/eager_executor.h"
+#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+// Associates a Tensor and a Device, used in the eager runtime. Internal version
+// executor_of the TFE_TensorHandle struct and the python EagerTensor class
+// (unrelated to python TensorHandle).
+class TensorHandle : public core::RefCounted {
+ public:
+ TensorHandle(const Tensor& t, Device* d, Device* op_device)
+ : dtype(t.dtype()),
+ node_id(0),
+ tensor_(t),
+ device_(d),
+ op_device_(op_device),
+ ctx_(nullptr) {}
+
+ TensorHandle(uint64 node_id, DataType dtype, EagerContext* ctx)
+ : dtype(dtype),
+ node_id(node_id),
+ tensor_(dtype),
+ device_(nullptr),
+ op_device_(nullptr),
+ ctx_(ctx) {
+ DCHECK_GT(node_id, 0);
+ }
+
+ ~TensorHandle() override {}
+
+ Status Tensor(const tensorflow::Tensor** t);
+
+ Status Device(tensorflow::Device** d);
+
+ Status OpDevice(tensorflow::Device** d);
+
+ Status TensorAndDevice(const tensorflow::Tensor** tensor,
+ tensorflow::Device** device,
+ tensorflow::Device** op_device);
+
+ // Note that this can be called at most once, and only on non-ready handles,
+ // and makes them ready.
+ void SetTensorAndDevice(const tensorflow::Tensor& tensor,
+ tensorflow::Device* device,
+ tensorflow::Device* op_device);
+
+ // dtype for the handle. It must be the same as t.dtype() once the handle is
+ // ready.
+ const DataType dtype;
+
+ private:
+ // If the contents of the Tensor pointed to by this handle is yet to be
+ // computed by a EagerNode, this function will block till that compuatation is
+ // done and the handle is "ready".
+ Status WaitReady();
+
+ bool IsReady();
+
+ // Id for the EagerNode that will compute the value pointed to by this handle.
+ // If the value is 0, the handle is already ready, but not vice-versa.
+ const uint64 node_id;
+
+ tensorflow::Tensor tensor_;
+
+ // TODO(ashankar): device_ == nullptr iff local CPU
+ // This was expedient, but perhaps worth revisiting ('device_' should always
+ // be a valid pointer?)
+ // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
+ // provided with the appropriate TFE_Context.
+ //
+ // TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a
+ // TFE_TensorHandle does not outlive the TFE_Context from which it came?
+ tensorflow::Device* device_;
+
+ // Device in which the op producing this tensor was executed. Equals to
+ // device_ for constant tensors.
+ tensorflow::Device* op_device_;
+
+ mutex ctx_mutex_;
+
+ // `ctx` is only guaranteed to be set if the handle is not "ready". This is
+ // typically true when the handle was produced during async execution.
+ // `ctx` object is not owned and should outlive this handle.
+ EagerContext* ctx_ GUARDED_BY(ctx_mutex_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 701f68b8f7..55ba509065 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1013,12 +1013,13 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
TFE_TensorHandle* t = EagerTensor_Handle(tensor);
tensorflow::int64 id = EagerTensor_id(tensor);
const tensorflow::Tensor* tensor = nullptr;
- const tensorflow::Status status = t->Tensor(&tensor);
+ const tensorflow::Status status = t->handle->Tensor(&tensor);
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
- return tensorflow::eager::TapeTensor{id, t->dtype,
+ return tensorflow::eager::TapeTensor{id, t->handle->dtype,
tensorflow::TensorShape({})};
} else {
- return tensorflow::eager::TapeTensor{id, t->dtype, tensor->shape()};
+ return tensorflow::eager::TapeTensor{id, t->handle->dtype,
+ tensor->shape()};
}
}
tensorflow::int64 id = FastTensorId(tensor);
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc
index 02eafd42b3..22317a348c 100644
--- a/tensorflow/python/lib/core/py_func.cc
+++ b/tensorflow/python/lib/core/py_func.cc
@@ -166,7 +166,7 @@ bool IsSingleNone(PyObject* obj) {
// Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`.
tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
const Tensor** output_tensor) {
- return EagerTensor_Handle(eager_tensor)->Tensor(output_tensor);
+ return EagerTensor_Handle(eager_tensor)->handle->Tensor(output_tensor);
}
// Calls the registered py function through the trampoline.