diff options
author | Asim Shankar <ashankar@google.com> | 2017-07-27 14:27:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-27 14:31:43 -0700 |
commit | 22651083406ca01ac9d481e3367a3510d25f88cd (patch) | |
tree | bda0f3289d50f383eb1e632a595ab97258e35162 /tensorflow/c/c_api_internal.h | |
parent | 613bf1c7c1f8dfceed34fc85f2c71dd00432651e (diff) |
C API: Groundwork for experimenting with TF_Tensor in device memory.
TF_Tensor objects are always backed by host memory. This commit lays
the groundwork for allowing TF_Tensor objects to refer to tensor data
on device (e.g., GPU) memory.
PiperOrigin-RevId: 163388079
Diffstat (limited to 'tensorflow/c/c_api_internal.h')
-rw-r--r-- | tensorflow/c/c_api_internal.h | 53 |
1 files changed, 45 insertions, 8 deletions
diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index d077ad264b..687e18aace 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -18,19 +18,25 @@ limitations under the License. #include "tensorflow/c/c_api.h" -#include <vector> #include <unordered_map> +#include <vector> +#include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/public/session.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +class Device; +class DeviceMgr; +} // namespace tensorflow +class TF_BufferAndDevice; // Internal structures used by the C API. These are likely to change and should // not be depended on. @@ -40,9 +46,11 @@ struct TF_Status { }; struct TF_Tensor { + ~TF_Tensor(); + TF_DataType dtype; tensorflow::TensorShape shape; - tensorflow::TensorBuffer* buffer; + TF_BufferAndDevice* buffer; }; struct TF_SessionOptions { @@ -100,12 +108,19 @@ struct TF_Operation { }; struct TF_Session { - TF_Session(tensorflow::Session* s, TF_Graph* g) - : session(s), graph(g), last_num_graph_nodes(0) {} + TF_Session(tensorflow::Session* s, TF_Graph* g); + tensorflow::Session* session; TF_Graph* graph; + tensorflow::mutex mu; int last_num_graph_nodes; + + // NOTE(ashankar): Experimental fields to help keep the + // buffers of a TF_Tensor pinned in device memory. + const tensorflow::DeviceMgr* device_mgr; // Owned by session. + std::vector<tensorflow::Device*> devices; // Owned by device_mgr. + int num_outstanding_buffers GUARDED_BY(mu); }; struct TF_ImportGraphDefOptions { @@ -116,6 +131,28 @@ struct TF_DeviceList { std::vector<tensorflow::DeviceAttributes> response; }; +// TF_BufferAndDevice encapsulates the memory addresses of data backing a Tensor +// and the device (e.g., GPU or host) whose memory the addresses refer to. +class TF_BufferAndDevice { + public: + explicit TF_BufferAndDevice(tensorflow::TensorBuffer* buffer); + TF_BufferAndDevice(tensorflow::TensorBuffer* buffer, TF_Session* session, + int device_index); + ~TF_BufferAndDevice(); + + tensorflow::TensorBuffer* buffer() const { return buffer_; } + tensorflow::Device* device() const { + if (device_owner_ == nullptr) return nullptr; + return device_owner_->devices[device_index_]; + } + bool on_cpu() const { return device() == nullptr; } + + private: + tensorflow::TensorBuffer* buffer_; + TF_Session* device_owner_; + const int device_index_; +}; + namespace tensorflow { class TensorCApi { |