aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_device_context.h
blob: 912f8d779e72f44821bc4fb25efa30bd35d01412 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_

#include <memory>

#include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/lib/core/status.h"

namespace tensorflow {

// The allocator used for Tensors assigned to the XLA device. The allocator
// ignores the alignment and size of the request and always returns a new,
// empty, XlaTensor.
class XlaDeviceAllocator : public Allocator {
 public:
  XlaDeviceAllocator();
  ~XlaDeviceAllocator() override;

  string Name() override;

  void* AllocateRaw(size_t alignment, size_t num_bytes) override;
  void DeallocateRaw(void* ptr) override;
  void GetStats(AllocatorStats* stats) override;
};

// Helper class for managing data transfers between host and XLA devices.
class XlaTransferManager {
 public:
  explicit XlaTransferManager(
      se::Stream* compute_stream, se::Stream* host_to_device_stream,
      se::Stream* device_to_host_stream, xla::LocalClient* client,
      bool transfer_as_literal,
      XlaCompiler::ShapeRepresentationFn shape_representation_fn);

  void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
                             Tensor* device_tensor, StatusCallback done) const;
  void CopyDeviceTensorToCPU(const Tensor* device_tensor,
                             StringPiece tensor_name, Device* device,
                             Tensor* cpu_tensor, StatusCallback done);

  void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
                                const StatusCallback& done);

  se::Stream* stream() const { return stream_; }

 private:
  Status TransferLiteralToDevice(const Tensor& host_tensor,
                                 Tensor* device_tensor) const;
  void TransferLiteralFromDevice(Tensor* host_tensor,
                                 const Tensor& device_tensor,
                                 const StatusCallback& done) const;
  bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; }

  // The main compute stream of the device, used to synchronize the transfer
  // streams if they are set.
  se::Stream* stream_;
  // The stream to use for transferring data from host to device. Can be
  // idential to stream_, but must not be nullptr.
  se::Stream* host_to_device_stream_;
  // The stream to use for transferring data from device to host. Can be
  // idential to stream_, but must not be nullptr.
  se::Stream* device_to_host_stream_;
  // For the underlying memory allocator and XLA's TransferManager.
  xla::LocalClient* client_;
  // Transfer manager, for marshalling data to and from the device.
  xla::TransferManager* transfer_manager_;
  // True if we must use XLA's TransferManager for correct device transfers.
  const bool transfer_as_literal_;
  XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
};

// DeviceContext for operators assigned to XlaDevice devices. The
// implementation must inherit from DeviceContext but otherwise just
// wraps the methods in XlaTransferManager.
class XlaDeviceContext : public DeviceContext {
 public:
  explicit XlaDeviceContext(
      se::Stream* compute_stream, se::Stream* host_to_device_stream,
      se::Stream* device_to_host_stream, xla::LocalClient* client,
      bool transfer_as_literal,
      XlaCompiler::ShapeRepresentationFn shape_representation_fn);

  void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
                             Tensor* device_tensor,
                             StatusCallback done) const override;
  void CopyDeviceTensorToCPU(const Tensor* device_tensor,
                             StringPiece tensor_name, Device* device,
                             Tensor* cpu_tensor, StatusCallback done) override;
  void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
                                const StatusCallback& done);

  se::Stream* stream() const override { return manager_.stream(); }

 private:
  XlaTransferManager manager_;
};

}  // namespace tensorflow

#endif  // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_