diff options
author | 2017-03-07 12:57:39 -0800 | |
---|---|---|
committer | 2017-03-07 13:08:04 -0800 | |
commit | b59b9043afd453d952dc6ae829fa05f68408e3b6 (patch) | |
tree | adb2c4286a40bd183aef5c69b60a0fbaf2332d33 /tensorflow/core/framework/tensor_util.cc | |
parent | 59b3144f8bd0352326dbcc05acfd8c4b7848ab91 (diff) |
Create non-crashy versions of Concat() and Split(), to use in serving code.
Change: 149454449
Diffstat (limited to 'tensorflow/core/framework/tensor_util.cc')
-rw-r--r-- | tensorflow/core/framework/tensor_util.cc | 70 |
1 files changed, 52 insertions, 18 deletions
diff --git a/tensorflow/core/framework/tensor_util.cc b/tensorflow/core/framework/tensor_util.cc index 6d9ae6a350..9628d002dd 100644 --- a/tensorflow/core/framework/tensor_util.cc +++ b/tensorflow/core/framework/tensor_util.cc @@ -43,22 +43,41 @@ Tensor DeepCopy(const Tensor& other) { } Tensor Concat(const gtl::ArraySlice<Tensor>& tensors) { - CHECK_GT(tensors.size(), size_t{0}); + Tensor result; + TF_CHECK_OK(TryConcat(tensors, &result)); + return result; +} + +Status TryConcat(const gtl::ArraySlice<Tensor>& tensors, Tensor* result) { + if (tensors.empty()) { + return errors::InvalidArgument("Cannot concatenate zero tensors"); + } int64 total_dim0_size = 0; for (const Tensor& tensor : tensors) { - CHECK_GT(tensor.dims(), 0); + if (tensor.dims() == 0) { + return errors::InvalidArgument( + "Cannot concatenate a zero-dimensional tensor"); + } total_dim0_size += tensor.dim_size(0); } TensorShape shape = tensors[0].shape(); shape.set_dim(0, total_dim0_size); - Tensor result = Tensor(tensors[0].dtype(), shape); + + const DataType dtype = tensors[0].dtype(); + for (int i = 1; i < tensors.size(); ++i) { + if (tensors[i].dtype() != dtype) { + return errors::InvalidArgument( + "Cannot concatenate tensors that have different data types"); + } + } + *result = Tensor(dtype, shape); // We use StringPiece as a convenient map over the tensor buffer, // but we cast the type to get to the underlying buffer to do the // copy. - StringPiece to_data = result.tensor_data(); + StringPiece to_data = result->tensor_data(); - if (DataTypeCanUseMemcpy(result.dtype())) { + if (DataTypeCanUseMemcpy(dtype)) { int64 offset = 0; for (const Tensor& tensor : tensors) { StringPiece from_data = tensor.tensor_data(); @@ -69,14 +88,16 @@ Tensor Concat(const gtl::ArraySlice<Tensor>& tensors) { offset += from_data.size(); } } else { - CHECK_EQ(DT_STRING, result.dtype()); + if (dtype != DT_STRING) { + return errors::Internal("Unexpected data type"); + } string* to_strings = reinterpret_cast<string*>(const_cast<char*>(to_data.data())); int64 offset = 0; for (const Tensor& tensor : tensors) { auto from_strings = tensor.flat<string>(); - CHECK_LE(offset + tensor.NumElements(), result.NumElements()); + CHECK_LE(offset + tensor.NumElements(), result->NumElements()); for (int i = 0; i < tensor.NumElements(); ++i) { to_strings[offset + i] = from_strings(i); } @@ -85,19 +106,30 @@ Tensor Concat(const gtl::ArraySlice<Tensor>& tensors) { } } - return result; + return Status::OK(); } std::vector<Tensor> Split(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes) { - CHECK_GT(tensor.dims(), 0); + std::vector<Tensor> result; + TF_CHECK_OK(TrySplit(tensor, sizes, &result)); + return result; +} + +Status TrySplit(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes, + std::vector<Tensor>* result) { + if (tensor.dims() == 0) { + return errors::InvalidArgument("Cannot split a zero-dimensional tensor"); + } int64 total_size = 0; for (int64 size : sizes) { total_size += size; } - CHECK_EQ(total_size, tensor.dim_size(0)); - - std::vector<Tensor> result; + if (total_size != tensor.dim_size(0)) { + return errors::InvalidArgument( + "The values in 'sizes' do not sum to the zeroth-dimension size of " + "'tensor'"); + } StringPiece from_data = tensor.tensor_data(); @@ -106,8 +138,8 @@ std::vector<Tensor> Split(const Tensor& tensor, for (int64 size : sizes) { TensorShape shape = tensor.shape(); shape.set_dim(0, size); - result.emplace_back(tensor.dtype(), shape); - Tensor* split = &result[result.size() - 1]; + result->emplace_back(tensor.dtype(), shape); + Tensor* split = &(*result)[result->size() - 1]; // We use StringPiece as a convenient map over the tensor buffer, // but we cast the type to get to the underlying buffer to do the @@ -120,15 +152,17 @@ std::vector<Tensor> Split(const Tensor& tensor, offset += to_data.size(); } } else { - CHECK_EQ(DT_STRING, tensor.dtype()); + if (tensor.dtype() != DT_STRING) { + return errors::Internal("Unexpected data type"); + } auto from_strings = tensor.flat<string>(); int64 offset = 0; for (int64 size : sizes) { TensorShape shape = tensor.shape(); shape.set_dim(0, size); - result.emplace_back(tensor.dtype(), shape); - Tensor& split = result[result.size() - 1]; + result->emplace_back(tensor.dtype(), shape); + Tensor& split = (*result)[result->size() - 1]; string* to_strings = reinterpret_cast<string*>( const_cast<char*>(split.tensor_data().data())); @@ -141,7 +175,7 @@ std::vector<Tensor> Split(const Tensor& tensor, } } - return result; + return Status::OK(); } } // namespace tensor |