diff options
author | Dan Smilkov <dsmilkov@gmail.com> | 2016-05-17 06:52:52 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-05-17 08:02:30 -0700 |
commit | 10cf1ab777373d2b6147cbd6443304ce53c9dc87 (patch) | |
tree | 5b32a66981855318e11968acab48474c43466e56 /tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts | |
parent | b95b3d9bccca8156ae6c3f7095a0b80a82163f18 (diff) |
Show the correct edge shape. Previously, the shape of the first output tensor of the source node was shown.
Now we show the shape of the exact output tensor that underlies the edge.
Also show the shapes of each output tensor in the info card.
Change: 122523849
Diffstat (limited to 'tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts')
-rw-r--r-- | tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts | 22 |
1 files changed, 11 insertions, 11 deletions
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts index cb2972e890..f1d4b6b24b 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts @@ -123,14 +123,17 @@ export function buildGroup(sceneGroup, return edgeGroups; }; -export function getShapeLabelFromNode(node: OpNode, - renderInfo: render.RenderGraphInfo) { +/** + * Returns the label for the given base edge. + * The label is the shape of the underlying tensor. + */ +export function getLabelForBaseEdge( + baseEdge: BaseEdge, renderInfo: render.RenderGraphInfo): string { + let node = <OpNode>renderInfo.getNodeByName(baseEdge.v); if (node.outputShapes == null || node.outputShapes.length === 0) { return null; } - // TODO(smilkov): Figure out exactly which output tensor this - // edge is from. - let shape = node.outputShapes[0]; + let shape = node.outputShapes[baseEdge.outputTensorIndex]; if (shape == null) { return null; } @@ -149,12 +152,9 @@ export function getShapeLabelFromNode(node: OpNode, export function getLabelForEdge(metaedge: Metaedge, renderInfo: render.RenderGraphInfo): string { let isMultiEdge = metaedge.baseEdgeList.length > 1; - if (isMultiEdge) { - return metaedge.baseEdgeList.length + ' tensors'; - } else { - let node = <OpNode> renderInfo.getNodeByName(metaedge.baseEdgeList[0].v); - return getShapeLabelFromNode(node, renderInfo); - } + return isMultiEdge ? + metaedge.baseEdgeList.length + ' tensors' : + getLabelForBaseEdge(metaedge.baseEdgeList[0], renderInfo); } /** |