aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-09-21 12:38:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 12:42:24 -0700
commit25c1a4441bbf364c8ed263f75e0bebad30f6599c (patch)
treeab11976563c54fa3c3edfee4434a88f08b19c382
parent7461ff7837bb9c57f0020d8adf46a73596dfb77d (diff)
[tf.data] Add a ConverterRegistry for vectorization converters
PiperOrigin-RevId: 214027910
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/BUILD69
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc54
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc61
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h49
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc47
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h75
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc50
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc70
9 files changed, 414 insertions, 62 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 79d5fe87b6..cf305cebe1 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -464,6 +464,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/optimizers/data/vectorization",
"//tensorflow/core/grappler/utils:functions",
] + tf_protos_all(),
)
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
new file mode 100644
index 0000000000..1462cb234d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
@@ -0,0 +1,69 @@
+package(
+ default_visibility = ["//visibility:private"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
+
+VECTORIZER_DEPS = [
+ ":vectorizer_registry",
+ "//tensorflow/core/grappler/optimizers/data:function_utils",
+] + tf_protos_all()
+
+cc_library(
+ name = "vectorizer",
+ hdrs = ["vectorizer.h"],
+ deps = [
+ "//tensorflow/core:lib",
+ ] + tf_protos_all(),
+)
+
+cc_library(
+ name = "vectorizer_registry",
+ srcs = ["vectorizer_registry.cc"],
+ hdrs = ["vectorizer_registry.h"],
+ deps = [
+ ":vectorizer",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
+ name = "cast_vectorizer",
+ srcs = ["cast_vectorizer.cc"],
+ deps = VECTORIZER_DEPS,
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "unpack_vectorizer",
+ srcs = ["unpack_vectorizer.cc"],
+ deps = VECTORIZER_DEPS,
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "vectorization",
+ hdrs = ["vectorizer_registry.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cast_vectorizer",
+ ":unpack_vectorizer",
+ ":vectorizer",
+ ":vectorizer_registry",
+ ],
+)
+
+tf_cc_test(
+ name = "vectorizer_registry_test",
+ srcs = ["vectorizer_registry_test.cc"],
+ deps = [
+ ":vectorizer_registry",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ] + tf_protos_all(),
+)
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
new file mode 100644
index 0000000000..c1739737a0
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
@@ -0,0 +1,54 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+class CastVectorizer : public Vectorizer {
+ public:
+ Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
+ FunctionDef* outer_scope,
+ std::map<string, string>* conversion_map) override {
+ if (inputs.size() != 1) {
+ return errors::Internal("Cast op should only have one input.");
+ }
+
+ // Add new Cast node
+ NodeDef* new_cast_node = outer_scope->add_node_def();
+ *new_cast_node = node;
+ new_cast_node->clear_name();
+ function_utils::SetUniqueFunctionNodeName(
+ strings::StrCat("vectorized/", node.name()), outer_scope,
+ new_cast_node);
+ new_cast_node->set_input(0, inputs[0]);
+
+ // Add the output mapping to conversion map
+ (*conversion_map)[strings::StrCat(node.name(), ":y:0")] =
+ strings::StrCat(new_cast_node->name(), ":y:0");
+
+ return Status::OK();
+ }
+};
+
+REGISTER_VECTORIZER("Cast", CastVectorizer);
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
new file mode 100644
index 0000000000..776d3179c5
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
@@ -0,0 +1,61 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+class UnpackVectorizer : public Vectorizer {
+ public:
+ Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
+ FunctionDef* outer_scope,
+ std::map<string, string>* conversion_map) override {
+ if (inputs.size() != 1) {
+ return errors::Internal("Unpack op should only have one input.");
+ }
+
+ // Add new Unpack node
+ NodeDef* new_unpack_node = outer_scope->add_node_def();
+ *new_unpack_node = node;
+ new_unpack_node->clear_name();
+ function_utils::SetUniqueFunctionNodeName(
+ strings::StrCat("vectorized/", node.name()), outer_scope,
+ new_unpack_node);
+
+ // Increment "axis" attr by 1:
+ (*new_unpack_node->mutable_attr())["axis"].set_i(
+ node.attr().at("axis").i() + 1);
+ new_unpack_node->set_input(0, inputs[0]);
+
+ // Add the output mappings to conversion map
+ int num = new_unpack_node->attr().at("num").i();
+ for (int i = 0; i < num; ++i) {
+ (*conversion_map)[strings::StrCat(node.name(), ":output:", i)] =
+ strings::StrCat(new_unpack_node->name(), ":output:", i);
+ }
+
+ return Status::OK();
+ }
+};
+
+REGISTER_VECTORIZER("Unpack", UnpackVectorizer);
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
new file mode 100644
index 0000000000..d341dbba7d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
+
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+// Interface for vectorization of TensorFlow operations. See `CastVectorizer`
+// for an example.
+class Vectorizer {
+ public:
+ virtual ~Vectorizer() {}
+
+ // Vectorizes an operation, `node`, by adding operation(s) to `outer_scope`
+ // that produce the same vector output(s) as executing `node`'s op
+ // on elements of the vector inputs, and adding mappings to `conversion_map`
+ // from old output tensor names to new (vectorized) output tensor names.
+ // The new node(s) collectively have the same number of inputs and outputs as
+ // the node being converted, and use the tensor names in `inputs` as their
+ // inputs.
+ virtual Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
+ FunctionDef* outer_scope,
+ std::map<string, string>* conversion_map) = 0;
+};
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
new file mode 100644
index 0000000000..a6551e36ac
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+VectorizerRegistry* VectorizerRegistry::Global() {
+ static VectorizerRegistry* registry = new VectorizerRegistry;
+ return registry;
+}
+
+Vectorizer* VectorizerRegistry::Get(const string& op_type) {
+ auto found = vectorizers_.find(op_type);
+ if (found == vectorizers_.end()) {
+ return nullptr;
+ }
+ return found->second.get();
+}
+
+void VectorizerRegistry::Register(const string& op_type,
+ std::unique_ptr<Vectorizer> vectorizer) {
+ auto existing = Get(op_type);
+ CHECK_EQ(existing, nullptr)
+ << "Vectorizer for op type: " << op_type << " already registered";
+ vectorizers_.insert(std::pair<const string&, std::unique_ptr<Vectorizer>>(
+ op_type, std::move(vectorizer)));
+}
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
new file mode 100644
index 0000000000..16159d47ca
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
@@ -0,0 +1,75 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_REGISTRY_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_REGISTRY_H_
+
+#include <functional>
+#include <map>
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+// A global VectorizerRegistry is used to hold all the vectorizers.
+class VectorizerRegistry {
+ public:
+ // Returns a pointer to a global VectorizerRegistry object.
+ static VectorizerRegistry* Global();
+
+ // Returns a pointer to a vectorizer that can vectorize an op for the op type.
+ Vectorizer* Get(const string& op_type);
+
+ // Registers a vectorizer that can vectorize an op for the given op type.
+ void Register(const string& op_type, std::unique_ptr<Vectorizer> vectorizer);
+
+ private:
+ std::map<string, std::unique_ptr<Vectorizer>> vectorizers_;
+};
+
+namespace vectorizer_registration {
+
+class VectorizerRegistration {
+ public:
+ VectorizerRegistration(const string& op_type,
+ std::unique_ptr<Vectorizer> vectorizer) {
+ VectorizerRegistry::Global()->Register(op_type, std::move(vectorizer));
+ }
+};
+
+} // namespace vectorizer_registration
+
+#define REGISTER_VECTORIZER(op_type, vectorizer) \
+ REGISTER_VECTORIZER_UNIQ_HELPER(__COUNTER__, op_type, vectorizer)
+
+#define REGISTER_VECTORIZER_UNIQ_HELPER(ctr, op_type, vectorizer) \
+ REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer)
+
+#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \
+ static ::tensorflow::grappler::vectorization_utils:: \
+ vectorizer_registration::VectorizerRegistration \
+ vectorizer_registration_##ctr( \
+ op_type, \
+ ::std::unique_ptr< \
+ ::tensorflow::grappler::vectorization_utils::Vectorizer>( \
+ new vectorizer()))
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_REGISTRY_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
new file mode 100644
index 0000000000..86e303564b
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
@@ -0,0 +1,50 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+class TestVectorizer : public Vectorizer {
+ public:
+ Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
+ FunctionDef* outer_scope,
+ std::map<string, string>* conversion_map) override {
+ return Status::OK();
+ }
+};
+
+REGISTER_VECTORIZER("test_op", TestVectorizer);
+
+TEST(TestVectorizer, TestTestVectorizer) {
+ EXPECT_EQ(VectorizerRegistry::Global()->Get("nonexistent"), nullptr);
+
+ auto vectorizer = VectorizerRegistry::Global()->Get("test_op");
+ EXPECT_NE(vectorizer, nullptr);
+
+ FunctionDef function;
+ NodeDef node;
+ std::map<string, string> conversion_map;
+ EXPECT_TRUE(vectorizer->Vectorize(node, {}, &function, &conversion_map).ok());
+}
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index bfca63b820..cb56b65985 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
#include "absl/strings/str_join.h"
#include "tensorflow/core/framework/attr_value.pb.h"
@@ -90,59 +91,6 @@ void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
->ExtractSubrange(output_position, 1, nullptr);
}
-Status ConvertCastOp(FunctionDef* outer_scope, gtl::ArraySlice<string> inputs,
- const NodeDef& cast_node,
- std::map<string, string>* conversion_map) {
- if (inputs.size() != 1) {
- return errors::Internal("Cast op should only have one input.");
- }
-
- // Add new Cast node
- NodeDef* new_cast_node = outer_scope->add_node_def();
- *new_cast_node = cast_node;
- new_cast_node->clear_name();
- function_utils::SetUniqueFunctionNodeName(
- strings::StrCat("vectorized/", cast_node.name()), outer_scope,
- new_cast_node);
- new_cast_node->set_input(0, inputs[0]);
-
- // Add the output mapping to conversion map
- (*conversion_map)[strings::StrCat(cast_node.name(), ":y:0")] =
- strings::StrCat(new_cast_node->name(), ":y:0");
-
- return Status::OK();
-}
-
-Status ConvertUnpackOp(FunctionDef* outer_scope, gtl::ArraySlice<string> inputs,
- const NodeDef& unpack_node,
- std::map<string, string>* conversion_map) {
- if (inputs.size() != 1) {
- return errors::Internal("Unpack op should only have one input.");
- }
-
- // Add new Unpack node
- NodeDef* new_unpack_node = outer_scope->add_node_def();
- *new_unpack_node = unpack_node;
- new_unpack_node->clear_name();
- function_utils::SetUniqueFunctionNodeName(
- strings::StrCat("vectorized/", unpack_node.name()), outer_scope,
- new_unpack_node);
-
- // Increment "axis" attr by 1:
- (*new_unpack_node->mutable_attr())["axis"].set_i(
- unpack_node.attr().at("axis").i() + 1);
- new_unpack_node->set_input(0, inputs[0]);
-
- // Add the output mappings to conversion map
- int num = new_unpack_node->attr().at("num").i();
- for (int i = 0; i < num; ++i) {
- (*conversion_map)[strings::StrCat(unpack_node.name(), ":output:", i)] =
- strings::StrCat(new_unpack_node->name(), ":output:", i);
- }
-
- return Status::OK();
-}
-
int FindOutputToConvert(const FunctionDef& function,
const std::set<string>& unconvertible,
FunctionDefTensorDesc* f) {
@@ -239,17 +187,15 @@ Status Vectorization::AddConversionMappingFromOp(
":output:", map_defun_fn_->signature().output_arg_size() + i));
}
- if (node.op() == "Cast") {
- TF_RETURN_IF_ERROR(
- ConvertCastOp(outer_scope_, promoted_inputs, node, &conversion_map_));
- } else if (node.op() == "Unpack") {
- TF_RETURN_IF_ERROR(
- ConvertUnpackOp(outer_scope_, promoted_inputs, node, &conversion_map_));
- } else {
- return errors::Unimplemented("Op converter for \"", node.op(),
- "\" not implemented yet");
+ auto vectorizer = VectorizerRegistry::Global()->Get(node.op());
+ if (vectorizer == nullptr) {
+ return errors::Unimplemented("No vectorizer registered for op: ",
+ node.op());
}
+ TF_RETURN_IF_ERROR(vectorizer->Vectorize(node, promoted_inputs, outer_scope_,
+ &conversion_map_));
+
// If we get here, the conversion was successful, so we promote the inputs
// of the ops to MapDefun outputs.
for (int i = 0; i < types.size(); ++i) {