diff options
author | 2018-03-28 08:25:14 -0700 | |
---|---|---|
committer | 2018-03-28 08:27:47 -0700 | |
commit | 119ed5aa2acb6df04595835f6dfa99f5422449f2 (patch) | |
tree | 532d10c5e27ad23bbeddd8211480f81fb9f4ea65 /tensorflow/c/eager | |
parent | 134f4ca0a70ef0373f5436b890be0f8585badb34 (diff) |
Move ExecuteNode and CopyToDevice_Internal
PiperOrigin-RevId: 190775681
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r-- | tensorflow/c/eager/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api.cc | 139 |
2 files changed, 28 insertions, 112 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index e57011a08b..a2d96357ac 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -31,6 +31,7 @@ tf_cuda_library( "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:execute", + "//tensorflow/core/common_runtime/eager:execute_node", "//tensorflow/core/common_runtime/eager:kernel_and_device", "//tensorflow/core/common_runtime/eager:tensor_handle", "//tensorflow/core/common_runtime/eager:copy_to_device_node", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index ac7114f71e..028865d360 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h" #include "tensorflow/core/common_runtime/eager/execute.h" +#include "tensorflow/core/common_runtime/eager/execute_node.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/node_def_util.h" @@ -435,39 +436,8 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, namespace { -// TODO(apassos) move to TensorHandle -tensorflow::TensorHandle* TFE_TensorHandleCopyToDevice_Internal( - tensorflow::TensorHandle* h, TFE_Context* ctx, const char* device_name, - TF_Status* status) { - status->status = ctx->context.GetStatus(); - if (!status->status.ok()) { - return nullptr; - } - tensorflow::Device* dstd = ctx->context.HostCPU(); - if (device_name != nullptr && strlen(device_name) > 0) { - status->status = - ctx->context.device_mgr()->LookupDevice(device_name, &dstd); - if (!status->status.ok()) return nullptr; - } - if (ctx->context.Async()) { - // Note that `h` may not be currently ready. However execution order will - // make sure that `h` is ready before the copy is actually done. - tensorflow::CopyToDeviceNode* node = - new tensorflow::CopyToDeviceNode(h, dstd, &ctx->context); - tensorflow::TensorHandle* output = node->dst(); - // Note that calling Add makes `node` accessible by the EagerExecutor - // thread. So further accesses need to be thread-safe. - ctx->context.ExecutorAdd(node); - return output; - } else { - tensorflow::TensorHandle* output = nullptr; - status->status = h->CopyToDevice(&ctx->context, dstd, &output); - return output; - } -} - tensorflow::Status ValidateInputTypeAndPlacement( - TFE_Context* ctx, tensorflow::Device* host_device, + tensorflow::EagerContext* ctx, tensorflow::Device* host_device, tensorflow::Device* op_device, TFE_Op* op, const tensorflow::OpKernel* kernel) { const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types(); @@ -484,8 +454,8 @@ tensorflow::Status ValidateInputTypeAndPlacement( const tensorflow::Device* actual_device = handle_device == nullptr ? host_device : handle_device; if (expected_device != actual_device) { - switch (TFE_ContextGetDevicePlacementPolicy(ctx)) { - case TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32: + switch (ctx->GetDevicePlacementPolicy()) { + case tensorflow::DEVICE_PLACEMENT_SILENT_FOR_INT32: // TODO(xpan): See if we could bubble python related error up // to python level. if (handle->dtype == tensorflow::DT_INT32) { @@ -494,7 +464,7 @@ tensorflow::Status ValidateInputTypeAndPlacement( break; } TF_FALLTHROUGH_INTENDED; - case TFE_DEVICE_PLACEMENT_EXPLICIT: + case tensorflow::DEVICE_PLACEMENT_EXPLICIT: return tensorflow::errors::InvalidArgument( "Tensors on conflicting devices:" " cannot compute ", @@ -506,7 +476,7 @@ tensorflow::Status ValidateInputTypeAndPlacement( " or transparently copied by using tfe.enable_eager_execution(" "tfe.DEVICE_PLACEMENT_SILENT). Copying tensors between devices" " may slow down your model"); - case TFE_DEVICE_PLACEMENT_WARN: + case tensorflow::DEVICE_PLACEMENT_WARN: LOG(WARNING) << "before computing " << op->name << " input #" << i << " was expected to be on " << expected_device->name() << " but is actually on " << actual_device->name() @@ -514,17 +484,14 @@ tensorflow::Status ValidateInputTypeAndPlacement( << "). This triggers a copy which can be a performance " "bottleneck."; break; - case TFE_DEVICE_PLACEMENT_SILENT: // Do nothing. + case tensorflow::DEVICE_PLACEMENT_SILENT: // Do nothing. break; } // We are only here if the policy is warn or silent copies, so we should // trigger a copy. - TF_Status* s = TF_NewStatus(); - tensorflow::TensorHandle* copied_tensor = - TFE_TensorHandleCopyToDevice_Internal( - handle, ctx, expected_device->name().c_str(), s); - tensorflow::Status status = s->status; - TF_DeleteStatus(s); + tensorflow::TensorHandle* copied_tensor = nullptr; + tensorflow::Status status = tensorflow::EagerCopyToDevice( + handle, ctx, expected_device->name().c_str(), &copied_tensor); if (!status.ok()) { if (copied_tensor != nullptr) copied_tensor->Unref(); return tensorflow::errors::Internal( @@ -576,68 +543,6 @@ tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef, } -// TODO(agarwal): move EagerExecutor and EagerNode related code to a separate -// file. -class ExecuteNode : public tensorflow::EagerNode { - public: - ExecuteNode(TFE_Op* op, tensorflow::KernelAndDevice* kernel, - tensorflow::NodeExecStats* maybe_stats, - const tensorflow::DataTypeVector& output_dtypes, - TFE_TensorHandle** retvals, int num_retvals) - : tensorflow::EagerNode(op->ctx->context.NextId()), - ctx_(op->ctx), - op_device_(op->device), - inputs_(op->inputs), - kernel_(kernel), - maybe_stats_(maybe_stats), - retvals_(num_retvals) { - for (auto handle : inputs_) { - handle->Ref(); - } - TFE_Context* ctx = op->ctx; - for (int i = 0; i < num_retvals; ++i) { - tensorflow::TensorHandle* h = - new tensorflow::TensorHandle(id, output_dtypes[i], &ctx->context); - h->Ref(); - retvals[i] = new TFE_TensorHandle(h); - retvals_[i] = h; - } - } - - ~ExecuteNode() override { - for (auto handle : inputs_) { - handle->Unref(); - } - for (auto handle : retvals_) { - handle->Unref(); - } - } - - tensorflow::Status Run() override { - const tensorflow::Status status = tensorflow::EagerExecute( - &ctx_->context, op_device_, inputs_, kernel_, maybe_stats_.get(), - retvals_.begin(), retvals_.size()); - if (status.ok()) { - return status; - } else { - return tensorflow::Status( - status.code(), - tensorflow::strings::StrCat("Got error, \"", status.error_message(), - "\" while executing kernel ", - kernel_->kernel()->def().DebugString())); - } - } - - private: - TFE_Context* ctx_; - tensorflow::Device* op_device_; - tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs_; - tensorflow::KernelAndDevice* kernel_; - std::unique_ptr<tensorflow::NodeExecStats> maybe_stats_; - tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> retvals_; -}; - - #ifdef TENSORFLOW_EAGER_USE_XLA // Synthesizes and returns a wrapper function over `op`, which must be a // primitive op (e.g. matmul). @@ -961,8 +866,8 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // device from the one requested above. device = kernel->device(); } - status->status = ValidateInputTypeAndPlacement(ctx, ctx->context.HostCPU(), - device, op, kernel->kernel()); + status->status = ValidateInputTypeAndPlacement( + &ctx->context, ctx->context.HostCPU(), device, op, kernel->kernel()); if (!status->status.ok()) return; std::unique_ptr<tensorflow::NodeExecStats> maybe_stats; if (ctx->context.ShouldStoreMetadata()) { @@ -977,9 +882,18 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // Note that for async mode, execution order will make sure that all // input handles are ready before executing them. // TODO(agarwal): Consider executing "cheap" kernels inline for performance. - tensorflow::EagerNode* node = - new ExecuteNode(op, kernel, maybe_stats.release(), output_dtypes, - retvals, *num_retvals); + tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> handle_retvals( + *num_retvals); + tensorflow::uint64 id = op->ctx->context.NextId(); + for (int i = 0; i < *num_retvals; ++i) { + tensorflow::TensorHandle* h = + new tensorflow::TensorHandle(id, output_dtypes[i], &op->ctx->context); + retvals[i] = new TFE_TensorHandle(h); + handle_retvals[i] = h; + } + tensorflow::EagerNode* node = new tensorflow::ExecuteNode( + id, &op->ctx->context, op->device, op->inputs, kernel, + maybe_stats.release(), output_dtypes, handle_retvals); ctx->context.ExecutorAdd(node); } else { // Execute checks if retvals[i] is nullptr or not to figure if it needs to @@ -999,8 +913,9 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, TF_Status* status) { - tensorflow::TensorHandle* handle = TFE_TensorHandleCopyToDevice_Internal( - h->handle, ctx, device_name, status); + tensorflow::TensorHandle* handle; + status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context, + device_name, &handle); if (status->status.ok()) { return new TFE_TensorHandle(handle); } |