aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/BUILD1
-rw-r--r--tensorflow/c/c_api.cc58
-rw-r--r--tensorflow/c/c_api_internal.h1
3 files changed, 33 insertions, 27 deletions
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 21b09865db..242a628d37 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -30,6 +30,7 @@ tf_cuda_library(
name = "c_api_internal",
srcs = ["c_api.h"],
hdrs = ["c_api_internal.h"],
+ visibility = ["//tensorflow/c:__subpackages__"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index cb074cb2f1..663fec56f1 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -456,6 +456,24 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {
return Status::OK();
}
+// Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
+// result in a zero-sized tensor.
+static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) {
+ static char empty;
+ tensorflow::int64 nelems = 1;
+ std::vector<tensorflow::int64> dims;
+ for (int i = 0; i < shape.dims(); ++i) {
+ dims.push_back(shape.dim_size(i));
+ nelems *= shape.dim_size(i);
+ }
+ CHECK_EQ(nelems, 0);
+ static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
+ "64-bit int types should match in size");
+ return TF_NewTensor(dtype, reinterpret_cast<const int64_t*>(dims.data()),
+ shape.dims(), reinterpret_cast<void*>(&empty), 0,
+ [](void*, size_t, void*) {}, nullptr);
+}
+
// Non-static for testing.
TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
TF_Status* status) {
@@ -464,15 +482,19 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
"attempt to use a tensor with an uninitialized value");
return nullptr;
}
+ if (src.NumElements() == 0) {
+ return EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
+ }
if (src.dtype() == DT_RESOURCE) {
- DCHECK_EQ(0, src.shape().dims()) << src.shape().DebugString();
if (src.shape().dims() != 0) {
- LOG(ERROR) << "Unexpected non-scalar DT_RESOURCE tensor seen (shape: "
- << src.shape().DebugString()
- << "). Please file a bug at "
- "https://github.com/tensorflow/tensorflow/issues/new, "
- "ideally with a "
- "short code snippet that reproduces this error.";
+ status->status = InvalidArgument(
+ "Unexpected non-scalar DT_RESOURCE tensor seen (shape: ",
+ src.shape().DebugString(),
+ "). Please file a bug at "
+ "https://github.com/tensorflow/tensorflow/issues/new, "
+ "ideally with a "
+ "short code snippet that reproduces this error.");
+ return nullptr;
}
const string str = src.scalar<ResourceHandle>()().SerializeAsString();
TF_Tensor* t = TF_AllocateTensor(TF_RESOURCE, {}, 0, str.size());
@@ -536,24 +558,6 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src,
dimvec.size(), base, size, DeleteArray, base);
}
-// Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
-// result in a zero-sized tensor.
-static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) {
- static char empty;
- tensorflow::int64 nelems = 1;
- std::vector<tensorflow::int64> dims;
- for (int i = 0; i < shape.dims(); ++i) {
- dims.push_back(shape.dim_size(i));
- nelems *= shape.dim_size(i);
- }
- CHECK_EQ(nelems, 0);
- static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
- "64-bit int types should match in size");
- return TF_NewTensor(dtype, reinterpret_cast<const int64_t*>(dims.data()),
- shape.dims(), reinterpret_cast<void*>(&empty), 0,
- [](void*, size_t, void*) {}, nullptr);
-}
-
// Helpers for loading a TensorFlow plugin (a .so file).
Status LoadLibrary(const char* library_filename, void** result,
const void** buf, size_t* len);
@@ -629,8 +633,8 @@ static void TF_Run_Helper(
for (int i = 0; i < noutputs; ++i) {
const Tensor& src = outputs[i];
if (!src.IsInitialized() || src.NumElements() == 0) {
- c_outputs[i] = tensorflow::EmptyTensor(
- static_cast<TF_DataType>(src.dtype()), src.shape());
+ c_outputs[i] =
+ EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
continue;
}
c_outputs[i] = TF_TensorFromTensor(src, status);
diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h
index b89acbcf35..89621d8603 100644
--- a/tensorflow/c/c_api_internal.h
+++ b/tensorflow/c/c_api_internal.h
@@ -140,6 +140,7 @@ class TensorCApi {
}
};
+TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
} // end namespace tensorflow
#endif // TENSORFLOW_C_C_API_INTERNAL_H_