aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc27
1 files changed, 15 insertions, 12 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc
index 22c258cec5..e9f24a29ab 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc
@@ -24,15 +24,17 @@ limitations under the License.
namespace toco {
-bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ShuffleFCWeights::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
Operator* op = model->operators[op_index].get();
if (op->type != OperatorType::kFullyConnected) {
- return false;
+ return ::tensorflow::Status::OK();
}
FullyConnectedOperator* fc_op = static_cast<FullyConnectedOperator*>(op);
// Exit if this FC op already has shuffled weights
if (fc_op->weights_format != FullyConnectedWeightsFormat::kDefault) {
- return false;
+ return ::tensorflow::Status::OK();
}
const Array& input_array = model->GetArray(fc_op->inputs[0]);
const string& weights_name = fc_op->inputs[1];
@@ -46,11 +48,11 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
output_array.data_type != ArrayDataType::kInt16 ||
!input_array.quantization_params || !weights_array.quantization_params ||
!output_array.quantization_params) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Exit if the shapes aren't known
if (!input_array.has_shape() || !weights_array.has_shape()) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Exit if, based on the known shapes, this FC op is not a GEMV.
// The shuffling of FC weights is only useful to enable fast GEMV paths.
@@ -64,7 +66,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
"the input shape is not 1D or 2D (possibly with additional inner "
"dimensions of size 1)",
LogName(*op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
if (input_shape.dims(0) != 1 && input_shape.dims(0) != 4) {
@@ -73,7 +75,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
"the input shape's leading dimension, i.e. the 'batch size', is not "
"equal to 1 or 4",
LogName(*op));
- return false;
+ return ::tensorflow::Status::OK();
}
// Exit if the weights shape isn't an integral multiple of the shuffled
// block shape, 4x16. We don't want to have to write code dealing with
@@ -88,7 +90,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
// two.
const Shape& weights_shape = weights_array.shape();
if (weights_shape.dimensions_count() != 2) {
- return false;
+ return ::tensorflow::Status::OK();
}
const int rows = weights_shape.dims(0);
const int cols = weights_shape.dims(1);
@@ -97,11 +99,11 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
"Not applying experimental shuffling to the weights of %s because its "
"shape isn't a multiple of the shuffling block shape, 4x16",
LogName(*op));
- return false;
+ return ::tensorflow::Status::OK();
}
// Exit if the weights aren't already a constant array.
if (!weights_array.buffer) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Exit if the weights are used by more than one op.
if (CountOpsWithInput(*model, weights_name) != 1) {
@@ -109,7 +111,7 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
"Not applying experimental shuffling to the weights of %s because that "
"array is consumed by other operators",
LogName(*op));
- return false;
+ return ::tensorflow::Status::OK();
}
// Compute the shuffled weights
auto& weights_data =
@@ -152,7 +154,8 @@ bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
shuffled_input_workspace_array.GetOrCreateQuantizationParams() =
input_array.GetQuantizationParams();
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco