From b7f2d11cc308631a8f0b733a1b2db39696507155 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Tue, 28 Aug 2018 10:07:45 -0700 Subject: [tf.data] Enable optimizations for input pipelines with stateful functions. PiperOrigin-RevId: 210559796 --- .../optimization/map_vectorization_test.py | 19 +++++++++---------- .../kernel_tests/optimize_dataset_op_test.py | 13 +++++++++++++ .../contrib/data/python/kernel_tests/test_utils.py | 18 ++++++++++++++---- tensorflow/core/framework/dataset.cc | 21 ++++++++++++--------- tensorflow/core/framework/dataset.h | 22 +++++++++++++--------- tensorflow/core/kernels/data/filter_dataset_op.cc | 2 +- .../core/kernels/data/flat_map_dataset_op.cc | 2 +- .../kernels/data/group_by_reducer_dataset_op.cc | 9 ++++----- .../kernels/data/group_by_window_dataset_op.cc | 7 +++---- .../core/kernels/data/interleave_dataset_op.cc | 2 +- .../core/kernels/data/map_and_batch_dataset_op.cc | 2 +- tensorflow/core/kernels/data/map_dataset_op.cc | 2 +- .../core/kernels/data/optimize_dataset_op.cc | 1 + .../kernels/data/parallel_interleave_dataset_op.cc | 3 +-- .../core/kernels/data/parallel_map_dataset_op.cc | 2 +- tensorflow/core/kernels/data/scan_dataset_op.cc | 2 +- 16 files changed, 77 insertions(+), 50 deletions(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py index 57bf22591a..e2c9bc82df 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py @@ -122,15 +122,12 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase): base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], [3, 4]]).repeat(5) - _, optimized = self._get_test_datasets( + unoptimized, optimized = self._get_test_datasets( base_dataset, map_fn, expect_optimized=False) - nxt = optimized.make_one_shot_iterator().get_next() - - # NOTE: Right now, it raises an error because we can't save datasets that - # are stateful, and we rely on this saving mechanism to optimize datasets, - # so stateful functions can't be optimized. - with self.assertRaisesRegexp(errors.InvalidArgumentError, "[Ss]tateful"): - self.evaluate(nxt) + self._assert_datasets_raise_same_error( + unoptimized, optimized, errors.InvalidArgumentError, + [("OneShotIterator", "OneShotIterator_1", 1), + ("IteratorGetNext", "IteratorGetNext_1", 1)]) def testOptimizationIgnoreRagged(self): # Make sure we ignore inputs that might not be uniformly sized @@ -151,8 +148,10 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase): base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True) unoptimized, optimized = self._get_test_datasets( base_dataset, map_fn, expect_optimized=False) - self._assert_datasets_raise_same_error(unoptimized, optimized, - errors.InvalidArgumentError) + self._assert_datasets_raise_same_error( + unoptimized, optimized, errors.InvalidArgumentError, + [("OneShotIterator", "OneShotIterator_1", 1), + ("IteratorGetNext", "IteratorGetNext_1", 1)]) class MapVectorizationBenchmark(test.Benchmark): diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py index ec43bc3653..446bf8d749 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -22,6 +22,7 @@ from absl.testing import parameterized from tensorflow.contrib.data.python.ops import optimization from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors +from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -100,6 +101,18 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testStatefulFunctionOptimization(self): + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next([ + "MapAndBatch" + ])).map(lambda _: random_ops.random_uniform([])).batch(10).apply( + optimization.optimize(["map_and_batch_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(get_next) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py index 1b962b3418..1d70b16041 100644 --- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py +++ b/tensorflow/contrib/data/python/kernel_tests/test_utils.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import re + from tensorflow.python.data.util import nest from tensorflow.python.framework import errors from tensorflow.python.platform import test @@ -45,7 +47,11 @@ class DatasetTestBase(test.TestCase): for i in range(len(op1)): self.assertAllEqual(op1[i], op2[i]) - def _assert_datasets_raise_same_error(self, dataset1, dataset2, exc_class): + def _assert_datasets_raise_same_error(self, + dataset1, + dataset2, + exception_class, + replacements=None): next1 = dataset1.make_one_shot_iterator().get_next() next2 = dataset2.make_one_shot_iterator().get_next() with self.test_session() as sess: @@ -53,8 +59,12 @@ class DatasetTestBase(test.TestCase): sess.run(next1) raise ValueError( "Expected dataset to raise an error of type %s, but it did not." % - repr(exc_class)) - except exc_class as e: + repr(exception_class)) + except exception_class as e: + expected_message = e.message + for old, new, count in replacements: + expected_message = expected_message.replace(old, new, count) # Check that the first segment of the error messages are the same. - with self.assertRaisesRegexp(exc_class, e.message.split(". ")[0]): + with self.assertRaisesRegexp(exception_class, + re.escape(expected_message)): sess.run(next2) diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index f3c7189292..b0b27ce94f 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -133,22 +133,25 @@ Status GraphDefBuilderWrapper::AddDataset( return Status::OK(); } -Status GraphDefBuilderWrapper::AddFunction( - const FunctionLibraryDefinition& flib_def, const string& function_name) { +Status GraphDefBuilderWrapper::AddFunction(SerializationContext* ctx, + const string& function_name) { if (b_->HasFunction(function_name)) { VLOG(1) << "Function with name " << function_name << "already exists in" << " the graph. It will not be added again."; return Status::OK(); } - TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(flib_def, function_name)); - const FunctionDef* f_def = flib_def.Find(function_name); + if (!ctx->allow_stateful_functions()) { + TF_RETURN_IF_ERROR( + EnsureFunctionIsStateless(ctx->flib_def(), function_name)); + } + const FunctionDef* f_def = ctx->flib_def().Find(function_name); if (f_def == nullptr) { return errors::InvalidArgument("Unable to find FunctionDef for ", function_name, " in the registry."); } FunctionDefLibrary def; *def.add_function() = *f_def; - const string gradient_func = flib_def.FindGradient(function_name); + const string gradient_func = ctx->flib_def().FindGradient(function_name); if (!gradient_func.empty()) { GradientDef* g_def = def.add_gradient(); g_def->set_function_name(function_name); @@ -159,19 +162,19 @@ Status GraphDefBuilderWrapper::AddFunction( // Recursively add functions in inputs of function_name. for (const NodeDef& node_def : f_def->node_def()) { const OpRegistrationData* op_reg_data = nullptr; - TF_RETURN_IF_ERROR(flib_def.LookUp(node_def.op(), &op_reg_data)); + TF_RETURN_IF_ERROR(ctx->flib_def().LookUp(node_def.op(), &op_reg_data)); if (op_reg_data->is_function_op) { - TF_RETURN_IF_ERROR(AddFunction(flib_def, op_reg_data->op_def.name())); + TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name())); } // Recursively add functions in attrs of this NodeDef. for (const auto& pair : node_def.attr()) { - TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, flib_def)); + TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, pair.second)); } } // Recursively add functions in attrs of function_name. for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) { - TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, flib_def)); + TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, iter->second)); } return Status::OK(); } diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index e0c26d9286..e06ca68bca 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -41,6 +41,7 @@ limitations under the License. namespace tensorflow { class DatasetBase; +class SerializationContext; // Interface for reading values from a key-value store. // Used for restoring iterator state. @@ -155,11 +156,11 @@ class GraphDefBuilderWrapper { // Adds a user-defined function with name `function_name` to the graph and // recursively adds all functions it references. If a function with a matching // name has already been added, returns with OK status. If a user-defined with - // name `function_name` is not found in the FunctionLibraryDefinition, returns - // an InvalidArgumentError. If the function with name `function_name` or any - // of its dependent functions are stateful, returns an InvalidArgument error. - Status AddFunction(const FunctionLibraryDefinition& flib_def, - const string& function_name); + // name `function_name` is not found in the context's function library, + // returns an InvalidArgumentError. If the function with name `function_name` + // or any of its dependent functions are stateful, and the context does not + // explicitly permit stateful functions, returns an InvalidArgument error. + Status AddFunction(SerializationContext* ctx, const string& function_name); template void BuildAttrValue(const T& value, AttrValue* attr) { @@ -220,13 +221,13 @@ class GraphDefBuilderWrapper { return false; } - Status AddAttrFunctions(const AttrValue& attr_value, - const FunctionLibraryDefinition& flib_def) { + Status AddAttrFunctions(SerializationContext* ctx, + const AttrValue& attr_value) { if (attr_value.has_func()) { - TF_RETURN_IF_ERROR(AddFunction(flib_def, attr_value.func().name())); + TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name())); } else if (attr_value.has_list()) { for (const NameAttrList& name_attr_list : attr_value.list().func()) { - TF_RETURN_IF_ERROR(AddFunction(flib_def, name_attr_list.name())); + TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name())); } } return Status::OK(); @@ -332,11 +333,14 @@ class IteratorContext { class SerializationContext { public: struct Params { + bool allow_stateful_functions = false; const FunctionLibraryDefinition* flib_def; // Not owned. }; explicit SerializationContext(Params params) : params_(std::move(params)) {} + bool allow_stateful_functions() { return params_.allow_stateful_functions; } + const FunctionLibraryDefinition& flib_def() { return *params_.flib_def; } private: diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index f5c7d336a6..bbce001eaf 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -112,7 +112,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); Node* input_graph_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc index 21e627a8e8..b1eb2fd849 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc @@ -94,7 +94,7 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc index 4a388645f2..130f04da3e 100644 --- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc @@ -109,11 +109,10 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), key_func().name())); - TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), init_func().name())); - TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), reduce_func().name())); - TF_RETURN_IF_ERROR( - b->AddFunction(ctx->flib_def(), finalize_func().name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func().name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, init_func().name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func().name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, finalize_func().name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc index f993a68934..46a3185b49 100644 --- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc @@ -139,10 +139,9 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), key_func_.name())); - TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), reduce_func_.name())); - TF_RETURN_IF_ERROR( - b->AddFunction(ctx->flib_def(), window_size_func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, window_size_func_.name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc index 6bba667759..716e040277 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc @@ -116,7 +116,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); Node* input_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); Node* cycle_length_node; diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index c4df7f2756..8b0c9ad6b2 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -147,7 +147,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), map_fn_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* batch_size_node; diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 26ae26a7fd..7f8182d917 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -92,7 +92,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index 9b14078407..831e7252da 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -92,6 +92,7 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { DatasetGraphDefBuilder db(&b); Node* input_node = nullptr; SerializationContext::Params params; + params.allow_stateful_functions = true; params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); SerializationContext serialization_ctx(params); TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index bf86361a71..f6b3fd97e3 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -137,8 +137,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR( - b->AddFunction(ctx->flib_def(), interleave_func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name())); Node* input_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); Node* cycle_length_node; diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index e03a4e353b..bff54813d6 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -142,7 +142,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { b->AddScalar(num_parallel_calls_, &num_parallel_calls)); // Attr: f - TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); AttrValue f; b->BuildAttrValue(func_, &f); diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc index 5d3319b19f..fccad933d0 100644 --- a/tensorflow/core/kernels/data/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/scan_dataset_op.cc @@ -109,7 +109,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); Node* input_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); std::vector initial_state_nodes; -- cgit v1.2.3