diff options
author | Alexandre Passos <apassos@google.com> | 2018-03-23 16:15:55 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-25 04:24:15 -0700 |
commit | f54f57337078c93877df5c9a1b126e879f5b33a5 (patch) | |
tree | 8230a8ab3ddd8ea3c30879ba271778841373099d /tensorflow/c/eager | |
parent | dd3adb6165605c28f1a993f9093e8f7c99b357c5 (diff) |
Moves TensorHandleCopyToDevice to TensorHandle::CopyToDevice.
PiperOrigin-RevId: 190291768
Diffstat (limited to 'tensorflow/c/eager')
-rw-r--r-- | tensorflow/c/eager/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/c/eager/c_api.cc | 125 |
2 files changed, 5 insertions, 121 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index d2d8d59323..8df7b56623 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -32,6 +32,7 @@ tf_cuda_library( "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:kernel_and_device", "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/common_runtime/eager:copy_to_device_node", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 59432f2ef8..c69635d529 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/eager/copy_to_device_node.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/node_def_util.h" @@ -213,82 +214,6 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { } } // extern "C" -namespace { - -tensorflow::Status TensorHandleCopyToDevice(tensorflow::TensorHandle* h, - TFE_Context* ctx, - tensorflow::Device* dstd, - tensorflow::TensorHandle** output) { - const tensorflow::Tensor* src = nullptr; - tensorflow::Device* srcd = nullptr; - // TODO(agarwal): src_opd is unused. Perhaps allow TensorAndDevice to accept - // nullptr. - tensorflow::Device* src_opd = nullptr; - TF_RETURN_IF_ERROR(h->TensorAndDevice(&src, &srcd, &src_opd)); - if (srcd == nullptr) srcd = ctx->context.HostCPU(); - bool is_same_device = - (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd)); - const bool dst_cpu = IsCPU(dstd); - const bool src_cpu = IsCPU(srcd); - // both_on_cpu can be true and yet is_same_device is false, if one of src/dst - // has device type XLA_CPU, and the other CPU. - const bool both_on_cpu = src_cpu && dst_cpu; - if (is_same_device || both_on_cpu) { - dstd = dst_cpu ? nullptr : dstd; - *output = new tensorflow::TensorHandle(*src, dstd, dstd); - return tensorflow::Status::OK(); - } - if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT && - !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) { - return tensorflow::errors::InvalidArgument( - "Can't copy Tensor with type ", - tensorflow::DataTypeString(src->dtype()), " to device ", - DeviceName(dstd), "."); - } - tensorflow::AllocatorAttributes attr; - if (src->dtype() == tensorflow::DT_VARIANT) { - attr.set_on_host(true); - } - tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape()); - if (src->shape().num_elements() == 0) { - dstd = dst_cpu ? nullptr : dstd; - *output = new tensorflow::TensorHandle(dst, dstd, dstd); - return tensorflow::Status::OK(); - } - tensorflow::DeviceContext* src_device_context = nullptr; - if (!src_cpu) { - src_device_context = srcd->tensorflow_gpu_device_info()->default_context; - } - tensorflow::DeviceContext* dst_device_context = nullptr; - if (!dst_cpu) { - dst_device_context = dstd->tensorflow_gpu_device_info()->default_context; - } - // TODO(ashankar): The Sync() call below may be more aggressive than - // necessary. It is based on knowledge of implementation details - that - // GPU devices are implemented using 3 streams - one for host->device copies, - // one for device->host copies and one for sending operations to the GPU. - // With that setup, Sync()ing across all 3 streams should be sufficient - // but more than necessary (since it waits for operations that might have - // nothing to do with this tensor to complete). - TF_RETURN_IF_ERROR(srcd->Sync()); - tensorflow::Notification n; - tensorflow::Status status; - tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context, - srcd, dstd, tensorflow::AllocatorAttributes(), - tensorflow::AllocatorAttributes(), src, &dst, - [&status, &n](const tensorflow::Status& s) { - status = s; - n.Notify(); - }); - n.WaitForNotification(); - if (status.ok()) { - dstd = dst_cpu ? nullptr : dstd; - *output = new tensorflow::TensorHandle(dst, dstd, dstd); - } - return status; -} -} // namespace - extern "C" { TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, @@ -509,49 +434,6 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, namespace { -class CopyToDeviceNode : public tensorflow::EagerNode { - public: - CopyToDeviceNode(tensorflow::TensorHandle* src, tensorflow::Device* dstd, - TFE_Context* ctx) - : tensorflow::EagerNode(ctx->context.NextId()), - src_(src), - dstd_(dstd), - ctx_(ctx), - dst_(new tensorflow::TensorHandle(id, src_->dtype, &ctx->context)) { - src_->Ref(); - dst_->Ref(); - } - - ~CopyToDeviceNode() override { - src_->Unref(); - dst_->Unref(); - } - - tensorflow::Status Run() override { - tensorflow::TensorHandle* temp = nullptr; - TF_RETURN_IF_ERROR(TensorHandleCopyToDevice(src_, ctx_, dstd_, &temp)); - const tensorflow::Tensor* tensor = nullptr; - tensorflow::Device* device = nullptr; - tensorflow::Device* op_device = nullptr; - tensorflow::Status status = - temp->TensorAndDevice(&tensor, &device, &op_device); - // `temp` is a ready handle. So the following call should return OK. - TF_DCHECK_OK(status) << status.error_message(); - DCHECK(tensor); - dst_->SetTensorAndDevice(*tensor, device, op_device); - temp->Unref(); - return tensorflow::Status::OK(); - } - - tensorflow::TensorHandle* dst() { return dst_; } - - private: - tensorflow::TensorHandle* src_; - tensorflow::Device* dstd_; - TFE_Context* ctx_; - tensorflow::TensorHandle* dst_; -}; - // TODO(apassos) move to TensorHandle tensorflow::TensorHandle* TFE_TensorHandleCopyToDevice_Internal( tensorflow::TensorHandle* h, TFE_Context* ctx, const char* device_name, @@ -569,7 +451,8 @@ tensorflow::TensorHandle* TFE_TensorHandleCopyToDevice_Internal( if (ctx->context.Async()) { // Note that `h` may not be currently ready. However execution order will // make sure that `h` is ready before the copy is actually done. - CopyToDeviceNode* node = new CopyToDeviceNode(h, dstd, ctx); + tensorflow::CopyToDeviceNode* node = + new tensorflow::CopyToDeviceNode(h, dstd, &ctx->context); tensorflow::TensorHandle* output = node->dst(); // Note that calling Add makes `node` accessible by the EagerExecutor // thread. So further accesses need to be thread-safe. @@ -577,7 +460,7 @@ tensorflow::TensorHandle* TFE_TensorHandleCopyToDevice_Internal( return output; } else { tensorflow::TensorHandle* output = nullptr; - status->status = TensorHandleCopyToDevice(h, ctx, dstd, &output); + status->status = h->CopyToDevice(&ctx->context, dstd, &output); return output; } } |