aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc24
1 files changed, 14 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc b/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc
index dd9e26e68b..e19527968d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/unpartition_embedding_lookup.cc
@@ -22,7 +22,10 @@ limitations under the License.
namespace toco {
-bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status UnpartitionEmbeddingLookup::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
// Collapses a partitioned tf.nn.embedding_lookup back into a single Gather.
// https://www.tensorflow.org/api_docs/python/tf/nn/embedding_lookup
// This transform attempts to identify the len(params) > 1 case and collapse
@@ -47,7 +50,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
// First look for the final DynamicStitch.
auto op_it = model->operators.begin() + op_index;
if (op_it->get()->type != OperatorType::kDynamicStitch) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto* stitch_op = static_cast<DynamicStitchOperator*>(op_it->get());
@@ -72,7 +75,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
"Skipping because indices input %s into "
"%s is unexpected",
LogName(*op), LogName(*stitch_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (!indices_partition_op) {
indices_partition_op = static_cast<DynamicPartitionOperator*>(op);
@@ -83,7 +86,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
"Skipping because indices input %s into "
"%s is from a different source op than others",
LogName(*op), LogName(*stitch_op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
}
@@ -92,12 +95,12 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
// The data for the indices must be a constant range of the array shape.
if (!IsConstantParameterArray(*model, indices_partition_op->inputs[0])) {
AddMessageF("Skipping because indices partition data is non-constant");
- return false;
+ return ::tensorflow::Status::OK();
}
auto& indices_data_array = model->GetArray(indices_partition_op->inputs[0]);
if (indices_data_array.data_type == ArrayDataType::kNone) {
// Yield until data types are propagated.
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK(indices_data_array.data_type == ArrayDataType::kInt32)
<< "Indices partition inputs must be int32";
@@ -117,7 +120,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
"Skipping because data input %s into %s "
"is unexpected",
LogName(*op), LogName(*stitch_op));
- return false;
+ return ::tensorflow::Status::OK();
}
gather_ops.push_back(static_cast<GatherOperator*>(op));
}
@@ -132,7 +135,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
"Skipping because data input %s into "
"%s is unexpected",
LogName(*op), LogName(*gather_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (!data_partition_op) {
data_partition_op = static_cast<DynamicPartitionOperator*>(op);
@@ -143,7 +146,7 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
"Skipping because data input %s into "
"%s is from a different source op than others",
LogName(*op), LogName(*gather_op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
}
@@ -236,7 +239,8 @@ bool UnpartitionEmbeddingLookup::Run(Model* model, std::size_t op_index) {
DeleteOpAndArraysIfUnused(model, indices_partition_op);
DeleteOpAndArraysIfUnused(model, data_partition_op);
DeleteOpAndArraysIfUnused(model, stitch_op);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco