aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-03-28 08:25:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-28 08:27:47 -0700
commit119ed5aa2acb6df04595835f6dfa99f5422449f2 (patch)
tree532d10c5e27ad23bbeddd8211480f81fb9f4ea65 /tensorflow/c/eager
parent134f4ca0a70ef0373f5436b890be0f8585badb34 (diff)
Move ExecuteNode and CopyToDevice_Internal
PiperOrigin-RevId: 190775681
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r--tensorflow/c/eager/BUILD1
-rw-r--r--tensorflow/c/eager/c_api.cc139
2 files changed, 28 insertions, 112 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index e57011a08b..a2d96357ac 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -31,6 +31,7 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:execute",
+ "//tensorflow/core/common_runtime/eager:execute_node",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/common_runtime/eager:copy_to_device_node",
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index ac7114f71e..028865d360 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
#include "tensorflow/core/common_runtime/eager/execute.h"
+#include "tensorflow/core/common_runtime/eager/execute_node.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -435,39 +436,8 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
namespace {
-// TODO(apassos) move to TensorHandle
-tensorflow::TensorHandle* TFE_TensorHandleCopyToDevice_Internal(
- tensorflow::TensorHandle* h, TFE_Context* ctx, const char* device_name,
- TF_Status* status) {
- status->status = ctx->context.GetStatus();
- if (!status->status.ok()) {
- return nullptr;
- }
- tensorflow::Device* dstd = ctx->context.HostCPU();
- if (device_name != nullptr && strlen(device_name) > 0) {
- status->status =
- ctx->context.device_mgr()->LookupDevice(device_name, &dstd);
- if (!status->status.ok()) return nullptr;
- }
- if (ctx->context.Async()) {
- // Note that `h` may not be currently ready. However execution order will
- // make sure that `h` is ready before the copy is actually done.
- tensorflow::CopyToDeviceNode* node =
- new tensorflow::CopyToDeviceNode(h, dstd, &ctx->context);
- tensorflow::TensorHandle* output = node->dst();
- // Note that calling Add makes `node` accessible by the EagerExecutor
- // thread. So further accesses need to be thread-safe.
- ctx->context.ExecutorAdd(node);
- return output;
- } else {
- tensorflow::TensorHandle* output = nullptr;
- status->status = h->CopyToDevice(&ctx->context, dstd, &output);
- return output;
- }
-}
-
tensorflow::Status ValidateInputTypeAndPlacement(
- TFE_Context* ctx, tensorflow::Device* host_device,
+ tensorflow::EagerContext* ctx, tensorflow::Device* host_device,
tensorflow::Device* op_device, TFE_Op* op,
const tensorflow::OpKernel* kernel) {
const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types();
@@ -484,8 +454,8 @@ tensorflow::Status ValidateInputTypeAndPlacement(
const tensorflow::Device* actual_device =
handle_device == nullptr ? host_device : handle_device;
if (expected_device != actual_device) {
- switch (TFE_ContextGetDevicePlacementPolicy(ctx)) {
- case TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32:
+ switch (ctx->GetDevicePlacementPolicy()) {
+ case tensorflow::DEVICE_PLACEMENT_SILENT_FOR_INT32:
// TODO(xpan): See if we could bubble python related error up
// to python level.
if (handle->dtype == tensorflow::DT_INT32) {
@@ -494,7 +464,7 @@ tensorflow::Status ValidateInputTypeAndPlacement(
break;
}
TF_FALLTHROUGH_INTENDED;
- case TFE_DEVICE_PLACEMENT_EXPLICIT:
+ case tensorflow::DEVICE_PLACEMENT_EXPLICIT:
return tensorflow::errors::InvalidArgument(
"Tensors on conflicting devices:"
" cannot compute ",
@@ -506,7 +476,7 @@ tensorflow::Status ValidateInputTypeAndPlacement(
" or transparently copied by using tfe.enable_eager_execution("
"tfe.DEVICE_PLACEMENT_SILENT). Copying tensors between devices"
" may slow down your model");
- case TFE_DEVICE_PLACEMENT_WARN:
+ case tensorflow::DEVICE_PLACEMENT_WARN:
LOG(WARNING) << "before computing " << op->name << " input #" << i
<< " was expected to be on " << expected_device->name()
<< " but is actually on " << actual_device->name()
@@ -514,17 +484,14 @@ tensorflow::Status ValidateInputTypeAndPlacement(
<< "). This triggers a copy which can be a performance "
"bottleneck.";
break;
- case TFE_DEVICE_PLACEMENT_SILENT: // Do nothing.
+ case tensorflow::DEVICE_PLACEMENT_SILENT: // Do nothing.
break;
}
// We are only here if the policy is warn or silent copies, so we should
// trigger a copy.
- TF_Status* s = TF_NewStatus();
- tensorflow::TensorHandle* copied_tensor =
- TFE_TensorHandleCopyToDevice_Internal(
- handle, ctx, expected_device->name().c_str(), s);
- tensorflow::Status status = s->status;
- TF_DeleteStatus(s);
+ tensorflow::TensorHandle* copied_tensor = nullptr;
+ tensorflow::Status status = tensorflow::EagerCopyToDevice(
+ handle, ctx, expected_device->name().c_str(), &copied_tensor);
if (!status.ok()) {
if (copied_tensor != nullptr) copied_tensor->Unref();
return tensorflow::errors::Internal(
@@ -576,68 +543,6 @@ tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef,
}
-// TODO(agarwal): move EagerExecutor and EagerNode related code to a separate
-// file.
-class ExecuteNode : public tensorflow::EagerNode {
- public:
- ExecuteNode(TFE_Op* op, tensorflow::KernelAndDevice* kernel,
- tensorflow::NodeExecStats* maybe_stats,
- const tensorflow::DataTypeVector& output_dtypes,
- TFE_TensorHandle** retvals, int num_retvals)
- : tensorflow::EagerNode(op->ctx->context.NextId()),
- ctx_(op->ctx),
- op_device_(op->device),
- inputs_(op->inputs),
- kernel_(kernel),
- maybe_stats_(maybe_stats),
- retvals_(num_retvals) {
- for (auto handle : inputs_) {
- handle->Ref();
- }
- TFE_Context* ctx = op->ctx;
- for (int i = 0; i < num_retvals; ++i) {
- tensorflow::TensorHandle* h =
- new tensorflow::TensorHandle(id, output_dtypes[i], &ctx->context);
- h->Ref();
- retvals[i] = new TFE_TensorHandle(h);
- retvals_[i] = h;
- }
- }
-
- ~ExecuteNode() override {
- for (auto handle : inputs_) {
- handle->Unref();
- }
- for (auto handle : retvals_) {
- handle->Unref();
- }
- }
-
- tensorflow::Status Run() override {
- const tensorflow::Status status = tensorflow::EagerExecute(
- &ctx_->context, op_device_, inputs_, kernel_, maybe_stats_.get(),
- retvals_.begin(), retvals_.size());
- if (status.ok()) {
- return status;
- } else {
- return tensorflow::Status(
- status.code(),
- tensorflow::strings::StrCat("Got error, \"", status.error_message(),
- "\" while executing kernel ",
- kernel_->kernel()->def().DebugString()));
- }
- }
-
- private:
- TFE_Context* ctx_;
- tensorflow::Device* op_device_;
- tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 4> inputs_;
- tensorflow::KernelAndDevice* kernel_;
- std::unique_ptr<tensorflow::NodeExecStats> maybe_stats_;
- tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> retvals_;
-};
-
-
#ifdef TENSORFLOW_EAGER_USE_XLA
// Synthesizes and returns a wrapper function over `op`, which must be a
// primitive op (e.g. matmul).
@@ -961,8 +866,8 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
// device from the one requested above.
device = kernel->device();
}
- status->status = ValidateInputTypeAndPlacement(ctx, ctx->context.HostCPU(),
- device, op, kernel->kernel());
+ status->status = ValidateInputTypeAndPlacement(
+ &ctx->context, ctx->context.HostCPU(), device, op, kernel->kernel());
if (!status->status.ok()) return;
std::unique_ptr<tensorflow::NodeExecStats> maybe_stats;
if (ctx->context.ShouldStoreMetadata()) {
@@ -977,9 +882,18 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
// Note that for async mode, execution order will make sure that all
// input handles are ready before executing them.
// TODO(agarwal): Consider executing "cheap" kernels inline for performance.
- tensorflow::EagerNode* node =
- new ExecuteNode(op, kernel, maybe_stats.release(), output_dtypes,
- retvals, *num_retvals);
+ tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> handle_retvals(
+ *num_retvals);
+ tensorflow::uint64 id = op->ctx->context.NextId();
+ for (int i = 0; i < *num_retvals; ++i) {
+ tensorflow::TensorHandle* h =
+ new tensorflow::TensorHandle(id, output_dtypes[i], &op->ctx->context);
+ retvals[i] = new TFE_TensorHandle(h);
+ handle_retvals[i] = h;
+ }
+ tensorflow::EagerNode* node = new tensorflow::ExecuteNode(
+ id, &op->ctx->context, op->device, op->inputs, kernel,
+ maybe_stats.release(), output_dtypes, handle_retvals);
ctx->context.ExecutorAdd(node);
} else {
// Execute checks if retvals[i] is nullptr or not to figure if it needs to
@@ -999,8 +913,9 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
TFE_Context* ctx,
const char* device_name,
TF_Status* status) {
- tensorflow::TensorHandle* handle = TFE_TensorHandleCopyToDevice_Internal(
- h->handle, ctx, device_name, status);
+ tensorflow::TensorHandle* handle;
+ status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context,
+ device_name, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle(handle);
}