aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tooling_util.cc
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-08-08 11:03:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 11:07:58 -0700
commitb93ba55c00df027fcd3b00f025eed4c9c487de6e (patch)
treedc17d4ce7964545d448532e7ae12d31db0e453e0 /tensorflow/contrib/lite/toco/tooling_util.cc
parent4a4ae62c75f1de3455c3adea96802d22c7e986e3 (diff)
Allow empty shapes in certain cases within toco
PiperOrigin-RevId: 207913842
Diffstat (limited to 'tensorflow/contrib/lite/toco/tooling_util.cc')
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc43
1 files changed, 28 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 80df09eb08..2ad2719811 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -602,14 +602,33 @@ void UnextendShape(Shape* shape, int new_shape_size) {
shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction);
}
-bool IsValid(const Shape& shape) {
+// In general, zero-sized dimensions are disallowed, but there are exceptions,
+// e.g., if the tensor data itself represents a scalar (rank 0) shape, its
+// shape will have dimensions [0]. CheckNonEmptyShapeDimensions is more
+// strict, and is appropriate for ops and comparisons where an empty shape
+// doesn't make sense.
+template <typename Dims>
+void CheckValidShapeDimensions(const Dims& dims) {
+ if (dims.size() == 1 && dims[0] == 0) {
+ return;
+ }
+ for (const auto& dim : dims) {
+ CHECK_GE(dim, 1);
+ }
+}
+
+void CheckValidShape(const Shape& shape) {
+ CheckValidShapeDimensions(shape.dims());
+}
+
+bool IsNonEmpty(const Shape& shape) {
for (int i = 0; i < shape.dimensions_count(); ++i) {
if (shape.dims(i) < 1) return false;
}
return true;
}
-void CheckShapeDimensions(const Shape& shape) {
+void CheckNonEmptyShapeDimensions(const Shape& shape) {
for (int i = 0; i < shape.dimensions_count(); ++i) {
CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i
<< ". shape = " << ShapeToString(shape);
@@ -617,8 +636,8 @@ void CheckShapeDimensions(const Shape& shape) {
}
bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) {
- CheckShapeDimensions(shape0);
- CheckShapeDimensions(shape1);
+ CheckNonEmptyShapeDimensions(shape0);
+ CheckNonEmptyShapeDimensions(shape1);
const Shape* longer = &shape0;
const Shape* shorter = &shape1;
@@ -645,8 +664,8 @@ bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) {
}
bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) {
- CheckShapeDimensions(shape0);
- CheckShapeDimensions(shape1);
+ CheckNonEmptyShapeDimensions(shape0);
+ CheckNonEmptyShapeDimensions(shape1);
const Shape* longer = &shape0;
const Shape* shorter = &shape1;
@@ -683,9 +702,9 @@ bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) {
}
int RequiredBufferSizeForShape(const Shape& shape) {
+ CheckValidShape(shape);
int max_offset = 1;
for (const auto& dim : shape.dims()) {
- CHECK_GE(dim, 1);
max_offset *= dim;
}
return max_offset;
@@ -946,13 +965,7 @@ void CheckEachArray(const Model& model) {
// shape.
CHECK(array->has_shape());
// Constant buffer should has a valid shape.
- bool is_scalar =
- array->shape().dimensions_count() == 1 && array->shape().dims(0) == 0;
- if (!is_scalar) {
- for (int d : array->shape().dims()) {
- CHECK_GE(d, 1);
- }
- }
+ CheckValidShape(array->shape());
// The shape flat-size should agree with the buffer length.
CHECK_EQ(array->buffer->Length(),
RequiredBufferSizeForShape(array->shape()));
@@ -1544,8 +1557,8 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
if (!input_array.has_shape()) {
if (input_array_proto.has_shape()) {
auto& input_array_dims = *input_array.mutable_shape()->mutable_dims();
+ CheckValidShapeDimensions(input_array_proto.shape().dims());
for (auto dim : input_array_proto.shape().dims()) {
- CHECK_GE(dim, 1);
input_array_dims.push_back(dim);
}
}