aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-23 13:44:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-23 13:48:42 -0700
commit105c7df01b12b77bc17909cfb4a0d0c0aff87571 (patch)
treef4aa40460f4e7c389fee5536e3cf299dbd74fef3 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
parent19ee0605b6eadb516703c37b7ba38e7122a6c51f (diff)
More relaxed size checking for TransposeConv, and miscellaneous bug fixes.
PiperOrigin-RevId: 193977375
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc56
1 files changed, 19 insertions, 37 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index ba244cf5ef..7946492633 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -168,7 +168,9 @@ void ProcessConvOperator(Model* model, ConvOperator* op) {
return;
}
const auto& input_shape = input_array.shape();
- CHECK_EQ(input_shape.dimensions_count(), 4);
+ CHECK(input_shape.dimensions_count() == 4)
+ << "Conv ops require 4D inputs. Input array \"" << op->inputs[0]
+ << "\" is " << input_shape.dimensions_count() << "D.";
const auto& weights_array = model->GetArray(op->inputs[1]);
// Yield until weights dims have been resolved.
@@ -249,12 +251,6 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
<< op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
<< toco::ShapeToString(weights_shape) << ".";
- CHECK(weights_shape.dims(0) == 1 && weights_shape.dims(3) == 1)
- << "TransposeConv weights dimensions must begin and end with 1. Input "
- "weights \""
- << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
- << toco::ShapeToString(weights_shape) << ".";
-
// Compute padding
const int kheight = weights_shape.dims(1);
const int kwidth = weights_shape.dims(2);
@@ -269,9 +265,7 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
LOG(FATAL) << "TransposeConv only supports SAME or VALID padding";
}
- // VALIDATE OUTPUT SHAPE
- // Compute the output shape from the input and weights shapes to verify it
- // agrees with the specified output shape.
+ // VALIDATE some dimensions and set the output shape.
const auto& input_array =
model->GetArray(op->inputs[TransposeConvOperator::DATA_INPUT]);
if (!input_array.has_shape()) {
@@ -283,31 +277,13 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
<< "TransposeConv input shape must have 4 dimensions. Input \""
<< op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
<< toco::ShapeToString(weights_shape) << ".";
+ CHECK_EQ(input_shape.dims(3), weights_shape.dims(0))
+ << "Input shape depth and weight depth do not agree";
- // Compute output shape
- const int input_width = input_shape.dims(2);
- const int input_height = input_shape.dims(1);
- int output_height = op->stride_height * (input_height - 1);
- int output_width = op->stride_width * (input_width - 1);
- if (op->padding.type == PaddingType::kValid) {
- output_height += kheight;
- output_width += kwidth;
- } else if (op->padding.type == PaddingType::kSame) {
- output_height += 1;
- output_width += 1;
- }
-
- CHECK(specified_output_shape_array.GetBuffer<ArrayDataType::kInt32>().data ==
- std::vector<int32>({input_shape.dims(0), output_height, output_width,
- weights_shape.dims(3)}))
- << "Specified output shape: " << ShapeToString(output_array.shape())
- << ", does not agree with shape computed from input data and weights: ["
- << input_shape.dims(0) << ", " << output_height << ", " << output_width
- << ", " << weights_shape.dims(3) << "].";
-
- // SUCCESS: Set the op's output shape according to the specified output shape.
- *(output_array.mutable_shape()->mutable_dims()) =
+ // Set the output shape according to the specified output shape.
+ std::vector<int32> const& specified_output_shape =
specified_output_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ *(output_array.mutable_shape()->mutable_dims()) = specified_output_shape;
}
void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
@@ -1179,6 +1155,11 @@ void ProcessRankOperator(Model* model, RankOperator* op) {
return;
}
+ if (output_array.data_type == ArrayDataType::kNone) {
+ // Yield until the output type has been set by PropagateArrayDataTypes
+ return;
+ }
+
const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
@@ -1200,6 +1181,11 @@ void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) {
return;
}
+ if (output_array.data_type == ArrayDataType::kNone) {
+ // Yield until the output type has been set by PropagateArrayDataTypes
+ return;
+ }
+
const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
@@ -1230,10 +1216,6 @@ void ProcessStackOperator(Model* model, StackOperator* op) {
}
Shape shape = input_array.shape();
- if (shape.dimensions_count() == 0) {
- // Convert 0D scalars to 1D scalars of shape {1}.
- shape.mutable_dims()->push_back(1);
- }
if (!stacked_shape) {
stacked_shape.reset(new Shape(shape));
} else {