diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc b/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc index dd9e26e68b..e19527968d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc @@ -22,7 +22,10 @@ limitations under the License. namespace toco { -bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { +::tensorflow::Status UnpartitionEmbeddingLookup::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; // Collapses a partitioned tf.nn.embedding_lookup back into a single Gather. // https://www.tensorflow.org/api_docs/python/tf/nn/embedding_lookup // This transform attempts to identify the len(params) > 1 case and collapse @@ -47,7 +50,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { // First look for the final DynamicStitch. auto op_it = model->operators.begin() + op_index; if (op_it->get()->type != OperatorType::kDynamicStitch) { - return false; + return ::tensorflow::Status::OK(); } auto* stitch_op = static_cast<DynamicStitchOperator*>(op_it->get()); @@ -72,7 +75,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { "Skipping because indices input %s into " "%s is unexpected", LogName(*op), LogName(*stitch_op)); - return false; + return ::tensorflow::Status::OK(); } if (!indices_partition_op) { indices_partition_op = static_cast<DynamicPartitionOperator*>(op); @@ -83,7 +86,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { "Skipping because indices input %s into " "%s is from a different source op than others", LogName(*op), LogName(*stitch_op)); - return false; + return ::tensorflow::Status::OK(); } } } @@ -92,12 +95,12 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { // The data for the indices must be a constant range of the array shape. if (!IsConstantParameterArray(*model, indices_partition_op->inputs[0])) { AddMessageF("Skipping because indices partition data is non-constant"); - return false; + return ::tensorflow::Status::OK(); } auto& indices_data_array = model->GetArray(indices_partition_op->inputs[0]); if (indices_data_array.data_type == ArrayDataType::kNone) { // Yield until data types are propagated. - return false; + return ::tensorflow::Status::OK(); } CHECK(indices_data_array.data_type == ArrayDataType::kInt32) << "Indices partition inputs must be int32"; @@ -117,7 +120,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { "Skipping because data input %s into %s " "is unexpected", LogName(*op), LogName(*stitch_op)); - return false; + return ::tensorflow::Status::OK(); } gather_ops.push_back(static_cast<GatherOperator*>(op)); } @@ -132,7 +135,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { "Skipping because data input %s into " "%s is unexpected", LogName(*op), LogName(*gather_op)); - return false; + return ::tensorflow::Status::OK(); } if (!data_partition_op) { data_partition_op = static_cast<DynamicPartitionOperator*>(op); @@ -143,7 +146,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { "Skipping because data input %s into " "%s is from a different source op than others", LogName(*op), LogName(*gather_op)); - return false; + return ::tensorflow::Status::OK(); } } } @@ -236,7 +239,8 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) { DeleteOpAndArraysIfUnused(model, indices_partition_op); DeleteOpAndArraysIfUnused(model, data_partition_op); DeleteOpAndArraysIfUnused(model, stitch_op); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |