aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/tensor_util.cc
diff options
context:
space:
mode:
authorGravatar Christopher Olston <olston@google.com>2017-03-07 12:57:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-07 13:08:04 -0800
commitb59b9043afd453d952dc6ae829fa05f68408e3b6 (patch)
treeadb2c4286a40bd183aef5c69b60a0fbaf2332d33 /tensorflow/core/framework/tensor_util.cc
parent59b3144f8bd0352326dbcc05acfd8c4b7848ab91 (diff)
Create non-crashy versions of Concat() and Split(), to use in serving code.
Change: 149454449
Diffstat (limited to 'tensorflow/core/framework/tensor_util.cc')
-rw-r--r--tensorflow/core/framework/tensor_util.cc70
1 files changed, 52 insertions, 18 deletions
diff --git a/tensorflow/core/framework/tensor_util.cc b/tensorflow/core/framework/tensor_util.cc
index 6d9ae6a350..9628d002dd 100644
--- a/tensorflow/core/framework/tensor_util.cc
+++ b/tensorflow/core/framework/tensor_util.cc
@@ -43,22 +43,41 @@ Tensor DeepCopy(const Tensor& other) {
}
Tensor Concat(const gtl::ArraySlice<Tensor>& tensors) {
- CHECK_GT(tensors.size(), size_t{0});
+ Tensor result;
+ TF_CHECK_OK(TryConcat(tensors, &result));
+ return result;
+}
+
+Status TryConcat(const gtl::ArraySlice<Tensor>& tensors, Tensor* result) {
+ if (tensors.empty()) {
+ return errors::InvalidArgument("Cannot concatenate zero tensors");
+ }
int64 total_dim0_size = 0;
for (const Tensor& tensor : tensors) {
- CHECK_GT(tensor.dims(), 0);
+ if (tensor.dims() == 0) {
+ return errors::InvalidArgument(
+ "Cannot concatenate a zero-dimensional tensor");
+ }
total_dim0_size += tensor.dim_size(0);
}
TensorShape shape = tensors[0].shape();
shape.set_dim(0, total_dim0_size);
- Tensor result = Tensor(tensors[0].dtype(), shape);
+
+ const DataType dtype = tensors[0].dtype();
+ for (int i = 1; i < tensors.size(); ++i) {
+ if (tensors[i].dtype() != dtype) {
+ return errors::InvalidArgument(
+ "Cannot concatenate tensors that have different data types");
+ }
+ }
+ *result = Tensor(dtype, shape);
// We use StringPiece as a convenient map over the tensor buffer,
// but we cast the type to get to the underlying buffer to do the
// copy.
- StringPiece to_data = result.tensor_data();
+ StringPiece to_data = result->tensor_data();
- if (DataTypeCanUseMemcpy(result.dtype())) {
+ if (DataTypeCanUseMemcpy(dtype)) {
int64 offset = 0;
for (const Tensor& tensor : tensors) {
StringPiece from_data = tensor.tensor_data();
@@ -69,14 +88,16 @@ Tensor Concat(const gtl::ArraySlice<Tensor>& tensors) {
offset += from_data.size();
}
} else {
- CHECK_EQ(DT_STRING, result.dtype());
+ if (dtype != DT_STRING) {
+ return errors::Internal("Unexpected data type");
+ }
string* to_strings =
reinterpret_cast<string*>(const_cast<char*>(to_data.data()));
int64 offset = 0;
for (const Tensor& tensor : tensors) {
auto from_strings = tensor.flat<string>();
- CHECK_LE(offset + tensor.NumElements(), result.NumElements());
+ CHECK_LE(offset + tensor.NumElements(), result->NumElements());
for (int i = 0; i < tensor.NumElements(); ++i) {
to_strings[offset + i] = from_strings(i);
}
@@ -85,19 +106,30 @@ Tensor Concat(const gtl::ArraySlice<Tensor>& tensors) {
}
}
- return result;
+ return Status::OK();
}
std::vector<Tensor> Split(const Tensor& tensor,
const gtl::ArraySlice<int64>& sizes) {
- CHECK_GT(tensor.dims(), 0);
+ std::vector<Tensor> result;
+ TF_CHECK_OK(TrySplit(tensor, sizes, &result));
+ return result;
+}
+
+Status TrySplit(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes,
+ std::vector<Tensor>* result) {
+ if (tensor.dims() == 0) {
+ return errors::InvalidArgument("Cannot split a zero-dimensional tensor");
+ }
int64 total_size = 0;
for (int64 size : sizes) {
total_size += size;
}
- CHECK_EQ(total_size, tensor.dim_size(0));
-
- std::vector<Tensor> result;
+ if (total_size != tensor.dim_size(0)) {
+ return errors::InvalidArgument(
+ "The values in 'sizes' do not sum to the zeroth-dimension size of "
+ "'tensor'");
+ }
StringPiece from_data = tensor.tensor_data();
@@ -106,8 +138,8 @@ std::vector<Tensor> Split(const Tensor& tensor,
for (int64 size : sizes) {
TensorShape shape = tensor.shape();
shape.set_dim(0, size);
- result.emplace_back(tensor.dtype(), shape);
- Tensor* split = &result[result.size() - 1];
+ result->emplace_back(tensor.dtype(), shape);
+ Tensor* split = &(*result)[result->size() - 1];
// We use StringPiece as a convenient map over the tensor buffer,
// but we cast the type to get to the underlying buffer to do the
@@ -120,15 +152,17 @@ std::vector<Tensor> Split(const Tensor& tensor,
offset += to_data.size();
}
} else {
- CHECK_EQ(DT_STRING, tensor.dtype());
+ if (tensor.dtype() != DT_STRING) {
+ return errors::Internal("Unexpected data type");
+ }
auto from_strings = tensor.flat<string>();
int64 offset = 0;
for (int64 size : sizes) {
TensorShape shape = tensor.shape();
shape.set_dim(0, size);
- result.emplace_back(tensor.dtype(), shape);
- Tensor& split = result[result.size() - 1];
+ result->emplace_back(tensor.dtype(), shape);
+ Tensor& split = (*result)[result->size() - 1];
string* to_strings = reinterpret_cast<string*>(
const_cast<char*>(split.tensor_data().data()));
@@ -141,7 +175,7 @@ std::vector<Tensor> Split(const Tensor& tensor,
}
}
- return result;
+ return Status::OK();
}
} // namespace tensor