aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/verbs
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-26 12:54:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-26 12:57:46 -0700
commitf3c89936e97c99dead1ca3310246691c1b221adf (patch)
tree3c99b66936ed59028b32609115a239f52798907d /tensorflow/contrib/verbs
parent0b9b09a8531004b44b133a52c3fcc67bc6759bd8 (diff)
Merge changes from github.
END_PUBLIC Note: this CL will break builds. cl/159887762 to follow to fix all the breakages. --- Commit 2336cdf7f authored by Maxwell Paul Brickner<mbrickn@users.noreply.github.com> Committed by gunan<gunan@google.com>: Updated link to use HTTPS (#10998) Howdy! I just updated a link to use https instead of http. Thanks! --- Commit ad0892df1 authored by Luke Iwanski<luke@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] Fixes run_metadata_test for SYCL This test is designed to test CUDA specific behavior --- Commit 6b37a0725 authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Update comments --- Commit 1699d904a authored by John Lawson<john@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] Fixes CUDA specific test run on SYCL (#56) The testBadParentValuesOnGPU should only be run on CUDA devices, as the test checks for particular CUDA behaviour. We don't actually provide a SYCL kernel for GatherTree and so it's not a problem that the tests don't target SYCL. --- Commit 3c1946230 authored by myPrecious<Moriadry@users.noreply.github.com> Committed by Shanqing Cai<cais@google.com>: Java API to get the size of specified input list of operations. (#10865) * Java API to get the size of specified input list of operations * remove unnecessary explain to avoid bring a new term to users. --- Commit e911c7480 authored by Luke Iwanski<luke@codeplay.com> Committed by Luke Iwanski<luke@codeplay.com>: [OpenCL] REGISTER -> REGISTER6 --- Commit fbf6c4cec authored by superryanguo<superryanguo@gmail.com> Committed by superryanguo<superryanguo@gmail.com>: Simplify the Quickstart section with the weblink is better --- Commit 72e2918cc authored by Taehoon Lee<taehoonlee@snu.ac.kr> Committed by Taehoon Lee<taehoonlee@snu.ac.kr>: Fix typos --- Commit 90c4406b7 authored by Rishabh Patel<patelrishabh@users.noreply.github.com> Committed by GitHub<noreply@github.com>: Correct the learning rate as per the code snippet --- Commit 03da61134 authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Update ir_array.cc --- Commit 2df6cd3ac authored by Todd Wang<toddwang@gmail.com> Committed by GitHub<noreply@github.com>: Another try --- Commit af0cbace1 authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Transpose to go through Eigen (#10321) --- Commit fc7361081 authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Registers RGBToHSV and HSVToRGB (#91) (#10848) * [OpenCL] Added RGBToHSV and HSVToRGB * Aligning '\' --- Commit 832894ef8 authored by Luke Iwanski<luke@codeplay.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: [OpenCL] Registers AdjustContrastv2 (#10949) * [OpenCL] Registers AdjustContrastv2 (#93) * [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL (#96) * [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL * simplified to #ifndef * Changed to "#if GOOGLE_CUDA" * Update adjust_contrast_op_benchmark_test.cc * Added comments --- Commit cb4c2f8d1 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Make TransferBufferToInFeed not virual so it compiles. --- Commit e89f04d80 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix calling Literal member functions. --- Commit 15a8df724 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix mac build clone from meheff's change: [XLA] Change return type of DeviceAssignment::Deserialize to fix build breakage on mac. The mac build had the following error: error: incomplete type 'xla::DeviceAssignment' used in type trait expression This was due to a static method returning a StatusOr<DeviceAssignment> inside of the definition of DeviceAssignment. --- Commit a54d43fa4 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Replace LiteralUtil to Literal in compiler/plugin/executor --- Commit 88a6bb80c authored by Guenther Schmuelling<guschmue@microsoft.com> Committed by Guenther Schmuelling<guschmue@microsoft.com>: expand inline for debug builds to limit number of symbols --- Commit 62fb49d31 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Fix visibility error for contrib/remote_fused_graph/pylib/BUILD. --- Commit 4c75252f2 authored by Mark Neumann<markn@allenai.org> Committed by Mark Neumann<markn@allenai.org>: fix initial test values to avoid numerical instability --- Commit b58d98353 authored by sj6077<epik03sj@gmail.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: Fixes of AutoParallel bug (#10368) * Fix the bug that auto_parallel could replicate variable snapshot name * Use NodeName in grappler:utils instead of substr, convert variables->variable_def of grappler item * remove variable_def from grappler item, exclude snapshot nodes from dont_replicate_nodes in auto_parallel --- Commit a286b7db8 authored by Yifei Feng<yifeif@google.com> Committed by Yifei Feng<yifeif@google.com>: Make debug_test slice integer. --- Commit 97fcfdfa6 authored by Toby Boyd<tobyboyd@google.com> Committed by GitHub<noreply@github.com>: Fixed path to seq2seq.py and minor formatting --- Commit 63c1befb8 authored by Anish Shah<shah.anish07@gmail.com> Committed by Anish Shah<shah.anish07@gmail.com>: Improve docs for tf.nn.depthwise_conv2d_native --- Commit 8d42202b2 authored by Yong Tang<yong.tang.github@outlook.com> Committed by Yong Tang<yong.tang.github@outlook.com>: Fix mismatched delete in mkl_tfconv_op.cc This fix fixes mismatched new[]-delete in mkl_tfconv_op.cc (the file went through clang-format so there are some additional changes) Signed-off-by: Yong Tang <yong.tang.github@outlook.com> --- Commit 26301bd55 authored by Danny Goodman<goodman.danny@gmail.com> Committed by Danny Goodman<goodman.danny@gmail.com>: fix error format --- Commit b3f33ad46 authored by Yao Zhang<yaozhang@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Make changes to prepare for the fused option of batch norm to be set to None (None means using fused batch norm if possible). PiperOrigin-RevId: 159649743 --- Commit a4a469832 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Add tests for select ops and while loops that produce tuples that contain predicates. PiperOrigin-RevId: 159645900 --- Commit 980d3f2be authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Use C API to implement Operation.name property This name property is used in many existing tests including those that already run with C API enabled (math_ops_test, framework_ops_test, session_test, session_partial_run_test, math_ops_test_gpu, etc). PiperOrigin-RevId: 159645767 --- Commit 26239c706 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Previously we didn't have an implementation of BatchNormInference and BatchNormTraining, which gives a linker error if anyone ever tries to call that. A dummy implementation is friendlier than a linker error. PiperOrigin-RevId: 159645612 --- Commit f671c5caa authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: BEGIN_PUBLIC Automated g4 rollback of changelist 159570549 PiperOrigin-RevId: 160182040
Diffstat (limited to 'tensorflow/contrib/verbs')
-rw-r--r--tensorflow/contrib/verbs/rdma.cc55
-rw-r--r--tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc41
-rw-r--r--tensorflow/contrib/verbs/verbs_util.cc34
-rw-r--r--tensorflow/contrib/verbs/verbs_util.h10
4 files changed, 124 insertions, 16 deletions
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc
index bc687be0ab..6f3a616fe8 100644
--- a/tensorflow/contrib/verbs/rdma.cc
+++ b/tensorflow/contrib/verbs/rdma.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+#include "tensorflow/core/common_runtime/gpu/process_state.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/framework/rendezvous.h"
@@ -683,7 +684,6 @@ void RdmaTensorBuffer::SendNextItem() {
<< " error message: " << status.error_message();
size_t buffer_size = RdmaMessage::kMessageTotalBytes;
size_t tensor_bytes = 0;
- TensorProto proto;
// Figures out which device the tensor is hosted on.
Device* src_dev = nullptr;
Status s = channel_->adapter_->worker_env_->device_mgr->LookupDevice(
@@ -703,21 +703,47 @@ void RdmaTensorBuffer::SendNextItem() {
CHECK(s.ok()) << "dst device not found";
AllocatorAttributes dst_alloc_attr;
dst_alloc_attr.set_on_host(true);
+
+ bool can_memcpy = DataTypeCanUseMemcpy(in.dtype());
// string tensor needs to be serialized
+ Tensor copy;
+ StringPiece copy_buf;
+ TensorProto proto;
if (src_dev->tensorflow_gpu_device_info() &&
(!send_args.alloc_attrs.on_host())) {
CHECK(send_args.device_context)
- << "send dev name: " << src_dev->name()
- << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
- // "val" is on a GPU. Uses GPUUtil to fill the proto.
- s = VerbsUtil::SetProtoFromGPUSync(
- in, src_dev, send_args.device_context, &proto, is_dead);
- CHECK(s.ok()) << "set proto from gpu sync";
+ << "send dev name: " << src_dev->name()
+ << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
+
+ if (can_memcpy) {
+ AllocatorAttributes host_alloc_attrs;
+ host_alloc_attrs.set_gpu_compatible(true);
+ host_alloc_attrs.set_on_host(true);
+ Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
+ copy = Tensor(alloc, in.dtype(), in.shape());
+ s = VerbsUtil::CopyGPUTensorToCPUSync(
+ src_dev, send_args.device_context, &in, &copy);
+ CHECK(s.ok()) << "copy tensor from gpu sync";
+ copy_buf = copy.tensor_data();
+ } else {
+ // "val" is on a GPU. Uses GPUUtil to fill the proto.
+ s = VerbsUtil::SetProtoFromGPUSync(
+ in, src_dev, send_args.device_context, &proto, is_dead);
+ CHECK(s.ok()) << "set proto from gpu sync";
+ }
} else {
// tensor is in CPU memory.
- in.AsProtoTensorContent(&proto);
+ if (can_memcpy) {
+ copy_buf = in.tensor_data();
+ } else {
+ in.AsProtoTensorContent(&proto);
+ }
+ }
+ if (can_memcpy) {
+ tensor_bytes = in.TotalBytes();
+ } else {
+ tensor_bytes = proto.ByteSize();
}
- tensor_bytes = proto.ByteSize();
// maybe some margin for string tensor?
buffer_size += tensor_bytes;
// prepare message
@@ -771,7 +797,16 @@ void RdmaTensorBuffer::SendNextItem() {
static_cast<void*>(static_cast<char*>(buffer_) +
RdmaMessage::kTensorBufferStartIndex);
CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_);
- proto.SerializeToArray(output, tensor_bytes);
+ if (can_memcpy) {
+ CHECK(copy_buf.size() == tensor_bytes)
+ << "unexpected tensor size: "
+ << copy_buf.size()
+ << " != "
+ << tensor_bytes;
+ memcpy(output, copy_buf.data(), tensor_bytes);
+ } else {
+ proto.SerializeToArray(output, tensor_bytes);
+ }
} else {
buffer_size = RdmaMessage::kMessageTotalBytes;
}
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
index 5871400f26..9ea696589a 100644
--- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/gpu/process_state.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -99,12 +100,40 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
if (!rm.is_dead_) {
void* input = static_cast<char*>(rb->buffer_) +
RdmaMessage::kTensorBufferStartIndex;
- TensorProto proto;
- CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <=
- rb->size_);
- CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_))
- << "fail to parse proto from array";
- s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val);
+ bool can_memcpy = DataTypeCanUseMemcpy(rm.data_type_);
+ if (can_memcpy) {
+ if (dst_dev->tensorflow_gpu_device_info() &&
+ (!recv_args.alloc_attrs.on_host())) {
+ CHECK(recv_args.device_context)
+ << "send dev name: " << src_dev->name()
+ << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
+ Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
+ Tensor copy(alloc, rm.data_type_, rm.tensor_shape_);
+ memcpy(DMAHelper::base(&copy), input, rm.tensor_bytes_);
+
+ Allocator* dst_alloc = dst_dev->GetAllocator(recv_args.alloc_attrs);
+ Tensor gpu_copy(dst_alloc, rm.data_type_, rm.tensor_shape_);
+ s = VerbsUtil::CopyCPUTensorToGPUSync(&copy, recv_args.device_context,
+ dst_dev, &gpu_copy);
+ CHECK(s.ok()) << "copy tensor to gpu sync";
+ val = std::move(gpu_copy);
+ } else {
+ AllocatorAttributes host_alloc_attrs;
+ host_alloc_attrs.set_gpu_compatible(true);
+ host_alloc_attrs.set_on_host(true);
+ Allocator* alloc = dst_dev->GetAllocator(host_alloc_attrs);
+ Tensor copy(alloc, rm.data_type_, rm.tensor_shape_);
+ memcpy(DMAHelper::base(&copy), input, rm.tensor_bytes_);
+ val = std::move(copy);
+ }
+ } else {
+ TensorProto proto;
+ CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <=
+ rb->size_);
+ CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_))
+ << "fail to parse proto from array";
+ s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val);
+ }
}
rc->RemoveRecvCallback(key_with_step_id);
diff --git a/tensorflow/contrib/verbs/verbs_util.cc b/tensorflow/contrib/verbs/verbs_util.cc
index c3350f7958..76e44d34a9 100644
--- a/tensorflow/contrib/verbs/verbs_util.cc
+++ b/tensorflow/contrib/verbs/verbs_util.cc
@@ -21,6 +21,40 @@ limitations under the License.
namespace tensorflow {
// static sync wrapper:
+Status VerbsUtil::CopyGPUTensorToCPUSync(Device* gpu_device,
+ const DeviceContext* device_context,
+ const Tensor* gpu_tensor,
+ Tensor* cpu_tensor) {
+ Notification n;
+ Status status;
+ GPUUtil::CopyGPUTensorToCPU(gpu_device, device_context,
+ gpu_tensor, cpu_tensor,
+ [&n, &status](const Status& s) {
+ status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ return status;
+}
+
+// static sync wrapper:
+Status VerbsUtil::CopyCPUTensorToGPUSync(const Tensor* cpu_tensor,
+ const DeviceContext* device_context,
+ Device* gpu_device,
+ Tensor* gpu_tensor) {
+ Notification n;
+ Status status;
+ GPUUtil::CopyCPUTensorToGPU(cpu_tensor, device_context,
+ gpu_device, gpu_tensor,
+ [&n, &status](const Status& s) {
+ status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ return status;
+}
+
+// static sync wrapper:
Status VerbsUtil::SetProtoFromGPUSync(const Tensor& tensor, Device* dev,
const DeviceContext* device_context,
TensorProto* proto, bool is_dead) {
diff --git a/tensorflow/contrib/verbs/verbs_util.h b/tensorflow/contrib/verbs/verbs_util.h
index cbc01adae4..d9da396228 100644
--- a/tensorflow/contrib/verbs/verbs_util.h
+++ b/tensorflow/contrib/verbs/verbs_util.h
@@ -28,6 +28,16 @@ class TensorProto;
class VerbsUtil {
public:
+ // synchronous wrapper of CopyGPUTensorToCPU
+ static Status CopyGPUTensorToCPUSync(Device* gpu_device,
+ const DeviceContext* device_context,
+ const Tensor* gpu_tensor,
+ Tensor* cpu_tensor);
+ // synchronous wrapper of CopyCPUTensorToGPU
+ static Status CopyCPUTensorToGPUSync(const Tensor* cpu_tensor,
+ const DeviceContext* device_context,
+ Device* gpu_device,
+ Tensor* gpu_tensor);
// synchronous wrapper of SetProtoFromGPU
static Status SetProtoFromGPUSync(const Tensor& tensor, Device* dev,
const DeviceContext* device_context,