aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-01-05 15:49:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-05 15:53:24 -0800
commit35a068e7c29202f298575e51320c469f91f22f95 (patch)
tree297cdec2284b671140dc7fe97bad70fcb7565aa8
parent2f0c40624112bfdcf4e284aeb862c5f51761e909 (diff)
Properly set the type of the swap nodes.
PiperOrigin-RevId: 180985878
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc30
1 files changed, 21 insertions, 9 deletions
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index 74d7f2f94d..bb4839d2e1 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -481,8 +481,19 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
}
}
-std::pair<NodeDef*, NodeDef*> BuildSwapPair(NodeDef* node, int input_to_swap,
- GraphDef* graph) {
+Status BuildSwapPair(NodeDef* node, int input_to_swap, GraphDef* graph,
+ std::pair<NodeDef*, NodeDef*>* swap_pair) {
+ const OpDef* op_def;
+ TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def));
+ DataType input_type;
+ TF_RETURN_IF_ERROR(
+ InputTypeForNode(*node, *op_def, input_to_swap, &input_type));
+ if (IsRefType(input_type)) {
+ return errors::InvalidArgument("Can't swap input ", input_to_swap,
+ " of node ", node->name(),
+ " since it expects a reference");
+ }
+
string tensor_to_swap = strings::StrCat(node->name(), "_", input_to_swap);
// Force the tensor to be copied to cpu.
@@ -502,10 +513,11 @@ std::pair<NodeDef*, NodeDef*> BuildSwapPair(NodeDef* node, int input_to_swap,
(*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
(*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
- const DataType input_type = node->attr().at("T").type();
(*swap_in_node->mutable_attr())["T"].set_type(input_type);
(*swap_out_node->mutable_attr())["T"].set_type(input_type);
- return std::make_pair(swap_out_node, swap_in_node);
+ *swap_pair = std::make_pair(swap_out_node, swap_in_node);
+
+ return Status::OK();
}
static int64 EstimateSize(const OpInfo::TensorProperties& t) {
@@ -762,7 +774,6 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
return Status::OK();
}
- {
// Estimate the size of the data to swap for each node.
GraphProperties properties(item);
TF_RETURN_IF_ERROR(properties.InferStatically(true));
@@ -779,7 +790,6 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// Let's assume we're going to swap over PCIe running at 16 GBps.
swap_info.time_to_swap = bytes_to_swap / 16;
}
- }
std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
TF_RETURN_IF_ERROR(
@@ -792,7 +802,7 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
for (auto& swap : nodes_to_swap) {
NodeDef* node = swap.first;
- SwapInfo& swap_info = swap.second;
+ const SwapInfo& swap_info = swap.second;
// Make sure the tensor isn't swapped back in right away: look for node that
// will execute just before we need to swap the data back, and add a control
@@ -804,8 +814,10 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
// Swap all the tensors that are marked with the 'swap_to_host' attribute.
for (int input_id : swap_info.inputs_to_swap) {
- std::pair<NodeDef*, NodeDef*> swap_nodes =
- BuildSwapPair(node, input_id, optimized_graph);
+ std::pair<NodeDef*, NodeDef*> swap_nodes;
+ if (!BuildSwapPair(node, input_id, optimized_graph, &swap_nodes).ok()) {
+ continue;
+ }
*swap_nodes.first->add_input() = node->input(input_id);
*node->mutable_input(input_id) = swap_nodes.second->name();