diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-14 11:40:28 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-14 11:43:11 -0700 |
commit | 8e4c4144817bea5ffd9255df48a78740fdb14f57 (patch) | |
tree | 91595cd3f71825b5f54210a8fb735df506bc48fa /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | |
parent | 8f7afe01a583058726b03a0d849add35fcde41a3 (diff) |
Optimized implementation of transpose conv. Uses an im2col array and GEMM, similar to conv.
PiperOrigin-RevId: 200592004
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 170a499d4e..b6f0d96900 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -211,12 +211,6 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { // might as well calculate the output shape and ensure it matches the // specified one - // Check if we have already run. - auto& output_array = model->GetArray(op->outputs[0]); - if (output_array.has_shape()) { - return; - } - // SPECIFIED OUTPUT SHAPE // The below is the specified, or prescribed output shape, _given_ to the // operator as an input. @@ -284,7 +278,17 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { // Set the output shape according to the specified output shape. std::vector<int32> const& specified_output_shape = specified_output_shape_array.GetBuffer<ArrayDataType::kInt32>().data; + auto& output_array = model->GetArray(op->outputs[0]); *(output_array.mutable_shape()->mutable_dims()) = specified_output_shape; + + // Set im2col array dimensions if there is one. + if (op->outputs.size() == 2) { + const int input_depth = weights_shape.dims(3); + auto& im2col_array = model->GetArray(op->outputs[1]); + im2col_array.copy_shape( + Shape{specified_output_shape[0], specified_output_shape[1], + specified_output_shape[2], input_depth * kheight * kwidth}); + } } void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { |