aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts
diff options
context:
space:
mode:
authorGravatar Dan Smilkov <dsmilkov@gmail.com>2016-05-17 06:52:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-17 08:02:30 -0700
commit10cf1ab777373d2b6147cbd6443304ce53c9dc87 (patch)
tree5b32a66981855318e11968acab48474c43466e56 /tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts
parentb95b3d9bccca8156ae6c3f7095a0b80a82163f18 (diff)
Show the correct edge shape. Previously, the shape of the first output tensor of the source node was shown.
Now we show the shape of the exact output tensor that underlies the edge. Also show the shapes of each output tensor in the info card. Change: 122523849
Diffstat (limited to 'tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts')
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts22
1 files changed, 11 insertions, 11 deletions
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts
index cb2972e890..f1d4b6b24b 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts
@@ -123,14 +123,17 @@ export function buildGroup(sceneGroup,
return edgeGroups;
};
-export function getShapeLabelFromNode(node: OpNode,
- renderInfo: render.RenderGraphInfo) {
+/**
+ * Returns the label for the given base edge.
+ * The label is the shape of the underlying tensor.
+ */
+export function getLabelForBaseEdge(
+ baseEdge: BaseEdge, renderInfo: render.RenderGraphInfo): string {
+ let node = <OpNode>renderInfo.getNodeByName(baseEdge.v);
if (node.outputShapes == null || node.outputShapes.length === 0) {
return null;
}
- // TODO(smilkov): Figure out exactly which output tensor this
- // edge is from.
- let shape = node.outputShapes[0];
+ let shape = node.outputShapes[baseEdge.outputTensorIndex];
if (shape == null) {
return null;
}
@@ -149,12 +152,9 @@ export function getShapeLabelFromNode(node: OpNode,
export function getLabelForEdge(metaedge: Metaedge,
renderInfo: render.RenderGraphInfo): string {
let isMultiEdge = metaedge.baseEdgeList.length > 1;
- if (isMultiEdge) {
- return metaedge.baseEdgeList.length + ' tensors';
- } else {
- let node = <OpNode> renderInfo.getNodeByName(metaedge.baseEdgeList[0].v);
- return getShapeLabelFromNode(node, renderInfo);
- }
+ return isMultiEdge ?
+ metaedge.baseEdgeList.length + ' tensors' :
+ getLabelForBaseEdge(metaedge.baseEdgeList[0], renderInfo);
}
/**