aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-09-21 12:38:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 12:42:24 -0700
commit25c1a4441bbf364c8ed263f75e0bebad30f6599c (patch)
treeab11976563c54fa3c3edfee4434a88f08b19c382 /tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
parent7461ff7837bb9c57f0020d8adf46a73596dfb77d (diff)
[tf.data] Add a ConverterRegistry for vectorization converters
PiperOrigin-RevId: 214027910
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization_utils.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc70
1 files changed, 8 insertions, 62 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index bfca63b820..cb56b65985 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
#include "absl/strings/str_join.h"
#include "tensorflow/core/framework/attr_value.pb.h"
@@ -90,59 +91,6 @@ void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
->ExtractSubrange(output_position, 1, nullptr);
}
-Status ConvertCastOp(FunctionDef* outer_scope, gtl::ArraySlice<string> inputs,
- const NodeDef& cast_node,
- std::map<string, string>* conversion_map) {
- if (inputs.size() != 1) {
- return errors::Internal("Cast op should only have one input.");
- }
-
- // Add new Cast node
- NodeDef* new_cast_node = outer_scope->add_node_def();
- *new_cast_node = cast_node;
- new_cast_node->clear_name();
- function_utils::SetUniqueFunctionNodeName(
- strings::StrCat("vectorized/", cast_node.name()), outer_scope,
- new_cast_node);
- new_cast_node->set_input(0, inputs[0]);
-
- // Add the output mapping to conversion map
- (*conversion_map)[strings::StrCat(cast_node.name(), ":y:0")] =
- strings::StrCat(new_cast_node->name(), ":y:0");
-
- return Status::OK();
-}
-
-Status ConvertUnpackOp(FunctionDef* outer_scope, gtl::ArraySlice<string> inputs,
- const NodeDef& unpack_node,
- std::map<string, string>* conversion_map) {
- if (inputs.size() != 1) {
- return errors::Internal("Unpack op should only have one input.");
- }
-
- // Add new Unpack node
- NodeDef* new_unpack_node = outer_scope->add_node_def();
- *new_unpack_node = unpack_node;
- new_unpack_node->clear_name();
- function_utils::SetUniqueFunctionNodeName(
- strings::StrCat("vectorized/", unpack_node.name()), outer_scope,
- new_unpack_node);
-
- // Increment "axis" attr by 1:
- (*new_unpack_node->mutable_attr())["axis"].set_i(
- unpack_node.attr().at("axis").i() + 1);
- new_unpack_node->set_input(0, inputs[0]);
-
- // Add the output mappings to conversion map
- int num = new_unpack_node->attr().at("num").i();
- for (int i = 0; i < num; ++i) {
- (*conversion_map)[strings::StrCat(unpack_node.name(), ":output:", i)] =
- strings::StrCat(new_unpack_node->name(), ":output:", i);
- }
-
- return Status::OK();
-}
-
int FindOutputToConvert(const FunctionDef& function,
const std::set<string>& unconvertible,
FunctionDefTensorDesc* f) {
@@ -239,17 +187,15 @@ Status Vectorization::AddConversionMappingFromOp(
":output:", map_defun_fn_->signature().output_arg_size() + i));
}
- if (node.op() == "Cast") {
- TF_RETURN_IF_ERROR(
- ConvertCastOp(outer_scope_, promoted_inputs, node, &conversion_map_));
- } else if (node.op() == "Unpack") {
- TF_RETURN_IF_ERROR(
- ConvertUnpackOp(outer_scope_, promoted_inputs, node, &conversion_map_));
- } else {
- return errors::Unimplemented("Op converter for \"", node.op(),
- "\" not implemented yet");
+ auto vectorizer = VectorizerRegistry::Global()->Get(node.op());
+ if (vectorizer == nullptr) {
+ return errors::Unimplemented("No vectorizer registered for op: ",
+ node.op());
}
+ TF_RETURN_IF_ERROR(vectorizer->Vectorize(node, promoted_inputs, outer_scope_,
+ &conversion_map_));
+
// If we get here, the conversion was successful, so we promote the inputs
// of the ops to MapDefun outputs.
for (int i = 0; i < types.size(); ++i) {