diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers')
-rw-r--r-- | tensorflow/core/grappler/optimizers/memory_optimizer.cc | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index a0446bf566..dc1567c60a 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -490,6 +490,9 @@ std::pair<NodeDef*, NodeDef*> BuildSwapPair(NodeDef* node, int input_to_swap, (*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group); (*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group); + const DataType input_type = node->attr().at("T").type(); + (*swap_in_node->mutable_attr())["T"].set_type(input_type); + (*swap_out_node->mutable_attr())["T"].set_type(input_type); return std::make_pair(swap_out_node, swap_in_node); } |