aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/eager/tensor_handle.h
blob: 46bc94f8752d61a0a9077c78dbcd8731174c7388 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_

#include <algorithm>
#include <cstddef>
#include <map>
#include <memory>
#include <queue>
#include <string>
#include <vector>

#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"

namespace tensorflow {

// Associates a Tensor and a Device, used in the eager runtime. Internal version
// of the TFE_TensorHandle struct and the python EagerTensor class
// (unrelated to python TensorHandle).
class TensorHandle : public core::RefCounted {
 public:
  TensorHandle(const Tensor& t, Device* d, Device* op_device, EagerContext* ctx)
      : dtype(t.dtype()),
        node_id_(0),
        tensor_(t),
        device_(d),
        op_device_(op_device),
        remote_op_id_(-1),
        remote_output_num_(-1),
        remote_shape_node_id_(-1),
        ctx_(ctx),
        is_ready_(true) {}

  TensorHandle(uint64 node_id, DataType dtype, EagerContext* ctx)
      : dtype(dtype),
        node_id_(node_id),
        tensor_(dtype),
        device_(nullptr),
        op_device_(nullptr),
        remote_op_id_(-1),
        remote_output_num_(-1),
        remote_shape_node_id_(-1),
        ctx_(ctx),
        is_ready_(ctx == nullptr) {
    DCHECK_GT(node_id_, 0);
  }

  // Remote tensor handle constructor.
  TensorHandle(int64 op_id, int32 output_num, uint64 remote_shape_node_id,
               DataType dtype, std::function<void()> call_on_destroy, Device* d,
               Device* op_device, EagerContext* ctx)
      : dtype(dtype),
        node_id_(0),
        device_(d),
        op_device_(op_device),
        remote_op_id_(op_id),
        remote_output_num_(output_num),
        remote_shape_node_id_(remote_shape_node_id),
        call_on_destroy_(std::move(call_on_destroy)),
        ctx_(ctx),
        is_ready_(true) {
    DCHECK(IsRemote()) << "Op ID and output num should be >= 0. Op ID: "
                       << op_id << ", Output num: " << output_num;
  }

  ~TensorHandle() override {
    if (call_on_destroy_) {
      call_on_destroy_();
    }
  }

  Status Tensor(const tensorflow::Tensor** t);

  Status Device(tensorflow::Device** d);

  Status OpDevice(tensorflow::Device** d);

  Status TensorAndDevice(const tensorflow::Tensor** tensor,
                         tensorflow::Device** device,
                         tensorflow::Device** op_device);

  Status NumDims(int* num_dims);
  Status Dim(int dim_index, int64* dim);

  // Return the op_id and output num if the handle refers to a remote tensor.
  Status RemoteAddress(int64* op_id, int32* output_num);

  // Note that this can be called at most once, and only on non-ready handles,
  // and makes them ready.
  void SetTensorAndDevice(const tensorflow::Tensor& tensor,
                          tensorflow::Device* device,
                          tensorflow::Device* op_device);

  Status CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
                      TensorHandle** output);

  // Warning: can return nullptr for CPU tensors.
  EagerContext* Context() {
    mutex_lock ml(ctx_mutex_);
    return ctx_;
  }

  // dtype for the handle. It must be the same as t.dtype() once the handle is
  // ready.
  const DataType dtype;

  void SetRemoteShape(std::unique_ptr<TensorShape> remote_shape) {
    remote_shape_ = std::move(remote_shape);
  }

 private:
  // If the contents of the Tensor pointed to by this handle is yet to be
  // computed by a EagerNode, this function will block till that compuatation is
  // done and the handle is "ready".
  Status WaitReady();
  Status WaitForNode(uint64 node_id, bool return_if_is_ready);

  bool IsReady();

  bool IsRemote();

  // Id for the EagerNode that will compute the value pointed to by this handle.
  // If the value is 0, the handle is already ready, but not vice-versa.
  const uint64 node_id_;

  tensorflow::Tensor tensor_;

  // TODO(ashankar): device_ == nullptr iff local CPU
  // This was expedient, but perhaps worth revisiting ('device_' should always
  // be a valid pointer?)
  // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
  // provided with the appropriate TFE_Context.
  //
  // TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a
  // TFE_TensorHandle does not outlive the TFE_Context from which it came?
  tensorflow::Device* device_;

  // Device in which the op producing this tensor was executed. Equals to
  // device_ for constant tensors.
  tensorflow::Device* op_device_;

  // IDs required when this class is representing a remote tensor handle.
  const int64 remote_op_id_;
  const int32 remote_output_num_;
  std::unique_ptr<TensorShape> remote_shape_;
  const uint64 remote_shape_node_id_;

  // A callback that is executed when the class is destroyed.
  //
  // This is currently used for remote tensor handles.
  const std::function<void()> call_on_destroy_;

  mutex ctx_mutex_;

  // `ctx` is only guaranteed to be set if the handle is not "ready". This is
  // typically true when the handle was produced during async execution.
  // `ctx` object is not owned and should outlive this handle.
  EagerContext* ctx_ GUARDED_BY(ctx_mutex_);
  bool is_ready_ GUARDED_BY(ctx_mutex_);
};

}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_