aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tooling_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/tooling_util.cc')
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc67
1 files changed, 47 insertions, 20 deletions
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 98e416b76e..3a4542f522 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -356,6 +356,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min
HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum
HANDLE_OPERATORTYPENAME_CASE(Neg)
+ HANDLE_OPERATORTYPENAME_CASE(OneHot)
HANDLE_OPERATORTYPENAME_CASE(Pack)
HANDLE_OPERATORTYPENAME_CASE(Pad)
HANDLE_OPERATORTYPENAME_CASE(PadV2)
@@ -402,6 +403,8 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Any)
HANDLE_OPERATORTYPENAME_CASE(LogicalAnd)
HANDLE_OPERATORTYPENAME_CASE(LogicalNot)
+ HANDLE_OPERATORTYPENAME_CASE(LogicalOr)
+ HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE
@@ -599,14 +602,33 @@ void UnextendShape(Shape* shape, int new_shape_size) {
shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction);
}
-bool IsValid(const Shape& shape) {
+// In general, zero-sized dimensions are disallowed, but there are exceptions,
+// e.g., if the tensor data itself represents a scalar (rank 0) shape, its
+// shape will have dimensions [0]. CheckNonEmptyShapeDimensions is more
+// strict, and is appropriate for ops and comparisons where an empty shape
+// doesn't make sense.
+template <typename Dims>
+void CheckValidShapeDimensions(const Dims& dims) {
+ if (dims.size() == 1 && dims[0] == 0) {
+ return;
+ }
+ for (const auto& dim : dims) {
+ CHECK_GE(dim, 1);
+ }
+}
+
+void CheckValidShape(const Shape& shape) {
+ CheckValidShapeDimensions(shape.dims());
+}
+
+bool IsNonEmpty(const Shape& shape) {
for (int i = 0; i < shape.dimensions_count(); ++i) {
if (shape.dims(i) < 1) return false;
}
return true;
}
-void CheckShapeDimensions(const Shape& shape) {
+void CheckNonEmptyShapeDimensions(const Shape& shape) {
for (int i = 0; i < shape.dimensions_count(); ++i) {
CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i
<< ". shape = " << ShapeToString(shape);
@@ -614,8 +636,8 @@ void CheckShapeDimensions(const Shape& shape) {
}
bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) {
- CheckShapeDimensions(shape0);
- CheckShapeDimensions(shape1);
+ CheckNonEmptyShapeDimensions(shape0);
+ CheckNonEmptyShapeDimensions(shape1);
const Shape* longer = &shape0;
const Shape* shorter = &shape1;
@@ -642,8 +664,8 @@ bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1) {
}
bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) {
- CheckShapeDimensions(shape0);
- CheckShapeDimensions(shape1);
+ CheckNonEmptyShapeDimensions(shape0);
+ CheckNonEmptyShapeDimensions(shape1);
const Shape* longer = &shape0;
const Shape* shorter = &shape1;
@@ -680,9 +702,9 @@ bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1) {
}
int RequiredBufferSizeForShape(const Shape& shape) {
+ CheckValidShape(shape);
int max_offset = 1;
for (const auto& dim : shape.dims()) {
- CHECK_GE(dim, 1);
max_offset *= dim;
}
return max_offset;
@@ -943,13 +965,7 @@ void CheckEachArray(const Model& model) {
// shape.
CHECK(array->has_shape());
// Constant buffer should has a valid shape.
- bool is_scalar =
- array->shape().dimensions_count() == 1 && array->shape().dims(0) == 0;
- if (!is_scalar) {
- for (int d : array->shape().dims()) {
- CHECK_GE(d, 1);
- }
- }
+ CheckValidShape(array->shape());
// The shape flat-size should agree with the buffer length.
CHECK_EQ(array->buffer->Length(),
RequiredBufferSizeForShape(array->shape()));
@@ -1541,8 +1557,8 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
if (!input_array.has_shape()) {
if (input_array_proto.has_shape()) {
auto& input_array_dims = *input_array.mutable_shape()->mutable_dims();
+ CheckValidShapeDimensions(input_array_proto.shape().dims());
for (auto dim : input_array_proto.shape().dims()) {
- CHECK_GE(dim, 1);
input_array_dims.push_back(dim);
}
}
@@ -1617,11 +1633,12 @@ void CheckIsReadyForQuantization(const Model& model) {
<< "Array " << input << ", which is an input to the "
<< HelpfulOperatorTypeName(*op) << " operator producing the output "
<< "array " << op->outputs[0] << ", is lacking min/max data, "
- << "which is necessary for quantization. Either target a "
- << "non-quantized output format, or change the input graph to "
- << "contain min/max information, or pass --default_ranges_min= and "
- << "--default_ranges_max= if you do not care about the accuracy of "
- << "results.";
+ << "which is necessary for quantization. If accuracy matters, either "
+ << "target a non-quantized output format, or run quantized training "
+ << "with your model from a floating point checkpoint to change the "
+ << "input graph to contain min/max information. If you don't care "
+ << "about accuracy, you can pass --default_ranges_min= and "
+ << "--default_ranges_max= for easy experimentation.";
}
}
}
@@ -2261,4 +2278,14 @@ void UndoWeightsShuffling(Model* model) {
}
}
+void CopyMinMaxAndQuantizationRelatedFields(const Array& src, Array* dst) {
+ if (src.minmax) {
+ dst->GetOrCreateMinMax() = src.GetMinMax();
+ }
+ if (src.quantization_params) {
+ dst->GetOrCreateQuantizationParams() = src.GetQuantizationParams();
+ }
+ dst->narrow_range = src.narrow_range;
+}
+
} // namespace toco