From 815fa1b32ded0d351689177dca65485afc0f7dcf Mon Sep 17 00:00:00 2001 From: Charles Nicholson Date: Thu, 17 Nov 2016 13:06:08 -0800 Subject: Start moving scatter plot methods out of vz-projector and into the scatter plot adapter. Add DataSet reference to Projection class. The projector adapter now listens to the the distance metric changed event, as well as creates + owns scatter plot. Change: 139496759 --- .../tensorboard/components/vz_projector/data.ts | 3 +- .../vz_projector/projectorEventContext.ts | 11 +- .../vz_projector/projectorScatterPlotAdapter.ts | 238 +++++++++++++++++---- .../components/vz_projector/scatterPlot.ts | 14 +- .../vz_projector/scatterPlotVisualizerSprites.ts | 5 +- .../vz_projector/vz-projector-projections-panel.ts | 12 +- .../components/vz_projector/vz-projector.ts | 174 +++++---------- 7 files changed, 274 insertions(+), 183 deletions(-) diff --git a/tensorflow/tensorboard/components/vz_projector/data.ts b/tensorflow/tensorboard/components/vz_projector/data.ts index 7d5e8c2cf2..4adbb56b80 100644 --- a/tensorflow/tensorboard/components/vz_projector/data.ts +++ b/tensorflow/tensorboard/components/vz_projector/data.ts @@ -415,7 +415,8 @@ export type ProjectionType = 'tsne' | 'pca' | 'custom'; export class Projection { constructor( public projectionType: ProjectionType, - public pointAccessors: PointAccessors3D, public dimensionality: number) {} + public pointAccessors: PointAccessors3D, public dimensionality: number, + public dataSet: DataSet) {} } export interface ColorOption { diff --git a/tensorflow/tensorboard/components/vz_projector/projectorEventContext.ts b/tensorflow/tensorboard/components/vz_projector/projectorEventContext.ts index 1b546e0ba5..36f5c4c584 100644 --- a/tensorflow/tensorboard/components/vz_projector/projectorEventContext.ts +++ b/tensorflow/tensorboard/components/vz_projector/projectorEventContext.ts @@ -13,15 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import {DataSet, DistanceFunction} from './data'; +import {DistanceFunction, Projection} from './data'; import {NearestEntry} from './knn'; export type HoverListener = (index: number) => void; export type SelectionChangedListener = (selectedPointIndices: number[], neighborsOfFirstPoint: NearestEntry[]) => void; -export type ProjectionChangedListener = (dataSet: DataSet) => void; - +export type ProjectionChangedListener = (projection: Projection) => void; +export type DistanceMetricChangedListener = + (distanceMetric: DistanceFunction) => void; export interface ProjectorEventContext { /** Register a callback to be invoked when the mouse hovers over a point. */ registerHoverListener(listener: HoverListener); @@ -37,6 +38,8 @@ export interface ProjectorEventContext { /** Registers a callback to be invoked when the projection changes. */ registerProjectionChangedListener(listener: ProjectionChangedListener); /** Notify listeners that a reprojection occurred. */ - notifyProjectionChanged(dataSet: DataSet); + notifyProjectionChanged(projection: Projection); + registerDistanceMetricChangedListener(listener: + DistanceMetricChangedListener); notifyDistanceMetricChanged(distMetric: DistanceFunction); } diff --git a/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts b/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts index 7855f5c2b1..ecf34cb977 100644 --- a/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts +++ b/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts @@ -13,9 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import {DataSet, DistanceFunction, PointAccessors3D} from './data'; +import {DataSet, DistanceFunction, PointAccessors3D, Projection, State} from './data'; import {NearestEntry} from './knn'; +import {ProjectorEventContext} from './projectorEventContext'; import {LabelRenderParams} from './renderContext'; +import {ScatterPlot} from './scatterPlot'; +import {ScatterPlotVisualizer3DLabels} from './scatterPlotVisualizer3DLabels'; +import {ScatterPlotVisualizerCanvasLabels} from './scatterPlotVisualizerCanvasLabels'; +import {ScatterPlotVisualizerSprites} from './scatterPlotVisualizerSprites'; +import {ScatterPlotVisualizerTraces} from './scatterPlotVisualizerTraces'; import * as vector from './vector'; const LABEL_FONT_SIZE = 10; @@ -69,6 +75,119 @@ const NN_COLOR_SCALE = * to use the ScatterPlot to render the current projected data set. */ export class ProjectorScatterPlotAdapter { + public scatterPlot: ScatterPlot; + private scatterPlotContainer: d3.Selection; + private projection: Projection; + private hoverPointIndex: number; + private selectedPointIndices: number[]; + private neighborsOfFirstSelectedPoint: NearestEntry[]; + private renderLabelsIn3D: boolean = false; + private legendPointColorer: (index: number) => string; + private distanceMetric: DistanceFunction; + + constructor( + scatterPlotContainer: d3.Selection, + projectorEventContext: ProjectorEventContext) { + this.scatterPlot = + new ScatterPlot(scatterPlotContainer, projectorEventContext); + this.scatterPlotContainer = scatterPlotContainer; + projectorEventContext.registerProjectionChangedListener(projection => { + this.projection = projection; + this.updateScatterPlotWithNewProjection(projection); + }); + projectorEventContext.registerSelectionChangedListener( + (selectedPointIndices, neighbors) => { + this.selectedPointIndices = selectedPointIndices; + this.neighborsOfFirstSelectedPoint = neighbors; + this.updateScatterPlotAttributes(); + this.scatterPlot.render(); + }); + projectorEventContext.registerHoverListener(hoverPointIndex => { + this.hoverPointIndex = hoverPointIndex; + this.updateScatterPlotAttributes(); + this.scatterPlot.render(); + }); + projectorEventContext.registerDistanceMetricChangedListener( + distanceMetric => { + this.distanceMetric = distanceMetric; + this.updateScatterPlotAttributes(); + this.scatterPlot.render(); + }); + this.createVisualizers(false); + } + + notifyProjectionPositionsUpdated() { + this.updateScatterPlotPositions(); + this.scatterPlot.render(); + } + + set3DLabelMode(renderLabelsIn3D: boolean) { + this.renderLabelsIn3D = renderLabelsIn3D; + this.createVisualizers(renderLabelsIn3D); + this.updateScatterPlotAttributes(); + this.scatterPlot.render(); + } + + setLegendPointColorer(legendPointColorer: (index: number) => string) { + this.legendPointColorer = legendPointColorer; + } + + resize() { + this.scatterPlot.resize(); + } + + populateBookmarkFromUI(state: State) { + state.cameraDef = this.scatterPlot.getCameraDef(); + } + + restoreUIFromBookmark(state: State) { + this.scatterPlot.setCameraParametersForNextCameraCreation( + state.cameraDef, false); + } + + updateScatterPlotPositions() { + const ds = (this.projection == null) ? null : this.projection.dataSet; + const accessors = + (this.projection == null) ? null : this.projection.pointAccessors; + const newPositions = this.generatePointPositionArray(ds, accessors); + this.scatterPlot.setPointPositions(ds, newPositions); + } + + updateScatterPlotAttributes() { + if (this.projection == null) { + return; + } + const dataSet = this.projection.dataSet; + const selectedSet = this.selectedPointIndices; + const hoverIndex = this.hoverPointIndex; + const neighbors = this.neighborsOfFirstSelectedPoint; + const pointColorer = this.legendPointColorer; + + const pointColors = this.generatePointColorArray( + dataSet, pointColorer, this.distanceMetric, selectedSet, neighbors, + hoverIndex, this.renderLabelsIn3D, this.getSpriteImageMode()); + const pointScaleFactors = this.generatePointScaleFactorArray( + dataSet, selectedSet, neighbors, hoverIndex); + const labels = this.generateVisibleLabelRenderParams( + dataSet, selectedSet, neighbors, hoverIndex); + const traceColors = this.generateLineSegmentColorMap(dataSet, pointColorer); + const traceOpacities = + this.generateLineSegmentOpacityArray(dataSet, selectedSet); + const traceWidths = + this.generateLineSegmentWidthArray(dataSet, selectedSet); + + this.scatterPlot.setPointColors(pointColors); + this.scatterPlot.setPointScaleFactors(pointScaleFactors); + this.scatterPlot.setLabels(labels); + this.scatterPlot.setTraceColors(traceColors); + this.scatterPlot.setTraceOpacities(traceOpacities); + this.scatterPlot.setTraceWidths(traceWidths); + } + + render() { + this.scatterPlot.render(); + } + generatePointPositionArray(ds: DataSet, pointAccessors: PointAccessors3D): Float32Array { if (ds == null) { @@ -116,19 +235,6 @@ export class ProjectorScatterPlotAdapter { return positions; } - private packRgbIntoUint8Array( - rgbArray: Uint8Array, labelIndex: number, r: number, g: number, - b: number) { - rgbArray[labelIndex * 3] = r; - rgbArray[labelIndex * 3 + 1] = g; - rgbArray[labelIndex * 3 + 2] = b; - } - - private styleRgbFromHexColor(hex: number): [number, number, number] { - const c = new THREE.Color(hex); - return [(c.r * 255) | 0, (c.g * 255) | 0, (c.b * 255) | 0]; - } - generateVisibleLabelRenderParams( ds: DataSet, selectedPointIndices: number[], neighborsOfFirstPoint: NearestEntry[], @@ -155,11 +261,11 @@ export class ProjectorScatterPlotAdapter { visibleLabels[dst] = hoverPointIndex; scale[dst] = LABEL_SCALE_LARGE; opacityFlags[dst] = 0; - const fillRgb = this.styleRgbFromHexColor(LABEL_FILL_COLOR_HOVER); - this.packRgbIntoUint8Array( + const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_HOVER); + packRgbIntoUint8Array( fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]); - const strokeRgb = this.styleRgbFromHexColor(LABEL_STROKE_COLOR_HOVER); - this.packRgbIntoUint8Array( + const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_HOVER); + packRgbIntoUint8Array( strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[1]); ++dst; } @@ -167,15 +273,15 @@ export class ProjectorScatterPlotAdapter { // Selected points { const n = selectedPointIndices.length; - const fillRgb = this.styleRgbFromHexColor(LABEL_FILL_COLOR_SELECTED); - const strokeRgb = this.styleRgbFromHexColor(LABEL_STROKE_COLOR_SELECTED); + const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_SELECTED); + const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_SELECTED); for (let i = 0; i < n; ++i) { visibleLabels[dst] = selectedPointIndices[i]; scale[dst] = LABEL_SCALE_LARGE; opacityFlags[dst] = (n === 1) ? 0 : 1; - this.packRgbIntoUint8Array( + packRgbIntoUint8Array( fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]); - this.packRgbIntoUint8Array( + packRgbIntoUint8Array( strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[2]); ++dst; } @@ -184,13 +290,13 @@ export class ProjectorScatterPlotAdapter { // Neighbors { const n = neighborsOfFirstPoint.length; - const fillRgb = this.styleRgbFromHexColor(LABEL_FILL_COLOR_NEIGHBOR); - const strokeRgb = this.styleRgbFromHexColor(LABEL_STROKE_COLOR_NEIGHBOR); + const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_NEIGHBOR); + const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_NEIGHBOR); for (let i = 0; i < n; ++i) { visibleLabels[dst] = neighborsOfFirstPoint[i].index; - this.packRgbIntoUint8Array( + packRgbIntoUint8Array( fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]); - this.packRgbIntoUint8Array( + packRgbIntoUint8Array( strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[2]); ++dst; } @@ -248,7 +354,6 @@ export class ProjectorScatterPlotAdapter { for (let i = 0; i < ds.traces.length; i++) { let dataTrace = ds.traces[i]; - let colors = new Float32Array(2 * (dataTrace.pointIndices.length - 1) * 3); let colorIndex = 0; @@ -262,21 +367,19 @@ export class ProjectorScatterPlotAdapter { colors[colorIndex++] = c1.r; colors[colorIndex++] = c1.g; colors[colorIndex++] = c1.b; - colors[colorIndex++] = c2.r; colors[colorIndex++] = c2.g; colors[colorIndex++] = c2.b; } } else { for (let j = 0; j < dataTrace.pointIndices.length - 1; j++) { - const c1 = this.getDefaultPointInTraceColor( - j, dataTrace.pointIndices.length); - const c2 = this.getDefaultPointInTraceColor( - j + 1, dataTrace.pointIndices.length); + const c1 = + getDefaultPointInTraceColor(j, dataTrace.pointIndices.length); + const c2 = + getDefaultPointInTraceColor(j + 1, dataTrace.pointIndices.length); colors[colorIndex++] = c1.r; colors[colorIndex++] = c1.g; colors[colorIndex++] = c1.b; - colors[colorIndex++] = c2.r; colors[colorIndex++] = c2.g; colors[colorIndex++] = c2.b; @@ -319,15 +422,6 @@ export class ProjectorScatterPlotAdapter { return widths; } - private getDefaultPointInTraceColor(index: number, totalPoints: number): - THREE.Color { - let hue = TRACE_START_HUE + - (TRACE_END_HUE - TRACE_START_HUE) * index / totalPoints; - - let rgb = d3.hsl(hue, TRACE_SATURATION, TRACE_LIGHTNESS).rgb(); - return new THREE.Color(rgb.r / 255, rgb.g / 255, rgb.b / 255); - } - generatePointColorArray( ds: DataSet, legendPointColorer: (index: number) => string, distFunc: DistanceFunction, selectedPointIndices: number[], @@ -419,6 +513,66 @@ export class ProjectorScatterPlotAdapter { return colors; } + + private updateScatterPlotWithNewProjection(projection: Projection) { + if (projection != null) { + this.scatterPlot.setDimensions(projection.dimensionality); + if (projection.dataSet.projectionCanBeRendered( + projection.projectionType)) { + this.updateScatterPlotAttributes(); + this.notifyProjectionPositionsUpdated(); + } + this.scatterPlot.setCameraParametersForNextCameraCreation(null, false); + } else { + this.updateScatterPlotAttributes(); + this.notifyProjectionPositionsUpdated(); + } + } + + private createVisualizers(inLabels3DMode: boolean) { + const scatterPlot = this.scatterPlot; + scatterPlot.removeAllVisualizers(); + if (inLabels3DMode) { + scatterPlot.addVisualizer(new ScatterPlotVisualizer3DLabels()); + } else { + scatterPlot.addVisualizer(new ScatterPlotVisualizerSprites()); + scatterPlot.addVisualizer( + new ScatterPlotVisualizerCanvasLabels(this.scatterPlotContainer)); + } + scatterPlot.addVisualizer(new ScatterPlotVisualizerTraces()); + } + + private getSpriteImageMode(): boolean { + if (this.projection == null) { + return false; + } + const ds = this.projection.dataSet; + if ((ds == null) || (ds.spriteAndMetadataInfo == null)) { + return false; + } + return ds.spriteAndMetadataInfo.spriteImage != null; + } +} + +function packRgbIntoUint8Array( + rgbArray: Uint8Array, labelIndex: number, r: number, g: number, b: number) { + rgbArray[labelIndex * 3] = r; + rgbArray[labelIndex * 3 + 1] = g; + rgbArray[labelIndex * 3 + 2] = b; +} + +function styleRgbFromHexColor(hex: number): [number, number, number] { + const c = new THREE.Color(hex); + return [(c.r * 255) | 0, (c.g * 255) | 0, (c.b * 255) | 0]; +} + +function getDefaultPointInTraceColor( + index: number, totalPoints: number): THREE.Color { + let hue = + TRACE_START_HUE + (TRACE_END_HUE - TRACE_START_HUE) * index / totalPoints; + + let rgb = d3.hsl(hue, TRACE_SATURATION, TRACE_LIGHTNESS).rgb(); + return new THREE.Color(rgb.r / 255, rgb.g / 255, rgb.b / 255); } /** diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlot.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlot.ts index 5be0d93c61..a381a997cc 100644 --- a/tensorflow/tensorboard/components/vz_projector/scatterPlot.ts +++ b/tensorflow/tensorboard/components/vz_projector/scatterPlot.ts @@ -117,14 +117,12 @@ export class ScatterPlot { private rectangleSelector: ScatterPlotRectangleSelector; constructor( - container: d3.Selection, labelAccessor: (index: number) => string, + container: d3.Selection, projectorEventContext: ProjectorEventContext) { this.containerNode = container.node() as HTMLElement; this.projectorEventContext = projectorEventContext; this.getLayoutValues(); - this.labelAccessor = labelAccessor; - this.scene = new THREE.Scene(); this.renderer = new THREE.WebGLRenderer({alpha: true, premultipliedAlpha: false}); @@ -457,7 +455,7 @@ export class ScatterPlot { return this.dimensionality === 3; } - private remove3dAxis(): THREE.Object3D { + private remove3dAxisFromScene(): THREE.Object3D { const axes = this.scene.getObjectByName('axes'); if (axes != null) { this.scene.remove(axes); @@ -481,7 +479,7 @@ export class ScatterPlot { const def = this.cameraDef || this.makeDefaultCameraDef(dimensionality); this.recreateCamera(def); - this.remove3dAxis(); + this.remove3dAxisFromScene(); if (dimensionality === 3) { this.add3dAxis(); } @@ -624,9 +622,11 @@ export class ScatterPlot { }); { - const axes = this.remove3dAxis(); + const axes = this.remove3dAxisFromScene(); this.renderer.render(this.scene, this.camera, this.pickingTexture); - this.scene.add(axes); + if (axes != null) { + this.scene.add(axes); + } } // Render second pass to color buffer, to be displayed on the canvas. diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts index 766659308d..5af3eea92b 100644 --- a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts +++ b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts @@ -305,14 +305,15 @@ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer { onPointPositionsChanged(newPositions: Float32Array, dataSet: DataSet) { if (this.points != null) { const notEnoughSpace = (this.pickingColors.length < newPositions.length); - const newImage = + const newImage = (dataSet != null) && (this.image !== dataSet.spriteAndMetadataInfo.spriteImage); if (notEnoughSpace || newImage) { this.dispose(); } } - this.image = dataSet.spriteAndMetadataInfo.spriteImage; + this.image = + (dataSet != null) ? dataSet.spriteAndMetadataInfo.spriteImage : null; this.worldSpacePointPositions = newPositions; if (this.points == null) { diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts index 1056bfb3ce..9c172e4707 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts @@ -372,13 +372,14 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { const accessors = dataSet.getPointAccessors('tsne', [0, 1, this.tSNEis3d ? 2 : null]); const dimensionality = this.tSNEis3d ? 3 : 2; - const projection = new Projection('tsne', accessors, dimensionality); + const projection = + new Projection('tsne', accessors, dimensionality, dataSet); this.projector.setProjection(projection); if (!this.dataSet.hasTSNERun) { this.runTSNE(); } else { - this.projector.notifyProjectionsUpdated(); + this.projector.notifyProjectionPositionsUpdated(); } } @@ -390,7 +391,7 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { (iteration: number) => { if (iteration != null) { this.iterationLabel.text(iteration); - this.projector.notifyProjectionsUpdated(); + this.projector.notifyProjectionPositionsUpdated(); } else { this.runTsneButton.attr('disabled', null); this.stopTsneButton.attr('disabled', true); @@ -426,7 +427,8 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { 'pca', [this.pcaX, this.pcaY, this.pcaZ]); const dimensionality = this.pcaIs3d ? 3 : 2; - const projection = new Projection('pca', accessors, dimensionality); + const projection = + new Projection('pca', accessors, dimensionality, this.dataSet); this.projector.setProjection(projection); let numComponents = Math.min(NUM_PCA_COMPONENTS, this.dataSet.dim[1]); this.updateTotalVarianceMessage(); @@ -454,7 +456,7 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { this.dataSet.projectLinear(yDir, 'linear-y'); const accessors = this.dataSet.getPointAccessors('custom', ['x', 'y']); - const projection = new Projection('custom', accessors, 2); + const projection = new Projection('custom', accessors, 2, this.dataSet); this.projector.setProjection(projection); } diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector.ts index 8900819048..f22bc3ad64 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector.ts +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector.ts @@ -21,13 +21,9 @@ import {ProtoDataProvider} from './data-provider-proto'; import {ServerDataProvider} from './data-provider-server'; import * as knn from './knn'; import * as logging from './logging'; -import {HoverListener, ProjectionChangedListener, ProjectorEventContext, SelectionChangedListener} from './projectorEventContext'; +import {DistanceMetricChangedListener, HoverListener, ProjectionChangedListener, ProjectorEventContext, SelectionChangedListener} from './projectorEventContext'; import {ProjectorScatterPlotAdapter} from './projectorScatterPlotAdapter'; -import {Mode, ScatterPlot} from './scatterPlot'; -import {ScatterPlotVisualizer3DLabels} from './scatterPlotVisualizer3DLabels'; -import {ScatterPlotVisualizerCanvasLabels} from './scatterPlotVisualizerCanvasLabels'; -import {ScatterPlotVisualizerSprites} from './scatterPlotVisualizerSprites'; -import {ScatterPlotVisualizerTraces} from './scatterPlotVisualizerTraces'; +import {Mode} from './scatterPlot'; import * as util from './util'; import {BookmarkPanel} from './vz-projector-bookmark-panel'; import {DataPanel} from './vz-projector-data-panel'; @@ -69,11 +65,11 @@ export class Projector extends ProjectorPolymer implements private selectionChangedListeners: SelectionChangedListener[]; private hoverListeners: HoverListener[]; private projectionChangedListeners: ProjectionChangedListener[]; + private distanceMetricChangedListeners: DistanceMetricChangedListener[]; private originalDataSet: DataSet; private dom: d3.Selection; private projectorScatterPlotAdapter: ProjectorScatterPlotAdapter; - private scatterPlot: ScatterPlot; private dim: number; private dataSetFilterIndices: number[]; @@ -108,6 +104,7 @@ export class Projector extends ProjectorPolymer implements this.selectionChangedListeners = []; this.hoverListeners = []; this.projectionChangedListeners = []; + this.distanceMetricChangedListeners = []; this.selectedPointIndices = []; this.neighborsOfFirstPoint = []; this.dom = d3.select(this); @@ -133,14 +130,17 @@ export class Projector extends ProjectorPolymer implements .metadata[this.selectedLabelOption] as string; }; this.metadataCard.setLabelOption(this.selectedLabelOption); - this.scatterPlot.setLabelAccessor(labelAccessor); - this.scatterPlot.render(); + this.projectorScatterPlotAdapter.scatterPlot.setLabelAccessor( + labelAccessor); + this.projectorScatterPlotAdapter.render(); } setSelectedColorOption(colorOption: ColorOption) { this.selectedColorOption = colorOption; - this.updateScatterPlotAttributes(); - this.scatterPlot.render(); + this.projectorScatterPlotAdapter.setLegendPointColorer( + this.getLegendPointColorer(colorOption)); + this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); + this.projectorScatterPlotAdapter.render(); } setNormalizeData(normalizeData: boolean) { @@ -153,8 +153,7 @@ export class Projector extends ProjectorPolymer implements metadataFile?: string) { this.dataSetFilterIndices = null; this.originalDataSet = ds; - if (this.scatterPlot == null || ds == null) { - // We are not ready yet. + if (this.projectorScatterPlotAdapter == null || ds == null) { return; } this.normalizeData = this.originalDataSet.dim[1] >= THRESHOLD_DIM_NORMALIZE; @@ -176,8 +175,8 @@ export class Projector extends ProjectorPolymer implements // height can grow indefinitely. let container = this.dom.select('#container'); container.style('height', container.property('clientHeight') + 'px'); - this.scatterPlot.resize(); - this.scatterPlot.render(); + this.projectorScatterPlotAdapter.resize(); + this.projectorScatterPlotAdapter.render(); } setSelectedTensor(run: string, tensorInfo: EmbeddingInfo) { @@ -203,7 +202,7 @@ export class Projector extends ProjectorPolymer implements return this.dataSet.points[localIndex].index; }); this.setCurrentDataSet(this.originalDataSet.getSubset()); - this.updateScatterPlotPositions(); + this.projectorScatterPlotAdapter.updateScatterPlotPositions(); this.dataSetFilterIndices = []; this.adjustSelectionAndHover(originalPointIndices); } @@ -247,8 +246,16 @@ export class Projector extends ProjectorPolymer implements this.projectionChangedListeners.push(listener); } - notifyProjectionChanged(dataSet: DataSet) { - this.projectionChangedListeners.forEach(l => l(dataSet)); + notifyProjectionChanged(projection: Projection) { + this.projectionChangedListeners.forEach(l => l(projection)); + } + + registerDistanceMetricChangedListener(l: DistanceMetricChangedListener) { + this.distanceMetricChangedListeners.push(l); + } + + notifyDistanceMetricChanged(distMetric: DistanceFunction) { + this.distanceMetricChangedListeners.forEach(l => l(distMetric)); } _dataProtoChanged(dataProtoString: string) { @@ -324,11 +331,6 @@ export class Projector extends ProjectorPolymer implements return (label3DModeButton as any).active; } - private getSpriteImageMode(): boolean { - return this.dataSet && this.dataSet.spriteAndMetadataInfo && - this.dataSet.spriteAndMetadataInfo.spriteImage != null; - } - adjustSelectionAndHover(selectedPointIndices: number[], hoverIndex?: number) { this.notifySelectionChanged(selectedPointIndices); this.notifyHoverOverPoint(hoverIndex); @@ -338,8 +340,7 @@ export class Projector extends ProjectorPolymer implements private setMode(mode: Mode) { let selectModeButton = this.querySelector('#selectMode'); (selectModeButton as any).active = (mode === Mode.SELECT); - - this.scatterPlot.setMode(mode); + this.projectorScatterPlotAdapter.scatterPlot.setMode(mode); } private setCurrentDataSet(ds: DataSet) { @@ -360,14 +361,15 @@ export class Projector extends ProjectorPolymer implements this.projectionsPanel.dataSetUpdated( this.dataSet, this.originalDataSet, this.dim); - this.scatterPlot.setCameraParametersForNextCameraCreation(null, true); + this.projectorScatterPlotAdapter.scatterPlot + .setCameraParametersForNextCameraCreation(null, true); } private setupUIControls() { // View controls this.querySelector('#reset-zoom').addEventListener('click', () => { - this.scatterPlot.resetZoom(); - this.scatterPlot.startOrbitAnimation(); + this.projectorScatterPlotAdapter.scatterPlot.resetZoom(); + this.projectorScatterPlotAdapter.scatterPlot.startOrbitAnimation(); }); let selectModeButton = this.querySelector('#selectMode'); @@ -376,14 +378,13 @@ export class Projector extends ProjectorPolymer implements }); let nightModeButton = this.querySelector('#nightDayMode'); nightModeButton.addEventListener('click', () => { - this.scatterPlot.setDayNightMode((nightModeButton as any).active); + this.projectorScatterPlotAdapter.scatterPlot.setDayNightMode( + (nightModeButton as any).active); }); const labels3DModeButton = this.get3DLabelModeButton(); labels3DModeButton.addEventListener('click', () => { - this.createVisualizers(this.get3DLabelMode()); - this.updateScatterPlotAttributes(); - this.scatterPlot.render(); + this.projectorScatterPlotAdapter.set3DLabelMode(this.get3DLabelMode()); }); window.addEventListener('resize', () => { @@ -391,18 +392,19 @@ export class Projector extends ProjectorPolymer implements let parentHeight = (container.node().parentNode as HTMLElement).clientHeight; container.style('height', parentHeight + 'px'); - this.scatterPlot.resize(); + this.projectorScatterPlotAdapter.resize(); }); - this.projectorScatterPlotAdapter = new ProjectorScatterPlotAdapter(); - - this.scatterPlot = new ScatterPlot( - this.getScatterContainer(), - i => '' + this.dataSet.points[i].metadata[this.selectedLabelOption], - this as ProjectorEventContext); - this.createVisualizers(false); + { + const labelAccessor = i => + '' + this.dataSet.points[i].metadata[this.selectedLabelOption]; + this.projectorScatterPlotAdapter = new ProjectorScatterPlotAdapter( + this.getScatterContainer(), this as ProjectorEventContext); + this.projectorScatterPlotAdapter.scatterPlot.setLabelAccessor( + labelAccessor); + } - this.scatterPlot.onCameraMove( + this.projectorScatterPlotAdapter.scatterPlot.onCameraMove( (cameraPosition: THREE.Vector3, cameraTarget: THREE.Vector3) => this.bookmarkPanel.clearStateSelection()); @@ -425,75 +427,16 @@ export class Projector extends ProjectorPolymer implements hoverText = point.metadata[this.selectedLabelOption].toString(); } } - this.updateScatterPlotAttributes(); - this.scatterPlot.render(); if (this.selectedPointIndices.length === 0) { this.statusBar.style('display', hoverText ? null : 'none'); this.statusBar.text(hoverText); } } - private updateScatterPlotPositions() { - if (this.dataSet == null) { - return; - } - if (this.projection == null) { - return; - } - const newPositions = - this.projectorScatterPlotAdapter.generatePointPositionArray( - this.dataSet, this.projection.pointAccessors); - this.scatterPlot.setPointPositions(this.dataSet, newPositions); - } - - private updateScatterPlotAttributes() { - const dataSet = this.dataSet; - const selectedSet = this.selectedPointIndices; - const hoverIndex = this.hoverPointIndex; - const neighbors = this.neighborsOfFirstPoint; - const pointColorer = this.getLegendPointColorer(this.selectedColorOption); - const adapter = this.projectorScatterPlotAdapter; - - const pointColors = adapter.generatePointColorArray( - dataSet, pointColorer, this.inspectorPanel.distFunc, selectedSet, - neighbors, hoverIndex, this.get3DLabelMode(), - this.getSpriteImageMode()); - const pointScaleFactors = adapter.generatePointScaleFactorArray( - dataSet, selectedSet, neighbors, hoverIndex); - const labels = adapter.generateVisibleLabelRenderParams( - dataSet, selectedSet, neighbors, hoverIndex); - const traceColors = - adapter.generateLineSegmentColorMap(dataSet, pointColorer); - const traceOpacities = - adapter.generateLineSegmentOpacityArray(dataSet, selectedSet); - const traceWidths = - adapter.generateLineSegmentWidthArray(dataSet, selectedSet); - - this.scatterPlot.setPointColors(pointColors); - this.scatterPlot.setPointScaleFactors(pointScaleFactors); - this.scatterPlot.setLabels(labels); - this.scatterPlot.setTraceColors(traceColors); - this.scatterPlot.setTraceOpacities(traceOpacities); - this.scatterPlot.setTraceWidths(traceWidths); - } - private getScatterContainer(): d3.Selection { return this.dom.select('#scatter'); } - private createVisualizers(inLabels3DMode: boolean) { - const scatterPlot = this.scatterPlot; - scatterPlot.removeAllVisualizers(); - if (inLabels3DMode) { - scatterPlot.addVisualizer(new ScatterPlotVisualizer3DLabels()); - } else { - scatterPlot.addVisualizer(new ScatterPlotVisualizerSprites()); - scatterPlot.addVisualizer( - new ScatterPlotVisualizerCanvasLabels(this.getScatterContainer())); - } - scatterPlot.addVisualizer(new ScatterPlotVisualizerTraces()); - } - private onSelectionChanged( selectedPointIndices: number[], neighborsOfFirstPoint: knn.NearestEntry[]) { @@ -503,26 +446,18 @@ export class Projector extends ProjectorPolymer implements this.selectedPointIndices.length + neighborsOfFirstPoint.length; this.statusBar.text(`Selected ${totalNumPoints} points`) .style('display', totalNumPoints > 0 ? null : 'none'); - this.updateScatterPlotAttributes(); - this.scatterPlot.render(); } setProjection(projection: Projection) { this.projection = projection; - this.scatterPlot.setDimensions(projection.dimensionality); - this.analyticsLogger.logProjectionChanged(projection.projectionType); - if (this.dataSet.projectionCanBeRendered(projection.projectionType)) { - this.updateScatterPlotAttributes(); - this.notifyProjectionsUpdated(); + if (projection != null) { + this.analyticsLogger.logProjectionChanged(projection.projectionType); } - - this.scatterPlot.setCameraParametersForNextCameraCreation(null, false); - this.notifyProjectionChanged(this.dataSet); + this.notifyProjectionChanged(projection); } - notifyProjectionsUpdated() { - this.updateScatterPlotPositions(); - this.scatterPlot.render(); + notifyProjectionPositionsUpdated() { + this.projectorScatterPlotAdapter.notifyProjectionPositionsUpdated(); } /** @@ -547,7 +482,7 @@ export class Projector extends ProjectorPolymer implements state.tSNEIteration = this.dataSet.tSNEIteration; state.selectedPoints = this.selectedPointIndices; state.filteredPoints = this.dataSetFilterIndices; - state.cameraDef = this.scatterPlot.getCameraDef(); + this.projectorScatterPlotAdapter.populateBookmarkFromUI(state); state.selectedColorOptionName = this.dataPanel.selectedColorOptionName; state.selectedLabelOption = this.selectedLabelOption; this.projectionsPanel.populateBookmarkFromUI(state); @@ -556,6 +491,7 @@ export class Projector extends ProjectorPolymer implements /** Loads a State object into the world. */ loadState(state: State) { + this.setProjection(null); { this.projectionsPanel.disablePolymerChangesTriggerReprojection(); this.resetFilterDataset(); @@ -578,23 +514,17 @@ export class Projector extends ProjectorPolymer implements this.inspectorPanel.restoreUIFromBookmark(state); this.dataPanel.selectedColorOptionName = state.selectedColorOptionName; this.selectedLabelOption = state.selectedLabelOption; - this.scatterPlot.setCameraParametersForNextCameraCreation( - state.cameraDef, false); + this.projectorScatterPlotAdapter.restoreUIFromBookmark(state); { const dimensions = stateGetAccessorDimensions(state); const accessors = this.dataSet.getPointAccessors(state.selectedProjection, dimensions); const projection = new Projection( - state.selectedProjection, accessors, dimensions.length); + state.selectedProjection, accessors, dimensions.length, this.dataSet); this.setProjection(projection); } this.notifySelectionChanged(state.selectedPoints); } - - notifyDistanceMetricChanged(distMetric: DistanceFunction) { - this.updateScatterPlotAttributes(); - this.scatterPlot.render(); - } } document.registerElement(Projector.prototype.is, Projector); -- cgit v1.2.3