diff options
author | Rachel Lim <rachelim@google.com> | 2018-09-20 10:41:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-20 10:46:06 -0700 |
commit | 07c9ba4bba274d424404eedab14b3fab3f072350 (patch) | |
tree | 68d8b9f1031fa7cbf08645d69c727b2498cb0d4f | |
parent | 7eacd3278207baa64e0b67fe8463d4eee59b31de (diff) |
[tf.data] Use vectorization_utils::VectorizeMapDefun in MapVectorization optimization
PiperOrigin-RevId: 213840320
4 files changed, 56 insertions, 21 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py index e2c9bc82df..5b493f44c9 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py @@ -173,16 +173,6 @@ class MapVectorizationBenchmark(test.Benchmark): self.report_benchmark(iters=num_iters, wall_time=median_time, name=name) return median_time - def benchmark_CheapFns(self): - - input_sizes = [(10, 10, 3), (10, 100, 300)] - batch_size = 1000 - for input_size in input_sizes: - input_dataset = dataset_ops.Dataset.from_tensor_slices( - (np.random.rand(*input_size), np.random.rand(*input_size))).repeat() - for map_fn, str_id in self._get_known_cheap_fns(): - self._compare(input_dataset, map_fn, batch_size, input_size, str_id) - def _compare(self, input_dataset, map_fn, batch_size, input_size, str_id): num_elems = np.prod(input_size) name_template = "{}__batch_size_{}_input_size_{}_{}" @@ -205,14 +195,28 @@ class MapVectorizationBenchmark(test.Benchmark): "Speedup: {}\n".format(batch_size, input_size, str_id, (unoptimized_time / optimized_time))) - def _get_known_cheap_fns(self): - return [ - (lambda *args: [array_ops.identity(x) for x in args], "identity"), - (lambda *args: [x + 1 for x in args], "add_const"), - (lambda *args: args[0], "select"), - (lambda *args: [math_ops.cast(x, dtypes.float64) for x in args], - "cast"), - ] + # Known cheap functions + def benchmarkIdentity(self): + self._benchmark_helper(lambda *args: [array_ops.identity(x) for x in args], + "identity") + + def benchmarkAddConst(self): + self._benchmark_helper(lambda *args: [x + 1 for x in args], "add_const") + + def benchmarkSelect(self): + self._benchmark_helper(lambda *args: args[0], "select") + + def benchmarkCast(self): + self._benchmark_helper( + lambda *args: [math_ops.cast(x, dtypes.float64) for x in args], "cast") + + def _benchmark_helper(self, map_fn, str_id): + input_sizes = [(10, 10, 3), (10, 100, 300)] + batch_size = 1000 + for input_size in input_sizes: + input_dataset = dataset_ops.Dataset.from_tensor_slices( + (np.random.rand(*input_size), np.random.rand(*input_size))).repeat() + self._compare(input_dataset, map_fn, batch_size, input_size, str_id) if __name__ == "__main__": diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 7128a50be0..79d5fe87b6 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -175,6 +175,7 @@ cc_library( deps = [ ":function_utils", ":graph_utils", + ":vectorization_utils", "//tensorflow/core:lib", "//tensorflow/core/grappler:mutable_graph_view", "//tensorflow/core/grappler:grappler_item", diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc index 07766aa7b3..ad6722a3ae 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/data/map_vectorization.h" +#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -38,11 +39,11 @@ void CopyAttribute(const string& attr_name, const NodeDef& from, NodeDef* to) { (*to->mutable_attr())[attr_name] = from.attr().at(attr_name); } -FunctionDef* AddVectorizedFunction(const NodeDef& map_node, +// Returns a FunctionDef containing a MapDefun op that wraps the original +// function. +FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node, const FunctionDef& orig_func, FunctionDefLibrary* library) { - // If we decide to use a different method of vectorization, we can just - // swap out this part. FunctionDef* vectorized_func = library->add_function(); // Function inputs and outputs are the same as original, just // with different shapes. @@ -82,6 +83,30 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node, return vectorized_func; } +FunctionDef* AddVectorizedFunction(const NodeDef& map_node, + const FunctionDef& orig_func, + FunctionDefLibrary* library) { + // Vectorizes orig_func naively by wrapping in a MapDefun op, then tries to + // do true vectorization with Vectorize. + FunctionDef* vectorized_func = + CreateMapDefunWrapper(map_node, orig_func, library); + NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Mutable(0); + DCHECK_EQ(map_defun_node->op(), "MapDefun"); + + // Create a copy of the original function so that we can mutate it, and + // attach that to the map defun node. + FunctionDef* map_defun_fn = library->add_function(); + *map_defun_fn = orig_func; + graph_utils::SetUniqueGraphFunctionName(orig_func.signature().name(), library, + map_defun_fn); + (*map_defun_node->mutable_attr())["f"].mutable_func()->set_name( + map_defun_fn->signature().name()); + + vectorization_utils::VectorizeMapDefun(vectorized_func, map_defun_fn, + map_defun_node); + return vectorized_func; +} + bool IsOutputShapesFullyDefined(const NodeDef& node) { auto* shapes_attr = gtl::FindOrNull(node.attr(), "output_shapes"); if (shapes_attr == nullptr) return false; diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc index 6a59eb0d32..5dd9d00511 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -333,6 +333,11 @@ void Vectorization::Vectorize() { void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn, NodeDef* map_defun_node) { + if (map_defun_node->attr().at("f").func().name() != + map_defun_fn->signature().name()) { + LOG(ERROR) << "`map_defun_fn` and `map_defun_node` do not match"; + return; + } Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize(); } |