diff options
author | 2017-03-28 10:56:13 -0800 | |
---|---|---|
committer | 2017-03-28 12:05:49 -0700 | |
commit | 0a6bbfaf2f3c4c9184cae9c239b99b7b855638a4 (patch) | |
tree | e9b554916fab9476c36b4d9bae1f292d8b888f8e /tensorflow/core/kernels/remote_fused_graph_execute_utils.cc | |
parent | b02d858e067afd574b120f7194a68ce3df90774b (diff) |
Put shape and type information into NodeDef to simplify GraphTransferer
Change: 151478672
Diffstat (limited to 'tensorflow/core/kernels/remote_fused_graph_execute_utils.cc')
-rw-r--r-- | tensorflow/core/kernels/remote_fused_graph_execute_utils.cc | 63 |
1 files changed, 49 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc index eacd18a793..ee470ed465 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc @@ -135,20 +135,29 @@ RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() { std::vector<Tensor> output_tensors; output_tensors.reserve(graph_def.node_size()); std::vector<string> output_node_names; - for (const NodeDef& node : graph_def.node()) { - if (!IsInputNode(input_node_info_list, node.name())) { - // CAVEAT: We only support one output. Use shape Inference Version - // if there are two or more outputs in a node. - output_node_names.emplace_back(strings::StrCat(node.name(), ":", 0)); + + Graph graph(OpRegistry::Global()); + Status status = ImportGraphDef({}, graph_def, &graph, nullptr); + if (!status.ok()) { + return status; + } + + for (const Node* node : graph.nodes()) { + if (IsInputNode(input_node_info_list, node->name())) { + continue; + } + for (int i = 0; i < node->num_outputs(); ++i) { + output_node_names.emplace_back(strings::StrCat(node->name(), ":", i)); } } - const Status status = - DryRunInference(graph_def, input_node_info_list, output_node_names, - initialize_by_zero, &output_tensors); + + status = DryRunInference(graph_def, input_node_info_list, output_node_names, + initialize_by_zero, &output_tensors); if (!status.ok()) { VLOG(1) << "Failed to dryrun " << status; return status; } + CHECK_EQ(output_node_names.size(), output_tensors.size()) << output_node_names.size() << ", " << output_tensors.size(); @@ -169,7 +178,8 @@ RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() { const Tensor& tensor = output_tensors.at(output_node_names.size() + i); EmplaceTensorShapeType(name, tensor, tensor_shape_map); } - CHECK(graph_def.node_size() == output_tensors.size()); + CHECK_EQ(output_node_names.size() + input_node_info_list.size(), + output_tensors.size()); return status; } @@ -248,6 +258,26 @@ RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap( return Status::OK(); } +/* static */ Status RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( + const NodeDef& node_def, std::vector<DataType>* data_types, + std::vector<TensorShape>* shapes) { + Status status; + if (data_types != nullptr) { + status = GetNodeAttr(node_def, ATTR_OUTPUT_DATA_TYPES, data_types); + } + if (!status.ok()) { + return status; + } + if (shapes != nullptr) { + status = GetNodeAttr(node_def, ATTR_OUTPUT_SHAPES, shapes); + if (status.ok() && data_types != nullptr) { + CHECK_EQ(data_types->size(), shapes->size()); + } + } + + return status; +} + /* static */ Status RemoteFusedGraphExecuteUtils::PropagateShapeInference( const GraphDef& graph_def, const std::vector<std::pair<string, Tensor>>& input_node_info_list, @@ -269,8 +299,13 @@ RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap( shape_inference::ShapeHandle handle; status = context->MakeShapeFromTensorShape( input_node_info.second.shape(), &handle); - // TODO(b/32704451): Don't just ignore this status! - shape_refiner->SetShape(node, 0, handle).IgnoreError(); + if (!status.ok()) { + break; + } + status = shape_refiner->SetShape(node, 0, handle); + if (!status.ok()) { + break; + } is_input_node = true; } if (!status.ok()) { @@ -280,9 +315,9 @@ RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap( // If not an input node call AddNode() that recomputes the shape. if (!is_input_node && status.ok()) { status = shape_refiner->AddNode(node); - if (!status.ok()) { - VLOG(1) << "Shape inference failed for node: " << node->name(); - } + } + if (!status.ok()) { + VLOG(1) << "Shape inference failed for node: " << node->name(); } }; |