aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/memory_optimizer.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/memory_optimizer.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc3
1 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index a0446bf566..dc1567c60a 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -490,6 +490,9 @@ std::pair<NodeDef*, NodeDef*> BuildSwapPair(NodeDef* node, int input_to_swap,
(*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
(*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
+ const DataType input_type = node->attr().at("T").type();
+ (*swap_in_node->mutable_attr())["T"].set_type(input_type);
+ (*swap_out_node->mutable_attr())["T"].set_type(input_type);
return std::make_pair(swap_out_node, swap_in_node);
}