aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-28 10:56:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-28 12:05:49 -0700
commit0a6bbfaf2f3c4c9184cae9c239b99b7b855638a4 (patch)
treee9b554916fab9476c36b4d9bae1f292d8b888f8e /tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
parentb02d858e067afd574b120f7194a68ce3df90774b (diff)
Put shape and type information into NodeDef to simplify GraphTransferer
Change: 151478672
Diffstat (limited to 'tensorflow/core/kernels/remote_fused_graph_execute_utils.cc')
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_utils.cc63
1 files changed, 49 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
index eacd18a793..ee470ed465 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
@@ -135,20 +135,29 @@ RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() {
std::vector<Tensor> output_tensors;
output_tensors.reserve(graph_def.node_size());
std::vector<string> output_node_names;
- for (const NodeDef& node : graph_def.node()) {
- if (!IsInputNode(input_node_info_list, node.name())) {
- // CAVEAT: We only support one output. Use shape Inference Version
- // if there are two or more outputs in a node.
- output_node_names.emplace_back(strings::StrCat(node.name(), ":", 0));
+
+ Graph graph(OpRegistry::Global());
+ Status status = ImportGraphDef({}, graph_def, &graph, nullptr);
+ if (!status.ok()) {
+ return status;
+ }
+
+ for (const Node* node : graph.nodes()) {
+ if (IsInputNode(input_node_info_list, node->name())) {
+ continue;
+ }
+ for (int i = 0; i < node->num_outputs(); ++i) {
+ output_node_names.emplace_back(strings::StrCat(node->name(), ":", i));
}
}
- const Status status =
- DryRunInference(graph_def, input_node_info_list, output_node_names,
- initialize_by_zero, &output_tensors);
+
+ status = DryRunInference(graph_def, input_node_info_list, output_node_names,
+ initialize_by_zero, &output_tensors);
if (!status.ok()) {
VLOG(1) << "Failed to dryrun " << status;
return status;
}
+
CHECK_EQ(output_node_names.size(), output_tensors.size())
<< output_node_names.size() << ", " << output_tensors.size();
@@ -169,7 +178,8 @@ RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() {
const Tensor& tensor = output_tensors.at(output_node_names.size() + i);
EmplaceTensorShapeType(name, tensor, tensor_shape_map);
}
- CHECK(graph_def.node_size() == output_tensors.size());
+ CHECK_EQ(output_node_names.size() + input_node_info_list.size(),
+ output_tensors.size());
return status;
}
@@ -248,6 +258,26 @@ RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap(
return Status::OK();
}
+/* static */ Status RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
+ const NodeDef& node_def, std::vector<DataType>* data_types,
+ std::vector<TensorShape>* shapes) {
+ Status status;
+ if (data_types != nullptr) {
+ status = GetNodeAttr(node_def, ATTR_OUTPUT_DATA_TYPES, data_types);
+ }
+ if (!status.ok()) {
+ return status;
+ }
+ if (shapes != nullptr) {
+ status = GetNodeAttr(node_def, ATTR_OUTPUT_SHAPES, shapes);
+ if (status.ok() && data_types != nullptr) {
+ CHECK_EQ(data_types->size(), shapes->size());
+ }
+ }
+
+ return status;
+}
+
/* static */ Status RemoteFusedGraphExecuteUtils::PropagateShapeInference(
const GraphDef& graph_def,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
@@ -269,8 +299,13 @@ RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap(
shape_inference::ShapeHandle handle;
status = context->MakeShapeFromTensorShape(
input_node_info.second.shape(), &handle);
- // TODO(b/32704451): Don't just ignore this status!
- shape_refiner->SetShape(node, 0, handle).IgnoreError();
+ if (!status.ok()) {
+ break;
+ }
+ status = shape_refiner->SetShape(node, 0, handle);
+ if (!status.ok()) {
+ break;
+ }
is_input_node = true;
}
if (!status.ok()) {
@@ -280,9 +315,9 @@ RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap(
// If not an input node call AddNode() that recomputes the shape.
if (!is_input_node && status.ok()) {
status = shape_refiner->AddNode(node);
- if (!status.ok()) {
- VLOG(1) << "Shape inference failed for node: " << node->name();
- }
+ }
+ if (!status.ok()) {
+ VLOG(1) << "Shape inference failed for node: " << node->name();
}
};