aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-09-20 10:41:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 10:46:06 -0700
commit07c9ba4bba274d424404eedab14b3fab3f072350 (patch)
tree68d8b9f1031fa7cbf08645d69c727b2498cb0d4f
parent7eacd3278207baa64e0b67fe8463d4eee59b31de (diff)
[tf.data] Use vectorization_utils::VectorizeMapDefun in MapVectorization optimization
PiperOrigin-RevId: 213840320
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py40
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc31
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc5
4 files changed, 56 insertions, 21 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
index e2c9bc82df..5b493f44c9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
@@ -173,16 +173,6 @@ class MapVectorizationBenchmark(test.Benchmark):
self.report_benchmark(iters=num_iters, wall_time=median_time, name=name)
return median_time
- def benchmark_CheapFns(self):
-
- input_sizes = [(10, 10, 3), (10, 100, 300)]
- batch_size = 1000
- for input_size in input_sizes:
- input_dataset = dataset_ops.Dataset.from_tensor_slices(
- (np.random.rand(*input_size), np.random.rand(*input_size))).repeat()
- for map_fn, str_id in self._get_known_cheap_fns():
- self._compare(input_dataset, map_fn, batch_size, input_size, str_id)
-
def _compare(self, input_dataset, map_fn, batch_size, input_size, str_id):
num_elems = np.prod(input_size)
name_template = "{}__batch_size_{}_input_size_{}_{}"
@@ -205,14 +195,28 @@ class MapVectorizationBenchmark(test.Benchmark):
"Speedup: {}\n".format(batch_size, input_size, str_id,
(unoptimized_time / optimized_time)))
- def _get_known_cheap_fns(self):
- return [
- (lambda *args: [array_ops.identity(x) for x in args], "identity"),
- (lambda *args: [x + 1 for x in args], "add_const"),
- (lambda *args: args[0], "select"),
- (lambda *args: [math_ops.cast(x, dtypes.float64) for x in args],
- "cast"),
- ]
+ # Known cheap functions
+ def benchmarkIdentity(self):
+ self._benchmark_helper(lambda *args: [array_ops.identity(x) for x in args],
+ "identity")
+
+ def benchmarkAddConst(self):
+ self._benchmark_helper(lambda *args: [x + 1 for x in args], "add_const")
+
+ def benchmarkSelect(self):
+ self._benchmark_helper(lambda *args: args[0], "select")
+
+ def benchmarkCast(self):
+ self._benchmark_helper(
+ lambda *args: [math_ops.cast(x, dtypes.float64) for x in args], "cast")
+
+ def _benchmark_helper(self, map_fn, str_id):
+ input_sizes = [(10, 10, 3), (10, 100, 300)]
+ batch_size = 1000
+ for input_size in input_sizes:
+ input_dataset = dataset_ops.Dataset.from_tensor_slices(
+ (np.random.rand(*input_size), np.random.rand(*input_size))).repeat()
+ self._compare(input_dataset, map_fn, batch_size, input_size, str_id)
if __name__ == "__main__":
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 7128a50be0..79d5fe87b6 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -175,6 +175,7 @@ cc_library(
deps = [
":function_utils",
":graph_utils",
+ ":vectorization_utils",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:grappler_item",
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index 07766aa7b3..ad6722a3ae 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/map_vectorization.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@@ -38,11 +39,11 @@ void CopyAttribute(const string& attr_name, const NodeDef& from, NodeDef* to) {
(*to->mutable_attr())[attr_name] = from.attr().at(attr_name);
}
-FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
+// Returns a FunctionDef containing a MapDefun op that wraps the original
+// function.
+FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
const FunctionDef& orig_func,
FunctionDefLibrary* library) {
- // If we decide to use a different method of vectorization, we can just
- // swap out this part.
FunctionDef* vectorized_func = library->add_function();
// Function inputs and outputs are the same as original, just
// with different shapes.
@@ -82,6 +83,30 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
return vectorized_func;
}
+FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
+ const FunctionDef& orig_func,
+ FunctionDefLibrary* library) {
+ // Vectorizes orig_func naively by wrapping in a MapDefun op, then tries to
+ // do true vectorization with Vectorize.
+ FunctionDef* vectorized_func =
+ CreateMapDefunWrapper(map_node, orig_func, library);
+ NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Mutable(0);
+ DCHECK_EQ(map_defun_node->op(), "MapDefun");
+
+ // Create a copy of the original function so that we can mutate it, and
+ // attach that to the map defun node.
+ FunctionDef* map_defun_fn = library->add_function();
+ *map_defun_fn = orig_func;
+ graph_utils::SetUniqueGraphFunctionName(orig_func.signature().name(), library,
+ map_defun_fn);
+ (*map_defun_node->mutable_attr())["f"].mutable_func()->set_name(
+ map_defun_fn->signature().name());
+
+ vectorization_utils::VectorizeMapDefun(vectorized_func, map_defun_fn,
+ map_defun_node);
+ return vectorized_func;
+}
+
bool IsOutputShapesFullyDefined(const NodeDef& node) {
auto* shapes_attr = gtl::FindOrNull(node.attr(), "output_shapes");
if (shapes_attr == nullptr) return false;
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index 6a59eb0d32..5dd9d00511 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -333,6 +333,11 @@ void Vectorization::Vectorize() {
void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
NodeDef* map_defun_node) {
+ if (map_defun_node->attr().at("f").func().name() !=
+ map_defun_fn->signature().name()) {
+ LOG(ERROR) << "`map_defun_fn` and `map_defun_node` do not match";
+ return;
+ }
Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize();
}