diff options
author | Rachel Lim <rachelim@google.com> | 2018-09-21 12:38:07 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-21 12:42:24 -0700 |
commit | 25c1a4441bbf364c8ed263f75e0bebad30f6599c (patch) | |
tree | ab11976563c54fa3c3edfee4434a88f08b19c382 /tensorflow/core/grappler/optimizers/data/vectorization_utils.cc | |
parent | 7461ff7837bb9c57f0020d8adf46a73596dfb77d (diff) |
[tf.data] Add a ConverterRegistry for vectorization converters
PiperOrigin-RevId: 214027910
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization_utils.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/vectorization_utils.cc | 70 |
1 files changed, 8 insertions, 62 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc index bfca63b820..cb56b65985 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h" +#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" #include "absl/strings/str_join.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -90,59 +91,6 @@ void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn, ->ExtractSubrange(output_position, 1, nullptr); } -Status ConvertCastOp(FunctionDef* outer_scope, gtl::ArraySlice<string> inputs, - const NodeDef& cast_node, - std::map<string, string>* conversion_map) { - if (inputs.size() != 1) { - return errors::Internal("Cast op should only have one input."); - } - - // Add new Cast node - NodeDef* new_cast_node = outer_scope->add_node_def(); - *new_cast_node = cast_node; - new_cast_node->clear_name(); - function_utils::SetUniqueFunctionNodeName( - strings::StrCat("vectorized/", cast_node.name()), outer_scope, - new_cast_node); - new_cast_node->set_input(0, inputs[0]); - - // Add the output mapping to conversion map - (*conversion_map)[strings::StrCat(cast_node.name(), ":y:0")] = - strings::StrCat(new_cast_node->name(), ":y:0"); - - return Status::OK(); -} - -Status ConvertUnpackOp(FunctionDef* outer_scope, gtl::ArraySlice<string> inputs, - const NodeDef& unpack_node, - std::map<string, string>* conversion_map) { - if (inputs.size() != 1) { - return errors::Internal("Unpack op should only have one input."); - } - - // Add new Unpack node - NodeDef* new_unpack_node = outer_scope->add_node_def(); - *new_unpack_node = unpack_node; - new_unpack_node->clear_name(); - function_utils::SetUniqueFunctionNodeName( - strings::StrCat("vectorized/", unpack_node.name()), outer_scope, - new_unpack_node); - - // Increment "axis" attr by 1: - (*new_unpack_node->mutable_attr())["axis"].set_i( - unpack_node.attr().at("axis").i() + 1); - new_unpack_node->set_input(0, inputs[0]); - - // Add the output mappings to conversion map - int num = new_unpack_node->attr().at("num").i(); - for (int i = 0; i < num; ++i) { - (*conversion_map)[strings::StrCat(unpack_node.name(), ":output:", i)] = - strings::StrCat(new_unpack_node->name(), ":output:", i); - } - - return Status::OK(); -} - int FindOutputToConvert(const FunctionDef& function, const std::set<string>& unconvertible, FunctionDefTensorDesc* f) { @@ -239,17 +187,15 @@ Status Vectorization::AddConversionMappingFromOp( ":output:", map_defun_fn_->signature().output_arg_size() + i)); } - if (node.op() == "Cast") { - TF_RETURN_IF_ERROR( - ConvertCastOp(outer_scope_, promoted_inputs, node, &conversion_map_)); - } else if (node.op() == "Unpack") { - TF_RETURN_IF_ERROR( - ConvertUnpackOp(outer_scope_, promoted_inputs, node, &conversion_map_)); - } else { - return errors::Unimplemented("Op converter for \"", node.op(), - "\" not implemented yet"); + auto vectorizer = VectorizerRegistry::Global()->Get(node.op()); + if (vectorizer == nullptr) { + return errors::Unimplemented("No vectorizer registered for op: ", + node.op()); } + TF_RETURN_IF_ERROR(vectorizer->Vectorize(node, promoted_inputs, outer_scope_, + &conversion_map_)); + // If we get here, the conversion was successful, so we promote the inputs // of the ops to MapDefun outputs. for (int i = 0; i < types.size(); ++i) { |