aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/c/c_api.cc')
-rw-r--r--tensorflow/c/c_api.cc93
1 files changed, 71 insertions, 22 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 371264ef6c..f5013c3f6a 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/loader.h"
#endif
#include "tensorflow/c/c_api_internal.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/log_memory.h"
@@ -79,6 +80,7 @@ using tensorflow::TensorId;
using tensorflow::TensorShape;
using tensorflow::TensorShapeProto;
using tensorflow::error::Code;
+using tensorflow::errors::FailedPrecondition;
using tensorflow::errors::InvalidArgument;
using tensorflow::gtl::ArraySlice;
using tensorflow::mutex_lock;
@@ -179,6 +181,26 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in,
} // namespace
+TF_BufferAndDevice::TF_BufferAndDevice(TensorBuffer* buffer)
+ : buffer_(buffer), device_owner_(nullptr), device_index_(-1) {}
+
+TF_BufferAndDevice::TF_BufferAndDevice(TensorBuffer* buffer,
+ TF_Session* session, int device_index)
+ : buffer_(buffer), device_owner_(session), device_index_(device_index) {
+ mutex_lock l(device_owner_->mu);
+ device_owner_->num_outstanding_buffers++;
+}
+
+TF_BufferAndDevice::~TF_BufferAndDevice() {
+ buffer_->Unref();
+ if (device_owner_ != nullptr) {
+ mutex_lock l(device_owner_->mu);
+ device_owner_->num_outstanding_buffers--;
+ }
+}
+
+TF_Tensor::~TF_Tensor() { delete buffer; }
+
TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
int num_dims, size_t len) {
void* data = allocate_tensor("TF_AllocateTensor", len);
@@ -211,33 +233,35 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
buf->deallocator_ = deallocator;
buf->deallocator_arg_ = deallocator_arg;
}
- return new TF_Tensor{dtype, TensorShape(dimvec), buf};
+ return new TF_Tensor{dtype, TensorShape(dimvec), new TF_BufferAndDevice(buf)};
}
TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) {
// It is safe to move the Tensor if and only if we own the unique reference to
// it. In that case, we might as well not delete and reallocate, but a future
// implementation might need to do so.
- if (tensor->buffer->RefCountIsOne() &&
- tensor->buffer->root_buffer()->RefCountIsOne() &&
- tensor->buffer->OwnsMemory()) {
+ TensorBuffer* buf = tensor->buffer->buffer();
+ if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() &&
+ buf->OwnsMemory()) {
return tensor;
}
return nullptr;
}
-void TF_DeleteTensor(TF_Tensor* t) {
- t->buffer->Unref();
- delete t;
-}
+void TF_DeleteTensor(TF_Tensor* t) { delete t; }
TF_DataType TF_TensorType(const TF_Tensor* t) { return t->dtype; }
int TF_NumDims(const TF_Tensor* t) { return t->shape.dims(); }
int64_t TF_Dim(const TF_Tensor* t, int dim_index) {
return static_cast<int64_t>(t->shape.dim_size(dim_index));
}
-size_t TF_TensorByteSize(const TF_Tensor* t) { return t->buffer->size(); }
-void* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); }
+size_t TF_TensorByteSize(const TF_Tensor* t) {
+ return t->buffer->buffer()->size();
+}
+void* TF_TensorData(const TF_Tensor* t) {
+ if (t->buffer->on_cpu()) return t->buffer->buffer()->data();
+ return nullptr;
+}
// --------------------------------------------------------------------------
size_t TF_StringEncode(const char* src, size_t src_len, char* dst,
@@ -396,7 +420,8 @@ namespace tensorflow {
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
if (src->dtype != TF_STRING) {
- *dst = TensorCApi::MakeTensor(src->dtype, src->shape, src->buffer);
+ *dst =
+ TensorCApi::MakeTensor(src->dtype, src->shape, src->buffer->buffer());
return Status::OK();
}
// TF_STRING tensors require copying since Tensor class expects a sequence of
@@ -437,7 +462,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src) {
TensorBuffer* buf = TensorCApi::Buffer(src);
buf->Ref();
return new TF_Tensor{static_cast<TF_DataType>(src.dtype()), src.shape(),
- buf};
+ new TF_BufferAndDevice(buf)};
}
// DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly
// encoded sequence of strings.
@@ -2119,6 +2144,17 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
// TF_Session functions ----------------------------------------------
+TF_Session::TF_Session(tensorflow::Session* s, TF_Graph* g)
+ : session(s),
+ graph(g),
+ last_num_graph_nodes(0),
+ device_mgr(nullptr),
+ num_outstanding_buffers(0) {
+ if (s->LocalDeviceManager(&device_mgr).ok()) {
+ devices = device_mgr->ListDevices();
+ }
+}
+
TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
TF_Status* status) {
Session* session;
@@ -2149,7 +2185,6 @@ TF_Session* TF_LoadSessionFromSavedModel(
return nullptr;
#else
mutex_lock l(graph->mu);
-
if (!graph->name_map.empty()) {
status->status = InvalidArgument("Graph is non-empty.");
return nullptr;
@@ -2203,16 +2238,30 @@ void TF_CloseSession(TF_Session* s, TF_Status* status) {
}
void TF_DeleteSession(TF_Session* s, TF_Status* status) {
- status->status = Status::OK();
- TF_Graph* const graph = s->graph;
- if (graph != nullptr) {
- graph->mu.lock();
- graph->num_sessions -= 1;
- const bool del = graph->delete_requested && graph->num_sessions == 0;
- graph->mu.unlock();
- if (del) delete graph;
+ {
+ mutex_lock l(s->mu);
+ if (s->num_outstanding_buffers > 0) {
+ // This can probably be relaxed: An alternative might be to mark
+ // this session for deletion and do the actual delete only when
+ // the last TF_BufferAndDevice has been deleted.
+ status->status = FailedPrecondition(
+ s->num_outstanding_buffers,
+ " TF_Tensor objects with memory backed by a device "
+ "owned by this TF_Session are still alive. Release "
+ "them using TF_DeleteTensor and retry");
+ return;
+ }
+ status->status = Status::OK();
+ TF_Graph* const graph = s->graph;
+ if (graph != nullptr) {
+ graph->mu.lock();
+ graph->num_sessions -= 1;
+ const bool del = graph->delete_requested && graph->num_sessions == 0;
+ graph->mu.unlock();
+ if (del) delete graph;
+ }
+ delete s->session;
}
- delete s->session;
delete s;
}