aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-14 11:40:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-14 11:43:11 -0700
commit8e4c4144817bea5ffd9255df48a78740fdb14f57 (patch)
tree91595cd3f71825b5f54210a8fb735df506bc48fa /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
parent8f7afe01a583058726b03a0d849add35fcde41a3 (diff)
Optimized implementation of transpose conv. Uses an im2col array and GEMM, similar to conv.
PiperOrigin-RevId: 200592004
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc16
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 170a499d4e..b6f0d96900 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -211,12 +211,6 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
// might as well calculate the output shape and ensure it matches the
// specified one
- // Check if we have already run.
- auto& output_array = model->GetArray(op->outputs[0]);
- if (output_array.has_shape()) {
- return;
- }
-
// SPECIFIED OUTPUT SHAPE
// The below is the specified, or prescribed output shape, _given_ to the
// operator as an input.
@@ -284,7 +278,17 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
// Set the output shape according to the specified output shape.
std::vector<int32> const& specified_output_shape =
specified_output_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ auto& output_array = model->GetArray(op->outputs[0]);
*(output_array.mutable_shape()->mutable_dims()) = specified_output_shape;
+
+ // Set im2col array dimensions if there is one.
+ if (op->outputs.size() == 2) {
+ const int input_depth = weights_shape.dims(3);
+ auto& im2col_array = model->GetArray(op->outputs[1]);
+ im2col_array.copy_shape(
+ Shape{specified_output_shape[0], specified_output_shape[1],
+ specified_output_shape[2], input_depth * kheight * kwidth});
+ }
}
void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {