aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/generic_transfer_manager.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-01-09 12:04:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-09 12:26:35 -0800
commit1e67c90e2caceeff82d09793d1ef5fa0300d219b (patch)
tree6567ea8b0fa01fcfcd608b7e4c636865d33c7032 /tensorflow/compiler/xla/service/generic_transfer_manager.cc
parent7ad7e4dfae4344d6b955b5eb61dc4b6bb792f1b3 (diff)
Initial open-source release of XLA: Accelerated Linear Algebra.
XLA is a compiler-based linear algebra execution engine that targets CPUs, GPUs and custom accelerators. XLA is still experimental; we are releasing it early to get the community involved. Change: 143990941
Diffstat (limited to 'tensorflow/compiler/xla/service/generic_transfer_manager.cc')
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc183
1 files changed, 183 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
new file mode 100644
index 0000000000..086306696d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -0,0 +1,183 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+
+GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id)
+ : platform_id_(platform_id) {
+ // We currently only support kHostPlatformId for CPU and kCudaPlatformId for
+ // GPU. Before supporting other platforms, we need to test this transfer
+ // manager on them.
+ CHECK(platform_id_ == se::host::kHostPlatformId ||
+ platform_id_ == se::cuda::kCudaPlatformId);
+}
+
+se::Platform::Id GenericTransferManager::PlatformId() const {
+ if (platform_id_ == se::cuda::kCudaPlatformId ||
+ platform_id_ == se::host::kHostPlatformId) {
+ return platform_id_;
+ }
+ CHECK(false) << "GenericTransferManager::platform_id_ is invalid";
+}
+
+Status GenericTransferManager::TransferLiteralFromDevice(
+ se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
+ const Shape& device_shape, const Shape& literal_shape, Literal* literal) {
+ VLOG(2) << "transferring literal shape from device: "
+ << ShapeUtil::HumanString(literal_shape)
+ << "; device location: " << source.opaque();
+ TF_RET_CHECK(ShapeUtil::Compatible(device_shape, literal_shape));
+
+ // Tuples are a special case and contain one or more shapes inside of them to
+ // an arbitrary nesting depth.
+ if (device_shape.element_type() == TUPLE) {
+ *literal->mutable_shape() = literal_shape;
+ TF_ASSIGN_OR_RETURN(
+ std::vector<se::DeviceMemoryBase> element_buffers,
+ ShallowCopyTupleFromDevice(executor, source, device_shape));
+ TF_RET_CHECK(element_buffers.size() ==
+ ShapeUtil::TupleElementCount(device_shape));
+ for (int64 i = 0; i < element_buffers.size(); ++i) {
+ const Shape& element_device_shape = device_shape.tuple_shapes(i);
+ const Shape& element_literal_shape = literal_shape.tuple_shapes(i);
+ Literal* element_literal = literal->add_tuple_literals();
+ // Recursively call TransferFromDevice to copy over the data in the
+ // element array.
+ TF_RETURN_IF_ERROR(TransferLiteralFromDevice(
+ executor, element_buffers[i], /*device_shape=*/element_device_shape,
+ /*literal_shape=*/element_literal_shape, element_literal));
+ }
+ return Status::OK();
+ }
+
+ *literal->mutable_shape() = device_shape;
+ LiteralUtil::Reserve(ShapeUtil::ElementsIn(device_shape), literal);
+ TF_RETURN_IF_ERROR(TransferBufferFromDevice(
+ executor, source, /*size=*/ShapeUtil::ByteSizeOf(device_shape),
+ /*destination=*/LiteralUtil::MutableInternalData(literal)));
+ if (!ShapeUtil::Equal(literal_shape, device_shape)) {
+ literal->Swap(
+ LiteralUtil::Relayout(*literal, literal_shape.layout()).get());
+ }
+ TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape()));
+ return Status::OK();
+}
+
+StatusOr<std::vector<se::DeviceMemoryBase>>
+GenericTransferManager::ShallowCopyTupleFromDevice(
+ se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
+ const Shape& shape) {
+ TF_RET_CHECK(ShapeUtil::IsTuple(shape));
+
+ // For devices which use the GenericTransferManager, a tuple is stored as an
+ // array of pointers to buffers. Copy the contents of the tuple buffer into
+ // a vector of void* pointers.
+ std::vector<void*> element_pointers(ShapeUtil::TupleElementCount(shape),
+ nullptr);
+ int64 tuple_size = ShapeUtil::ByteSizeOf(shape);
+ auto copy_status = executor->SynchronousMemcpyD2H(source, tuple_size,
+ element_pointers.data());
+ if (!copy_status.ok()) {
+ return AddStatus(
+ Status(static_cast<tensorflow::error::Code>(copy_status.code()),
+ copy_status.error_message()),
+ "failed transfer of tuple buffer " + ShapeUtil::HumanString(shape));
+ }
+
+ // Create a DeviceMemoryBase from each void* pointer.
+ std::vector<se::DeviceMemoryBase> destination;
+ for (int i = 0; i < element_pointers.size(); ++i) {
+ if (element_pointers[i] == nullptr &&
+ !ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) {
+ return FailedPrecondition("tuple contains nullptr at element %d", i);
+ }
+ int64 buffer_size = ShapeUtil::ByteSizeOf(shape.tuple_shapes(i));
+ destination.emplace_back(element_pointers[i], buffer_size);
+ }
+ return std::move(destination);
+}
+
+Status GenericTransferManager::TransferLiteralToDevice(
+ se::StreamExecutor* executor, const Literal& literal,
+ se::DeviceMemoryBase* destination) {
+ const Shape& shape = literal.shape();
+ VLOG(2) << "transferring literal shape to device: "
+ << ShapeUtil::HumanString(shape)
+ << "; device location: " << destination->opaque();
+
+ if (ShapeUtil::IsTuple(literal.shape())) {
+ std::vector<void*> tuple_elements_on_device;
+ for (const Literal& tuple_element : literal.tuple_literals()) {
+ se::DeviceMemoryBase allocation = executor->AllocateArray<uint8>(
+ GetByteSizeRequirement(tuple_element.shape()));
+ TF_RETURN_IF_ERROR(
+ TransferLiteralToDevice(executor, tuple_element, &allocation));
+ tuple_elements_on_device.push_back(allocation.opaque());
+ }
+ return TransferBufferToDevice(
+ executor, tuple_elements_on_device.size() * sizeof(void*),
+ tuple_elements_on_device.data(), destination);
+ }
+
+ return TransferBufferToDevice(
+ executor, /*size=*/GetByteSizeRequirement(shape),
+ /*source=*/LiteralUtil::InternalData(literal), destination);
+}
+
+Status GenericTransferManager::TransferLiteralToInfeed(
+ se::StreamExecutor* executor, const Literal& literal) {
+ return Unimplemented("Infeed is not supported on GPU (b/30467474)");
+}
+
+Status GenericTransferManager::ResetDevice(se::StreamExecutor* executor) {
+ return Unimplemented(
+ "Device reset is not yet supported on CPU and GPU (b/30481585)");
+}
+
+int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) {
+ return ShapeUtil::ByteSizeOf(shape);
+}
+
+} // namespace xla
+
+static xla::TransferManager* CreateGenericTransferManager() {
+ return new xla::GenericTransferManager(se::cuda::kCudaPlatformId);
+}
+
+static bool InitModule() {
+ xla::TransferManager::RegisterTransferManager(se::cuda::kCudaPlatformId,
+ CreateGenericTransferManager);
+ return true;
+}
+static bool module_initialized = InitModule();