aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-08-28 10:07:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 10:15:09 -0700
commitb7f2d11cc308631a8f0b733a1b2db39696507155 (patch)
tree7a450e82844f11eeb60df737ed65bac402c155f0 /tensorflow/core/framework
parent00045099ee05f85f05c8367a122bcd9ef6fc6b07 (diff)
[tf.data] Enable optimizations for input pipelines with stateful functions.
PiperOrigin-RevId: 210559796
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/dataset.cc21
-rw-r--r--tensorflow/core/framework/dataset.h22
2 files changed, 25 insertions, 18 deletions
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index f3c7189292..b0b27ce94f 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -133,22 +133,25 @@ Status GraphDefBuilderWrapper::AddDataset(
return Status::OK();
}
-Status GraphDefBuilderWrapper::AddFunction(
- const FunctionLibraryDefinition& flib_def, const string& function_name) {
+Status GraphDefBuilderWrapper::AddFunction(SerializationContext* ctx,
+ const string& function_name) {
if (b_->HasFunction(function_name)) {
VLOG(1) << "Function with name " << function_name << "already exists in"
<< " the graph. It will not be added again.";
return Status::OK();
}
- TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(flib_def, function_name));
- const FunctionDef* f_def = flib_def.Find(function_name);
+ if (!ctx->allow_stateful_functions()) {
+ TF_RETURN_IF_ERROR(
+ EnsureFunctionIsStateless(ctx->flib_def(), function_name));
+ }
+ const FunctionDef* f_def = ctx->flib_def().Find(function_name);
if (f_def == nullptr) {
return errors::InvalidArgument("Unable to find FunctionDef for ",
function_name, " in the registry.");
}
FunctionDefLibrary def;
*def.add_function() = *f_def;
- const string gradient_func = flib_def.FindGradient(function_name);
+ const string gradient_func = ctx->flib_def().FindGradient(function_name);
if (!gradient_func.empty()) {
GradientDef* g_def = def.add_gradient();
g_def->set_function_name(function_name);
@@ -159,19 +162,19 @@ Status GraphDefBuilderWrapper::AddFunction(
// Recursively add functions in inputs of function_name.
for (const NodeDef& node_def : f_def->node_def()) {
const OpRegistrationData* op_reg_data = nullptr;
- TF_RETURN_IF_ERROR(flib_def.LookUp(node_def.op(), &op_reg_data));
+ TF_RETURN_IF_ERROR(ctx->flib_def().LookUp(node_def.op(), &op_reg_data));
if (op_reg_data->is_function_op) {
- TF_RETURN_IF_ERROR(AddFunction(flib_def, op_reg_data->op_def.name()));
+ TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name()));
}
// Recursively add functions in attrs of this NodeDef.
for (const auto& pair : node_def.attr()) {
- TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, flib_def));
+ TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, pair.second));
}
}
// Recursively add functions in attrs of function_name.
for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) {
- TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, flib_def));
+ TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, iter->second));
}
return Status::OK();
}
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index e0c26d9286..e06ca68bca 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -41,6 +41,7 @@ limitations under the License.
namespace tensorflow {
class DatasetBase;
+class SerializationContext;
// Interface for reading values from a key-value store.
// Used for restoring iterator state.
@@ -155,11 +156,11 @@ class GraphDefBuilderWrapper {
// Adds a user-defined function with name `function_name` to the graph and
// recursively adds all functions it references. If a function with a matching
// name has already been added, returns with OK status. If a user-defined with
- // name `function_name` is not found in the FunctionLibraryDefinition, returns
- // an InvalidArgumentError. If the function with name `function_name` or any
- // of its dependent functions are stateful, returns an InvalidArgument error.
- Status AddFunction(const FunctionLibraryDefinition& flib_def,
- const string& function_name);
+ // name `function_name` is not found in the context's function library,
+ // returns an InvalidArgumentError. If the function with name `function_name`
+ // or any of its dependent functions are stateful, and the context does not
+ // explicitly permit stateful functions, returns an InvalidArgument error.
+ Status AddFunction(SerializationContext* ctx, const string& function_name);
template <typename T>
void BuildAttrValue(const T& value, AttrValue* attr) {
@@ -220,13 +221,13 @@ class GraphDefBuilderWrapper {
return false;
}
- Status AddAttrFunctions(const AttrValue& attr_value,
- const FunctionLibraryDefinition& flib_def) {
+ Status AddAttrFunctions(SerializationContext* ctx,
+ const AttrValue& attr_value) {
if (attr_value.has_func()) {
- TF_RETURN_IF_ERROR(AddFunction(flib_def, attr_value.func().name()));
+ TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name()));
} else if (attr_value.has_list()) {
for (const NameAttrList& name_attr_list : attr_value.list().func()) {
- TF_RETURN_IF_ERROR(AddFunction(flib_def, name_attr_list.name()));
+ TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name()));
}
}
return Status::OK();
@@ -332,11 +333,14 @@ class IteratorContext {
class SerializationContext {
public:
struct Params {
+ bool allow_stateful_functions = false;
const FunctionLibraryDefinition* flib_def; // Not owned.
};
explicit SerializationContext(Params params) : params_(std::move(params)) {}
+ bool allow_stateful_functions() { return params_.allow_stateful_functions; }
+
const FunctionLibraryDefinition& flib_def() { return *params_.flib_def; }
private: