From 35a068e7c29202f298575e51320c469f91f22f95 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 5 Jan 2018 15:49:41 -0800 Subject: Properly set the type of the swap nodes. PiperOrigin-RevId: 180985878 --- .../core/grappler/optimizers/memory_optimizer.cc | 30 +++++++++++++++------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index 74d7f2f94d..bb4839d2e1 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -481,8 +481,19 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level, } } -std::pair BuildSwapPair(NodeDef* node, int input_to_swap, - GraphDef* graph) { +Status BuildSwapPair(NodeDef* node, int input_to_swap, GraphDef* graph, + std::pair* swap_pair) { + const OpDef* op_def; + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def)); + DataType input_type; + TF_RETURN_IF_ERROR( + InputTypeForNode(*node, *op_def, input_to_swap, &input_type)); + if (IsRefType(input_type)) { + return errors::InvalidArgument("Can't swap input ", input_to_swap, + " of node ", node->name(), + " since it expects a reference"); + } + string tensor_to_swap = strings::StrCat(node->name(), "_", input_to_swap); // Force the tensor to be copied to cpu. @@ -502,10 +513,11 @@ std::pair BuildSwapPair(NodeDef* node, int input_to_swap, (*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group); (*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group); - const DataType input_type = node->attr().at("T").type(); (*swap_in_node->mutable_attr())["T"].set_type(input_type); (*swap_out_node->mutable_attr())["T"].set_type(input_type); - return std::make_pair(swap_out_node, swap_in_node); + *swap_pair = std::make_pair(swap_out_node, swap_in_node); + + return Status::OK(); } static int64 EstimateSize(const OpInfo::TensorProperties& t) { @@ -762,7 +774,6 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, return Status::OK(); } - { // Estimate the size of the data to swap for each node. GraphProperties properties(item); TF_RETURN_IF_ERROR(properties.InferStatically(true)); @@ -779,7 +790,6 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // Let's assume we're going to swap over PCIe running at 16 GBps. swap_info.time_to_swap = bytes_to_swap / 16; } - } std::unordered_map execution_times; TF_RETURN_IF_ERROR( @@ -792,7 +802,7 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, for (auto& swap : nodes_to_swap) { NodeDef* node = swap.first; - SwapInfo& swap_info = swap.second; + const SwapInfo& swap_info = swap.second; // Make sure the tensor isn't swapped back in right away: look for node that // will execute just before we need to swap the data back, and add a control @@ -804,8 +814,10 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } // Swap all the tensors that are marked with the 'swap_to_host' attribute. for (int input_id : swap_info.inputs_to_swap) { - std::pair swap_nodes = - BuildSwapPair(node, input_id, optimized_graph); + std::pair swap_nodes; + if (!BuildSwapPair(node, input_id, optimized_graph, &swap_nodes).ok()) { + continue; + } *swap_nodes.first->add_input() = node->input(input_id); *node->mutable_input(input_id) = swap_nodes.second->name(); -- cgit v1.2.3