diff options
author | Rachel Lim <rachelim@google.com> | 2018-09-21 12:38:07 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-21 12:42:24 -0700 |
commit | 25c1a4441bbf364c8ed263f75e0bebad30f6599c (patch) | |
tree | ab11976563c54fa3c3edfee4434a88f08b19c382 | |
parent | 7461ff7837bb9c57f0020d8adf46a73596dfb77d (diff) |
[tf.data] Add a ConverterRegistry for vectorization converters
PiperOrigin-RevId: 214027910
9 files changed, 414 insertions, 62 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 79d5fe87b6..cf305cebe1 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -464,6 +464,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core/grappler:mutable_graph_view", "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/optimizers/data/vectorization", "//tensorflow/core/grappler/utils:functions", ] + tf_protos_all(), ) diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD new file mode 100644 index 0000000000..1462cb234d --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD @@ -0,0 +1,69 @@ +package( + default_visibility = ["//visibility:private"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all") + +VECTORIZER_DEPS = [ + ":vectorizer_registry", + "//tensorflow/core/grappler/optimizers/data:function_utils", +] + tf_protos_all() + +cc_library( + name = "vectorizer", + hdrs = ["vectorizer.h"], + deps = [ + "//tensorflow/core:lib", + ] + tf_protos_all(), +) + +cc_library( + name = "vectorizer_registry", + srcs = ["vectorizer_registry.cc"], + hdrs = ["vectorizer_registry.h"], + deps = [ + ":vectorizer", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "cast_vectorizer", + srcs = ["cast_vectorizer.cc"], + deps = VECTORIZER_DEPS, + alwayslink = 1, +) + +cc_library( + name = "unpack_vectorizer", + srcs = ["unpack_vectorizer.cc"], + deps = VECTORIZER_DEPS, + alwayslink = 1, +) + +cc_library( + name = "vectorization", + hdrs = ["vectorizer_registry.h"], + visibility = ["//visibility:public"], + deps = [ + ":cast_vectorizer", + ":unpack_vectorizer", + ":vectorizer", + ":vectorizer_registry", + ], +) + +tf_cc_test( + name = "vectorizer_registry_test", + srcs = ["vectorizer_registry_test.cc"], + deps = [ + ":vectorizer_registry", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ] + tf_protos_all(), +) diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc new file mode 100644 index 0000000000..c1739737a0 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" + +namespace tensorflow { +namespace grappler { +namespace vectorization_utils { + +class CastVectorizer : public Vectorizer { + public: + Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, + FunctionDef* outer_scope, + std::map<string, string>* conversion_map) override { + if (inputs.size() != 1) { + return errors::Internal("Cast op should only have one input."); + } + + // Add new Cast node + NodeDef* new_cast_node = outer_scope->add_node_def(); + *new_cast_node = node; + new_cast_node->clear_name(); + function_utils::SetUniqueFunctionNodeName( + strings::StrCat("vectorized/", node.name()), outer_scope, + new_cast_node); + new_cast_node->set_input(0, inputs[0]); + + // Add the output mapping to conversion map + (*conversion_map)[strings::StrCat(node.name(), ":y:0")] = + strings::StrCat(new_cast_node->name(), ":y:0"); + + return Status::OK(); + } +}; + +REGISTER_VECTORIZER("Cast", CastVectorizer); + +} // namespace vectorization_utils +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc new file mode 100644 index 0000000000..776d3179c5 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" + +namespace tensorflow { +namespace grappler { +namespace vectorization_utils { + +class UnpackVectorizer : public Vectorizer { + public: + Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, + FunctionDef* outer_scope, + std::map<string, string>* conversion_map) override { + if (inputs.size() != 1) { + return errors::Internal("Unpack op should only have one input."); + } + + // Add new Unpack node + NodeDef* new_unpack_node = outer_scope->add_node_def(); + *new_unpack_node = node; + new_unpack_node->clear_name(); + function_utils::SetUniqueFunctionNodeName( + strings::StrCat("vectorized/", node.name()), outer_scope, + new_unpack_node); + + // Increment "axis" attr by 1: + (*new_unpack_node->mutable_attr())["axis"].set_i( + node.attr().at("axis").i() + 1); + new_unpack_node->set_input(0, inputs[0]); + + // Add the output mappings to conversion map + int num = new_unpack_node->attr().at("num").i(); + for (int i = 0; i < num; ++i) { + (*conversion_map)[strings::StrCat(node.name(), ":output:", i)] = + strings::StrCat(new_unpack_node->name(), ":output:", i); + } + + return Status::OK(); + } +}; + +REGISTER_VECTORIZER("Unpack", UnpackVectorizer); + +} // namespace vectorization_utils +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h new file mode 100644 index 0000000000..d341dbba7d --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_ + +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { +namespace grappler { +namespace vectorization_utils { + +// Interface for vectorization of TensorFlow operations. See `CastVectorizer` +// for an example. +class Vectorizer { + public: + virtual ~Vectorizer() {} + + // Vectorizes an operation, `node`, by adding operation(s) to `outer_scope` + // that produce the same vector output(s) as executing `node`'s op + // on elements of the vector inputs, and adding mappings to `conversion_map` + // from old output tensor names to new (vectorized) output tensor names. + // The new node(s) collectively have the same number of inputs and outputs as + // the node being converted, and use the tensor names in `inputs` as their + // inputs. + virtual Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, + FunctionDef* outer_scope, + std::map<string, string>* conversion_map) = 0; +}; + +} // namespace vectorization_utils +} // namespace grappler +} // namespace tensorflow +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_ diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc new file mode 100644 index 0000000000..a6551e36ac --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace grappler { +namespace vectorization_utils { + +VectorizerRegistry* VectorizerRegistry::Global() { + static VectorizerRegistry* registry = new VectorizerRegistry; + return registry; +} + +Vectorizer* VectorizerRegistry::Get(const string& op_type) { + auto found = vectorizers_.find(op_type); + if (found == vectorizers_.end()) { + return nullptr; + } + return found->second.get(); +} + +void VectorizerRegistry::Register(const string& op_type, + std::unique_ptr<Vectorizer> vectorizer) { + auto existing = Get(op_type); + CHECK_EQ(existing, nullptr) + << "Vectorizer for op type: " << op_type << " already registered"; + vectorizers_.insert(std::pair<const string&, std::unique_ptr<Vectorizer>>( + op_type, std::move(vectorizer))); +} +} // namespace vectorization_utils +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h new file mode 100644 index 0000000000..16159d47ca --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h @@ -0,0 +1,75 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_REGISTRY_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_REGISTRY_H_ + +#include <functional> +#include <map> + +#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h" + +namespace tensorflow { +namespace grappler { +namespace vectorization_utils { + +// A global VectorizerRegistry is used to hold all the vectorizers. +class VectorizerRegistry { + public: + // Returns a pointer to a global VectorizerRegistry object. + static VectorizerRegistry* Global(); + + // Returns a pointer to a vectorizer that can vectorize an op for the op type. + Vectorizer* Get(const string& op_type); + + // Registers a vectorizer that can vectorize an op for the given op type. + void Register(const string& op_type, std::unique_ptr<Vectorizer> vectorizer); + + private: + std::map<string, std::unique_ptr<Vectorizer>> vectorizers_; +}; + +namespace vectorizer_registration { + +class VectorizerRegistration { + public: + VectorizerRegistration(const string& op_type, + std::unique_ptr<Vectorizer> vectorizer) { + VectorizerRegistry::Global()->Register(op_type, std::move(vectorizer)); + } +}; + +} // namespace vectorizer_registration + +#define REGISTER_VECTORIZER(op_type, vectorizer) \ + REGISTER_VECTORIZER_UNIQ_HELPER(__COUNTER__, op_type, vectorizer) + +#define REGISTER_VECTORIZER_UNIQ_HELPER(ctr, op_type, vectorizer) \ + REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) + +#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \ + static ::tensorflow::grappler::vectorization_utils:: \ + vectorizer_registration::VectorizerRegistration \ + vectorizer_registration_##ctr( \ + op_type, \ + ::std::unique_ptr< \ + ::tensorflow::grappler::vectorization_utils::Vectorizer>( \ + new vectorizer())) + +} // namespace vectorization_utils +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_REGISTRY_H_ diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc new file mode 100644 index 0000000000..86e303564b --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace vectorization_utils { + +class TestVectorizer : public Vectorizer { + public: + Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, + FunctionDef* outer_scope, + std::map<string, string>* conversion_map) override { + return Status::OK(); + } +}; + +REGISTER_VECTORIZER("test_op", TestVectorizer); + +TEST(TestVectorizer, TestTestVectorizer) { + EXPECT_EQ(VectorizerRegistry::Global()->Get("nonexistent"), nullptr); + + auto vectorizer = VectorizerRegistry::Global()->Get("test_op"); + EXPECT_NE(vectorizer, nullptr); + + FunctionDef function; + NodeDef node; + std::map<string, string> conversion_map; + EXPECT_TRUE(vectorizer->Vectorize(node, {}, &function, &conversion_map).ok()); +} + +} // namespace vectorization_utils +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc index bfca63b820..cb56b65985 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h" +#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" #include "absl/strings/str_join.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -90,59 +91,6 @@ void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn, ->ExtractSubrange(output_position, 1, nullptr); } -Status ConvertCastOp(FunctionDef* outer_scope, gtl::ArraySlice<string> inputs, - const NodeDef& cast_node, - std::map<string, string>* conversion_map) { - if (inputs.size() != 1) { - return errors::Internal("Cast op should only have one input."); - } - - // Add new Cast node - NodeDef* new_cast_node = outer_scope->add_node_def(); - *new_cast_node = cast_node; - new_cast_node->clear_name(); - function_utils::SetUniqueFunctionNodeName( - strings::StrCat("vectorized/", cast_node.name()), outer_scope, - new_cast_node); - new_cast_node->set_input(0, inputs[0]); - - // Add the output mapping to conversion map - (*conversion_map)[strings::StrCat(cast_node.name(), ":y:0")] = - strings::StrCat(new_cast_node->name(), ":y:0"); - - return Status::OK(); -} - -Status ConvertUnpackOp(FunctionDef* outer_scope, gtl::ArraySlice<string> inputs, - const NodeDef& unpack_node, - std::map<string, string>* conversion_map) { - if (inputs.size() != 1) { - return errors::Internal("Unpack op should only have one input."); - } - - // Add new Unpack node - NodeDef* new_unpack_node = outer_scope->add_node_def(); - *new_unpack_node = unpack_node; - new_unpack_node->clear_name(); - function_utils::SetUniqueFunctionNodeName( - strings::StrCat("vectorized/", unpack_node.name()), outer_scope, - new_unpack_node); - - // Increment "axis" attr by 1: - (*new_unpack_node->mutable_attr())["axis"].set_i( - unpack_node.attr().at("axis").i() + 1); - new_unpack_node->set_input(0, inputs[0]); - - // Add the output mappings to conversion map - int num = new_unpack_node->attr().at("num").i(); - for (int i = 0; i < num; ++i) { - (*conversion_map)[strings::StrCat(unpack_node.name(), ":output:", i)] = - strings::StrCat(new_unpack_node->name(), ":output:", i); - } - - return Status::OK(); -} - int FindOutputToConvert(const FunctionDef& function, const std::set<string>& unconvertible, FunctionDefTensorDesc* f) { @@ -239,17 +187,15 @@ Status Vectorization::AddConversionMappingFromOp( ":output:", map_defun_fn_->signature().output_arg_size() + i)); } - if (node.op() == "Cast") { - TF_RETURN_IF_ERROR( - ConvertCastOp(outer_scope_, promoted_inputs, node, &conversion_map_)); - } else if (node.op() == "Unpack") { - TF_RETURN_IF_ERROR( - ConvertUnpackOp(outer_scope_, promoted_inputs, node, &conversion_map_)); - } else { - return errors::Unimplemented("Op converter for \"", node.op(), - "\" not implemented yet"); + auto vectorizer = VectorizerRegistry::Global()->Get(node.op()); + if (vectorizer == nullptr) { + return errors::Unimplemented("No vectorizer registered for op: ", + node.op()); } + TF_RETURN_IF_ERROR(vectorizer->Vectorize(node, promoted_inputs, outer_scope_, + &conversion_map_)); + // If we get here, the conversion was successful, so we promote the inputs // of the ops to MapDefun outputs. for (int i = 0; i < types.size(); ++i) { |