diff options
author | Jiri Simsa <jsimsa@google.com> | 2018-08-28 10:07:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 10:15:09 -0700 |
commit | b7f2d11cc308631a8f0b733a1b2db39696507155 (patch) | |
tree | 7a450e82844f11eeb60df737ed65bac402c155f0 /tensorflow/core/framework | |
parent | 00045099ee05f85f05c8367a122bcd9ef6fc6b07 (diff) |
[tf.data] Enable optimizations for input pipelines with stateful functions.
PiperOrigin-RevId: 210559796
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r-- | tensorflow/core/framework/dataset.cc | 21 | ||||
-rw-r--r-- | tensorflow/core/framework/dataset.h | 22 |
2 files changed, 25 insertions, 18 deletions
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index f3c7189292..b0b27ce94f 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -133,22 +133,25 @@ Status GraphDefBuilderWrapper::AddDataset( return Status::OK(); } -Status GraphDefBuilderWrapper::AddFunction( - const FunctionLibraryDefinition& flib_def, const string& function_name) { +Status GraphDefBuilderWrapper::AddFunction(SerializationContext* ctx, + const string& function_name) { if (b_->HasFunction(function_name)) { VLOG(1) << "Function with name " << function_name << "already exists in" << " the graph. It will not be added again."; return Status::OK(); } - TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(flib_def, function_name)); - const FunctionDef* f_def = flib_def.Find(function_name); + if (!ctx->allow_stateful_functions()) { + TF_RETURN_IF_ERROR( + EnsureFunctionIsStateless(ctx->flib_def(), function_name)); + } + const FunctionDef* f_def = ctx->flib_def().Find(function_name); if (f_def == nullptr) { return errors::InvalidArgument("Unable to find FunctionDef for ", function_name, " in the registry."); } FunctionDefLibrary def; *def.add_function() = *f_def; - const string gradient_func = flib_def.FindGradient(function_name); + const string gradient_func = ctx->flib_def().FindGradient(function_name); if (!gradient_func.empty()) { GradientDef* g_def = def.add_gradient(); g_def->set_function_name(function_name); @@ -159,19 +162,19 @@ Status GraphDefBuilderWrapper::AddFunction( // Recursively add functions in inputs of function_name. for (const NodeDef& node_def : f_def->node_def()) { const OpRegistrationData* op_reg_data = nullptr; - TF_RETURN_IF_ERROR(flib_def.LookUp(node_def.op(), &op_reg_data)); + TF_RETURN_IF_ERROR(ctx->flib_def().LookUp(node_def.op(), &op_reg_data)); if (op_reg_data->is_function_op) { - TF_RETURN_IF_ERROR(AddFunction(flib_def, op_reg_data->op_def.name())); + TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name())); } // Recursively add functions in attrs of this NodeDef. for (const auto& pair : node_def.attr()) { - TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, flib_def)); + TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, pair.second)); } } // Recursively add functions in attrs of function_name. for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) { - TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, flib_def)); + TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, iter->second)); } return Status::OK(); } diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index e0c26d9286..e06ca68bca 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -41,6 +41,7 @@ limitations under the License. namespace tensorflow { class DatasetBase; +class SerializationContext; // Interface for reading values from a key-value store. // Used for restoring iterator state. @@ -155,11 +156,11 @@ class GraphDefBuilderWrapper { // Adds a user-defined function with name `function_name` to the graph and // recursively adds all functions it references. If a function with a matching // name has already been added, returns with OK status. If a user-defined with - // name `function_name` is not found in the FunctionLibraryDefinition, returns - // an InvalidArgumentError. If the function with name `function_name` or any - // of its dependent functions are stateful, returns an InvalidArgument error. - Status AddFunction(const FunctionLibraryDefinition& flib_def, - const string& function_name); + // name `function_name` is not found in the context's function library, + // returns an InvalidArgumentError. If the function with name `function_name` + // or any of its dependent functions are stateful, and the context does not + // explicitly permit stateful functions, returns an InvalidArgument error. + Status AddFunction(SerializationContext* ctx, const string& function_name); template <typename T> void BuildAttrValue(const T& value, AttrValue* attr) { @@ -220,13 +221,13 @@ class GraphDefBuilderWrapper { return false; } - Status AddAttrFunctions(const AttrValue& attr_value, - const FunctionLibraryDefinition& flib_def) { + Status AddAttrFunctions(SerializationContext* ctx, + const AttrValue& attr_value) { if (attr_value.has_func()) { - TF_RETURN_IF_ERROR(AddFunction(flib_def, attr_value.func().name())); + TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name())); } else if (attr_value.has_list()) { for (const NameAttrList& name_attr_list : attr_value.list().func()) { - TF_RETURN_IF_ERROR(AddFunction(flib_def, name_attr_list.name())); + TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name())); } } return Status::OK(); @@ -332,11 +333,14 @@ class IteratorContext { class SerializationContext { public: struct Params { + bool allow_stateful_functions = false; const FunctionLibraryDefinition* flib_def; // Not owned. }; explicit SerializationContext(Params params) : params_(std::move(params)) {} + bool allow_stateful_functions() { return params_.allow_stateful_functions; } + const FunctionLibraryDefinition& flib_def() { return *params_.flib_def; } private: |