diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc | 27 |
1 files changed, 15 insertions, 12 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc index 22c258cec5..e9f24a29ab 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc @@ -24,15 +24,17 @@ limitations under the License. namespace toco { -bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ShuffleFCWeights::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; Operator* op = model->operators[op_index].get(); if (op->type != OperatorType::kFullyConnected) { - return false; + return ::tensorflow::Status::OK(); } FullyConnectedOperator* fc_op = static_cast<FullyConnectedOperator*>(op); // Exit if this FC op already has shuffled weights if (fc_op->weights_format != FullyConnectedWeightsFormat::kDefault) { - return false; + return ::tensorflow::Status::OK(); } const Array& input_array = model->GetArray(fc_op->inputs[0]); const string& weights_name = fc_op->inputs[1]; @@ -46,11 +48,11 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) { output_array.data_type != ArrayDataType::kInt16 || !input_array.quantization_params || !weights_array.quantization_params || !output_array.quantization_params) { - return false; + return ::tensorflow::Status::OK(); } // Exit if the shapes aren't known if (!input_array.has_shape() || !weights_array.has_shape()) { - return false; + return ::tensorflow::Status::OK(); } // Exit if, based on the known shapes, this FC op is not a GEMV. // The shuffling of FC weights is only useful to enable fast GEMV paths. @@ -64,7 +66,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) { "the input shape is not 1D or 2D (possibly with additional inner " "dimensions of size 1)", LogName(*op)); - return false; + return ::tensorflow::Status::OK(); } } if (input_shape.dims(0) != 1 && input_shape.dims(0) != 4) { @@ -73,7 +75,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) { "the input shape's leading dimension, i.e. the 'batch size', is not " "equal to 1 or 4", LogName(*op)); - return false; + return ::tensorflow::Status::OK(); } // Exit if the weights shape isn't an integral multiple of the shuffled // block shape, 4x16. We don't want to have to write code dealing with @@ -88,7 +90,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) { // two. const Shape& weights_shape = weights_array.shape(); if (weights_shape.dimensions_count() != 2) { - return false; + return ::tensorflow::Status::OK(); } const int rows = weights_shape.dims(0); const int cols = weights_shape.dims(1); @@ -97,11 +99,11 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) { "Not applying experimental shuffling to the weights of %s because its " "shape isn't a multiple of the shuffling block shape, 4x16", LogName(*op)); - return false; + return ::tensorflow::Status::OK(); } // Exit if the weights aren't already a constant array. if (!weights_array.buffer) { - return false; + return ::tensorflow::Status::OK(); } // Exit if the weights are used by more than one op. if (CountOpsWithInput(*model, weights_name) != 1) { @@ -109,7 +111,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) { "Not applying experimental shuffling to the weights of %s because that " "array is consumed by other operators", LogName(*op)); - return false; + return ::tensorflow::Status::OK(); } // Compute the shuffled weights auto& weights_data = @@ -152,7 +154,8 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) { shuffled_input_workspace_array.GetOrCreateQuantizationParams() = input_array.GetQuantizationParams(); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |