aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/function_optimizer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/function_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer.cc19
1 files changed, 12 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index 645e4c2087..56364f0095 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -453,6 +453,7 @@ Status InitializeFunctionSpecializationSignature(
}
Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
+ const int graph_def_version,
FunctionOptimizerContext* ctx,
GraphDef* optimized_graph) {
VLOG(2) << "Specialize function instantiation: "
@@ -492,7 +493,8 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
// Make a GrapplerFunctionItem and convert it back to FunctionDef after
// pushing all constant inputs into the function body.
GrapplerFunctionItem item;
- TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, func_attr, flib,
+ graph_def_version, &item));
// Push const inputs into the function body, and keep track of their control
// dependencies.
@@ -576,15 +578,15 @@ NodeDef InlinedFunctionOutputsNode(const NodeDef& func_node,
Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
const FunctionOptimizerContext& ctx,
- GraphDef* optimized_graph) {
+ const int graph_def_version, GraphDef* optimized_graph) {
VLOG(2) << "Inline function instantiation: " << SummarizeNodeDef(func_node);
const std::unordered_map<string, AttrValue> func_attr(
func_node.attr().begin(), func_node.attr().end());
GrapplerFunctionItem item;
- Status item_status =
- MakeGrapplerFunctionItem(func, func_attr, ctx.function_library(), &item);
+ Status item_status = MakeGrapplerFunctionItem(
+ func, func_attr, ctx.function_library(), graph_def_version, &item);
if (!item_status.ok()) {
return errors::InvalidArgument("Failed to inline function ", func_node.op(),
@@ -645,7 +647,8 @@ Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
if (func_body_node_func != nullptr) {
// Recursively inline function calls.
TF_RETURN_IF_ERROR(InlineFunction(func_body_node, *func_body_node_func,
- ctx, optimized_graph));
+ ctx, graph_def_version,
+ optimized_graph));
} else {
// Annotate the node with the function attributes.
for (const auto& attr : func.attr()) {
@@ -824,7 +827,8 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
if (inline_func && ctx.IsInlinedFunction(func_name)) {
// Inline function body into the optimized graph}
TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
- InlineFunction(node, *func, ctx, optimized_graph));
+ InlineFunction(node, *func, ctx, item.graph.versions().producer(),
+ optimized_graph));
continue;
}
@@ -837,7 +841,8 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// TODO(ezhulenev): Specialize function call if input has a known shape.
// Specialize function body for its instantiation attributes and inputs.
TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
- SpecializeFunction(node, *func, &ctx, optimized_graph));
+ SpecializeFunction(node, *func, item.graph.versions().producer(),
+ &ctx, optimized_graph));
continue;
}
}