diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-23 13:44:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-23 13:48:42 -0700 |
commit | 105c7df01b12b77bc17909cfb4a0d0c0aff87571 (patch) | |
tree | f4aa40460f4e7c389fee5536e3cf299dbd74fef3 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | |
parent | 19ee0605b6eadb516703c37b7ba38e7122a6c51f (diff) |
More relaxed size checking for TransposeConv, and miscellaneous bug fixes.
PiperOrigin-RevId: 193977375
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | 56 |
1 files changed, 19 insertions, 37 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index ba244cf5ef..7946492633 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -168,7 +168,9 @@ void ProcessConvOperator(Model* model, ConvOperator* op) { return; } const auto& input_shape = input_array.shape(); - CHECK_EQ(input_shape.dimensions_count(), 4); + CHECK(input_shape.dimensions_count() == 4) + << "Conv ops require 4D inputs. Input array \"" << op->inputs[0] + << "\" is " << input_shape.dimensions_count() << "D."; const auto& weights_array = model->GetArray(op->inputs[1]); // Yield until weights dims have been resolved. @@ -249,12 +251,6 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape " << toco::ShapeToString(weights_shape) << "."; - CHECK(weights_shape.dims(0) == 1 && weights_shape.dims(3) == 1) - << "TransposeConv weights dimensions must begin and end with 1. Input " - "weights \"" - << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape " - << toco::ShapeToString(weights_shape) << "."; - // Compute padding const int kheight = weights_shape.dims(1); const int kwidth = weights_shape.dims(2); @@ -269,9 +265,7 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { LOG(FATAL) << "TransposeConv only supports SAME or VALID padding"; } - // VALIDATE OUTPUT SHAPE - // Compute the output shape from the input and weights shapes to verify it - // agrees with the specified output shape. + // VALIDATE some dimensions and set the output shape. const auto& input_array = model->GetArray(op->inputs[TransposeConvOperator::DATA_INPUT]); if (!input_array.has_shape()) { @@ -283,31 +277,13 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { << "TransposeConv input shape must have 4 dimensions. Input \"" << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape " << toco::ShapeToString(weights_shape) << "."; + CHECK_EQ(input_shape.dims(3), weights_shape.dims(0)) + << "Input shape depth and weight depth do not agree"; - // Compute output shape - const int input_width = input_shape.dims(2); - const int input_height = input_shape.dims(1); - int output_height = op->stride_height * (input_height - 1); - int output_width = op->stride_width * (input_width - 1); - if (op->padding.type == PaddingType::kValid) { - output_height += kheight; - output_width += kwidth; - } else if (op->padding.type == PaddingType::kSame) { - output_height += 1; - output_width += 1; - } - - CHECK(specified_output_shape_array.GetBuffer<ArrayDataType::kInt32>().data == - std::vector<int32>({input_shape.dims(0), output_height, output_width, - weights_shape.dims(3)})) - << "Specified output shape: " << ShapeToString(output_array.shape()) - << ", does not agree with shape computed from input data and weights: [" - << input_shape.dims(0) << ", " << output_height << ", " << output_width - << ", " << weights_shape.dims(3) << "]."; - - // SUCCESS: Set the op's output shape according to the specified output shape. - *(output_array.mutable_shape()->mutable_dims()) = + // Set the output shape according to the specified output shape. + std::vector<int32> const& specified_output_shape = specified_output_shape_array.GetBuffer<ArrayDataType::kInt32>().data; + *(output_array.mutable_shape()->mutable_dims()) = specified_output_shape; } void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { @@ -1179,6 +1155,11 @@ void ProcessRankOperator(Model* model, RankOperator* op) { return; } + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes + return; + } + const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // Yield until input dims have been resolved. @@ -1200,6 +1181,11 @@ void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) { return; } + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes + return; + } + const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // Yield until input dims have been resolved. @@ -1230,10 +1216,6 @@ void ProcessStackOperator(Model* model, StackOperator* op) { } Shape shape = input_array.shape(); - if (shape.dimensions_count() == 0) { - // Convert 0D scalars to 1D scalars of shape {1}. - shape.mutable_dims()->push_back(1); - } if (!stacked_shape) { stacked_shape.reset(new Shape(shape)); } else { |