diff options
-rw-r--r-- | tensorflow/c/eager/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api.cc | 272 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api_internal.h | 85 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/BUILD | 23 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/tensor_handle.cc | 107 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/tensor_handle.h | 130 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe_src.cc | 7 | ||||
-rw-r--r-- | tensorflow/python/lib/core/py_func.cc | 2 |
8 files changed, 393 insertions, 235 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index bea5a121b3..d2d8d59323 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -31,6 +31,7 @@ tf_cuda_library( "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:kernel_and_device", + "//tensorflow/core/common_runtime/eager:tensor_handle", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -68,6 +69,7 @@ tf_cuda_library( "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:kernel_and_device", + "//tensorflow/core/common_runtime/eager:tensor_handle", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 2402a6d044..59432f2ef8 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -161,29 +161,32 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { DCHECK(h); - h->Unref(); + if (h->handle) { + h->handle->Unref(); + } + delete h; } TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { - return static_cast<TF_DataType>(h->dtype); + return static_cast<TF_DataType>(h->handle->dtype); } int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { const tensorflow::Tensor* t = nullptr; - status->status = h->Tensor(&t); + status->status = h->handle->Tensor(&t); return t == nullptr ? 0 : t->dims(); } int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, TF_Status* status) { const tensorflow::Tensor* t = nullptr; - status->status = h->Tensor(&t); + status->status = h->handle->Tensor(&t); return t == nullptr ? 0 : t->dim_size(dim_index); } const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { tensorflow::Device* d = nullptr; - status->status = h->OpDevice(&d); + status->status = h->handle->OpDevice(&d); return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" : d->name().c_str(); } @@ -193,7 +196,7 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { tensorflow::Device* d = nullptr; tensorflow::Device* op_device = nullptr; const tensorflow::Tensor* t = nullptr; - status->status = h->TensorAndDevice(&t, &d, &op_device); + status->status = h->handle->TensorAndDevice(&t, &d, &op_device); if (!status->status.ok()) return nullptr; if (!IsCPU(d)) { TF_SetStatus(status, TF_UNIMPLEMENTED, @@ -212,10 +215,10 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { namespace { -tensorflow::Status TensorHandleCopyToDevice(TFE_TensorHandle* h, +tensorflow::Status TensorHandleCopyToDevice(tensorflow::TensorHandle* h, TFE_Context* ctx, tensorflow::Device* dstd, - TFE_TensorHandle** output) { + tensorflow::TensorHandle** output) { const tensorflow::Tensor* src = nullptr; tensorflow::Device* srcd = nullptr; // TODO(agarwal): src_opd is unused. Perhaps allow TensorAndDevice to accept @@ -232,7 +235,7 @@ tensorflow::Status TensorHandleCopyToDevice(TFE_TensorHandle* h, const bool both_on_cpu = src_cpu && dst_cpu; if (is_same_device || both_on_cpu) { dstd = dst_cpu ? nullptr : dstd; - *output = new TFE_TensorHandle(*src, dstd, dstd); + *output = new tensorflow::TensorHandle(*src, dstd, dstd); return tensorflow::Status::OK(); } if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT && @@ -249,7 +252,7 @@ tensorflow::Status TensorHandleCopyToDevice(TFE_TensorHandle* h, tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape()); if (src->shape().num_elements() == 0) { dstd = dst_cpu ? nullptr : dstd; - *output = new TFE_TensorHandle(dst, dstd, dstd); + *output = new tensorflow::TensorHandle(dst, dstd, dstd); return tensorflow::Status::OK(); } tensorflow::DeviceContext* src_device_context = nullptr; @@ -280,7 +283,7 @@ tensorflow::Status TensorHandleCopyToDevice(TFE_TensorHandle* h, n.WaitForNotification(); if (status.ok()) { dstd = dst_cpu ? nullptr : dstd; - *output = new TFE_TensorHandle(dst, dstd, dstd); + *output = new tensorflow::TensorHandle(dst, dstd, dstd); } return status; } @@ -335,12 +338,12 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { tensorflow::Device* d = nullptr; // TODO(agarwal): This call may block if h is not ready. Avoid this if // possible. - status->status = h->Device(&d); + status->status = h->handle->Device(&d); if (!status->status.ok()) return; if (!IsCPU(d)) op->device = d; } - h->Ref(); - op->inputs.push_back(h); + h->handle->Ref(); + op->inputs.push_back(h->handle); op->attrs.NumInputs(op->inputs.size()); } @@ -506,6 +509,79 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, namespace { +class CopyToDeviceNode : public tensorflow::EagerNode { + public: + CopyToDeviceNode(tensorflow::TensorHandle* src, tensorflow::Device* dstd, + TFE_Context* ctx) + : tensorflow::EagerNode(ctx->context.NextId()), + src_(src), + dstd_(dstd), + ctx_(ctx), + dst_(new tensorflow::TensorHandle(id, src_->dtype, &ctx->context)) { + src_->Ref(); + dst_->Ref(); + } + + ~CopyToDeviceNode() override { + src_->Unref(); + dst_->Unref(); + } + + tensorflow::Status Run() override { + tensorflow::TensorHandle* temp = nullptr; + TF_RETURN_IF_ERROR(TensorHandleCopyToDevice(src_, ctx_, dstd_, &temp)); + const tensorflow::Tensor* tensor = nullptr; + tensorflow::Device* device = nullptr; + tensorflow::Device* op_device = nullptr; + tensorflow::Status status = + temp->TensorAndDevice(&tensor, &device, &op_device); + // `temp` is a ready handle. So the following call should return OK. + TF_DCHECK_OK(status) << status.error_message(); + DCHECK(tensor); + dst_->SetTensorAndDevice(*tensor, device, op_device); + temp->Unref(); + return tensorflow::Status::OK(); + } + + tensorflow::TensorHandle* dst() { return dst_; } + + private: + tensorflow::TensorHandle* src_; + tensorflow::Device* dstd_; + TFE_Context* ctx_; + tensorflow::TensorHandle* dst_; +}; + +// TODO(apassos) move to TensorHandle +tensorflow::TensorHandle* TFE_TensorHandleCopyToDevice_Internal( + tensorflow::TensorHandle* h, TFE_Context* ctx, const char* device_name, + TF_Status* status) { + status->status = ctx->context.GetStatus(); + if (!status->status.ok()) { + return nullptr; + } + tensorflow::Device* dstd = ctx->context.HostCPU(); + if (device_name != nullptr && strlen(device_name) > 0) { + status->status = + ctx->context.device_mgr()->LookupDevice(device_name, &dstd); + if (!status->status.ok()) return nullptr; + } + if (ctx->context.Async()) { + // Note that `h` may not be currently ready. However execution order will + // make sure that `h` is ready before the copy is actually done. + CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx); + tensorflow::TensorHandle* output = node->dst(); + // Note that calling Add makes `node` accessible by the EagerExecutor + // thread. So further accesses need to be thread-safe. + ctx->context.ExecutorAdd(node); + return output; + } else { + tensorflow::TensorHandle* output = nullptr; + status->status = TensorHandleCopyToDevice(h, ctx, dstd, &output); + return output; + } +} + tensorflow::Status ValidateInputTypeAndPlacement( TFE_Context* ctx, tensorflow::Device* host_device, tensorflow::Device* op_device, TFE_Op* op, @@ -518,7 +594,7 @@ tensorflow::Status ValidateInputTypeAndPlacement( for (int i = 0; i < op->inputs.size(); ++i) { const tensorflow::Device* expected_device = memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device; - TFE_TensorHandle* handle = op->inputs[i]; + tensorflow::TensorHandle* handle = op->inputs[i]; tensorflow::Device* handle_device = nullptr; TF_RETURN_IF_ERROR(handle->Device(&handle_device)); const tensorflow::Device* actual_device = @@ -560,8 +636,9 @@ tensorflow::Status ValidateInputTypeAndPlacement( // We are only here if the policy is warn or silent copies, so we should // trigger a copy. TF_Status* s = TF_NewStatus(); - TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice( - handle, ctx, expected_device->name().c_str(), s); + tensorflow::TensorHandle* copied_tensor = + TFE_TensorHandleCopyToDevice_Internal( + handle, ctx, expected_device->name().c_str(), s); tensorflow::Status status = s->status; TF_DeleteStatus(s); if (!status.ok()) { @@ -616,9 +693,10 @@ tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef, tensorflow::Status Execute( TFE_Context* ctx, tensorflow::Device* device, - const tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4>& op_inputs, + const tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4>& + op_inputs, tensorflow::KernelAndDevice* kernel, tensorflow::NodeExecStats* maybe_stats, - TFE_TensorHandle** retvals, int num_retvals) { + tensorflow::TensorHandle** retvals, int num_retvals) { if (!ctx->context.SoftPlacement() && device == nullptr) { device = ctx->context.HostCPU(); } @@ -683,7 +761,7 @@ tensorflow::Status Execute( d = nullptr; } if (retvals[i] == nullptr) { - retvals[i] = new TFE_TensorHandle(outputs[i], d, op_device); + retvals[i] = new tensorflow::TensorHandle(outputs[i], d, op_device); } else { retvals[i]->SetTensorAndDevice(outputs[i], d, op_device); } @@ -711,9 +789,10 @@ class ExecuteNode : public tensorflow::EagerNode { } TFE_Context* ctx = op->ctx; for (int i = 0; i < num_retvals; ++i) { - TFE_TensorHandle* h = new TFE_TensorHandle(id, output_dtypes[i], ctx); + tensorflow::TensorHandle* h = + new tensorflow::TensorHandle(id, output_dtypes[i], &ctx->context); h->Ref(); - retvals[i] = h; + retvals[i] = new TFE_TensorHandle(h); retvals_[i] = h; } } @@ -745,54 +824,12 @@ class ExecuteNode : public tensorflow::EagerNode { private: TFE_Context* ctx_; tensorflow::Device* op_device_; - tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4> inputs_; + tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs_; tensorflow::KernelAndDevice* kernel_; std::unique_ptr<tensorflow::NodeExecStats> maybe_stats_; - tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals_; + tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> retvals_; }; -class CopyToDeviceNode : public tensorflow::EagerNode { - public: - CopyToDeviceNode(TFE_TensorHandle* src, tensorflow::Device* dstd, - TFE_Context* ctx) - : tensorflow::EagerNode(ctx->context.NextId()), - src_(src), - dstd_(dstd), - ctx_(ctx), - dst_(new TFE_TensorHandle(id, src_->dtype, ctx)) { - src_->Ref(); - dst_->Ref(); - } - - ~CopyToDeviceNode() override { - src_->Unref(); - dst_->Unref(); - } - - tensorflow::Status Run() override { - TFE_TensorHandle* temp = nullptr; - TF_RETURN_IF_ERROR(TensorHandleCopyToDevice(src_, ctx_, dstd_, &temp)); - const tensorflow::Tensor* tensor = nullptr; - tensorflow::Device* device = nullptr; - tensorflow::Device* op_device = nullptr; - tensorflow::Status status = - temp->TensorAndDevice(&tensor, &device, &op_device); - // `temp` is a ready handle. So the following call should return OK. - TF_DCHECK_OK(status) << status.error_message(); - DCHECK(tensor); - dst_->SetTensorAndDevice(*tensor, device, op_device); - temp->Unref(); - return tensorflow::Status::OK(); - } - - TFE_TensorHandle* dst() { return dst_; } - - private: - TFE_TensorHandle* src_; - tensorflow::Device* dstd_; - TFE_Context* ctx_; - TFE_TensorHandle* dst_; -}; #ifdef TENSORFLOW_EAGER_USE_XLA // Synthesizes and returns a wrapper function over `op`, which must be a @@ -1140,11 +1177,14 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, } else { // Execute checks if retvals[i] is nullptr or not to figure if it needs to // allocate it. + std::vector<tensorflow::TensorHandle*> handle_retvals(*num_retvals, + nullptr); + status->status = + Execute(op->ctx, op->device, op->inputs, kernel, maybe_stats.get(), + handle_retvals.data(), *num_retvals); for (int i = 0; i < *num_retvals; ++i) { - retvals[i] = nullptr; + retvals[i] = new TFE_TensorHandle(handle_retvals[i]); } - status->status = Execute(op->ctx, op->device, op->inputs, kernel, - maybe_stats.get(), retvals, *num_retvals); } } @@ -1152,30 +1192,12 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, TF_Status* status) { - status->status = ctx->context.GetStatus(); - if (!status->status.ok()) { - return nullptr; - } - tensorflow::Device* dstd = ctx->context.HostCPU(); - if (device_name != nullptr && strlen(device_name) > 0) { - status->status = - ctx->context.device_mgr()->LookupDevice(device_name, &dstd); - if (!status->status.ok()) return nullptr; - } - if (ctx->context.Async()) { - // Note that `h` may not be currently ready. However execution order will - // make sure that `h` is ready before the copy is actually done. - CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx); - TFE_TensorHandle* output = node->dst(); - // Note that calling Add makes `node` accessible by the EagerExecutor - // thread. So further accesses need to be thread-safe. - ctx->context.ExecutorAdd(node); - return output; - } else { - TFE_TensorHandle* output = nullptr; - status->status = TensorHandleCopyToDevice(h, ctx, dstd, &output); - return output; + tensorflow::TensorHandle* handle = TFE_TensorHandleCopyToDevice_Internal( + h->handle, ctx, device_name, status); + if (status->status.ok()) { + return new TFE_TensorHandle(handle); } + return nullptr; } void TFE_ContextAddFunctionDef(TFE_Context* ctx, @@ -1214,7 +1236,7 @@ const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( tensorflow::Device* d = nullptr; tensorflow::Device* op_device = nullptr; const tensorflow::Tensor* t = nullptr; - status->status = h->TensorAndDevice(&t, &d, &op_device); + status->status = h->handle->TensorAndDevice(&t, &d, &op_device); if (!status->status.ok()) return nullptr; if (d != nullptr) { status->status = tensorflow::errors::FailedPrecondition( @@ -1306,70 +1328,8 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, } // namespace tensorflow - -bool TFE_TensorHandle::IsReady() { - if (node_id == 0) return true; - tensorflow::mutex_lock l(ctx_mutex_); - return ctx_ == nullptr; -} - -tensorflow::Status TFE_TensorHandle::WaitReady() { - if (node_id == 0) return tensorflow::Status::OK(); - tensorflow::EagerExecutor* executor = nullptr; - { - tensorflow::mutex_lock l(ctx_mutex_); - if (ctx_ == nullptr) return tensorflow::Status::OK(); - executor = ctx_->context.Executor(); - } - return executor->WaitFor(node_id); -} - -tensorflow::Status TFE_TensorHandle::Tensor(const tensorflow::Tensor** t) { - TF_RETURN_IF_ERROR(WaitReady()); - DCHECK(IsReady()); - *t = &tensor_; - return tensorflow::Status::OK(); -} - -tensorflow::Status TFE_TensorHandle::Device(tensorflow::Device** d) { - TF_RETURN_IF_ERROR(WaitReady()); - DCHECK(IsReady()); - *d = device_; - return tensorflow::Status::OK(); -} - -tensorflow::Status TFE_TensorHandle::OpDevice(tensorflow::Device** d) { - TF_RETURN_IF_ERROR(WaitReady()); - DCHECK(IsReady()); - *d = op_device_; - return tensorflow::Status::OK(); -} - -tensorflow::Status TFE_TensorHandle::TensorAndDevice( - const tensorflow::Tensor** tensor, tensorflow::Device** device, - tensorflow::Device** op_device) { - TF_RETURN_IF_ERROR(WaitReady()); - DCHECK(IsReady()); - *tensor = &tensor_; - *device = device_; - *op_device = op_device_; - return tensorflow::Status::OK(); -} - -void TFE_TensorHandle::SetTensorAndDevice(const tensorflow::Tensor& tensor, - tensorflow::Device* device, - tensorflow::Device* op_device) { - tensorflow::mutex_lock l(ctx_mutex_); - DCHECK(node_id > 0 && ctx_) << "SetTensorAndDevice should be only called " - << "on non-ready handles."; - ctx_ = nullptr; - tensor_ = tensor; - device_ = device; - op_device_ = op_device; -} - TFE_Op::~TFE_Op() { - for (TFE_TensorHandle* h : inputs) { + for (tensorflow::TensorHandle* h : inputs) { h->Unref(); } } diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 5b29120b40..e6d2ab75ff 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/rendezvous.h" @@ -67,84 +68,18 @@ struct TFE_Context { tensorflow::EagerContext context; }; -struct TFE_TensorHandle : public tensorflow::core::RefCounted { - public: +struct TFE_TensorHandle { TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d, tensorflow::Device* op_device) - : dtype(t.dtype()), - node_id(0), - tensor_(t), - device_(d), - op_device_(op_device), - ctx_(nullptr) {} + : handle(new tensorflow::TensorHandle(t, d, op_device)) {} TFE_TensorHandle(tensorflow::uint64 node_id, tensorflow::DataType dtype, - TFE_Context* ctx) - : dtype(dtype), - node_id(node_id), - tensor_(dtype), - device_(nullptr), - op_device_(nullptr), - ctx_(ctx) { - DCHECK_GT(node_id, 0); - } - - ~TFE_TensorHandle() override {} - - tensorflow::Status Tensor(const tensorflow::Tensor** t); - - tensorflow::Status Device(tensorflow::Device** d); - - tensorflow::Status OpDevice(tensorflow::Device** d); - - tensorflow::Status TensorAndDevice(const tensorflow::Tensor** tensor, - tensorflow::Device** device, - tensorflow::Device** op_device); - - // Note that this can be called at most once, and only on non-ready handles, - // and makes them ready. - void SetTensorAndDevice(const tensorflow::Tensor& tensor, - tensorflow::Device* device, - tensorflow::Device* op_device); - - // dtype for the handle. It must be the same as t.dtype() once the handle is - // ready. - const tensorflow::DataType dtype; - - private: - // If the contents of the Tensor pointed to by this handle is yet to be - // computed by a EagerNode, this function will block till that compuatation is - // done and the handle is "ready". - tensorflow::Status WaitReady(); - - bool IsReady(); - - // Id for the EagerNode that will compute the value pointed to by this handle. - // If the value is 0, the handle is already ready, but not vice-versa. - const tensorflow::uint64 node_id; - - tensorflow::Tensor tensor_; - - // TODO(ashankar): device_ == nullptr iff local CPU - // This was expedient, but perhaps worth revisiting ('device_' should always - // be a valid pointer?) - // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are - // provided with the appropriate TFE_Context. - // - // TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a - // TFE_TensorHandle does not outlive the TFE_Context from which it came? - tensorflow::Device* device_; - - // Device in which the op producing this tensor was executed. Equals to - // device_ for constant tensors. - tensorflow::Device* op_device_; - - tensorflow::mutex ctx_mutex_; - - // `ctx` is only guaranteed to be set if the handle is not "ready". This is - // typically true when the handle was produced during async execution. - // `ctx` object is not owned and should outlive this handle. - TFE_Context* ctx_ GUARDED_BY(ctx_mutex_); + tensorflow::EagerContext* ctx) + : handle(new tensorflow::TensorHandle(node_id, dtype, ctx)) {} + + TFE_TensorHandle(tensorflow::TensorHandle* handle) : handle(handle) {} + + tensorflow::TensorHandle* handle; }; struct TFE_Op { @@ -161,7 +96,7 @@ struct TFE_Op { const tensorflow::string name; tensorflow::AttrBuilder attrs; const tensorflow::AttrTypeMap* attr_types; - tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 4> inputs; + tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs; tensorflow::Device* device; bool use_xla = false; }; diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index de10b10b7e..02fb83200a 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -55,6 +55,29 @@ tf_cuda_library( ) tf_cuda_library( + name = "tensor_handle", + srcs = [ + "tensor_handle.cc", + ], + hdrs = [ + "tensor_handle.h", + ], + visibility = ["//tensorflow:internal"], + deps = [ + ":context", + ":eager_executor", + ":kernel_and_device", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + ], +) + +tf_cuda_library( name = "kernel_and_device", srcs = [ "kernel_and_device.cc", diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc new file mode 100644 index 0000000000..5bc1700627 --- /dev/null +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -0,0 +1,107 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" + +#include <algorithm> +#include <cstddef> +#include <map> +#include <memory> +#include <queue> +#include <string> +#include <vector> + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +bool TensorHandle::IsReady() { + if (node_id == 0) return true; + mutex_lock l(ctx_mutex_); + return ctx_ == nullptr; +} + +Status TensorHandle::WaitReady() { + if (node_id == 0) return Status::OK(); + EagerExecutor* executor = nullptr; + { + mutex_lock l(ctx_mutex_); + if (ctx_ == nullptr) return Status::OK(); + executor = ctx_->Executor(); + } + return executor->WaitFor(node_id); +} + +Status TensorHandle::Tensor(const tensorflow::Tensor** t) { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + *t = &tensor_; + return Status::OK(); +} + +Status TensorHandle::Device(tensorflow::Device** d) { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + *d = device_; + return Status::OK(); +} + +Status TensorHandle::OpDevice(tensorflow::Device** d) { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + *d = op_device_; + return Status::OK(); +} + +Status TensorHandle::TensorAndDevice(const tensorflow::Tensor** tensor, + tensorflow::Device** device, + tensorflow::Device** op_device) { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + *tensor = &tensor_; + *device = device_; + *op_device = op_device_; + return Status::OK(); +} + +void TensorHandle::SetTensorAndDevice(const tensorflow::Tensor& tensor, + tensorflow::Device* device, + tensorflow::Device* op_device) { + mutex_lock l(ctx_mutex_); + DCHECK(node_id > 0 && ctx_) << "SetTensorAndDevice should be only called " + << "on non-ready handles."; + ctx_ = nullptr; + tensor_ = tensor; + device_ = device; + op_device_ = op_device; +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h new file mode 100644 index 0000000000..97e67e4652 --- /dev/null +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -0,0 +1,130 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_ + +#include <algorithm> +#include <cstddef> +#include <map> +#include <memory> +#include <queue> +#include <string> +#include <vector> + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +// Associates a Tensor and a Device, used in the eager runtime. Internal version +// executor_of the TFE_TensorHandle struct and the python EagerTensor class +// (unrelated to python TensorHandle). +class TensorHandle : public core::RefCounted { + public: + TensorHandle(const Tensor& t, Device* d, Device* op_device) + : dtype(t.dtype()), + node_id(0), + tensor_(t), + device_(d), + op_device_(op_device), + ctx_(nullptr) {} + + TensorHandle(uint64 node_id, DataType dtype, EagerContext* ctx) + : dtype(dtype), + node_id(node_id), + tensor_(dtype), + device_(nullptr), + op_device_(nullptr), + ctx_(ctx) { + DCHECK_GT(node_id, 0); + } + + ~TensorHandle() override {} + + Status Tensor(const tensorflow::Tensor** t); + + Status Device(tensorflow::Device** d); + + Status OpDevice(tensorflow::Device** d); + + Status TensorAndDevice(const tensorflow::Tensor** tensor, + tensorflow::Device** device, + tensorflow::Device** op_device); + + // Note that this can be called at most once, and only on non-ready handles, + // and makes them ready. + void SetTensorAndDevice(const tensorflow::Tensor& tensor, + tensorflow::Device* device, + tensorflow::Device* op_device); + + // dtype for the handle. It must be the same as t.dtype() once the handle is + // ready. + const DataType dtype; + + private: + // If the contents of the Tensor pointed to by this handle is yet to be + // computed by a EagerNode, this function will block till that compuatation is + // done and the handle is "ready". + Status WaitReady(); + + bool IsReady(); + + // Id for the EagerNode that will compute the value pointed to by this handle. + // If the value is 0, the handle is already ready, but not vice-versa. + const uint64 node_id; + + tensorflow::Tensor tensor_; + + // TODO(ashankar): device_ == nullptr iff local CPU + // This was expedient, but perhaps worth revisiting ('device_' should always + // be a valid pointer?) + // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are + // provided with the appropriate TFE_Context. + // + // TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a + // TFE_TensorHandle does not outlive the TFE_Context from which it came? + tensorflow::Device* device_; + + // Device in which the op producing this tensor was executed. Equals to + // device_ for constant tensors. + tensorflow::Device* op_device_; + + mutex ctx_mutex_; + + // `ctx` is only guaranteed to be set if the handle is not "ready". This is + // typically true when the handle was produced during async execution. + // `ctx` object is not owned and should outlive this handle. + EagerContext* ctx_ GUARDED_BY(ctx_mutex_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_ diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 701f68b8f7..55ba509065 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1013,12 +1013,13 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) { TFE_TensorHandle* t = EagerTensor_Handle(tensor); tensorflow::int64 id = EagerTensor_id(tensor); const tensorflow::Tensor* tensor = nullptr; - const tensorflow::Status status = t->Tensor(&tensor); + const tensorflow::Status status = t->handle->Tensor(&tensor); if (MaybeRaiseExceptionFromStatus(status, nullptr)) { - return tensorflow::eager::TapeTensor{id, t->dtype, + return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensorflow::TensorShape({})}; } else { - return tensorflow::eager::TapeTensor{id, t->dtype, tensor->shape()}; + return tensorflow::eager::TapeTensor{id, t->handle->dtype, + tensor->shape()}; } } tensorflow::int64 id = FastTensorId(tensor); diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 02eafd42b3..22317a348c 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -166,7 +166,7 @@ bool IsSingleNone(PyObject* obj) { // Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`. tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor, const Tensor** output_tensor) { - return EagerTensor_Handle(eager_tensor)->Tensor(output_tensor); + return EagerTensor_Handle(eager_tensor)->handle->Tensor(output_tensor); } // Calls the registered py function through the trampoline. |