diff options
author | 2016-11-16 17:29:04 -0800 | |
---|---|---|
committer | 2016-11-16 17:50:34 -0800 | |
commit | a10fc607a372fb352a0a80d5cb35a44db8eaa47f (patch) | |
tree | 010ca4ca321fa46d419f72062e86e7295b2141c6 | |
parent | 16408372a6411c7f5a44139502d90bd7111d86a4 (diff) |
Rename Projection enum to ProjectionType. Add new Projection class that
contains a full description of a projection.
Change: 139402857
4 files changed, 58 insertions, 60 deletions
diff --git a/tensorflow/tensorboard/components/vz_projector/analyticsLogger.ts b/tensorflow/tensorboard/components/vz_projector/analyticsLogger.ts index 3187074491..4e6ef2c910 100644 --- a/tensorflow/tensorboard/components/vz_projector/analyticsLogger.ts +++ b/tensorflow/tensorboard/components/vz_projector/analyticsLogger.ts @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import {Projection} from './data'; +import {ProjectionType} from './data'; export class AnalyticsLogger { private eventLogging: boolean; @@ -43,7 +43,7 @@ export class AnalyticsLogger { } } - logProjectionChanged(projection: Projection) { + logProjectionChanged(projection: ProjectionType) { if (this.eventLogging) { ga('send', { hitType: 'event', @@ -53,4 +53,4 @@ export class AnalyticsLogger { }); } } -}
\ No newline at end of file +} diff --git a/tensorflow/tensorboard/components/vz_projector/data.ts b/tensorflow/tensorboard/components/vz_projector/data.ts index 155f8bdccd..7d5e8c2cf2 100644 --- a/tensorflow/tensorboard/components/vz_projector/data.ts +++ b/tensorflow/tensorboard/components/vz_projector/data.ts @@ -25,19 +25,14 @@ export type DistanceFunction = (a: number[], b: number[]) => number; export type PointAccessor = (index: number) => number; export type PointAccessors3D = [PointAccessor, PointAccessor, PointAccessor]; -export interface PointMetadata { - [key: string]: number | string; -} +export interface PointMetadata { [key: string]: number|string; } export interface DataProto { shape: [number, number]; tensor: number[]; metadata: { - columns: Array<{ - name: string; - stringValues: string[]; - numericValues: number[]; - }>; + columns: Array< + {name: string; stringValues: string[]; numericValues: number[];}>; }; } @@ -139,8 +134,8 @@ export class DataSet { private tsne: TSNE; /** Creates a new Dataset */ - constructor(points: DataPoint[], - spriteAndMetadataInfo?: SpriteAndMetadataInfo) { + constructor( + points: DataPoint[], spriteAndMetadataInfo?: SpriteAndMetadataInfo) { this.points = points; this.sampledDataIndices = shuffle(d3.range(this.points.length)).slice(0, SAMPLE_SIZE); @@ -193,7 +188,7 @@ export class DataSet { return traces; } - getPointAccessors(projection: Projection, components: (number|string)[]): + getPointAccessors(projection: ProjectionType, components: (number|string)[]): [PointAccessor, PointAccessor, PointAccessor] { if (components.length > 3) { throw new RangeError('components length must be <= 3'); @@ -212,7 +207,7 @@ export class DataSet { return accessors; } - projectionCanBeRendered(projection: Projection): boolean { + projectionCanBeRendered(projection: ProjectionType): boolean { if (projection !== 'tsne') { return true; } @@ -228,8 +223,8 @@ export class DataSet { * @return A subset of the original dataset. */ getSubset(subset?: number[]): DataSet { - let pointsSubset = subset && subset.length ? - subset.map(i => this.points[i]) : this.points; + let pointsSubset = + subset && subset.length ? subset.map(i => this.points[i]) : this.points; let points = pointsSubset.map(dp => { return { metadata: dp.metadata, @@ -382,7 +377,9 @@ export class DataSet { .forEach((m, i) => this.points[i].metadata = m); } - stopTSNE() { this.tSNEShouldStop = true; } + stopTSNE() { + this.tSNEShouldStop = true; + } /** * Finds the nearest neighbors of the query point using a @@ -391,8 +388,8 @@ export class DataSet { findNeighbors(pointIndex: number, distFunc: DistanceFunction, numNN: number): knn.NearestEntry[] { // Find the nearest neighbors of a particular point. - let neighbors = knn.findKNNofPoint(this.points, pointIndex, numNN, - (d => d.vector), distFunc); + let neighbors = knn.findKNNofPoint( + this.points, pointIndex, numNN, (d => d.vector), distFunc); // TODO(smilkov): Figure out why we slice. let result = neighbors.slice(0, numNN); return result; @@ -413,7 +410,13 @@ export class DataSet { } } -export type Projection = 'tsne' | 'pca' | 'custom'; +export type ProjectionType = 'tsne' | 'pca' | 'custom'; + +export class Projection { + constructor( + public projectionType: ProjectionType, + public pointAccessors: PointAccessors3D, public dimensionality: number) {} +} export interface ColorOption { name: string; @@ -438,7 +441,7 @@ export class State { isSelected: boolean = false; /** The selected projection tab. */ - selectedProjection: Projection; + selectedProjection: ProjectionType; /** Dimensions of the DataSet. */ dataSetDimensions: [number, number]; diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts index 5d2f2f945d..1056bfb3ce 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import {DataSet, SpriteAndMetadataInfo, PCA_SAMPLE_DIM, Projection, SAMPLE_SIZE, State} from './data'; -import {Vector} from './vector'; +import {DataSet, PCA_SAMPLE_DIM, Projection, ProjectionType, SAMPLE_SIZE, SpriteAndMetadataInfo, State} from './data'; import * as vector from './vector'; +import {Vector} from './vector'; import {Projector} from './vz-projector'; import {ProjectorInput} from './vz-projector-input'; // tslint:disable-next-line:no-unused-variable @@ -44,18 +44,14 @@ export let ProjectionsPanelPolymer = PolymerElement({ } }); -type InputControlName = 'xLeft' | 'xRight' | 'yUp' | 'yDown'; +type InputControlName = 'xLeft'|'xRight'|'yUp'|'yDown'; type CentroidResult = { - centroid?: Vector; - numMatches?: number; + centroid?: Vector; numMatches?: number; }; type Centroids = { - [key: string]: Vector; - xLeft: Vector; - xRight: Vector; - yUp: Vector; + [key: string]: Vector; xLeft: Vector; xRight: Vector; yUp: Vector; yDown: Vector; }; @@ -64,12 +60,9 @@ type Centroids = { */ export class ProjectionsPanel extends ProjectionsPanelPolymer { private projector: Projector; - private pcaComponents: Array<{ - id: number, - componentNumber: number, - percVariance: string - }>; - private currentProjection: Projection; + private pcaComponents: + Array<{id: number, componentNumber: number, percVariance: string}>; + private currentProjection: ProjectionType; private polymerChangesTriggerReprojection: boolean; private dataSet: DataSet; private originalDataSet: DataSet; @@ -184,7 +177,7 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { // TODO: figure out why `--paper-input-container-input` css mixin didn't // work. this.dom.selectAll('paper-dropdown-menu paper-input input') - .style('font-size', '14px'); + .style('font-size', '14px'); } restoreUIFromBookmark(bookmark: State) { @@ -332,7 +325,7 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { this.searchByMetadataOptions[Math.max(0, searchByMetadataIndex)]; } - public showTab(id: Projection) { + public showTab(id: ProjectionType) { this.currentProjection = id; let tab = this.dom.select('.ink-tab[data-tab="' + id + '"]'); @@ -355,7 +348,7 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { this.beginProjection(id); } - private beginProjection(projection: string) { + private beginProjection(projection: ProjectionType) { if (this.polymerChangesTriggerReprojection === false) { return; } @@ -378,7 +371,9 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { } const accessors = dataSet.getPointAccessors('tsne', [0, 1, this.tSNEis3d ? 2 : null]); - this.projector.setProjection('tsne', this.tSNEis3d ? 3 : 2, accessors); + const dimensionality = this.tSNEis3d ? 3 : 2; + const projection = new Projection('tsne', accessors, dimensionality); + this.projector.setProjection(projection); if (!this.dataSet.hasTSNERun) { this.runTSNE(); @@ -430,7 +425,9 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { const accessors = this.dataSet.getPointAccessors( 'pca', [this.pcaX, this.pcaY, this.pcaZ]); - this.projector.setProjection('pca', this.pcaIs3d ? 3 : 2, accessors); + const dimensionality = this.pcaIs3d ? 3 : 2; + const projection = new Projection('pca', accessors, dimensionality); + this.projector.setProjection(projection); let numComponents = Math.min(NUM_PCA_COMPONENTS, this.dataSet.dim[1]); this.updateTotalVarianceMessage(); this.pcaComponents = d3.range(0, numComponents).map(i => { @@ -457,7 +454,8 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer { this.dataSet.projectLinear(yDir, 'linear-y'); const accessors = this.dataSet.getPointAccessors('custom', ['x', 'y']); - this.projector.setProjection('custom', 2, accessors); + const projection = new Projection('custom', accessors, 2); + this.projector.setProjection(projection); } clearCentroids(): void { diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector.ts index f7be1c55fa..8900819048 100644 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector.ts +++ b/tensorflow/tensorboard/components/vz_projector/vz-projector.ts @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ import {AnalyticsLogger} from './analyticsLogger'; -import {ColorOption, ColumnStats, DataPoint, DataProto, DataSet, DistanceFunction, PointAccessors3D, PointMetadata, Projection, SpriteAndMetadataInfo, State, stateGetAccessorDimensions} from './data'; +import {ColorOption, ColumnStats, DataPoint, DataProto, DataSet, DistanceFunction, PointMetadata, Projection, SpriteAndMetadataInfo, State, stateGetAccessorDimensions} from './data'; import {DataProvider, EmbeddingInfo, ServingMode} from './data-provider'; import {DemoDataProvider} from './data-provider-demo'; import {ProtoDataProvider} from './data-provider-proto'; @@ -88,8 +88,7 @@ export class Projector extends ProjectorPolymer implements private selectedLabelOption: string; private routePrefix: string; private normalizeData: boolean; - private selectedProjection: Projection; - private selectedProjectionPointAccessors: PointAccessors3D; + private projection: Projection; /** Polymer component panels */ private dataPanel: DataPanel; @@ -356,7 +355,7 @@ export class Projector extends ProjectorPolymer implements this.dom.select('span.numDataPoints').text(this.dataSet.dim[0]); this.dom.select('span.dim').text(this.dataSet.dim[1]); - this.selectedProjectionPointAccessors = null; + this.projection = null; this.projectionsPanel.dataSetUpdated( this.dataSet, this.originalDataSet, this.dim); @@ -438,12 +437,12 @@ export class Projector extends ProjectorPolymer implements if (this.dataSet == null) { return; } - if (this.selectedProjectionPointAccessors == null) { + if (this.projection == null) { return; } const newPositions = this.projectorScatterPlotAdapter.generatePointPositionArray( - this.dataSet, this.selectedProjectionPointAccessors); + this.dataSet, this.projection.pointAccessors); this.scatterPlot.setPointPositions(this.dataSet, newPositions); } @@ -508,14 +507,11 @@ export class Projector extends ProjectorPolymer implements this.scatterPlot.render(); } - setProjection( - projection: Projection, dimensionality: number, - pointAccessors: PointAccessors3D) { - this.selectedProjection = projection; - this.selectedProjectionPointAccessors = pointAccessors; - this.scatterPlot.setDimensions(dimensionality); - this.analyticsLogger.logProjectionChanged(projection); - if (this.dataSet.projectionCanBeRendered(projection)) { + setProjection(projection: Projection) { + this.projection = projection; + this.scatterPlot.setDimensions(projection.dimensionality); + this.analyticsLogger.logProjectionChanged(projection.projectionType); + if (this.dataSet.projectionCanBeRendered(projection.projectionType)) { this.updateScatterPlotAttributes(); this.notifyProjectionsUpdated(); } @@ -546,7 +542,7 @@ export class Projector extends ProjectorPolymer implements } state.projections.push(projections); } - state.selectedProjection = this.selectedProjection; + state.selectedProjection = this.projection.projectionType; state.dataSetDimensions = this.dataSet.dim; state.tSNEIteration = this.dataSet.tSNEIteration; state.selectedPoints = this.selectedPointIndices; @@ -588,8 +584,9 @@ export class Projector extends ProjectorPolymer implements const dimensions = stateGetAccessorDimensions(state); const accessors = this.dataSet.getPointAccessors(state.selectedProjection, dimensions); - this.setProjection( - state.selectedProjection, dimensions.length, accessors); + const projection = new Projection( + state.selectedProjection, accessors, dimensions.length); + this.setProjection(projection); } this.notifySelectionChanged(state.selectedPoints); } |