aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Charles Nicholson <nicholsonc@google.com>2016-11-17 13:06:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-17 13:24:20 -0800
commit815fa1b32ded0d351689177dca65485afc0f7dcf (patch)
treec7c519cf4a67fc538c4969f85c7578768a9d9a00
parent839ee165dc240655eb38d9adffc651705abc4155 (diff)
Start moving scatter plot methods out of vz-projector and into the scatter plot
adapter. Add DataSet reference to Projection class. The projector adapter now listens to the the distance metric changed event, as well as creates + owns scatter plot. Change: 139496759
-rw-r--r--tensorflow/tensorboard/components/vz_projector/data.ts3
-rw-r--r--tensorflow/tensorboard/components/vz_projector/projectorEventContext.ts11
-rw-r--r--tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts238
-rw-r--r--tensorflow/tensorboard/components/vz_projector/scatterPlot.ts14
-rw-r--r--tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts5
-rw-r--r--tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts12
-rw-r--r--tensorflow/tensorboard/components/vz_projector/vz-projector.ts174
7 files changed, 274 insertions, 183 deletions
diff --git a/tensorflow/tensorboard/components/vz_projector/data.ts b/tensorflow/tensorboard/components/vz_projector/data.ts
index 7d5e8c2cf2..4adbb56b80 100644
--- a/tensorflow/tensorboard/components/vz_projector/data.ts
+++ b/tensorflow/tensorboard/components/vz_projector/data.ts
@@ -415,7 +415,8 @@ export type ProjectionType = 'tsne' | 'pca' | 'custom';
export class Projection {
constructor(
public projectionType: ProjectionType,
- public pointAccessors: PointAccessors3D, public dimensionality: number) {}
+ public pointAccessors: PointAccessors3D, public dimensionality: number,
+ public dataSet: DataSet) {}
}
export interface ColorOption {
diff --git a/tensorflow/tensorboard/components/vz_projector/projectorEventContext.ts b/tensorflow/tensorboard/components/vz_projector/projectorEventContext.ts
index 1b546e0ba5..36f5c4c584 100644
--- a/tensorflow/tensorboard/components/vz_projector/projectorEventContext.ts
+++ b/tensorflow/tensorboard/components/vz_projector/projectorEventContext.ts
@@ -13,15 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-import {DataSet, DistanceFunction} from './data';
+import {DistanceFunction, Projection} from './data';
import {NearestEntry} from './knn';
export type HoverListener = (index: number) => void;
export type SelectionChangedListener =
(selectedPointIndices: number[], neighborsOfFirstPoint: NearestEntry[]) =>
void;
-export type ProjectionChangedListener = (dataSet: DataSet) => void;
-
+export type ProjectionChangedListener = (projection: Projection) => void;
+export type DistanceMetricChangedListener =
+ (distanceMetric: DistanceFunction) => void;
export interface ProjectorEventContext {
/** Register a callback to be invoked when the mouse hovers over a point. */
registerHoverListener(listener: HoverListener);
@@ -37,6 +38,8 @@ export interface ProjectorEventContext {
/** Registers a callback to be invoked when the projection changes. */
registerProjectionChangedListener(listener: ProjectionChangedListener);
/** Notify listeners that a reprojection occurred. */
- notifyProjectionChanged(dataSet: DataSet);
+ notifyProjectionChanged(projection: Projection);
+ registerDistanceMetricChangedListener(listener:
+ DistanceMetricChangedListener);
notifyDistanceMetricChanged(distMetric: DistanceFunction);
}
diff --git a/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts b/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts
index 7855f5c2b1..ecf34cb977 100644
--- a/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts
+++ b/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts
@@ -13,9 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-import {DataSet, DistanceFunction, PointAccessors3D} from './data';
+import {DataSet, DistanceFunction, PointAccessors3D, Projection, State} from './data';
import {NearestEntry} from './knn';
+import {ProjectorEventContext} from './projectorEventContext';
import {LabelRenderParams} from './renderContext';
+import {ScatterPlot} from './scatterPlot';
+import {ScatterPlotVisualizer3DLabels} from './scatterPlotVisualizer3DLabels';
+import {ScatterPlotVisualizerCanvasLabels} from './scatterPlotVisualizerCanvasLabels';
+import {ScatterPlotVisualizerSprites} from './scatterPlotVisualizerSprites';
+import {ScatterPlotVisualizerTraces} from './scatterPlotVisualizerTraces';
import * as vector from './vector';
const LABEL_FONT_SIZE = 10;
@@ -69,6 +75,119 @@ const NN_COLOR_SCALE =
* to use the ScatterPlot to render the current projected data set.
*/
export class ProjectorScatterPlotAdapter {
+ public scatterPlot: ScatterPlot;
+ private scatterPlotContainer: d3.Selection<any>;
+ private projection: Projection;
+ private hoverPointIndex: number;
+ private selectedPointIndices: number[];
+ private neighborsOfFirstSelectedPoint: NearestEntry[];
+ private renderLabelsIn3D: boolean = false;
+ private legendPointColorer: (index: number) => string;
+ private distanceMetric: DistanceFunction;
+
+ constructor(
+ scatterPlotContainer: d3.Selection<any>,
+ projectorEventContext: ProjectorEventContext) {
+ this.scatterPlot =
+ new ScatterPlot(scatterPlotContainer, projectorEventContext);
+ this.scatterPlotContainer = scatterPlotContainer;
+ projectorEventContext.registerProjectionChangedListener(projection => {
+ this.projection = projection;
+ this.updateScatterPlotWithNewProjection(projection);
+ });
+ projectorEventContext.registerSelectionChangedListener(
+ (selectedPointIndices, neighbors) => {
+ this.selectedPointIndices = selectedPointIndices;
+ this.neighborsOfFirstSelectedPoint = neighbors;
+ this.updateScatterPlotAttributes();
+ this.scatterPlot.render();
+ });
+ projectorEventContext.registerHoverListener(hoverPointIndex => {
+ this.hoverPointIndex = hoverPointIndex;
+ this.updateScatterPlotAttributes();
+ this.scatterPlot.render();
+ });
+ projectorEventContext.registerDistanceMetricChangedListener(
+ distanceMetric => {
+ this.distanceMetric = distanceMetric;
+ this.updateScatterPlotAttributes();
+ this.scatterPlot.render();
+ });
+ this.createVisualizers(false);
+ }
+
+ notifyProjectionPositionsUpdated() {
+ this.updateScatterPlotPositions();
+ this.scatterPlot.render();
+ }
+
+ set3DLabelMode(renderLabelsIn3D: boolean) {
+ this.renderLabelsIn3D = renderLabelsIn3D;
+ this.createVisualizers(renderLabelsIn3D);
+ this.updateScatterPlotAttributes();
+ this.scatterPlot.render();
+ }
+
+ setLegendPointColorer(legendPointColorer: (index: number) => string) {
+ this.legendPointColorer = legendPointColorer;
+ }
+
+ resize() {
+ this.scatterPlot.resize();
+ }
+
+ populateBookmarkFromUI(state: State) {
+ state.cameraDef = this.scatterPlot.getCameraDef();
+ }
+
+ restoreUIFromBookmark(state: State) {
+ this.scatterPlot.setCameraParametersForNextCameraCreation(
+ state.cameraDef, false);
+ }
+
+ updateScatterPlotPositions() {
+ const ds = (this.projection == null) ? null : this.projection.dataSet;
+ const accessors =
+ (this.projection == null) ? null : this.projection.pointAccessors;
+ const newPositions = this.generatePointPositionArray(ds, accessors);
+ this.scatterPlot.setPointPositions(ds, newPositions);
+ }
+
+ updateScatterPlotAttributes() {
+ if (this.projection == null) {
+ return;
+ }
+ const dataSet = this.projection.dataSet;
+ const selectedSet = this.selectedPointIndices;
+ const hoverIndex = this.hoverPointIndex;
+ const neighbors = this.neighborsOfFirstSelectedPoint;
+ const pointColorer = this.legendPointColorer;
+
+ const pointColors = this.generatePointColorArray(
+ dataSet, pointColorer, this.distanceMetric, selectedSet, neighbors,
+ hoverIndex, this.renderLabelsIn3D, this.getSpriteImageMode());
+ const pointScaleFactors = this.generatePointScaleFactorArray(
+ dataSet, selectedSet, neighbors, hoverIndex);
+ const labels = this.generateVisibleLabelRenderParams(
+ dataSet, selectedSet, neighbors, hoverIndex);
+ const traceColors = this.generateLineSegmentColorMap(dataSet, pointColorer);
+ const traceOpacities =
+ this.generateLineSegmentOpacityArray(dataSet, selectedSet);
+ const traceWidths =
+ this.generateLineSegmentWidthArray(dataSet, selectedSet);
+
+ this.scatterPlot.setPointColors(pointColors);
+ this.scatterPlot.setPointScaleFactors(pointScaleFactors);
+ this.scatterPlot.setLabels(labels);
+ this.scatterPlot.setTraceColors(traceColors);
+ this.scatterPlot.setTraceOpacities(traceOpacities);
+ this.scatterPlot.setTraceWidths(traceWidths);
+ }
+
+ render() {
+ this.scatterPlot.render();
+ }
+
generatePointPositionArray(ds: DataSet, pointAccessors: PointAccessors3D):
Float32Array {
if (ds == null) {
@@ -116,19 +235,6 @@ export class ProjectorScatterPlotAdapter {
return positions;
}
- private packRgbIntoUint8Array(
- rgbArray: Uint8Array, labelIndex: number, r: number, g: number,
- b: number) {
- rgbArray[labelIndex * 3] = r;
- rgbArray[labelIndex * 3 + 1] = g;
- rgbArray[labelIndex * 3 + 2] = b;
- }
-
- private styleRgbFromHexColor(hex: number): [number, number, number] {
- const c = new THREE.Color(hex);
- return [(c.r * 255) | 0, (c.g * 255) | 0, (c.b * 255) | 0];
- }
-
generateVisibleLabelRenderParams(
ds: DataSet, selectedPointIndices: number[],
neighborsOfFirstPoint: NearestEntry[],
@@ -155,11 +261,11 @@ export class ProjectorScatterPlotAdapter {
visibleLabels[dst] = hoverPointIndex;
scale[dst] = LABEL_SCALE_LARGE;
opacityFlags[dst] = 0;
- const fillRgb = this.styleRgbFromHexColor(LABEL_FILL_COLOR_HOVER);
- this.packRgbIntoUint8Array(
+ const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_HOVER);
+ packRgbIntoUint8Array(
fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]);
- const strokeRgb = this.styleRgbFromHexColor(LABEL_STROKE_COLOR_HOVER);
- this.packRgbIntoUint8Array(
+ const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_HOVER);
+ packRgbIntoUint8Array(
strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[1]);
++dst;
}
@@ -167,15 +273,15 @@ export class ProjectorScatterPlotAdapter {
// Selected points
{
const n = selectedPointIndices.length;
- const fillRgb = this.styleRgbFromHexColor(LABEL_FILL_COLOR_SELECTED);
- const strokeRgb = this.styleRgbFromHexColor(LABEL_STROKE_COLOR_SELECTED);
+ const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_SELECTED);
+ const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_SELECTED);
for (let i = 0; i < n; ++i) {
visibleLabels[dst] = selectedPointIndices[i];
scale[dst] = LABEL_SCALE_LARGE;
opacityFlags[dst] = (n === 1) ? 0 : 1;
- this.packRgbIntoUint8Array(
+ packRgbIntoUint8Array(
fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]);
- this.packRgbIntoUint8Array(
+ packRgbIntoUint8Array(
strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[2]);
++dst;
}
@@ -184,13 +290,13 @@ export class ProjectorScatterPlotAdapter {
// Neighbors
{
const n = neighborsOfFirstPoint.length;
- const fillRgb = this.styleRgbFromHexColor(LABEL_FILL_COLOR_NEIGHBOR);
- const strokeRgb = this.styleRgbFromHexColor(LABEL_STROKE_COLOR_NEIGHBOR);
+ const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_NEIGHBOR);
+ const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_NEIGHBOR);
for (let i = 0; i < n; ++i) {
visibleLabels[dst] = neighborsOfFirstPoint[i].index;
- this.packRgbIntoUint8Array(
+ packRgbIntoUint8Array(
fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]);
- this.packRgbIntoUint8Array(
+ packRgbIntoUint8Array(
strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[2]);
++dst;
}
@@ -248,7 +354,6 @@ export class ProjectorScatterPlotAdapter {
for (let i = 0; i < ds.traces.length; i++) {
let dataTrace = ds.traces[i];
-
let colors =
new Float32Array(2 * (dataTrace.pointIndices.length - 1) * 3);
let colorIndex = 0;
@@ -262,21 +367,19 @@ export class ProjectorScatterPlotAdapter {
colors[colorIndex++] = c1.r;
colors[colorIndex++] = c1.g;
colors[colorIndex++] = c1.b;
-
colors[colorIndex++] = c2.r;
colors[colorIndex++] = c2.g;
colors[colorIndex++] = c2.b;
}
} else {
for (let j = 0; j < dataTrace.pointIndices.length - 1; j++) {
- const c1 = this.getDefaultPointInTraceColor(
- j, dataTrace.pointIndices.length);
- const c2 = this.getDefaultPointInTraceColor(
- j + 1, dataTrace.pointIndices.length);
+ const c1 =
+ getDefaultPointInTraceColor(j, dataTrace.pointIndices.length);
+ const c2 =
+ getDefaultPointInTraceColor(j + 1, dataTrace.pointIndices.length);
colors[colorIndex++] = c1.r;
colors[colorIndex++] = c1.g;
colors[colorIndex++] = c1.b;
-
colors[colorIndex++] = c2.r;
colors[colorIndex++] = c2.g;
colors[colorIndex++] = c2.b;
@@ -319,15 +422,6 @@ export class ProjectorScatterPlotAdapter {
return widths;
}
- private getDefaultPointInTraceColor(index: number, totalPoints: number):
- THREE.Color {
- let hue = TRACE_START_HUE +
- (TRACE_END_HUE - TRACE_START_HUE) * index / totalPoints;
-
- let rgb = d3.hsl(hue, TRACE_SATURATION, TRACE_LIGHTNESS).rgb();
- return new THREE.Color(rgb.r / 255, rgb.g / 255, rgb.b / 255);
- }
-
generatePointColorArray(
ds: DataSet, legendPointColorer: (index: number) => string,
distFunc: DistanceFunction, selectedPointIndices: number[],
@@ -419,6 +513,66 @@ export class ProjectorScatterPlotAdapter {
return colors;
}
+
+ private updateScatterPlotWithNewProjection(projection: Projection) {
+ if (projection != null) {
+ this.scatterPlot.setDimensions(projection.dimensionality);
+ if (projection.dataSet.projectionCanBeRendered(
+ projection.projectionType)) {
+ this.updateScatterPlotAttributes();
+ this.notifyProjectionPositionsUpdated();
+ }
+ this.scatterPlot.setCameraParametersForNextCameraCreation(null, false);
+ } else {
+ this.updateScatterPlotAttributes();
+ this.notifyProjectionPositionsUpdated();
+ }
+ }
+
+ private createVisualizers(inLabels3DMode: boolean) {
+ const scatterPlot = this.scatterPlot;
+ scatterPlot.removeAllVisualizers();
+ if (inLabels3DMode) {
+ scatterPlot.addVisualizer(new ScatterPlotVisualizer3DLabels());
+ } else {
+ scatterPlot.addVisualizer(new ScatterPlotVisualizerSprites());
+ scatterPlot.addVisualizer(
+ new ScatterPlotVisualizerCanvasLabels(this.scatterPlotContainer));
+ }
+ scatterPlot.addVisualizer(new ScatterPlotVisualizerTraces());
+ }
+
+ private getSpriteImageMode(): boolean {
+ if (this.projection == null) {
+ return false;
+ }
+ const ds = this.projection.dataSet;
+ if ((ds == null) || (ds.spriteAndMetadataInfo == null)) {
+ return false;
+ }
+ return ds.spriteAndMetadataInfo.spriteImage != null;
+ }
+}
+
+function packRgbIntoUint8Array(
+ rgbArray: Uint8Array, labelIndex: number, r: number, g: number, b: number) {
+ rgbArray[labelIndex * 3] = r;
+ rgbArray[labelIndex * 3 + 1] = g;
+ rgbArray[labelIndex * 3 + 2] = b;
+}
+
+function styleRgbFromHexColor(hex: number): [number, number, number] {
+ const c = new THREE.Color(hex);
+ return [(c.r * 255) | 0, (c.g * 255) | 0, (c.b * 255) | 0];
+}
+
+function getDefaultPointInTraceColor(
+ index: number, totalPoints: number): THREE.Color {
+ let hue =
+ TRACE_START_HUE + (TRACE_END_HUE - TRACE_START_HUE) * index / totalPoints;
+
+ let rgb = d3.hsl(hue, TRACE_SATURATION, TRACE_LIGHTNESS).rgb();
+ return new THREE.Color(rgb.r / 255, rgb.g / 255, rgb.b / 255);
}
/**
diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlot.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlot.ts
index 5be0d93c61..a381a997cc 100644
--- a/tensorflow/tensorboard/components/vz_projector/scatterPlot.ts
+++ b/tensorflow/tensorboard/components/vz_projector/scatterPlot.ts
@@ -117,14 +117,12 @@ export class ScatterPlot {
private rectangleSelector: ScatterPlotRectangleSelector;
constructor(
- container: d3.Selection<any>, labelAccessor: (index: number) => string,
+ container: d3.Selection<any>,
projectorEventContext: ProjectorEventContext) {
this.containerNode = container.node() as HTMLElement;
this.projectorEventContext = projectorEventContext;
this.getLayoutValues();
- this.labelAccessor = labelAccessor;
-
this.scene = new THREE.Scene();
this.renderer =
new THREE.WebGLRenderer({alpha: true, premultipliedAlpha: false});
@@ -457,7 +455,7 @@ export class ScatterPlot {
return this.dimensionality === 3;
}
- private remove3dAxis(): THREE.Object3D {
+ private remove3dAxisFromScene(): THREE.Object3D {
const axes = this.scene.getObjectByName('axes');
if (axes != null) {
this.scene.remove(axes);
@@ -481,7 +479,7 @@ export class ScatterPlot {
const def = this.cameraDef || this.makeDefaultCameraDef(dimensionality);
this.recreateCamera(def);
- this.remove3dAxis();
+ this.remove3dAxisFromScene();
if (dimensionality === 3) {
this.add3dAxis();
}
@@ -624,9 +622,11 @@ export class ScatterPlot {
});
{
- const axes = this.remove3dAxis();
+ const axes = this.remove3dAxisFromScene();
this.renderer.render(this.scene, this.camera, this.pickingTexture);
- this.scene.add(axes);
+ if (axes != null) {
+ this.scene.add(axes);
+ }
}
// Render second pass to color buffer, to be displayed on the canvas.
diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts
index 766659308d..5af3eea92b 100644
--- a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts
+++ b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts
@@ -305,14 +305,15 @@ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer {
onPointPositionsChanged(newPositions: Float32Array, dataSet: DataSet) {
if (this.points != null) {
const notEnoughSpace = (this.pickingColors.length < newPositions.length);
- const newImage =
+ const newImage = (dataSet != null) &&
(this.image !== dataSet.spriteAndMetadataInfo.spriteImage);
if (notEnoughSpace || newImage) {
this.dispose();
}
}
- this.image = dataSet.spriteAndMetadataInfo.spriteImage;
+ this.image =
+ (dataSet != null) ? dataSet.spriteAndMetadataInfo.spriteImage : null;
this.worldSpacePointPositions = newPositions;
if (this.points == null) {
diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts
index 1056bfb3ce..9c172e4707 100644
--- a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts
+++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts
@@ -372,13 +372,14 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer {
const accessors =
dataSet.getPointAccessors('tsne', [0, 1, this.tSNEis3d ? 2 : null]);
const dimensionality = this.tSNEis3d ? 3 : 2;
- const projection = new Projection('tsne', accessors, dimensionality);
+ const projection =
+ new Projection('tsne', accessors, dimensionality, dataSet);
this.projector.setProjection(projection);
if (!this.dataSet.hasTSNERun) {
this.runTSNE();
} else {
- this.projector.notifyProjectionsUpdated();
+ this.projector.notifyProjectionPositionsUpdated();
}
}
@@ -390,7 +391,7 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer {
(iteration: number) => {
if (iteration != null) {
this.iterationLabel.text(iteration);
- this.projector.notifyProjectionsUpdated();
+ this.projector.notifyProjectionPositionsUpdated();
} else {
this.runTsneButton.attr('disabled', null);
this.stopTsneButton.attr('disabled', true);
@@ -426,7 +427,8 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer {
'pca', [this.pcaX, this.pcaY, this.pcaZ]);
const dimensionality = this.pcaIs3d ? 3 : 2;
- const projection = new Projection('pca', accessors, dimensionality);
+ const projection =
+ new Projection('pca', accessors, dimensionality, this.dataSet);
this.projector.setProjection(projection);
let numComponents = Math.min(NUM_PCA_COMPONENTS, this.dataSet.dim[1]);
this.updateTotalVarianceMessage();
@@ -454,7 +456,7 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer {
this.dataSet.projectLinear(yDir, 'linear-y');
const accessors = this.dataSet.getPointAccessors('custom', ['x', 'y']);
- const projection = new Projection('custom', accessors, 2);
+ const projection = new Projection('custom', accessors, 2, this.dataSet);
this.projector.setProjection(projection);
}
diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector.ts
index 8900819048..f22bc3ad64 100644
--- a/tensorflow/tensorboard/components/vz_projector/vz-projector.ts
+++ b/tensorflow/tensorboard/components/vz_projector/vz-projector.ts
@@ -21,13 +21,9 @@ import {ProtoDataProvider} from './data-provider-proto';
import {ServerDataProvider} from './data-provider-server';
import * as knn from './knn';
import * as logging from './logging';
-import {HoverListener, ProjectionChangedListener, ProjectorEventContext, SelectionChangedListener} from './projectorEventContext';
+import {DistanceMetricChangedListener, HoverListener, ProjectionChangedListener, ProjectorEventContext, SelectionChangedListener} from './projectorEventContext';
import {ProjectorScatterPlotAdapter} from './projectorScatterPlotAdapter';
-import {Mode, ScatterPlot} from './scatterPlot';
-import {ScatterPlotVisualizer3DLabels} from './scatterPlotVisualizer3DLabels';
-import {ScatterPlotVisualizerCanvasLabels} from './scatterPlotVisualizerCanvasLabels';
-import {ScatterPlotVisualizerSprites} from './scatterPlotVisualizerSprites';
-import {ScatterPlotVisualizerTraces} from './scatterPlotVisualizerTraces';
+import {Mode} from './scatterPlot';
import * as util from './util';
import {BookmarkPanel} from './vz-projector-bookmark-panel';
import {DataPanel} from './vz-projector-data-panel';
@@ -69,11 +65,11 @@ export class Projector extends ProjectorPolymer implements
private selectionChangedListeners: SelectionChangedListener[];
private hoverListeners: HoverListener[];
private projectionChangedListeners: ProjectionChangedListener[];
+ private distanceMetricChangedListeners: DistanceMetricChangedListener[];
private originalDataSet: DataSet;
private dom: d3.Selection<any>;
private projectorScatterPlotAdapter: ProjectorScatterPlotAdapter;
- private scatterPlot: ScatterPlot;
private dim: number;
private dataSetFilterIndices: number[];
@@ -108,6 +104,7 @@ export class Projector extends ProjectorPolymer implements
this.selectionChangedListeners = [];
this.hoverListeners = [];
this.projectionChangedListeners = [];
+ this.distanceMetricChangedListeners = [];
this.selectedPointIndices = [];
this.neighborsOfFirstPoint = [];
this.dom = d3.select(this);
@@ -133,14 +130,17 @@ export class Projector extends ProjectorPolymer implements
.metadata[this.selectedLabelOption] as string;
};
this.metadataCard.setLabelOption(this.selectedLabelOption);
- this.scatterPlot.setLabelAccessor(labelAccessor);
- this.scatterPlot.render();
+ this.projectorScatterPlotAdapter.scatterPlot.setLabelAccessor(
+ labelAccessor);
+ this.projectorScatterPlotAdapter.render();
}
setSelectedColorOption(colorOption: ColorOption) {
this.selectedColorOption = colorOption;
- this.updateScatterPlotAttributes();
- this.scatterPlot.render();
+ this.projectorScatterPlotAdapter.setLegendPointColorer(
+ this.getLegendPointColorer(colorOption));
+ this.projectorScatterPlotAdapter.updateScatterPlotAttributes();
+ this.projectorScatterPlotAdapter.render();
}
setNormalizeData(normalizeData: boolean) {
@@ -153,8 +153,7 @@ export class Projector extends ProjectorPolymer implements
metadataFile?: string) {
this.dataSetFilterIndices = null;
this.originalDataSet = ds;
- if (this.scatterPlot == null || ds == null) {
- // We are not ready yet.
+ if (this.projectorScatterPlotAdapter == null || ds == null) {
return;
}
this.normalizeData = this.originalDataSet.dim[1] >= THRESHOLD_DIM_NORMALIZE;
@@ -176,8 +175,8 @@ export class Projector extends ProjectorPolymer implements
// height can grow indefinitely.
let container = this.dom.select('#container');
container.style('height', container.property('clientHeight') + 'px');
- this.scatterPlot.resize();
- this.scatterPlot.render();
+ this.projectorScatterPlotAdapter.resize();
+ this.projectorScatterPlotAdapter.render();
}
setSelectedTensor(run: string, tensorInfo: EmbeddingInfo) {
@@ -203,7 +202,7 @@ export class Projector extends ProjectorPolymer implements
return this.dataSet.points[localIndex].index;
});
this.setCurrentDataSet(this.originalDataSet.getSubset());
- this.updateScatterPlotPositions();
+ this.projectorScatterPlotAdapter.updateScatterPlotPositions();
this.dataSetFilterIndices = [];
this.adjustSelectionAndHover(originalPointIndices);
}
@@ -247,8 +246,16 @@ export class Projector extends ProjectorPolymer implements
this.projectionChangedListeners.push(listener);
}
- notifyProjectionChanged(dataSet: DataSet) {
- this.projectionChangedListeners.forEach(l => l(dataSet));
+ notifyProjectionChanged(projection: Projection) {
+ this.projectionChangedListeners.forEach(l => l(projection));
+ }
+
+ registerDistanceMetricChangedListener(l: DistanceMetricChangedListener) {
+ this.distanceMetricChangedListeners.push(l);
+ }
+
+ notifyDistanceMetricChanged(distMetric: DistanceFunction) {
+ this.distanceMetricChangedListeners.forEach(l => l(distMetric));
}
_dataProtoChanged(dataProtoString: string) {
@@ -324,11 +331,6 @@ export class Projector extends ProjectorPolymer implements
return (label3DModeButton as any).active;
}
- private getSpriteImageMode(): boolean {
- return this.dataSet && this.dataSet.spriteAndMetadataInfo &&
- this.dataSet.spriteAndMetadataInfo.spriteImage != null;
- }
-
adjustSelectionAndHover(selectedPointIndices: number[], hoverIndex?: number) {
this.notifySelectionChanged(selectedPointIndices);
this.notifyHoverOverPoint(hoverIndex);
@@ -338,8 +340,7 @@ export class Projector extends ProjectorPolymer implements
private setMode(mode: Mode) {
let selectModeButton = this.querySelector('#selectMode');
(selectModeButton as any).active = (mode === Mode.SELECT);
-
- this.scatterPlot.setMode(mode);
+ this.projectorScatterPlotAdapter.scatterPlot.setMode(mode);
}
private setCurrentDataSet(ds: DataSet) {
@@ -360,14 +361,15 @@ export class Projector extends ProjectorPolymer implements
this.projectionsPanel.dataSetUpdated(
this.dataSet, this.originalDataSet, this.dim);
- this.scatterPlot.setCameraParametersForNextCameraCreation(null, true);
+ this.projectorScatterPlotAdapter.scatterPlot
+ .setCameraParametersForNextCameraCreation(null, true);
}
private setupUIControls() {
// View controls
this.querySelector('#reset-zoom').addEventListener('click', () => {
- this.scatterPlot.resetZoom();
- this.scatterPlot.startOrbitAnimation();
+ this.projectorScatterPlotAdapter.scatterPlot.resetZoom();
+ this.projectorScatterPlotAdapter.scatterPlot.startOrbitAnimation();
});
let selectModeButton = this.querySelector('#selectMode');
@@ -376,14 +378,13 @@ export class Projector extends ProjectorPolymer implements
});
let nightModeButton = this.querySelector('#nightDayMode');
nightModeButton.addEventListener('click', () => {
- this.scatterPlot.setDayNightMode((nightModeButton as any).active);
+ this.projectorScatterPlotAdapter.scatterPlot.setDayNightMode(
+ (nightModeButton as any).active);
});
const labels3DModeButton = this.get3DLabelModeButton();
labels3DModeButton.addEventListener('click', () => {
- this.createVisualizers(this.get3DLabelMode());
- this.updateScatterPlotAttributes();
- this.scatterPlot.render();
+ this.projectorScatterPlotAdapter.set3DLabelMode(this.get3DLabelMode());
});
window.addEventListener('resize', () => {
@@ -391,18 +392,19 @@ export class Projector extends ProjectorPolymer implements
let parentHeight =
(container.node().parentNode as HTMLElement).clientHeight;
container.style('height', parentHeight + 'px');
- this.scatterPlot.resize();
+ this.projectorScatterPlotAdapter.resize();
});
- this.projectorScatterPlotAdapter = new ProjectorScatterPlotAdapter();
-
- this.scatterPlot = new ScatterPlot(
- this.getScatterContainer(),
- i => '' + this.dataSet.points[i].metadata[this.selectedLabelOption],
- this as ProjectorEventContext);
- this.createVisualizers(false);
+ {
+ const labelAccessor = i =>
+ '' + this.dataSet.points[i].metadata[this.selectedLabelOption];
+ this.projectorScatterPlotAdapter = new ProjectorScatterPlotAdapter(
+ this.getScatterContainer(), this as ProjectorEventContext);
+ this.projectorScatterPlotAdapter.scatterPlot.setLabelAccessor(
+ labelAccessor);
+ }
- this.scatterPlot.onCameraMove(
+ this.projectorScatterPlotAdapter.scatterPlot.onCameraMove(
(cameraPosition: THREE.Vector3, cameraTarget: THREE.Vector3) =>
this.bookmarkPanel.clearStateSelection());
@@ -425,75 +427,16 @@ export class Projector extends ProjectorPolymer implements
hoverText = point.metadata[this.selectedLabelOption].toString();
}
}
- this.updateScatterPlotAttributes();
- this.scatterPlot.render();
if (this.selectedPointIndices.length === 0) {
this.statusBar.style('display', hoverText ? null : 'none');
this.statusBar.text(hoverText);
}
}
- private updateScatterPlotPositions() {
- if (this.dataSet == null) {
- return;
- }
- if (this.projection == null) {
- return;
- }
- const newPositions =
- this.projectorScatterPlotAdapter.generatePointPositionArray(
- this.dataSet, this.projection.pointAccessors);
- this.scatterPlot.setPointPositions(this.dataSet, newPositions);
- }
-
- private updateScatterPlotAttributes() {
- const dataSet = this.dataSet;
- const selectedSet = this.selectedPointIndices;
- const hoverIndex = this.hoverPointIndex;
- const neighbors = this.neighborsOfFirstPoint;
- const pointColorer = this.getLegendPointColorer(this.selectedColorOption);
- const adapter = this.projectorScatterPlotAdapter;
-
- const pointColors = adapter.generatePointColorArray(
- dataSet, pointColorer, this.inspectorPanel.distFunc, selectedSet,
- neighbors, hoverIndex, this.get3DLabelMode(),
- this.getSpriteImageMode());
- const pointScaleFactors = adapter.generatePointScaleFactorArray(
- dataSet, selectedSet, neighbors, hoverIndex);
- const labels = adapter.generateVisibleLabelRenderParams(
- dataSet, selectedSet, neighbors, hoverIndex);
- const traceColors =
- adapter.generateLineSegmentColorMap(dataSet, pointColorer);
- const traceOpacities =
- adapter.generateLineSegmentOpacityArray(dataSet, selectedSet);
- const traceWidths =
- adapter.generateLineSegmentWidthArray(dataSet, selectedSet);
-
- this.scatterPlot.setPointColors(pointColors);
- this.scatterPlot.setPointScaleFactors(pointScaleFactors);
- this.scatterPlot.setLabels(labels);
- this.scatterPlot.setTraceColors(traceColors);
- this.scatterPlot.setTraceOpacities(traceOpacities);
- this.scatterPlot.setTraceWidths(traceWidths);
- }
-
private getScatterContainer(): d3.Selection<any> {
return this.dom.select('#scatter');
}
- private createVisualizers(inLabels3DMode: boolean) {
- const scatterPlot = this.scatterPlot;
- scatterPlot.removeAllVisualizers();
- if (inLabels3DMode) {
- scatterPlot.addVisualizer(new ScatterPlotVisualizer3DLabels());
- } else {
- scatterPlot.addVisualizer(new ScatterPlotVisualizerSprites());
- scatterPlot.addVisualizer(
- new ScatterPlotVisualizerCanvasLabels(this.getScatterContainer()));
- }
- scatterPlot.addVisualizer(new ScatterPlotVisualizerTraces());
- }
-
private onSelectionChanged(
selectedPointIndices: number[],
neighborsOfFirstPoint: knn.NearestEntry[]) {
@@ -503,26 +446,18 @@ export class Projector extends ProjectorPolymer implements
this.selectedPointIndices.length + neighborsOfFirstPoint.length;
this.statusBar.text(`Selected ${totalNumPoints} points`)
.style('display', totalNumPoints > 0 ? null : 'none');
- this.updateScatterPlotAttributes();
- this.scatterPlot.render();
}
setProjection(projection: Projection) {
this.projection = projection;
- this.scatterPlot.setDimensions(projection.dimensionality);
- this.analyticsLogger.logProjectionChanged(projection.projectionType);
- if (this.dataSet.projectionCanBeRendered(projection.projectionType)) {
- this.updateScatterPlotAttributes();
- this.notifyProjectionsUpdated();
+ if (projection != null) {
+ this.analyticsLogger.logProjectionChanged(projection.projectionType);
}
-
- this.scatterPlot.setCameraParametersForNextCameraCreation(null, false);
- this.notifyProjectionChanged(this.dataSet);
+ this.notifyProjectionChanged(projection);
}
- notifyProjectionsUpdated() {
- this.updateScatterPlotPositions();
- this.scatterPlot.render();
+ notifyProjectionPositionsUpdated() {
+ this.projectorScatterPlotAdapter.notifyProjectionPositionsUpdated();
}
/**
@@ -547,7 +482,7 @@ export class Projector extends ProjectorPolymer implements
state.tSNEIteration = this.dataSet.tSNEIteration;
state.selectedPoints = this.selectedPointIndices;
state.filteredPoints = this.dataSetFilterIndices;
- state.cameraDef = this.scatterPlot.getCameraDef();
+ this.projectorScatterPlotAdapter.populateBookmarkFromUI(state);
state.selectedColorOptionName = this.dataPanel.selectedColorOptionName;
state.selectedLabelOption = this.selectedLabelOption;
this.projectionsPanel.populateBookmarkFromUI(state);
@@ -556,6 +491,7 @@ export class Projector extends ProjectorPolymer implements
/** Loads a State object into the world. */
loadState(state: State) {
+ this.setProjection(null);
{
this.projectionsPanel.disablePolymerChangesTriggerReprojection();
this.resetFilterDataset();
@@ -578,23 +514,17 @@ export class Projector extends ProjectorPolymer implements
this.inspectorPanel.restoreUIFromBookmark(state);
this.dataPanel.selectedColorOptionName = state.selectedColorOptionName;
this.selectedLabelOption = state.selectedLabelOption;
- this.scatterPlot.setCameraParametersForNextCameraCreation(
- state.cameraDef, false);
+ this.projectorScatterPlotAdapter.restoreUIFromBookmark(state);
{
const dimensions = stateGetAccessorDimensions(state);
const accessors =
this.dataSet.getPointAccessors(state.selectedProjection, dimensions);
const projection = new Projection(
- state.selectedProjection, accessors, dimensions.length);
+ state.selectedProjection, accessors, dimensions.length, this.dataSet);
this.setProjection(projection);
}
this.notifySelectionChanged(state.selectedPoints);
}
-
- notifyDistanceMetricChanged(distMetric: DistanceFunction) {
- this.updateScatterPlotAttributes();
- this.scatterPlot.render();
- }
}
document.registerElement(Projector.prototype.is, Projector);