diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/function_optimizer.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/function_optimizer.cc | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index 645e4c2087..56364f0095 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -453,6 +453,7 @@ Status InitializeFunctionSpecializationSignature( } Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func, + const int graph_def_version, FunctionOptimizerContext* ctx, GraphDef* optimized_graph) { VLOG(2) << "Specialize function instantiation: " @@ -492,7 +493,8 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func, // Make a GrapplerFunctionItem and convert it back to FunctionDef after // pushing all constant inputs into the function body. GrapplerFunctionItem item; - TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); + TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, func_attr, flib, + graph_def_version, &item)); // Push const inputs into the function body, and keep track of their control // dependencies. @@ -576,15 +578,15 @@ NodeDef InlinedFunctionOutputsNode(const NodeDef& func_node, Status InlineFunction(const NodeDef& func_node, const FunctionDef& func, const FunctionOptimizerContext& ctx, - GraphDef* optimized_graph) { + const int graph_def_version, GraphDef* optimized_graph) { VLOG(2) << "Inline function instantiation: " << SummarizeNodeDef(func_node); const std::unordered_map<string, AttrValue> func_attr( func_node.attr().begin(), func_node.attr().end()); GrapplerFunctionItem item; - Status item_status = - MakeGrapplerFunctionItem(func, func_attr, ctx.function_library(), &item); + Status item_status = MakeGrapplerFunctionItem( + func, func_attr, ctx.function_library(), graph_def_version, &item); if (!item_status.ok()) { return errors::InvalidArgument("Failed to inline function ", func_node.op(), @@ -645,7 +647,8 @@ Status InlineFunction(const NodeDef& func_node, const FunctionDef& func, if (func_body_node_func != nullptr) { // Recursively inline function calls. TF_RETURN_IF_ERROR(InlineFunction(func_body_node, *func_body_node_func, - ctx, optimized_graph)); + ctx, graph_def_version, + optimized_graph)); } else { // Annotate the node with the function attributes. for (const auto& attr : func.attr()) { @@ -824,7 +827,8 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, if (inline_func && ctx.IsInlinedFunction(func_name)) { // Inline function body into the optimized graph} TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED( - InlineFunction(node, *func, ctx, optimized_graph)); + InlineFunction(node, *func, ctx, item.graph.versions().producer(), + optimized_graph)); continue; } @@ -837,7 +841,8 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // TODO(ezhulenev): Specialize function call if input has a known shape. // Specialize function body for its instantiation attributes and inputs. TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED( - SpecializeFunction(node, *func, &ctx, optimized_graph)); + SpecializeFunction(node, *func, item.graph.versions().producer(), + &ctx, optimized_graph)); continue; } } |