aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-08-28 10:07:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 10:15:09 -0700
commitb7f2d11cc308631a8f0b733a1b2db39696507155 (patch)
tree7a450e82844f11eeb60df737ed65bac402c155f0
parent00045099ee05f85f05c8367a122bcd9ef6fc6b07 (diff)
[tf.data] Enable optimizations for input pipelines with stateful functions.
PiperOrigin-RevId: 210559796
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py19
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py13
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/test_utils.py18
-rw-r--r--tensorflow/core/framework/dataset.cc21
-rw-r--r--tensorflow/core/framework/dataset.h22
-rw-r--r--tensorflow/core/kernels/data/filter_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/flat_map_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc9
-rw-r--r--tensorflow/core/kernels/data/group_by_window_dataset_op.cc7
-rw-r--r--tensorflow/core/kernels/data/interleave_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/optimize_dataset_op.cc1
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/scan_dataset_op.cc2
16 files changed, 77 insertions, 50 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
index 57bf22591a..e2c9bc82df 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
@@ -122,15 +122,12 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
[3, 4]]).repeat(5)
- _, optimized = self._get_test_datasets(
+ unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- nxt = optimized.make_one_shot_iterator().get_next()
-
- # NOTE: Right now, it raises an error because we can't save datasets that
- # are stateful, and we rely on this saving mechanism to optimize datasets,
- # so stateful functions can't be optimized.
- with self.assertRaisesRegexp(errors.InvalidArgumentError, "[Ss]tateful"):
- self.evaluate(nxt)
+ self._assert_datasets_raise_same_error(
+ unoptimized, optimized, errors.InvalidArgumentError,
+ [("OneShotIterator", "OneShotIterator_1", 1),
+ ("IteratorGetNext", "IteratorGetNext_1", 1)])
def testOptimizationIgnoreRagged(self):
# Make sure we ignore inputs that might not be uniformly sized
@@ -151,8 +148,10 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_raise_same_error(unoptimized, optimized,
- errors.InvalidArgumentError)
+ self._assert_datasets_raise_same_error(
+ unoptimized, optimized, errors.InvalidArgumentError,
+ [("OneShotIterator", "OneShotIterator_1", 1),
+ ("IteratorGetNext", "IteratorGetNext_1", 1)])
class MapVectorizationBenchmark(test.Benchmark):
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
index ec43bc3653..446bf8d749 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
@@ -22,6 +22,7 @@ from absl.testing import parameterized
from tensorflow.contrib.data.python.ops import optimization
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
+from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@@ -100,6 +101,18 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testStatefulFunctionOptimization(self):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next([
+ "MapAndBatch"
+ ])).map(lambda _: random_ops.random_uniform([])).batch(10).apply(
+ optimization.optimize(["map_and_batch_fusion"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(get_next)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
index 1b962b3418..1d70b16041 100644
--- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py
+++ b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
@@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import re
+
from tensorflow.python.data.util import nest
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
@@ -45,7 +47,11 @@ class DatasetTestBase(test.TestCase):
for i in range(len(op1)):
self.assertAllEqual(op1[i], op2[i])
- def _assert_datasets_raise_same_error(self, dataset1, dataset2, exc_class):
+ def _assert_datasets_raise_same_error(self,
+ dataset1,
+ dataset2,
+ exception_class,
+ replacements=None):
next1 = dataset1.make_one_shot_iterator().get_next()
next2 = dataset2.make_one_shot_iterator().get_next()
with self.test_session() as sess:
@@ -53,8 +59,12 @@ class DatasetTestBase(test.TestCase):
sess.run(next1)
raise ValueError(
"Expected dataset to raise an error of type %s, but it did not." %
- repr(exc_class))
- except exc_class as e:
+ repr(exception_class))
+ except exception_class as e:
+ expected_message = e.message
+ for old, new, count in replacements:
+ expected_message = expected_message.replace(old, new, count)
# Check that the first segment of the error messages are the same.
- with self.assertRaisesRegexp(exc_class, e.message.split(". ")[0]):
+ with self.assertRaisesRegexp(exception_class,
+ re.escape(expected_message)):
sess.run(next2)
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index f3c7189292..b0b27ce94f 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -133,22 +133,25 @@ Status GraphDefBuilderWrapper::AddDataset(
return Status::OK();
}
-Status GraphDefBuilderWrapper::AddFunction(
- const FunctionLibraryDefinition& flib_def, const string& function_name) {
+Status GraphDefBuilderWrapper::AddFunction(SerializationContext* ctx,
+ const string& function_name) {
if (b_->HasFunction(function_name)) {
VLOG(1) << "Function with name " << function_name << "already exists in"
<< " the graph. It will not be added again.";
return Status::OK();
}
- TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(flib_def, function_name));
- const FunctionDef* f_def = flib_def.Find(function_name);
+ if (!ctx->allow_stateful_functions()) {
+ TF_RETURN_IF_ERROR(
+ EnsureFunctionIsStateless(ctx->flib_def(), function_name));
+ }
+ const FunctionDef* f_def = ctx->flib_def().Find(function_name);
if (f_def == nullptr) {
return errors::InvalidArgument("Unable to find FunctionDef for ",
function_name, " in the registry.");
}
FunctionDefLibrary def;
*def.add_function() = *f_def;
- const string gradient_func = flib_def.FindGradient(function_name);
+ const string gradient_func = ctx->flib_def().FindGradient(function_name);
if (!gradient_func.empty()) {
GradientDef* g_def = def.add_gradient();
g_def->set_function_name(function_name);
@@ -159,19 +162,19 @@ Status GraphDefBuilderWrapper::AddFunction(
// Recursively add functions in inputs of function_name.
for (const NodeDef& node_def : f_def->node_def()) {
const OpRegistrationData* op_reg_data = nullptr;
- TF_RETURN_IF_ERROR(flib_def.LookUp(node_def.op(), &op_reg_data));
+ TF_RETURN_IF_ERROR(ctx->flib_def().LookUp(node_def.op(), &op_reg_data));
if (op_reg_data->is_function_op) {
- TF_RETURN_IF_ERROR(AddFunction(flib_def, op_reg_data->op_def.name()));
+ TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name()));
}
// Recursively add functions in attrs of this NodeDef.
for (const auto& pair : node_def.attr()) {
- TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, flib_def));
+ TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, pair.second));
}
}
// Recursively add functions in attrs of function_name.
for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) {
- TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, flib_def));
+ TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, iter->second));
}
return Status::OK();
}
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index e0c26d9286..e06ca68bca 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -41,6 +41,7 @@ limitations under the License.
namespace tensorflow {
class DatasetBase;
+class SerializationContext;
// Interface for reading values from a key-value store.
// Used for restoring iterator state.
@@ -155,11 +156,11 @@ class GraphDefBuilderWrapper {
// Adds a user-defined function with name `function_name` to the graph and
// recursively adds all functions it references. If a function with a matching
// name has already been added, returns with OK status. If a user-defined with
- // name `function_name` is not found in the FunctionLibraryDefinition, returns
- // an InvalidArgumentError. If the function with name `function_name` or any
- // of its dependent functions are stateful, returns an InvalidArgument error.
- Status AddFunction(const FunctionLibraryDefinition& flib_def,
- const string& function_name);
+ // name `function_name` is not found in the context's function library,
+ // returns an InvalidArgumentError. If the function with name `function_name`
+ // or any of its dependent functions are stateful, and the context does not
+ // explicitly permit stateful functions, returns an InvalidArgument error.
+ Status AddFunction(SerializationContext* ctx, const string& function_name);
template <typename T>
void BuildAttrValue(const T& value, AttrValue* attr) {
@@ -220,13 +221,13 @@ class GraphDefBuilderWrapper {
return false;
}
- Status AddAttrFunctions(const AttrValue& attr_value,
- const FunctionLibraryDefinition& flib_def) {
+ Status AddAttrFunctions(SerializationContext* ctx,
+ const AttrValue& attr_value) {
if (attr_value.has_func()) {
- TF_RETURN_IF_ERROR(AddFunction(flib_def, attr_value.func().name()));
+ TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name()));
} else if (attr_value.has_list()) {
for (const NameAttrList& name_attr_list : attr_value.list().func()) {
- TF_RETURN_IF_ERROR(AddFunction(flib_def, name_attr_list.name()));
+ TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name()));
}
}
return Status::OK();
@@ -332,11 +333,14 @@ class IteratorContext {
class SerializationContext {
public:
struct Params {
+ bool allow_stateful_functions = false;
const FunctionLibraryDefinition* flib_def; // Not owned.
};
explicit SerializationContext(Params params) : params_(std::move(params)) {}
+ bool allow_stateful_functions() { return params_.allow_stateful_functions; }
+
const FunctionLibraryDefinition& flib_def() { return *params_.flib_def; }
private:
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index f5c7d336a6..bbce001eaf 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -112,7 +112,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_graph_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
index 21e627a8e8..b1eb2fd849 100644
--- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
@@ -94,7 +94,7 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
index 4a388645f2..130f04da3e 100644
--- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
@@ -109,11 +109,10 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), key_func().name()));
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), init_func().name()));
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), reduce_func().name()));
- TF_RETURN_IF_ERROR(
- b->AddFunction(ctx->flib_def(), finalize_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, init_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, finalize_func().name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
index f993a68934..46a3185b49 100644
--- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
@@ -139,10 +139,9 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), key_func_.name()));
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), reduce_func_.name()));
- TF_RETURN_IF_ERROR(
- b->AddFunction(ctx->flib_def(), window_size_func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, window_size_func_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc
index 6bba667759..716e040277 100644
--- a/tensorflow/core/kernels/data/interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc
@@ -116,7 +116,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
Node* cycle_length_node;
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index c4df7f2756..8b0c9ad6b2 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -147,7 +147,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), map_fn_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* batch_size_node;
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index 26ae26a7fd..7f8182d917 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -92,7 +92,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
index 9b14078407..831e7252da 100644
--- a/tensorflow/core/kernels/data/optimize_dataset_op.cc
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -92,6 +92,7 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
DatasetGraphDefBuilder db(&b);
Node* input_node = nullptr;
SerializationContext::Params params;
+ params.allow_stateful_functions = true;
params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
SerializationContext serialization_ctx(params);
TF_RETURN_IF_ERROR(
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index bf86361a71..f6b3fd97e3 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -137,8 +137,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(
- b->AddFunction(ctx->flib_def(), interleave_func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name()));
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
Node* cycle_length_node;
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index e03a4e353b..bff54813d6 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -142,7 +142,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
b->AddScalar(num_parallel_calls_, &num_parallel_calls));
// Attr: f
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
AttrValue f;
b->BuildAttrValue(func_, &f);
diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc
index 5d3319b19f..fccad933d0 100644
--- a/tensorflow/core/kernels/data/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/data/scan_dataset_op.cc
@@ -109,7 +109,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
std::vector<Node*> initial_state_nodes;