diff options
-rw-r--r-- | tensorflow/c/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/c/c_api.cc | 58 | ||||
-rw-r--r-- | tensorflow/c/c_api_internal.h | 1 |
3 files changed, 33 insertions, 27 deletions
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 21b09865db..242a628d37 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -30,6 +30,7 @@ tf_cuda_library( name = "c_api_internal", srcs = ["c_api.h"], hdrs = ["c_api_internal.h"], + visibility = ["//tensorflow/c:__subpackages__"], deps = select({ "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib_lite", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index cb074cb2f1..663fec56f1 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -456,6 +456,24 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { return Status::OK(); } +// Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to +// result in a zero-sized tensor. +static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) { + static char empty; + tensorflow::int64 nelems = 1; + std::vector<tensorflow::int64> dims; + for (int i = 0; i < shape.dims(); ++i) { + dims.push_back(shape.dim_size(i)); + nelems *= shape.dim_size(i); + } + CHECK_EQ(nelems, 0); + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + return TF_NewTensor(dtype, reinterpret_cast<const int64_t*>(dims.data()), + shape.dims(), reinterpret_cast<void*>(&empty), 0, + [](void*, size_t, void*) {}, nullptr); +} + // Non-static for testing. TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, TF_Status* status) { @@ -464,15 +482,19 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, "attempt to use a tensor with an uninitialized value"); return nullptr; } + if (src.NumElements() == 0) { + return EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape()); + } if (src.dtype() == DT_RESOURCE) { - DCHECK_EQ(0, src.shape().dims()) << src.shape().DebugString(); if (src.shape().dims() != 0) { - LOG(ERROR) << "Unexpected non-scalar DT_RESOURCE tensor seen (shape: " - << src.shape().DebugString() - << "). Please file a bug at " - "https://github.com/tensorflow/tensorflow/issues/new, " - "ideally with a " - "short code snippet that reproduces this error."; + status->status = InvalidArgument( + "Unexpected non-scalar DT_RESOURCE tensor seen (shape: ", + src.shape().DebugString(), + "). Please file a bug at " + "https://github.com/tensorflow/tensorflow/issues/new, " + "ideally with a " + "short code snippet that reproduces this error."); + return nullptr; } const string str = src.scalar<ResourceHandle>()().SerializeAsString(); TF_Tensor* t = TF_AllocateTensor(TF_RESOURCE, {}, 0, str.size()); @@ -536,24 +558,6 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, dimvec.size(), base, size, DeleteArray, base); } -// Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to -// result in a zero-sized tensor. -static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) { - static char empty; - tensorflow::int64 nelems = 1; - std::vector<tensorflow::int64> dims; - for (int i = 0; i < shape.dims(); ++i) { - dims.push_back(shape.dim_size(i)); - nelems *= shape.dim_size(i); - } - CHECK_EQ(nelems, 0); - static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), - "64-bit int types should match in size"); - return TF_NewTensor(dtype, reinterpret_cast<const int64_t*>(dims.data()), - shape.dims(), reinterpret_cast<void*>(&empty), 0, - [](void*, size_t, void*) {}, nullptr); -} - // Helpers for loading a TensorFlow plugin (a .so file). Status LoadLibrary(const char* library_filename, void** result, const void** buf, size_t* len); @@ -629,8 +633,8 @@ static void TF_Run_Helper( for (int i = 0; i < noutputs; ++i) { const Tensor& src = outputs[i]; if (!src.IsInitialized() || src.NumElements() == 0) { - c_outputs[i] = tensorflow::EmptyTensor( - static_cast<TF_DataType>(src.dtype()), src.shape()); + c_outputs[i] = + EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape()); continue; } c_outputs[i] = TF_TensorFromTensor(src, status); diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index b89acbcf35..89621d8603 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -140,6 +140,7 @@ class TensorCApi { } }; +TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); } // end namespace tensorflow #endif // TENSORFLOW_C_C_API_INTERNAL_H_ |