aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_internal.h
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2017-07-27 14:27:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-27 14:31:43 -0700
commit22651083406ca01ac9d481e3367a3510d25f88cd (patch)
treebda0f3289d50f383eb1e632a595ab97258e35162 /tensorflow/c/c_api_internal.h
parent613bf1c7c1f8dfceed34fc85f2c71dd00432651e (diff)
C API: Groundwork for experimenting with TF_Tensor in device memory.
TF_Tensor objects are always backed by host memory. This commit lays the groundwork for allowing TF_Tensor objects to refer to tensor data on device (e.g., GPU) memory. PiperOrigin-RevId: 163388079
Diffstat (limited to 'tensorflow/c/c_api_internal.h')
-rw-r--r--tensorflow/c/c_api_internal.h53
1 files changed, 45 insertions, 8 deletions
diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h
index d077ad264b..687e18aace 100644
--- a/tensorflow/c/c_api_internal.h
+++ b/tensorflow/c/c_api_internal.h
@@ -18,19 +18,25 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
-#include <vector>
#include <unordered_map>
+#include <vector>
+#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/public/session.h"
-#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/common_runtime/shape_refiner.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+class Device;
+class DeviceMgr;
+} // namespace tensorflow
+class TF_BufferAndDevice;
// Internal structures used by the C API. These are likely to change and should
// not be depended on.
@@ -40,9 +46,11 @@ struct TF_Status {
};
struct TF_Tensor {
+ ~TF_Tensor();
+
TF_DataType dtype;
tensorflow::TensorShape shape;
- tensorflow::TensorBuffer* buffer;
+ TF_BufferAndDevice* buffer;
};
struct TF_SessionOptions {
@@ -100,12 +108,19 @@ struct TF_Operation {
};
struct TF_Session {
- TF_Session(tensorflow::Session* s, TF_Graph* g)
- : session(s), graph(g), last_num_graph_nodes(0) {}
+ TF_Session(tensorflow::Session* s, TF_Graph* g);
+
tensorflow::Session* session;
TF_Graph* graph;
+
tensorflow::mutex mu;
int last_num_graph_nodes;
+
+ // NOTE(ashankar): Experimental fields to help keep the
+ // buffers of a TF_Tensor pinned in device memory.
+ const tensorflow::DeviceMgr* device_mgr; // Owned by session.
+ std::vector<tensorflow::Device*> devices; // Owned by device_mgr.
+ int num_outstanding_buffers GUARDED_BY(mu);
};
struct TF_ImportGraphDefOptions {
@@ -116,6 +131,28 @@ struct TF_DeviceList {
std::vector<tensorflow::DeviceAttributes> response;
};
+// TF_BufferAndDevice encapsulates the memory addresses of data backing a Tensor
+// and the device (e.g., GPU or host) whose memory the addresses refer to.
+class TF_BufferAndDevice {
+ public:
+ explicit TF_BufferAndDevice(tensorflow::TensorBuffer* buffer);
+ TF_BufferAndDevice(tensorflow::TensorBuffer* buffer, TF_Session* session,
+ int device_index);
+ ~TF_BufferAndDevice();
+
+ tensorflow::TensorBuffer* buffer() const { return buffer_; }
+ tensorflow::Device* device() const {
+ if (device_owner_ == nullptr) return nullptr;
+ return device_owner_->devices[device_index_];
+ }
+ bool on_cpu() const { return device() == nullptr; }
+
+ private:
+ tensorflow::TensorBuffer* buffer_;
+ TF_Session* device_owner_;
+ const int device_index_;
+};
+
namespace tensorflow {
class TensorCApi {