aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/eager
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-03-23 16:15:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-25 04:24:15 -0700
commitf54f57337078c93877df5c9a1b126e879f5b33a5 (patch)
tree8230a8ab3ddd8ea3c30879ba271778841373099d /tensorflow/c/eager
parentdd3adb6165605c28f1a993f9093e8f7c99b357c5 (diff)
Moves TensorHandleCopyToDevice to TensorHandle::CopyToDevice.
PiperOrigin-RevId: 190291768
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r--tensorflow/c/eager/BUILD1
-rw-r--r--tensorflow/c/eager/c_api.cc125
2 files changed, 5 insertions, 121 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index d2d8d59323..8df7b56623 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -32,6 +32,7 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
+ "//tensorflow/core/common_runtime/eager:copy_to_device_node",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 59432f2ef8..c69635d529 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
+#include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -213,82 +214,6 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
}
} // extern "C"
-namespace {
-
-tensorflow::Status TensorHandleCopyToDevice(tensorflow::TensorHandle* h,
- TFE_Context* ctx,
- tensorflow::Device* dstd,
- tensorflow::TensorHandle** output) {
- const tensorflow::Tensor* src = nullptr;
- tensorflow::Device* srcd = nullptr;
- // TODO(agarwal): src_opd is unused. Perhaps allow TensorAndDevice to accept
- // nullptr.
- tensorflow::Device* src_opd = nullptr;
- TF_RETURN_IF_ERROR(h->TensorAndDevice(&src, &srcd, &src_opd));
- if (srcd == nullptr) srcd = ctx->context.HostCPU();
- bool is_same_device =
- (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd));
- const bool dst_cpu = IsCPU(dstd);
- const bool src_cpu = IsCPU(srcd);
- // both_on_cpu can be true and yet is_same_device is false, if one of src/dst
- // has device type XLA_CPU, and the other CPU.
- const bool both_on_cpu = src_cpu && dst_cpu;
- if (is_same_device || both_on_cpu) {
- dstd = dst_cpu ? nullptr : dstd;
- *output = new tensorflow::TensorHandle(*src, dstd, dstd);
- return tensorflow::Status::OK();
- }
- if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT &&
- !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) {
- return tensorflow::errors::InvalidArgument(
- "Can't copy Tensor with type ",
- tensorflow::DataTypeString(src->dtype()), " to device ",
- DeviceName(dstd), ".");
- }
- tensorflow::AllocatorAttributes attr;
- if (src->dtype() == tensorflow::DT_VARIANT) {
- attr.set_on_host(true);
- }
- tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape());
- if (src->shape().num_elements() == 0) {
- dstd = dst_cpu ? nullptr : dstd;
- *output = new tensorflow::TensorHandle(dst, dstd, dstd);
- return tensorflow::Status::OK();
- }
- tensorflow::DeviceContext* src_device_context = nullptr;
- if (!src_cpu) {
- src_device_context = srcd->tensorflow_gpu_device_info()->default_context;
- }
- tensorflow::DeviceContext* dst_device_context = nullptr;
- if (!dst_cpu) {
- dst_device_context = dstd->tensorflow_gpu_device_info()->default_context;
- }
- // TODO(ashankar): The Sync() call below may be more aggressive than
- // necessary. It is based on knowledge of implementation details - that
- // GPU devices are implemented using 3 streams - one for host->device copies,
- // one for device->host copies and one for sending operations to the GPU.
- // With that setup, Sync()ing across all 3 streams should be sufficient
- // but more than necessary (since it waits for operations that might have
- // nothing to do with this tensor to complete).
- TF_RETURN_IF_ERROR(srcd->Sync());
- tensorflow::Notification n;
- tensorflow::Status status;
- tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context,
- srcd, dstd, tensorflow::AllocatorAttributes(),
- tensorflow::AllocatorAttributes(), src, &dst,
- [&status, &n](const tensorflow::Status& s) {
- status = s;
- n.Notify();
- });
- n.WaitForNotification();
- if (status.ok()) {
- dstd = dst_cpu ? nullptr : dstd;
- *output = new tensorflow::TensorHandle(dst, dstd, dstd);
- }
- return status;
-}
-} // namespace
-
extern "C" {
TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
@@ -509,49 +434,6 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
namespace {
-class CopyToDeviceNode : public tensorflow::EagerNode {
- public:
- CopyToDeviceNode(tensorflow::TensorHandle* src, tensorflow::Device* dstd,
- TFE_Context* ctx)
- : tensorflow::EagerNode(ctx->context.NextId()),
- src_(src),
- dstd_(dstd),
- ctx_(ctx),
- dst_(new tensorflow::TensorHandle(id, src_->dtype, &ctx->context)) {
- src_->Ref();
- dst_->Ref();
- }
-
- ~CopyToDeviceNode() override {
- src_->Unref();
- dst_->Unref();
- }
-
- tensorflow::Status Run() override {
- tensorflow::TensorHandle* temp = nullptr;
- TF_RETURN_IF_ERROR(TensorHandleCopyToDevice(src_, ctx_, dstd_, &temp));
- const tensorflow::Tensor* tensor = nullptr;
- tensorflow::Device* device = nullptr;
- tensorflow::Device* op_device = nullptr;
- tensorflow::Status status =
- temp->TensorAndDevice(&tensor, &device, &op_device);
- // `temp` is a ready handle. So the following call should return OK.
- TF_DCHECK_OK(status) << status.error_message();
- DCHECK(tensor);
- dst_->SetTensorAndDevice(*tensor, device, op_device);
- temp->Unref();
- return tensorflow::Status::OK();
- }
-
- tensorflow::TensorHandle* dst() { return dst_; }
-
- private:
- tensorflow::TensorHandle* src_;
- tensorflow::Device* dstd_;
- TFE_Context* ctx_;
- tensorflow::TensorHandle* dst_;
-};
-
// TODO(apassos) move to TensorHandle
tensorflow::TensorHandle* TFE_TensorHandleCopyToDevice_Internal(
tensorflow::TensorHandle* h, TFE_Context* ctx, const char* device_name,
@@ -569,7 +451,8 @@ tensorflow::TensorHandle* TFE_TensorHandleCopyToDevice_Internal(
if (ctx->context.Async()) {
// Note that `h` may not be currently ready. However execution order will
// make sure that `h` is ready before the copy is actually done.
- CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx);
+ tensorflow::CopyToDeviceNode* node =
+ new tensorflow::CopyToDeviceNode(h, dstd, &ctx->context);
tensorflow::TensorHandle* output = node->dst();
// Note that calling Add makes `node` accessible by the EagerExecutor
// thread. So further accesses need to be thread-safe.
@@ -577,7 +460,7 @@ tensorflow::TensorHandle* TFE_TensorHandleCopyToDevice_Internal(
return output;
} else {
tensorflow::TensorHandle* output = nullptr;
- status->status = TensorHandleCopyToDevice(h, ctx, dstd, &output);
+ status->status = h->CopyToDevice(&ctx->context, dstd, &output);
return output;
}
}