diff options
30 files changed, 6314 insertions, 16 deletions
@@ -228,6 +228,13 @@ new_git_repository( ) new_git_repository( + name = "numericjs", + build_file = "bower.BUILD", + remote = "https://github.com/sloisel/numeric.git", + tag = "v1.2.6", +) + +new_git_repository( name = "paper_behaviors", build_file = "bower.BUILD", remote = "https://github.com/polymerelements/paper-behaviors.git", @@ -298,6 +305,13 @@ new_git_repository( ) new_git_repository( + name = "paper_listbox", + build_file = "bower.BUILD", + remote = "https://github.com/polymerelements/paper-listbox.git", + tag = "v1.1.2", +) + +new_git_repository( name = "paper_material", build_file = "bower.BUILD", remote = "https://github.com/polymerelements/paper-material.git", @@ -399,7 +413,7 @@ new_git_repository( name = "polymer", build_file = "bower.BUILD", remote = "https://github.com/polymer/polymer.git", - tag = "v1.6.0", + tag = "v1.6.1", ) new_git_repository( @@ -410,6 +424,13 @@ new_git_repository( ) new_git_repository( + name = "three_js", + build_file = "bower.BUILD", + remote = "https://github.com/mrdoob/three.js.git", + tag = "r77", +) + +new_git_repository( name = "web_animations_js", build_file = "bower.BUILD", remote = "https://github.com/web-animations/web-animations-js.git", @@ -422,3 +443,10 @@ new_git_repository( remote = "https://github.com/webcomponents/webcomponentsjs.git", tag = "v0.7.22", ) + +new_git_repository( + name = "weblas", + build_file = "bower.BUILD", + remote = "https://github.com/waylonflinn/weblas.git", + tag = "v0.9.0", +) diff --git a/bower.BUILD b/bower.BUILD index cb3309a36e..e01fb6d7e9 100644 --- a/bower.BUILD +++ b/bower.BUILD @@ -303,6 +303,41 @@ filegroup( ) filegroup( + name = "numericjs", + srcs = [ + "benchmark.html", + "benchmark2.html", + "demo.html", + "documentation.html", + "myworker.js", + "resources/style.css", + "resources/style-ie.css", + "src/documentation.html", + "src/numeric.js", + "src/quadprog.js", + "src/seedrandom.js", + "src/sparse2.js", + "src/svd.js", + "tools/XMLHttpRequest.js", + "tools/closurelib.js", + "tools/excanvas.min.js", + "tools/goog-require.js", + "tools/jquery.flot.image.js", + "tools/jquery.flot.image.min.js", + "tools/jquery.flot.js", + "tools/jquery.flot.min.js", + "tools/jquery-1.7.1.js", + "tools/jquery-1.7.1.min.js", + "tools/json2.js", + "tools/megalib.js", + "tools/mytest.html", + "tools/sylvester.js", + "tools/unit2.js", + "tools/workshop.html", + ], +) + +filegroup( name = "paper_behaviors", srcs = [ "index.html", @@ -405,6 +440,14 @@ filegroup( ) filegroup( + name = "paper_listbox", + srcs = [ + "index.html", + "paper-listbox.html", + ], +) + +filegroup( name = "paper_material", srcs = [ "index.html", @@ -556,6 +599,275 @@ filegroup( ) filegroup( + name = "three_js", + srcs = [ + "build/three.js", + "build/three.min.js", + "examples/js/AnimationClipCreator.js", + "examples/js/BlendCharacter.js", + "examples/js/BlendCharacterGui.js", + "examples/js/BufferGeometryUtils.js", + "examples/js/Car.js", + "examples/js/Cloth.js", + "examples/js/CurveExtras.js", + "examples/js/Detector.js", + "examples/js/Encodings.js", + "examples/js/GPUParticleSystem.js", + "examples/js/Gyroscope.js", + "examples/js/Half.js", + "examples/js/ImprovedNoise.js", + "examples/js/MD2Character.js", + "examples/js/MD2CharacterComplex.js", + "examples/js/MarchingCubes.js", + "examples/js/Mirror.js", + "examples/js/MorphAnimMesh.js", + "examples/js/MorphAnimation.js", + "examples/js/Ocean.js", + "examples/js/Octree.js", + "examples/js/PRNG.js", + "examples/js/ParametricGeometries.js", + "examples/js/RollerCoaster.js", + "examples/js/ShaderGodRays.js", + "examples/js/ShaderSkin.js", + "examples/js/ShaderTerrain.js", + "examples/js/ShaderToon.js", + "examples/js/SimplexNoise.js", + "examples/js/SimulationRenderer.js", + "examples/js/SkyShader.js", + "examples/js/TimelinerController.js", + "examples/js/TypedArrayUtils.js", + "examples/js/UCSCharacter.js", + "examples/js/Volume.js", + "examples/js/VolumeSlice.js", + "examples/js/WaterShader.js", + "examples/js/WebVR.js", + "examples/js/animation/CCDIKSolver.js", + "examples/js/animation/MMDPhysics.js", + "examples/js/cameras/CinematicCamera.js", + "examples/js/cameras/CombinedCamera.js", + "examples/js/controls/DeviceOrientationControls.js", + "examples/js/controls/DragControls.js", + "examples/js/controls/EditorControls.js", + "examples/js/controls/FirstPersonControls.js", + "examples/js/controls/FlyControls.js", + "examples/js/controls/MouseControls.js", + "examples/js/controls/OrbitControls.js", + "examples/js/controls/OrthographicTrackballControls.js", + "examples/js/controls/PointerLockControls.js", + "examples/js/controls/TrackballControls.js", + "examples/js/controls/TransformControls.js", + "examples/js/controls/VRControls.js", + "examples/js/crossfade/gui.js", + "examples/js/crossfade/scenes.js", + "examples/js/crossfade/transition.js", + "examples/js/curves/NURBSCurve.js", + "examples/js/curves/NURBSSurface.js", + "examples/js/curves/NURBSUtils.js", + "examples/js/effects/AnaglyphEffect.js", + "examples/js/effects/AsciiEffect.js", + "examples/js/effects/ParallaxBarrierEffect.js", + "examples/js/effects/PeppersGhostEffect.js", + "examples/js/effects/StereoEffect.js", + "examples/js/effects/VREffect.js", + "examples/js/exporters/OBJExporter.js", + "examples/js/exporters/STLBinaryExporter.js", + "examples/js/exporters/STLExporter.js", + "examples/js/exporters/TypedGeometryExporter.js", + "examples/js/geometries/ConvexGeometry.js", + "examples/js/geometries/DecalGeometry.js", + "examples/js/geometries/TeapotBufferGeometry.js", + "examples/js/geometries/hilbert2D.js", + "examples/js/geometries/hilbert3D.js", + "examples/js/libs/ammo.js", + "examples/js/libs/charsetencoder.min.js", + "examples/js/libs/dat.gui.min.js", + "examples/js/libs/earcut.js", + "examples/js/libs/inflate.min.js", + "examples/js/libs/jszip.min.js", + "examples/js/libs/msgpack-js.js", + "examples/js/libs/pnltri.min.js", + "examples/js/libs/stats.min.js", + "examples/js/libs/system.min.js", + "examples/js/libs/timeliner_gui.min.js", + "examples/js/libs/tween.min.js", + "examples/js/libs/zlib_and_gzip.min.js", + "examples/js/loaders/3MFLoader.js", + "examples/js/loaders/AMFLoader.js", + "examples/js/loaders/AWDLoader.js", + "examples/js/loaders/AssimpJSONLoader.js", + "examples/js/loaders/BabylonLoader.js", + "examples/js/loaders/BinaryLoader.js", + "examples/js/loaders/ColladaLoader.js", + "examples/js/loaders/ColladaLoader2.js", + "examples/js/loaders/DDSLoader.js", + "examples/js/loaders/FBXLoader.js", + "examples/js/loaders/HDRCubeTextureLoader.js", + "examples/js/loaders/KMZLoader.js", + "examples/js/loaders/MD2Loader.js", + "examples/js/loaders/MMDLoader.js", + "examples/js/loaders/MTLLoader.js", + "examples/js/loaders/NRRDLoader.js", + "examples/js/loaders/OBJLoader.js", + "examples/js/loaders/PCDLoader.js", + "examples/js/loaders/PDBLoader.js", + "examples/js/loaders/PLYLoader.js", + "examples/js/loaders/PVRLoader.js", + "examples/js/loaders/PlayCanvasLoader.js", + "examples/js/loaders/RGBELoader.js", + "examples/js/loaders/STLLoader.js", + "examples/js/loaders/SVGLoader.js", + "examples/js/loaders/TGALoader.js", + "examples/js/loaders/UTF8Loader.js", + "examples/js/loaders/VRMLLoader.js", + "examples/js/loaders/VTKLoader.js", + "examples/js/loaders/collada/Animation.js", + "examples/js/loaders/collada/AnimationHandler.js", + "examples/js/loaders/collada/KeyFrameAnimation.js", + "examples/js/loaders/ctm/CTMLoader.js", + "examples/js/loaders/ctm/CTMWorker.js", + "examples/js/loaders/ctm/ctm.js", + "examples/js/loaders/ctm/lzma.js", + "examples/js/loaders/deprecated/SceneLoader.js", + "examples/js/loaders/gltf/glTF-parser.js", + "examples/js/loaders/gltf/glTFAnimation.js", + "examples/js/loaders/gltf/glTFLoader.js", + "examples/js/loaders/gltf/glTFLoaderUtils.js", + "examples/js/loaders/gltf/glTFShaders.js", + "examples/js/loaders/gltf/gltfUtilities.js", + "examples/js/loaders/sea3d/SEA3D.js", + "examples/js/loaders/sea3d/SEA3DDeflate.js", + "examples/js/loaders/sea3d/SEA3DLZMA.js", + "examples/js/loaders/sea3d/SEA3DLegacy.js", + "examples/js/loaders/sea3d/SEA3DLoader.js", + "examples/js/math/ColorConverter.js", + "examples/js/math/Lut.js", + "examples/js/modifiers/BufferSubdivisionModifier.js", + "examples/js/modifiers/ExplodeModifier.js", + "examples/js/modifiers/SubdivisionModifier.js", + "examples/js/modifiers/TessellateModifier.js", + "examples/js/nodes/BuilderNode.js", + "examples/js/nodes/ConstNode.js", + "examples/js/nodes/FunctionCallNode.js", + "examples/js/nodes/FunctionNode.js", + "examples/js/nodes/GLNode.js", + "examples/js/nodes/InputNode.js", + "examples/js/nodes/NodeLib.js", + "examples/js/nodes/NodeMaterial.js", + "examples/js/nodes/RawNode.js", + "examples/js/nodes/TempNode.js", + "examples/js/nodes/accessors/CameraNode.js", + "examples/js/nodes/accessors/ColorsNode.js", + "examples/js/nodes/accessors/LightNode.js", + "examples/js/nodes/accessors/NormalNode.js", + "examples/js/nodes/accessors/PositionNode.js", + "examples/js/nodes/accessors/ReflectNode.js", + "examples/js/nodes/accessors/ScreenUVNode.js", + "examples/js/nodes/accessors/UVNode.js", + "examples/js/nodes/inputs/ColorNode.js", + "examples/js/nodes/inputs/CubeTextureNode.js", + "examples/js/nodes/inputs/FloatNode.js", + "examples/js/nodes/inputs/IntNode.js", + "examples/js/nodes/inputs/Matrix4Node.js", + "examples/js/nodes/inputs/MirrorNode.js", + "examples/js/nodes/inputs/ScreenNode.js", + "examples/js/nodes/inputs/TextureNode.js", + "examples/js/nodes/inputs/Vector2Node.js", + "examples/js/nodes/inputs/Vector3Node.js", + "examples/js/nodes/inputs/Vector4Node.js", + "examples/js/nodes/materials/PhongNode.js", + "examples/js/nodes/materials/PhongNodeMaterial.js", + "examples/js/nodes/materials/StandardNode.js", + "examples/js/nodes/materials/StandardNodeMaterial.js", + "examples/js/nodes/math/Math1Node.js", + "examples/js/nodes/math/Math2Node.js", + "examples/js/nodes/math/Math3Node.js", + "examples/js/nodes/math/OperatorNode.js", + "examples/js/nodes/postprocessing/NodePass.js", + "examples/js/nodes/utils/ColorAdjustmentNode.js", + "examples/js/nodes/utils/JoinNode.js", + "examples/js/nodes/utils/LuminanceNode.js", + "examples/js/nodes/utils/NoiseNode.js", + "examples/js/nodes/utils/NormalMapNode.js", + "examples/js/nodes/utils/ResolutionNode.js", + "examples/js/nodes/utils/RoughnessToBlinnExponentNode.js", + "examples/js/nodes/utils/SwitchNode.js", + "examples/js/nodes/utils/TimerNode.js", + "examples/js/nodes/utils/VelocityNode.js", + "examples/js/objects/ShadowMesh.js", + "examples/js/pmrem/PMREMCubeUVPacker.js", + "examples/js/pmrem/PMREMGenerator.js", + "examples/js/postprocessing/AdaptiveToneMappingPass.js", + "examples/js/postprocessing/BloomPass.js", + "examples/js/postprocessing/BokehPass.js", + "examples/js/postprocessing/ClearPass.js", + "examples/js/postprocessing/DotScreenPass.js", + "examples/js/postprocessing/EffectComposer.js", + "examples/js/postprocessing/FilmPass.js", + "examples/js/postprocessing/GlitchPass.js", + "examples/js/postprocessing/ManualMSAARenderPass.js", + "examples/js/postprocessing/MaskPass.js", + "examples/js/postprocessing/RenderPass.js", + "examples/js/postprocessing/SMAAPass.js", + "examples/js/postprocessing/SavePass.js", + "examples/js/postprocessing/ShaderPass.js", + "examples/js/postprocessing/TAARenderPass.js", + "examples/js/postprocessing/TexturePass.js", + "examples/js/renderers/CSS2DRenderer.js", + "examples/js/renderers/CSS3DRenderer.js", + "examples/js/renderers/CanvasRenderer.js", + "examples/js/renderers/Projector.js", + "examples/js/renderers/RaytracingRenderer.js", + "examples/js/renderers/RaytracingWorker.js", + "examples/js/renderers/SVGRenderer.js", + "examples/js/renderers/SoftwareRenderer.js", + "examples/js/shaders/BasicShader.js", + "examples/js/shaders/BleachBypassShader.js", + "examples/js/shaders/BlendShader.js", + "examples/js/shaders/BokehShader.js", + "examples/js/shaders/BokehShader2.js", + "examples/js/shaders/BrightnessContrastShader.js", + "examples/js/shaders/ColorCorrectionShader.js", + "examples/js/shaders/ColorifyShader.js", + "examples/js/shaders/ConvolutionShader.js", + "examples/js/shaders/CopyShader.js", + "examples/js/shaders/DOFMipMapShader.js", + "examples/js/shaders/DigitalGlitch.js", + "examples/js/shaders/DotScreenShader.js", + "examples/js/shaders/EdgeShader.js", + "examples/js/shaders/EdgeShader2.js", + "examples/js/shaders/FXAAShader.js", + "examples/js/shaders/FilmShader.js", + "examples/js/shaders/FocusShader.js", + "examples/js/shaders/FresnelShader.js", + "examples/js/shaders/GammaCorrectionShader.js", + "examples/js/shaders/HorizontalBlurShader.js", + "examples/js/shaders/HorizontalTiltShiftShader.js", + "examples/js/shaders/HueSaturationShader.js", + "examples/js/shaders/KaleidoShader.js", + "examples/js/shaders/LuminosityShader.js", + "examples/js/shaders/MirrorShader.js", + "examples/js/shaders/NormalMapShader.js", + "examples/js/shaders/OceanShaders.js", + "examples/js/shaders/ParallaxShader.js", + "examples/js/shaders/RGBShiftShader.js", + "examples/js/shaders/SMAAShader.js", + "examples/js/shaders/SSAOShader.js", + "examples/js/shaders/SepiaShader.js", + "examples/js/shaders/TechnicolorShader.js", + "examples/js/shaders/ToneMapShader.js", + "examples/js/shaders/TriangleBlurShader.js", + "examples/js/shaders/UnpackDepthRGBAShader.js", + "examples/js/shaders/VerticalBlurShader.js", + "examples/js/shaders/VerticalTiltShiftShader.js", + "examples/js/shaders/VignetteShader.js", + "examples/js/utils/GeometryUtils.js", + "examples/js/utils/ImageUtils.js", + "examples/js/utils/ShadowMapViewer.js", + "examples/js/utils/UVsDebug.js", + ], +) + +filegroup( name = "web_animations_js", srcs = [ "web-animations.html", @@ -582,3 +894,25 @@ filegroup( "webcomponents-lite.min.js", ], ) + +filegroup( + name = "weblas", + srcs = [ + "benchmark.html", + "benchmark/sgemm.js", + "dist/weblas.js", + "index.js", + "lib/globals.js", + "lib/pipeline.js", + "lib/saxpycalculator.js", + "lib/sclmpcalculator.js", + "lib/sdwnscalculator.js", + "lib/sgemmcalculator.js", + "lib/sscalcalculator.js", + "lib/tensor.js", + "lib/test.js", + "lib/webgl.js", + "test.html", + "test/data/generate.js", + ], +) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index add5d97169..2221d26c37 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -148,6 +148,7 @@ filegroup( "//tensorflow/tensorboard/app:all_files", "//tensorflow/tensorboard/backend:all_files", "//tensorflow/tensorboard/components:all_files", + "//tensorflow/tensorboard/components/vz-projector:all_files", "//tensorflow/tensorboard/lib:all_files", "//tensorflow/tensorboard/lib/python:all_files", "//tensorflow/tensorboard/scripts:all_files", diff --git a/tensorflow/tensorboard/bower.json b/tensorflow/tensorboard/bower.json index 20095d9275..368b0f8abd 100644 --- a/tensorflow/tensorboard/bower.json +++ b/tensorflow/tensorboard/bower.json @@ -60,6 +60,7 @@ "iron-validatable-behavior": "PolymerElements/iron-validatable-behavior#1.1.1", "lodash": "3.8.0", "neon-animation": "PolymerElements/neon-animation#1.2.2", + "numericjs": "1.2.6", "paper-behaviors": "PolymerElements/paper-behaviors#1.0.11", "paper-button": "PolymerElements/paper-button#1.0.11", "paper-checkbox": "PolymerElements/paper-checkbox#1.1.3", @@ -70,6 +71,7 @@ "paper-icon-button": "PolymerElements/paper-icon-button#1.1.1", "paper-input": "PolymerElements/paper-input#1.1.14", "paper-item": "PolymerElements/paper-item#1.1.4", + "paper-listbox": "PolymerElements/paper-listbox#1.1.2", "paper-material": "PolymerElements/paper-material#1.0.6", "paper-menu": "PolymerElements/paper-menu#1.2.2", "paper-menu-button": "PolymerElements/paper-menu-button#1.5.0", @@ -84,10 +86,12 @@ "paper-toolbar": "PolymerElements/paper-toolbar#1.1.4", "paper-tooltip": "PolymerElements/paper-tooltip#1.1.2", "plottable": "1.16.1", - "polymer": "1.6.0", + "polymer": "1.6.1", "promise-polyfill": "polymerlabs/promise-polyfill#1.0.0", + "three.js": "threejs#r77", "web-animations-js": "web-animations/web-animations-js#2.2.1", - "webcomponentsjs": "webcomponents/webcomponentsjs#0.7.22" + "webcomponentsjs": "webcomponents/webcomponentsjs#0.7.22", + "weblas": "0.9.0" }, "description": "TensorBoard: Visualizations for TensorFlow", "devDependencies": { @@ -136,6 +140,7 @@ "iron-validatable-behavior": "1.1.1", "lodash": "3.8.0", "neon-animation": "1.2.2", + "numericjs": "1.2.6", "paper-behaviors": "1.0.11", "paper-button": "1.0.11", "paper-checkbox": "1.1.3", @@ -146,6 +151,7 @@ "paper-icon-button": "1.1.1", "paper-input": "1.1.14", "paper-item": "1.1.4", + "paper-listbox": "1.1.2", "paper-material": "1.0.6", "paper-menu": "1.2.2", "paper-menu-button": "1.5.0", @@ -160,10 +166,12 @@ "paper-toolbar": "1.1.4", "paper-tooltip": "1.1.2", "plottable": "1.16.1", - "polymer": "1.6.0", + "polymer": "1.6.1", "promise-polyfill": "1.0.0", + "three.js": "threejs#r77", "web-animations-js": "2.2.1", - "webcomponentsjs": "0.7.22" + "webcomponentsjs": "0.7.22", + "weblas": "0.9.0" }, "version": "0.0.0" } diff --git a/tensorflow/tensorboard/bower/BUILD b/tensorflow/tensorboard/bower/BUILD index 4b43b55844..88b9314080 100644 --- a/tensorflow/tensorboard/bower/BUILD +++ b/tensorflow/tensorboard/bower/BUILD @@ -34,6 +34,7 @@ filegroup( "@iron_validatable_behavior//:iron_validatable_behavior", "@lodash//:lodash", "@neon_animation//:neon_animation", + "@numericjs//:numericjs", "@paper_behaviors//:paper_behaviors", "@paper_button//:paper_button", "@paper_checkbox//:paper_checkbox", @@ -44,6 +45,7 @@ filegroup( "@paper_icon_button//:paper_icon_button", "@paper_input//:paper_input", "@paper_item//:paper_item", + "@paper_listbox//:paper_listbox", "@paper_material//:paper_material", "@paper_menu//:paper_menu", "@paper_menu_button//:paper_menu_button", @@ -60,7 +62,9 @@ filegroup( "@plottable//:plottable", "@polymer//:polymer", "@promise_polyfill//:promise_polyfill", + "@three_js//:three_js", "@web_animations_js//:web_animations_js", "@webcomponentsjs//:webcomponentsjs", + "@weblas//:weblas", ], ) diff --git a/tensorflow/tensorboard/components/vz-projector/BUILD b/tensorflow/tensorboard/components/vz-projector/BUILD new file mode 100644 index 0000000000..8c222be10e --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/BUILD @@ -0,0 +1,19 @@ +# Description: +# Package for the Embedding Projector component. +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/tensorboard/components/vz-projector/async.ts b/tensorflow/tensorboard/components/vz-projector/async.ts new file mode 100644 index 0000000000..9ee6746b8d --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/async.ts @@ -0,0 +1,51 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** Delay for running async tasks, in milliseconds. */ +const ASYNC_DELAY = 15; + +/** + * Runs an expensive task asynchronously with some delay + * so that it doesn't block the UI thread immediately. + */ +export function runAsyncTask<T>(message: string, task: () => T): Promise<T> { + updateMessage(message); + return new Promise<T>((resolve, reject) => { + d3.timer(() => { + try { + let result = task(); + // Clearing the old message. + updateMessage(); + resolve(result); + } catch (ex) { + updateMessage('Error: ' + ex.message); + reject(ex); + } + return true; + }, ASYNC_DELAY); + }); +} + +/** + * Updates the user message at the top of the page. If the provided msg is + * null, the message box is hidden from the user. + */ +export function updateMessage(msg?: string): void { + if (msg == null) { + d3.select('#notify-msg').style('display', 'none'); + } else { + d3.select('#notify-msg').style('display', 'block').text(msg); + } +} diff --git a/tensorflow/tensorboard/components/vz-projector/bh_tsne.ts b/tensorflow/tensorboard/components/vz-projector/bh_tsne.ts new file mode 100644 index 0000000000..95c207ec6c --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/bh_tsne.ts @@ -0,0 +1,472 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** + * This is a fork of the Karpathy's TSNE.js (original license below). + * This fork implements Barnes-Hut approximation and runs in O(NlogN) + * time, as opposed to the Karpathy's O(N^2) version. + * + * @author smilkov@google.com (Daniel Smilkov) + */ + +/** + * The MIT License (MIT) + * Copyright (c) 2015 Andrej Karpathy + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +import {SPNode, SPTree} from './sptree'; + +type AugmSPNode = SPNode&{numCells: number, yCell: number[], rCell: number}; + +/** + * Barnes-hut approximation level. Higher means more approximation and faster + * results. Recommended value mentioned in the paper is 0.8. + */ +const THETA = 0.8; + +// Variables used for memorizing the second random number since running +// gaussRandom() generates two random numbers at the cost of 1 atomic +// computation. This optimization results in 2X speed-up of the generator. +let return_v = false; +let v_val = 0.0; + +/** Returns a vector filled with zeros */ +function zerosArray(length: number): number[] { + let result = new Array(length); + for (let i = 0; i < length; ++i) { + result[i] = 0; + } + return result; +} + + +/** Returns the square euclidean distance between two vectors. */ +export function dist2(a: number[], b: number[]): number { + if (a.length != b.length) { + throw new Error('Vectors a and b must be of same length'); + } + + let result = 0; + for (let i = 0; i < a.length; ++i) { + let diff = a[i] - b[i]; + result += diff * diff; + } + return result; +} + +/** Returns the square euclidean distance between two 2D points. */ +export function dist2_2D(a: number[], b: number[]): number { + let dX = a[0] - b[0]; + let dY = a[1] - b[1]; + return dX * dX + dY * dY; +} + +/** Returns the square euclidean distance between two 3D points. */ +export function dist2_3D(a: number[], b: number[]): number { + let dX = a[0] - b[0]; + let dY = a[1] - b[1]; + let dZ = a[2] - b[2]; + return dX * dX + dY * dY + dZ * dZ; +} + +function gaussRandom(rng: () => number): number { + if (return_v) { + return_v = false; + return v_val; + } + let u = 2 * rng() - 1; + let v = 2 * rng() - 1; + let r = u * u + v * v; + if (r == 0 || r > 1) { + return gaussRandom(rng); + } + let c = Math.sqrt(-2 * Math.log(r) / r); + v_val = v * c; // cache this for next function call for efficiency + return_v = true; + return u * c; +}; + +// return random normal number +function randn(rng: () => number, mu: number, std: number) { + return mu + gaussRandom(rng) * std; +}; + +// utilitity that creates contiguous vector of zeros of size n +function zeros(n: number): Float64Array { + return new Float64Array(n); +}; + +// utility that returns a matrix filled with random numbers +// generated by the provided generator. +function randnMatrix(n: number, d: number, rng: () => number) { + let nd = n * d; + let x = zeros(nd); + for (let i = 0; i < nd; ++i) { + x[i] = randn(rng, 0.0, 1E-4); + } + return x; +}; + +// utility that returns a matrix filled with the provided value. +function arrayofs(n: number, d: number, val: number) { + let x: number[][] = []; + for (let i = 0; i < n; ++i) { + let row = new Array(d); + for (let j = 0; j < d; ++j) { + row[j] = val; + } + x.push(row); + } + return x; +}; + +// compute (p_{i|j} + p_{j|i})/(2n) +function nearest2P( + nearest: {index: number, dist: number}[][], perplexity: number, + tol: number) { + let N = nearest.length; + let Htarget = Math.log(perplexity); // target entropy of distribution + let P = zeros(N * N); // temporary probability matrix + let K = nearest[0].length; + let pRow: number[] = new Array(K); // pij[]. + + for (let i = 0; i < N; ++i) { + let neighbors = nearest[i]; + let betaMin = -Infinity; + let betaMax = Infinity; + let beta = 1; // initial value of precision + let maxTries = 50; + + // perform binary search to find a suitable precision beta + // so that the entropy of the distribution is appropriate + let numTries = 0; + while (true) { + // compute entropy and kernel row with beta precision + let psum = 0.0; + for (let k = 0; k < neighbors.length; ++k) { + let neighbor = neighbors[k]; + let pij = (i == neighbor.index) ? 0 : Math.exp(-neighbor.dist * beta); + pRow[k] = pij; + psum += pij; + } + // normalize p and compute entropy + let Hhere = 0.0; + for (let k = 0; k < pRow.length; ++k) { + pRow[k] /= psum; + let pij = pRow[k]; + if (pij > 1E-7) { + Hhere -= pij * Math.log(pij); + }; + } + + // adjust beta based on result + if (Hhere > Htarget) { + // entropy was too high (distribution too diffuse) + // so we need to increase the precision for more peaky distribution + betaMin = beta; // move up the bounds + if (betaMax === Infinity) { + beta = beta * 2; + } else { + beta = (beta + betaMax) / 2; + } + + } else { + // converse case. make distrubtion less peaky + betaMax = beta; + if (betaMin === -Infinity) { + beta = beta / 2; + } else { + beta = (beta + betaMin) / 2; + } + } + numTries++; + // stopping conditions: too many tries or got a good precision + if (numTries >= maxTries || Math.abs(Hhere - Htarget) < tol) { + break; + } + } + + // copy over the final prow to P at row i + for (let k = 0; k < pRow.length; ++k) { + let pij = pRow[k]; + let j = neighbors[k].index; + P[i * N + j] = pij; + } + } // end loop over examples i + + // symmetrize P and normalize it to sum to 1 over all ij + let N2 = N * 2; + for (let i = 0; i < N; ++i) { + for (let j = i + 1; j < N; ++j) { + let i_j = i * N + j; + let j_i = j * N + i; + let value = (P[i_j] + P[j_i]) / N2; + P[i_j] = value; + P[j_i] = value; + } + } + return P; +}; + +// helper function +function sign(x: number) { + return x > 0 ? 1 : x < 0 ? -1 : 0; +} + +export interface TSNEOptions { + /** How many dimensions. */ + dim: number; + /** Roughly how many neighbors each point influences. */ + perplexity?: number; + /** Learning rate. */ + epsilon?: number; + /** A random number generator. */ + rng?: () => number; +} + +export class TSNE { + private perplexity: number; + private epsilon: number; + /** Random generator */ + private rng: () => number; + private iter = 0; + private Y: Float64Array; + private N: number; + private P: Float64Array; + private gains: number[][]; + private ystep: number[][]; + private nearest: {index: number, dist: number}[][]; + private dim: number; + private dist2: (a: number[], b: number[]) => number; + + constructor(opt: TSNEOptions) { + opt = opt || {dim: 2}; + this.perplexity = opt.perplexity || 30; + this.epsilon = opt.epsilon || 10; + this.rng = opt.rng || Math.random; + this.dim = opt.dim; + if (opt.dim == 2) { + this.dist2 = dist2_2D; + } else if (opt.dim == 3) { + this.dist2 = dist2_3D; + } else { + this.dist2 = dist2; + } + } + + // this function takes a fattened distance matrix and creates + // matrix P from them. + // D is assumed to be provided as an array of size N^2. + initDataDist(nearest: {index: number, dist: number}[][]) { + let N = nearest.length; + this.nearest = nearest; + this.P = nearest2P(nearest, this.perplexity, 1E-4); + this.N = N; + this.initSolution(); // refresh this + } + + // (re)initializes the solution to random + initSolution() { + // generate random solution to t-SNE + this.Y = randnMatrix(this.N, this.dim, this.rng); // the solution + this.gains = arrayofs(this.N, this.dim, 1.0); // step gains + // to accelerate progress in unchanging directions + this.ystep = arrayofs(this.N, this.dim, 0.0); // momentum accumulator + this.iter = 0; + } + + // return pointer to current solution + getSolution() { return this.Y; } + + // perform a single step of optimization to improve the embedding + step() { + this.iter += 1; + let N = this.N; + + let grad = this.costGrad(this.Y); // evaluate gradient + + // perform gradient step + let ymean = zerosArray(this.dim); + for (let i = 0; i < N; ++i) { + for (let d = 0; d < this.dim; ++d) { + let gid = grad[i][d]; + let sid = this.ystep[i][d]; + let gainid = this.gains[i][d]; + + // compute gain update + let newgain = sign(gid) === sign(sid) ? gainid * 0.8 : gainid + 0.2; + if (newgain < 0.01) { + newgain = 0.01; // clamp + } + this.gains[i][d] = newgain; // store for next turn + + // compute momentum step direction + let momval = this.iter < 250 ? 0.5 : 0.8; + let newsid = momval * sid - this.epsilon * newgain * grad[i][d]; + this.ystep[i][d] = newsid; // remember the step we took + + // step! + let i_d = i * this.dim + d; + this.Y[i_d] += newsid; + ymean[d] += this.Y[i_d]; // accumulate mean so that we + // can center later + } + } + + // reproject Y to be zero mean + for (let i = 0; i < N; ++i) { + for (let d = 0; d < this.dim; ++d) { + this.Y[i * this.dim + d] -= ymean[d] / N; + } + } + } + + // return cost and gradient, given an arrangement + costGrad(Y: Float64Array): number[][] { + let N = this.N; + let P = this.P; + + // Trick that helps with local optima. + let alpha = this.iter < 100 ? 4 : 1; + + // Make data for the SP tree. + let points: number[][] = new Array(N); // (x, y)[] + for (let i = 0; i < N; ++i) { + let iTimesD = i * this.dim; + let row = new Array(this.dim); + for (let d = 0; d < this.dim; ++d) { + row[d] = Y[iTimesD + d]; + } + points[i] = row; + } + + // Make a tree. + let tree = new SPTree(points, 1); + let root = tree.root as AugmSPNode; + // Annotate the tree. + + let annotateTree = + (node: AugmSPNode): {numCells: number, yCell: number[]} => { + let numCells = node.points ? node.points.length : 0; + if (node.children == null) { + // Update the current node and tell the parent. + node.numCells = numCells; + // TODO(smilkov): yCell should be average across all points. + node.yCell = node.points[0]; + return {numCells, yCell: node.yCell}; + } + // TODO(smilkov): yCell should be average across all points. + let yCell = + node.points ? node.points[0].slice() : zerosArray(this.dim); + for (let i = 0; i < node.children.length; ++i) { + let child = node.children[i]; + if (child == null) { + continue; + } + let result = annotateTree(child as AugmSPNode); + numCells += result.numCells; + for (let d = 0; d < this.dim; ++d) { + yCell[d] += result.yCell[d]; + } + } + // Update the node and tell the parent. + node.numCells = numCells; + node.yCell = yCell.map(v => v / numCells); + return {numCells, yCell}; + }; + + // Augment the tree with more info. + annotateTree(root); + tree.visit((node: AugmSPNode, low: number[], high: number[]) => { + node.rCell = high[0] - low[0]; + return false; + }); + // compute current Q distribution, unnormalized first + let grad: number[][] = []; + let Z = 0; + let forces: [number[], number[]][] = new Array(N); + for (let i = 0; i < N; ++i) { + let pointI = points[i]; + // Compute the positive forces for the i-th node. + let Fpos = zerosArray(this.dim); + let neighbors = this.nearest[i]; + for (let k = 0; k < neighbors.length; ++k) { + let j = neighbors[k].index; + let pij = P[i * N + j]; + let pointJ = points[j]; + let squaredDistItoJ = this.dist2(pointI, pointJ); + let premult = pij / (1 + squaredDistItoJ); + for (let d = 0; d < this.dim; ++d) { + Fpos[d] += premult * (pointI[d] - pointJ[d]); + } + } + // Compute the negative forces for the i-th node. + let FnegZ = zerosArray(this.dim); + tree.visit((node: AugmSPNode) => { + let squaredDistToCell = this.dist2(pointI, node.yCell); + // Squared distance from point i to cell. + if (node.children == null || + (node.rCell / Math.sqrt(squaredDistToCell) < THETA)) { + let qijZ = 1 / (1 + squaredDistToCell); + let dZ = node.numCells * qijZ; + Z += dZ; + dZ *= qijZ; + for (let d = 0; d < this.dim; ++d) { + FnegZ[d] += dZ * (pointI[d] - node.yCell[d]); + } + return true; + } + if (node.points != null) { + // TODO(smilkov): Iterate over all points. + let squaredDistToPoint = this.dist2(pointI, node.points[0]); + let qijZ = 1 / (1 + squaredDistToPoint); + Z += qijZ; + qijZ *= qijZ; + for (let d = 0; d < this.dim; ++d) { + FnegZ[d] += qijZ * (pointI[d] - node.points[0][d]); + } + } + return false; + }, true); + forces[i] = [Fpos, FnegZ]; + } + // Normalize the negative forces and compute the gradient. + for (let i = 0; i < N; ++i) { + let [FPos, FNegZ] = forces[i]; + let gsum = new Array(this.dim); + for (let d = 0; d < this.dim; ++d) { + gsum[d] = 4 * (alpha * FPos[d] - FNegZ[d] / Z); + } + grad.push(gsum); + } + return grad; + } +} diff --git a/tensorflow/tensorboard/components/vz-projector/data.ts b/tensorflow/tensorboard/components/vz-projector/data.ts new file mode 100644 index 0000000000..dc23defb4a --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/data.ts @@ -0,0 +1,324 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {runAsyncTask} from './async'; +import {TSNE} from './bh_tsne'; +import * as knn from './knn'; +import * as scatter from './scatter'; +import {shuffle} from './util'; +import * as vector from './vector'; + + +/** + * A DataSource is our ground truth data. The original parsed data should never + * be modified, only copied out. + */ +export class DataSource { + originalDataSet: DataSet; + spriteImage: HTMLImageElement; + metadata: DatasetMetadata; + + /** A shallow-copy constructor. */ + makeShallowCopy(): DataSource { + let copy = new DataSource(); + copy.originalDataSet = this.originalDataSet; + copy.spriteImage = this.spriteImage; + copy.metadata = this.metadata; + return copy; + } + + /** Returns a new dataset. */ + getDataSet(subset?: number[]): DataSet { + let pointsSubset = subset ? + subset.map(i => this.originalDataSet.points[i]) : + this.originalDataSet.points; + return new DataSet(pointsSubset); + } +} + +export interface DataPoint extends scatter.DataPoint { + /** The point in the original space. */ + vector: number[]; + + /* + * Metadata for each point. Each metadata is a set of key/value pairs + * where the value can be a string or a number. + */ + metadata: {[key: string]: number | string}; + + /** This is where the calculated projections space are cached */ + projections: {[key: string]: number}; +} + +/** Checks to see if the browser supports webgl. */ +function hasWebGLSupport(): boolean { + try { + let c = document.createElement('canvas'); + let gl = c.getContext('webgl') || c.getContext('experimental-webgl'); + return gl != null && typeof weblas !== 'undefined'; + } catch (e) { + return false; + } +} + +const WEBGL_SUPPORT = hasWebGLSupport(); +const MAX_TSNE_ITERS = 500; +/** + * Sampling is used when computing expensive operations such as PCA, or T-SNE. + */ +const SAMPLE_SIZE = 10000; +/** Number of dimensions to sample when doing approximate PCA. */ +const PCA_SAMPLE_DIM = 100; +/** Number of pca components to compute. */ +const NUM_PCA_COMPONENTS = 10; +/** Reserved metadata attribute used for trace information. */ +const TRACE_METADATA_ATTR = '__next__'; + +/** + * Dataset contains a DataPoints array that should be treated as immutable. This + * acts as a working subset of the original data, with cached properties + * from computationally expensive operations. Because creating a subset + * requires normalizing and shifting the vector space, we make a copy of the + * data so we can still always create new subsets based on the original data. + */ +export class DataSet implements scatter.DataSet { + points: DataPoint[]; + traces: scatter.DataTrace[]; + + sampledDataIndices: number[] = []; + + /** + * This keeps a list of all current projections so you can easily test to see + * if it's been calculated already. + */ + projections = d3.set(); + nearest: knn.NearestEntry[][]; + nearestK: number; + tSNEShouldStop = true; + dim = [0, 0]; + private tsne: TSNE; + + /** + * Creates a new Dataset by copying out data from an array of datapoints. + * We make a copy because we have to modify the vectors by normalizing them. + */ + constructor(points: DataPoint[]) { + // Keep a list of indices seen so we don't compute traces for a given + // point twice. + let indicesSeen: boolean[] = []; + + this.points = []; + points.forEach(dp => { + this.points.push({ + metadata: dp.metadata, + dataSourceIndex: dp.dataSourceIndex, + vector: dp.vector.slice(), + projectedPoint: { + x: 0, + y: 0, + z: 0, + }, + projections: {} + }); + indicesSeen.push(false); + }); + + this.sampledDataIndices = + shuffle(d3.range(this.points.length)).slice(0, SAMPLE_SIZE); + this.traces = this.computeTraces(points, indicesSeen); + + this.normalize(); + this.dim = [this.points.length, this.points[0].vector.length]; + } + + private computeTraces(points: DataPoint[], indicesSeen: boolean[]) { + // Compute traces. + let indexToTrace: {[index: number]: scatter.DataTrace} = {}; + let traces: scatter.DataTrace[] = []; + for (let i = 0; i < points.length; i++) { + if (indicesSeen[i]) { + continue; + } + indicesSeen[i] = true; + + // Ignore points without a trace attribute. + let next = points[i].metadata[TRACE_METADATA_ATTR]; + if (next == null || next === '') { + continue; + } + if (next in indexToTrace) { + let existingTrace = indexToTrace[+next]; + // Pushing at the beginning of the array. + existingTrace.pointIndices.unshift(i); + indexToTrace[i] = existingTrace; + continue; + } + // The current point is pointing to a new/unseen trace. + let newTrace: scatter.DataTrace = {pointIndices: []}; + indexToTrace[i] = newTrace; + traces.push(newTrace); + let currentIndex = i; + while (points[currentIndex]) { + newTrace.pointIndices.push(currentIndex); + let next = points[currentIndex].metadata[TRACE_METADATA_ATTR]; + if (next != null && next !== '') { + indicesSeen[+next] = true; + currentIndex = +next; + } else { + currentIndex = -1; + } + } + } + return traces; + } + + + /** + * Computes the centroid, shifts all points to that centroid, + * then makes them all unit norm. + */ + private normalize() { + // Compute the centroid of all data points. + let centroid = + vector.centroid(this.points, () => true, a => a.vector).centroid; + if (centroid == null) { + throw Error('centroid should not be null'); + } + // Shift all points by the centroid and make them unit norm. + for (let id = 0; id < this.points.length; ++id) { + let dataPoint = this.points[id]; + dataPoint.vector = vector.sub(dataPoint.vector, centroid); + vector.unit(dataPoint.vector); + } + } + + /** Projects the dataset onto a given vector and caches the result. */ + projectLinear(dir: vector.Vector, label: string) { + this.projections.add(label); + this.points.forEach(dataPoint => { + dataPoint.projections[label] = vector.dot(dataPoint.vector, dir); + }); + } + + /** Projects the dataset along the top 10 principal components. */ + projectPCA(): Promise<void> { + if (this.projections.has('pca-0')) { + return Promise.resolve<void>(); + } + return runAsyncTask('Computing PCA...', () => { + // Approximate pca vectors by sampling the dimensions. + let numDim = Math.min(this.points[0].vector.length, PCA_SAMPLE_DIM); + let reducedDimData = + vector.projectRandom(this.points.map(d => d.vector), numDim); + let sigma = numeric.div( + numeric.dot(numeric.transpose(reducedDimData), reducedDimData), + reducedDimData.length); + let U: any; + U = numeric.svd(sigma).U; + let pcaVectors = reducedDimData.map(vector => { + let newV: number[] = []; + for (let d = 0; d < NUM_PCA_COMPONENTS; d++) { + let dot = 0; + for (let i = 0; i < vector.length; i++) { + dot += vector[i] * U[i][d]; + } + newV.push(dot); + } + return newV; + }); + for (let j = 0; j < NUM_PCA_COMPONENTS; j++) { + let label = 'pca-' + j; + this.projections.add(label); + this.points.forEach( + (d, i) => { d.projections[label] = pcaVectors[i][j]; }); + } + }); + } + + /** Runs tsne on the data. */ + projectTSNE( + perplexity: number, learningRate: number, tsneDim: number, + stepCallback: (iter: number) => void) { + let k = Math.floor(3 * perplexity); + let opt = {epsilon: learningRate, perplexity: perplexity, dim: tsneDim}; + this.tsne = new TSNE(opt); + this.tSNEShouldStop = false; + let iter = 0; + + let step = () => { + if (this.tSNEShouldStop || iter > MAX_TSNE_ITERS) { + stepCallback(null); + return; + } + this.tsne.step(); + let result = this.tsne.getSolution(); + this.sampledDataIndices.forEach((index, i) => { + let dataPoint = this.points[index]; + + dataPoint.projections['tsne-0'] = result[i * tsneDim + 0]; + dataPoint.projections['tsne-1'] = result[i * tsneDim + 1]; + if (tsneDim === 3) { + dataPoint.projections['tsne-2'] = result[i * tsneDim + 2]; + } + }); + iter++; + stepCallback(iter); + requestAnimationFrame(step); + }; + + // Nearest neighbors calculations. + let knnComputation: Promise<knn.NearestEntry[][]>; + + if (this.nearest != null && k === this.nearestK) { + // We found the nearest neighbors before and will reuse them. + knnComputation = Promise.resolve(this.nearest); + } else { + let sampledData = this.sampledDataIndices.map(i => this.points[i]); + this.nearestK = k; + knnComputation = WEBGL_SUPPORT ? + knn.findKNNGPUCosine(sampledData, k, (d => d.vector)) : + knn.findKNN( + sampledData, k, (d => d.vector), + (a, b, limit) => vector.cosDistNorm(a, b)); + } + knnComputation.then(nearest => { + this.nearest = nearest; + runAsyncTask('Initializing T-SNE...', () => { + this.tsne.initDataDist(this.nearest); + }).then(step); + + + }); + } + + stopTSNE() { this.tSNEShouldStop = true; } +} + +export interface DatasetMetadata { + /** + * Metadata for an associated image sprite. The sprite should be a matrix + * of smaller images, filled in a row-by-row order. + * + * E.g. the image for the first data point should be in the upper-left + * corner, and to the right of it should be the image of the second data + * point. + */ + image?: { + /** The file path pointing to the sprite image. */ + sprite_fpath: string; + /** The dimensions of the image for a single data point. */ + single_image_dim: [number, number]; + }; +} diff --git a/tensorflow/tensorboard/components/vz-projector/data_test.ts b/tensorflow/tensorboard/components/vz-projector/data_test.ts new file mode 100644 index 0000000000..07286f8bcc --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/data_test.ts @@ -0,0 +1,66 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {DataPoint, DataSet} from './data'; + + +/** + * Helper method that makes a list of points given an array of + * trace indexes. + * + * @param traces The i-th entry holds the 'next' attribute for the i-th point. + */ +function makePointsWithTraces(traces: number[]) { + let nextAttr = '__next__'; + let points: DataPoint[] = []; + traces.forEach((t, i) => { + let metadata: {[key: string]: any} = {}; + metadata[nextAttr] = t >= 0 ? t : null; + points.push({ + vector: [], + metadata: metadata, + projections: {}, + projectedPoint: null, + dataSourceIndex: i + }); + }); + return points; +} + +const assert = chai.assert; + +it('Simple forward pointing traces', () => { + // The input is: 0->2, 1->None, 2->3, 3->None. This should return + // one trace 0->2->3. + let points = makePointsWithTraces([2, -1, 3, -1]); + let dataset = new DataSet(points); + assert.equal(dataset.traces.length, 1); + assert.deepEqual(dataset.traces[0].pointIndices, [0, 2, 3]); +}); + +it('No traces', () => { + let points = makePointsWithTraces([-1, -1, -1, -1]); + let dataset = new DataSet(points); + assert.equal(dataset.traces.length, 0); +}); + +it('A trace that goes backwards and forward in the array', () => { + // The input is: 0->2, 1->0, 2->nothing, 3->1. This should return + // one trace 3->1->0->2. + let points = makePointsWithTraces([2, 0, -1, 1]); + let dataset = new DataSet(points); + assert.equal(dataset.traces.length, 1); + assert.deepEqual(dataset.traces[0].pointIndices, [3, 1, 0, 2]); +}); diff --git a/tensorflow/tensorboard/components/vz-projector/demo/index.html b/tensorflow/tensorboard/components/vz-projector/demo/index.html new file mode 100644 index 0000000000..17ca273a3f --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/demo/index.html @@ -0,0 +1,73 @@ +<!DOCTYPE html> +<!-- +@license +Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--> +<html> +<head> + <link rel="icon" type="image/png" href="favicon.png"> + <!-- Polyfill for non-Chrome browsers --> + <script src="../../webcomponentsjs/webcomponents-lite.js"></script> + <link rel="import" href="../vz-projector.html"> + <link rel="import" href="../../paper-icon-button/paper-icon-button.html"> + <!-- TODO(jart): Refactor js_binary rule into ../ rather than ../../ --> + <script src="../../bundle.js"></script> + <title>Embedding projector - visualization of high-dimensional data</title> + <style> + html { + width: 100%; + height: 100%; + } + + body { + display: flex; + flex-direction: column; + font-family: "Roboto", "Helvetica", "Arial", sans-serif; + margin: 0; + width: 100%; + height: 100%; + } + + #appbar { + display: flex; + align-items: center; + justify-content: space-between; + padding: 0 24px; + height: 60px; + color: white; + background: black; + } + + #appbar .logo { + font-size: 18px; + font-weight: 300; + } + + .icons { + display: flex; + } + + .icons a { + color: white; + } + </style> +</head> +<body> + <div id="appbar"> + <div>Embedding Projector</div> + </div> + <vz-projector></vz-projector> +</body> +</html> diff --git a/tensorflow/tensorboard/components/vz-projector/external.d.ts b/tensorflow/tensorboard/components/vz-projector/external.d.ts new file mode 100644 index 0000000000..673896557b --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/external.d.ts @@ -0,0 +1,25 @@ +// TODO(smilkov): Split into weblas.d.ts and numeric.d.ts and write +// typings for numeric. +interface Tensor { + new(size: [number, number], data: Float32Array); + transfer(): Float32Array; + delete(): void; +} + +interface Weblas { + sgemm(M: number, N: number, K: number, alpha: number, + A: Float32Array, B: Float32Array, beta: number, C: Float32Array): + Float32Array; + pipeline: { + Tensor: Tensor; + sgemm(alpha: number, A: Tensor, B: Tensor, beta: number, + C: Tensor): Tensor; + }; + util: { + transpose(M: number, N: number, data: Float32Array): Tensor; + }; + +} + +declare let numeric: any; +declare let weblas: Weblas;
\ No newline at end of file diff --git a/tensorflow/tensorboard/components/vz-projector/heap.ts b/tensorflow/tensorboard/components/vz-projector/heap.ts new file mode 100644 index 0000000000..35f178e000 --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/heap.ts @@ -0,0 +1,146 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** Min key heap. */ +export type HeapItem<T> = { + key: number, + value: T +}; + +/** + * Min-heap data structure. Provides O(1) for peek, returning the smallest key. + */ +// TODO(jart): Rename to Heap and use Comparator. +export class MinHeap<T> { + private arr: HeapItem<T>[] = []; + + /** Push an element with the provided key. */ + push(key: number, value: T): void { + this.arr.push({key, value}); + this.bubbleUp(this.arr.length - 1); + } + + /** Pop the element with the smallest key. */ + pop(): HeapItem<T> { + if (this.arr.length === 0) { + throw new Error('pop() called on empty binary heap'); + } + let item = this.arr[0]; + let last = this.arr.length - 1; + this.arr[0] = this.arr[last]; + this.arr.pop(); + if (last > 0) { + this.bubbleDown(0); + } + return item; + }; + + /** Returns, but doesn't remove the element with the smallest key */ + peek(): HeapItem<T> { return this.arr[0]; } + + /** + * Pops the element with the smallest key and at the same time + * adds the newly provided element. This is faster than calling + * pop() and push() separately. + */ + popPush(key: number, value: T): HeapItem<T> { + if (this.arr.length === 0) { + throw new Error('pop() called on empty binary heap'); + } + let item = this.arr[0]; + this.arr[0] = {key, value}; + if (this.arr.length > 0) { + this.bubbleDown(0); + } + return item; + } + + /** Returns the number of elements in the heap. */ + size(): number { return this.arr.length; } + + /** Returns all the items in the heap. */ + items(): HeapItem<T>[] { return this.arr; } + + private swap(a: number, b: number) { + let temp = this.arr[a]; + this.arr[a] = this.arr[b]; + this.arr[b] = temp; + } + + private bubbleDown(pos: number) { + let left = (pos << 1) + 1; + let right = left + 1; + let largest = pos; + if (left < this.arr.length && this.arr[left].key < this.arr[largest].key) { + largest = left; + } + if (right < this.arr.length && + this.arr[right].key < this.arr[largest].key) { + largest = right; + } + if (largest != pos) { + this.swap(largest, pos); + this.bubbleDown(largest); + } + } + + private bubbleUp(pos: number) { + if (pos <= 0) { + return; + } + let parent = ((pos - 1) >> 1); + if (this.arr[pos].key < this.arr[parent].key) { + this.swap(pos, parent); + this.bubbleUp(parent); + } + } +} + +/** List that keeps the K elements with the smallest keys. */ +export class KMin<T> { + private k: number; + private maxHeap = new MinHeap<T>(); + + /** Constructs a new k-min data structure with the provided k. */ + constructor(k: number) { this.k = k; } + + /** Adds an element to the list. */ + add(key: number, value: T) { + if (this.maxHeap.size() < this.k) { + this.maxHeap.push(-key, value); + return; + } + let largest = this.maxHeap.peek(); + // If the new element is smaller, replace the largest with the new element. + if (key < -largest.key) { + this.maxHeap.popPush(-key, value); + } + } + + /** Returns the k items with the smallest keys. */ + getMinKItems(): T[] { + let items = this.maxHeap.items(); + items.sort((a, b) => b.key - a.key); + return items.map(a => a.value); + } + + /** Returns the size of the list. */ + getSize(): number { return this.maxHeap.size(); } + + /** Returns the largest key in the list. */ + getLargestKey(): number { + return this.maxHeap.size() == 0 ? null : -this.maxHeap.peek().key; + } +} diff --git a/tensorflow/tensorboard/components/vz-projector/knn.ts b/tensorflow/tensorboard/components/vz-projector/knn.ts new file mode 100644 index 0000000000..4d64595a3f --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/knn.ts @@ -0,0 +1,223 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {runAsyncTask} from './async'; +import {KMin} from './heap'; +import * as vector from './vector'; + +export type NearestEntry = { + index: number, + dist: number +}; + +/** + * Optimal size for the height of the matrix when doing computation on the GPU + * using WebGL. This was found experimentally. + * + * This also guarantees that for computing pair-wise distance for up to 10K + * vectors, no more than 40MB will be allocated in the GPU. Without the + * allocation limit, we can freeze the graphics of the whole OS. + */ +const OPTIMAL_GPU_BLOCK_SIZE = 512; + +/** + * Returns the K nearest neighbors for each vector where the distance + * computation is done on the GPU (WebGL) using cosine distance. + * + * @param dataPoints List of data points, where each data point holds an + * n-dimensional vector. + * @param k Number of nearest neighbors to find. + * @param accessor A method that returns the vector, given the data point. + */ +export function findKNNGPUCosine<T>( + dataPoints: T[], k: number, + accessor: (dataPoint: T) => number[]): Promise<NearestEntry[][]> { + let N = dataPoints.length; + let dim = accessor(dataPoints[0]).length; + + // The goal is to compute a large matrix multiplication A*A.T where A is of + // size NxD and A.T is its transpose. This results in a NxN matrix which + // could be too big to store on the GPU memory. To avoid memory overflow, we + // compute multiple A*partial_A.T where partial_A is of size BxD (B is much + // smaller than N). This results in storing only NxB size matrices on the GPU + // at a given time. + + // A*A.T will give us NxN matrix holding the cosine distance between every + // pair of points, which we sort using KMin data structure to obtain the + // K nearest neighbors for each point. + let typedArray = vector.toTypedArray(dataPoints, accessor); + let bigMatrix = new weblas.pipeline.Tensor([N, dim], typedArray); + let nearest: NearestEntry[][] = new Array(N); + let numPieces = Math.ceil(N / OPTIMAL_GPU_BLOCK_SIZE); + let M = Math.floor(N / numPieces); + let modulo = N % numPieces; + let offset = 0; + let progress = 0; + let progressDiff = 1 / (2 * numPieces); + let piece = 0; + + function step(resolve: (result: NearestEntry[][]) => void) { + let progressMsg = + 'Finding nearest neighbors: ' + (progress * 100).toFixed() + '%'; + runAsyncTask(progressMsg, () => { + let B = piece < modulo ? M + 1 : M; + let typedB = new Float32Array(B * dim); + for (let i = 0; i < B; ++i) { + let vector = accessor(dataPoints[offset + i]); + for (let d = 0; d < dim; ++d) { + typedB[i * dim + d] = vector[d]; + } + } + let partialMatrix = new weblas.pipeline.Tensor([B, dim], typedB); + // Result is N x B matrix. + let result = + weblas.pipeline.sgemm(1, bigMatrix, partialMatrix, null, null); + let partial = result.transfer(); + partialMatrix.delete(); + result.delete(); + progress += progressDiff; + for (let i = 0; i < B; i++) { + let kMin = new KMin<NearestEntry>(k); + let iReal = offset + i; + for (let j = 0; j < N; j++) { + if (j === iReal) { + continue; + } + let cosDist = 1 - partial[j * B + i]; // [j, i]; + kMin.add(cosDist, {index: j, dist: cosDist}); + } + nearest[iReal] = kMin.getMinKItems(); + } + progress += progressDiff; + offset += B; + piece++; + }).then(() => { + if (piece < numPieces) { + step(resolve); + } else { + bigMatrix.delete(); + resolve(nearest); + } + }); + } + return new Promise<NearestEntry[][]>(resolve => step(resolve)); +} + +/** + * Returns the K nearest neighbors for each vector where the distance + * computation is done on the CPU using a user-specified distance method. + * + * @param dataPoints List of data points, where each data point holds an + * n-dimensional vector. + * @param k Number of nearest neighbors to find. + * @param accessor A method that returns the vector, given the data point. + * @param dist Method that takes two vectors and a limit, and computes the + * distance between two vectors, with the ability to stop early if the + * distance is above the limit. + */ +export function findKNN<T>( + dataPoints: T[], k: number, accessor: (dataPoint: T) => number[], + dist: (a: number[], b: number[], limit: number) => + number): Promise<NearestEntry[][]> { + return runAsyncTask<NearestEntry[][]>('Finding nearest neighbors...', () => { + let N = dataPoints.length; + let nearest: NearestEntry[][] = new Array(N); + // Find the distances from node i. + let kMin: KMin<NearestEntry>[] = new Array(N); + for (let i = 0; i < N; i++) { + kMin[i] = new KMin<NearestEntry>(k); + } + for (let i = 0; i < N; i++) { + let a = accessor(dataPoints[i]); + let kMinA = kMin[i]; + for (let j = i + 1; j < N; j++) { + let kMinB = kMin[j]; + let limitI = kMinA.getSize() === k ? + kMinA.getLargestKey() || Number.MAX_VALUE : + Number.MAX_VALUE; + let limitJ = kMinB.getSize() === k ? + kMinB.getLargestKey() || Number.MAX_VALUE : + Number.MAX_VALUE; + let limit = Math.max(limitI, limitJ); + let dist2ItoJ = dist(a, accessor(dataPoints[j]), limit); + if (dist2ItoJ >= 0) { + kMinA.add(dist2ItoJ, {index: j, dist: dist2ItoJ}); + kMinB.add(dist2ItoJ, {index: i, dist: dist2ItoJ}); + } + } + } + for (let i = 0; i < N; i++) { + nearest[i] = kMin[i].getMinKItems(); + } + return nearest; + }); +} + +/** Calculates the minimum distance between a search point and a rectangle. */ +function minDist( + point: [number, number], x1: number, y1: number, x2: number, y2: number) { + let x = point[0]; + let y = point[1]; + let dx1 = x - x1; + let dx2 = x - x2; + let dy1 = y - y1; + let dy2 = y - y2; + + if (dx1 * dx2 <= 0) { // x is between x1 and x2 + if (dy1 * dy2 <= 0) { // (x,y) is inside the rectangle + return 0; // return 0 as point is in rect + } + return Math.min(Math.abs(dy1), Math.abs(dy2)); + } + if (dy1 * dy2 <= 0) { // y is between y1 and y2 + // We know it is already inside the rectangle + return Math.min(Math.abs(dx1), Math.abs(dx2)); + } + let corner: [number, number]; + if (x > x2) { + // Upper-right vs lower-right. + corner = y > y2 ? [x2, y2] : [x2, y1]; + } else { + // Upper-left vs lower-left. + corner = y > y2 ? [x1, y2] : [x1, y1]; + } + return Math.sqrt(vector.dist22D([x, y], corner)); +} + +/** + * Returns the nearest neighbors of a particular point. + * + * @param dataPoints List of data points. + * @param pointIndex The index of the point we need the nearest neighbors of. + * @param k Number of nearest neighbors to search for. + * @param accessor Method that maps a data point => vector (array of numbers). + * @param distance Method that takes two vectors and returns their distance. + */ +export function findKNNofPoint<T>( + dataPoints: T[], pointIndex: number, k: number, + accessor: (dataPoint: T) => number[], + distance: (a: number[], b: number[]) => number) { + let kMin = new KMin<NearestEntry>(k); + let a = accessor(dataPoints[pointIndex]); + for (let i = 0; i < dataPoints.length; ++i) { + if (i == pointIndex) { + continue; + } + let b = accessor(dataPoints[i]); + let dist = distance(a, b); + kMin.add(dist, {index: i, dist: dist}); + } + return kMin.getMinKItems(); +} diff --git a/tensorflow/tensorboard/components/vz-projector/label.ts b/tensorflow/tensorboard/components/vz-projector/label.ts new file mode 100644 index 0000000000..9689ef5869 --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/label.ts @@ -0,0 +1,151 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +export interface BoundingBox { + loX: number; + loY: number; + hiX: number; + hiY: number; +} + +/** + * Accelerates label placement by dividing the view into a uniform grid. + * Labels only need to be tested for collision with other labels that overlap + * the same grid cells. This is a fork of {@code amoeba.CollisionGrid}. + */ +export class CollisionGrid { + private numHorizCells: number; + private numVertCells: number; + private grid: BoundingBox[][]; + private bound: BoundingBox; + private cellWidth: number; + private cellHeight: number; + + /** + * Constructs a new Collision grid. + * + * @param bound The bound of the grid. Labels out of bounds will be rejected. + * @param cellWidth Width of a cell in the grid. + * @param cellHeight Height of a cell in the grid. + */ + constructor(bound: BoundingBox, cellWidth: number, cellHeight: number) { + /** The bound of the grid. Labels out of bounds will be rejected. */ + this.bound = bound; + + /** Width of a cell in the grid. */ + this.cellWidth = cellWidth; + + /** Height of a cell in the grid. */ + this.cellHeight = cellHeight; + + /** Number of grid cells along the x axis. */ + this.numHorizCells = Math.ceil(this.boundWidth(bound) / cellWidth); + + /** Number of grid cells along the y axis. */ + this.numVertCells = Math.ceil(this.boundHeight(bound) / cellHeight); + + /** + * The 2d grid (stored as a 1d array.) Each cell consists of an array of + * BoundingBoxes for objects that are in the cell. + */ + this.grid = new Array(this.numHorizCells * this.numVertCells); + } + + private boundWidth(bound: BoundingBox) { return bound.hiX - bound.loX; } + + private boundHeight(bound: BoundingBox) { return bound.hiY - bound.loY; } + + private boundsIntersect(a: BoundingBox, b: BoundingBox) { + return !(a.loX > b.hiX || a.loY > b.hiY || a.hiX < b.loX || a.hiY < b.loY); + } + + /** + * Checks if a given bounding box has any conflicts in the grid and inserts it + * if none are found. + * + * @param bound The bound to insert. + * @param justTest If true, just test if it conflicts, without inserting. + * @return True if the bound was successfully inserted; false if it + * could not be inserted due to a conflict. + */ + insert(bound: BoundingBox, justTest = false): boolean { + // Reject if the label is out of bounds. + if (bound.loX < this.bound.loX || bound.hiX > this.bound.hiX || + bound.loY < this.bound.loY || bound.hiY > this.bound.hiY) { + return false; + } + + let minCellX = this.getCellX(bound.loX); + let maxCellX = this.getCellX(bound.hiX); + let minCellY = this.getCellY(bound.loY); + let maxCellY = this.getCellY(bound.hiY); + + // Check all overlapped cells to verify that we can insert. + let baseIdx = minCellY * this.numHorizCells + minCellX; + let idx = baseIdx; + for (let j = minCellY; j <= maxCellY; j++) { + for (let i = minCellX; i <= maxCellX; i++) { + let cell = this.grid[idx++]; + if (cell) { + for (let k = 0; k < cell.length; k++) { + if (this.boundsIntersect(bound, cell[k])) { + return false; + } + } + } + } + idx += this.numHorizCells - (maxCellX - minCellX + 1); + } + + if (justTest) { + return true; + } + + // Insert into the overlapped cells. + idx = baseIdx; + for (let j = minCellY; j <= maxCellY; j++) { + for (let i = minCellX; i <= maxCellX; i++) { + if (!this.grid[idx]) { + this.grid[idx] = [bound]; + } else { + this.grid[idx].push(bound); + } + idx++; + } + idx += this.numHorizCells - (maxCellX - minCellX + 1); + } + return true; + } + + /** + * Returns the x index of the grid cell where the given x coordinate falls. + * + * @param x the coordinate, in world space. + * @return the x index of the cell. + */ + private getCellX(x: number) { + return Math.floor((x - this.bound.loX) / this.cellWidth); + }; + + /** + * Returns the y index of the grid cell where the given y coordinate falls. + * + * @param y the coordinate, in world space. + * @return the y index of the cell. + */ + private getCellY(y: number) { + return Math.floor((y - this.bound.loY) / this.cellHeight); + }; +}
\ No newline at end of file diff --git a/tensorflow/tensorboard/components/vz-projector/scatter.ts b/tensorflow/tensorboard/components/vz-projector/scatter.ts new file mode 100644 index 0000000000..392d8085c6 --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/scatter.ts @@ -0,0 +1,135 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +export interface Point3D { + /** Original x coordinate. */ + x: number; + /** Original y coordinate. */ + y: number; + /** Original z coordinate. */ + z: number; +} +; + +/** The spacial data of points and lines that will be shown in the projector. */ +export interface DataSet { + points: DataPoint[]; + traces: DataTrace[]; +} + +/** + * Points in 3D space that will be used in the projector. If the projector is + * in 2D mode, the Z coordinate of the point will be 0. + */ +export interface DataPoint { + projectedPoint: Point3D; + /** index of the trace, used for highlighting on click */ + traceIndex?: number; + /** index in the original data source */ + dataSourceIndex: number; +} + +/** A single collection of points which make up a trace through space. */ +export interface DataTrace { + /** Indices into the DataPoints array in the Data object. */ + pointIndices: number[]; +} + +export type OnHoverListener = (index: number) => void; +export type OnSelectionListener = (indexes: number[]) => void; + +/** Supported modes of interaction. */ +export enum Mode { + SELECT, + SEARCH, + HOVER +} + +export interface Scatter { + /** Sets the data for the scatter plot. */ + setDataSet(dataSet: DataSet, spriteImage?: HTMLImageElement): void; + /** Called with each data point in order to get its color. */ + setColorAccessor(colorAccessor: ((index: number) => string)): void; + /** Called with each data point in order to get its label. */ + setLabelAccessor(labelAccessor: ((index: number) => string)): void; + /** Called with each data point in order to get its x coordinate. */ + setXAccessor(xAccessor: ((index: number) => number)): void; + /** Called with each data point in order to get its y coordinate. */ + setYAccessor(yAccessor: ((index: number) => number)): void; + /** Called with each data point in order to get its z coordinate. */ + setZAccessor(zAccessor: ((index: number) => number)): void; + /** Sets the interaction mode (search, select or hover). */ + setMode(mode: Mode): void; + /** Returns the interaction mode. */ + getMode(): Mode; + /** Resets the zoom level to 1.*/ + resetZoom(): void; + /** + * Increases/decreases the zoom level. + * + * @param multiplier New zoom level = old zoom level * multiplier. + */ + zoomStep(multiplier: number): void; + /** + * Highlights the provided points. + * + * @param pointIndexes List of point indexes to highlight. If null, + * un-highlights all the points. + * @param stroke The stroke color used to highlight the point. + * @param favorLabels Whether to favor plotting the labels of the + * highlighted point. Default is false for all points. + */ + highlightPoints( + pointIndexes: number[], highlightStroke?: (index: number) => string, + favorLabels?: (index: number) => boolean): void; + /** Whether to show labels or not. */ + showLabels(show: boolean): void; + /** Toggle between day and night modes. */ + setDayNightMode(isNight: boolean): void; + /** Show/hide tick labels. */ + showTickLabels(show: boolean): void; + /** Whether to show axes or not. */ + showAxes(show: boolean): void; + /** Sets the axis labels. */ + setAxisLabels(xLabel: string, yLabel: string): void; + /** + * Recreates the scene (demolishes all datastructures, etc.) + */ + recreateScene(): void; + /** + * Redraws the data. Should be called anytime the accessor method + * for x and y coordinates changes, which means a new projection + * exists and the scatter plot should repaint the points. + */ + update(): void; + /** + * Should be called to notify the scatter plot that the container + * was resized and it should resize and redraw itself. + */ + resize(): void; + /** Registers a listener that will be called when selects some points. */ + onSelection(listener: OnSelectionListener): void; + /** + * Registers a listener that will be called when the user hovers over + * a point. + */ + onHover(listener: OnHoverListener): void; + /** + * Should emulate the same behavior as if the user clicked on the point. + * This is used to trigger a click from an external event, such as + * a search query. + */ + clickOnPoint(pointIndex: number): void; +} diff --git a/tensorflow/tensorboard/components/vz-projector/scatterWebGL.ts b/tensorflow/tensorboard/components/vz-projector/scatterWebGL.ts new file mode 100644 index 0000000000..79cee64c5b --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/scatterWebGL.ts @@ -0,0 +1,1510 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {BoundingBox, CollisionGrid} from './label'; +import {DataSet, Mode, OnHoverListener, OnSelectionListener, Point3D, Scatter} from './scatter'; +import {shuffle} from './util'; + +const FONT_SIZE = 10; + +// Colors (in various necessary formats). +const BACKGROUND_COLOR_DAY = 0xffffff; +const BACKGROUND_COLOR_NIGHT = 0x000000; +const AXIS_COLOR = 0xb3b3b3; +const LABEL_COLOR_DAY = 0x000000; +const LABEL_COLOR_NIGHT = 0xffffff; +const LABEL_STROKE_DAY = 0xffffff; +const LABEL_STROKE_NIGHT = 0x000000; +const POINT_COLOR = 0x7575D9; +const POINT_COLOR_GRAYED = 0x888888; +const BLENDING_DAY = THREE.MultiplyBlending; +const BLENDING_NIGHT = THREE.AdditiveBlending; +const TRACE_START_HUE = 60; +const TRACE_END_HUE = 360; +const TRACE_SATURATION = 1; +const TRACE_LIGHTNESS = .3; +const TRACE_DEFAULT_OPACITY = .2; +const TRACE_DEFAULT_LINEWIDTH = 2; +const TRACE_SELECTED_OPACITY = .9; +const TRACE_SELECTED_LINEWIDTH = 3; +const TRACE_DESELECTED_OPACITY = .05; + +// Various distance bounds. +const MAX_ZOOM = 10; +const MIN_ZOOM = .05; +const NUM_POINTS_FOG_THRESHOLD = 5000; +const MIN_POINT_SIZE = 5.0; +const IMAGE_SIZE = 30; + +// Constants relating to the camera parameters. +/** Camera frustum vertical field of view. */ +const FOV = 70; +const NEAR = 0.01; +const FAR = 100; + +// Constants relating to the indices of buffer arrays. +/** Item size of a single point in a bufferArray representing colors */ +const RGB_NUM_BYTES = 3; +/** Item size of a single point in a bufferArray representing indices */ +const INDEX_NUM_BYTES = 1; +/** Item size of a single point in a bufferArray representing locations */ +const XYZ_NUM_BYTES = 3; + +// Key presses. +const SHIFT_KEY = 16; +const CTRL_KEY = 17; + +// Original positions of camera and camera target, in 2d and 3d +const POS_3D = { + x: 1.5, + y: 1.5, + z: 1.5 +}; + +// Target for the camera in 3D is the center of the 1, 1, 1 square, as all our +// data is scaled to this. +const TAR_3D = { + x: 0, + y: 0, + z: 0 +}; + +const POS_2D = { + x: 0, + y: 0, + z: 2 +}; + +// In 3D, the target is the center of the xy plane. +const TAR_2D = { + x: 0, + y: 0, + z: 0 +}; + +// The maximum number of labels to draw to keep the frame rate up. +const SAMPLE_SIZE = 10000; + +// Shaders for images. +const VERTEX_SHADER = ` + // Index of the specific vertex (passed in as bufferAttribute), and the + // variable that will be used to pass it to the fragment shader. + attribute float vertexIndex; + varying vec2 xyIndex; + + // Similar to above, but for colors. + attribute vec3 color; + varying vec3 vColor; + + // If the point is highlighted, this will be 1.0 (else 0.0). + attribute float isHighlight; + + // Uniform passed in as a property from THREE.ShaderMaterial. + uniform bool sizeAttenuation; + uniform float pointSize; + uniform float imageWidth; + uniform float imageHeight; + + void main() { + // Pass index and color values to fragment shader. + vColor = color; + xyIndex = vec2(mod(vertexIndex, imageWidth), + floor(vertexIndex / imageWidth)); + + // Transform current vertex by modelViewMatrix (model world position and + // camera world position matrix). + vec4 mvPosition = modelViewMatrix * vec4(position, 1.0); + + // Project vertex in camera-space to screen coordinates using the camera's + // projection matrix. + gl_Position = projectionMatrix * mvPosition; + + // Create size attenuation (if we're in 3D mode) by making the size of + // each point inversly proportional to its distance to the camera. + float attenuatedSize = - pointSize / mvPosition.z; + gl_PointSize = sizeAttenuation ? attenuatedSize : pointSize; + + // If the point is a highlight, make it slightly bigger than the other + // points, and also don't let it get smaller than some threshold. + if (isHighlight == 1.0) { + gl_PointSize = max(gl_PointSize * 1.2, ${MIN_POINT_SIZE.toFixed(1)}); + }; + }`; + +const FRAGMENT_SHADER = ` + // Values passed in from the vertex shader. + varying vec2 xyIndex; + varying vec3 vColor; + + // Adding in the THREEjs shader chunks for fog. + ${THREE.ShaderChunk['common']} + ${THREE.ShaderChunk['fog_pars_fragment']} + + // Uniforms passed in as properties from THREE.ShaderMaterial. + uniform sampler2D texture; + uniform float imageWidth; + uniform float imageHeight; + uniform bool isImage; + + void main() { + // A mystery variable that is required to make the THREE shaderchunk for fog + // work correctly. + vec3 outgoingLight = vec3(0.0); + + if (isImage) { + // Coordinates of the vertex within the entire sprite image. + vec2 coords = (gl_PointCoord + xyIndex) / vec2(imageWidth, imageHeight); + // Determine the color of the spritesheet at the calculate spot. + vec4 fromTexture = texture2D(texture, coords); + + // Finally, set the fragment color. + gl_FragColor = vec4(vColor, 1.0) * fromTexture; + } else { + // Discard pixels outside the radius so points are rendered as circles. + vec2 uv = gl_PointCoord.xy - 0.5; + if (length(uv) > 0.5) discard; + + // If the point is not an image, just color it. + gl_FragColor = vec4(vColor, 1.0); + } + ${THREE.ShaderChunk['fog_fragment']} + }`; + +export class ScatterWebGL implements Scatter { + // MISC UNINITIALIZED VARIABLES. + + // Colors and options that are changed between Day and Night modes. + private backgroundColor: number; + private labelColor: number; + private labelStroke: number; + private blending: THREE.Blending; + private isNight: boolean; + + // THREE.js necessities. + private scene: THREE.Scene; + private perspCamera: THREE.PerspectiveCamera; + private renderer: THREE.WebGLRenderer; + private cameraControls: any; + private light: THREE.PointLight; + private fog: THREE.Fog; + + // Data structures (and THREE.js objects) associated with points. + private geometry: THREE.BufferGeometry; + /** Texture for rendering offscreen in order to enable interactive hover. */ + private pickingTexture: THREE.WebGLRenderTarget; + /** Array of unique colors for each point used in detecting hover. */ + private uniqueColArr: Float32Array; + private materialOptions: THREE.ShaderMaterialParameters; + private points: THREE.Points; + private traces: THREE.Line[]; + private dataSet: DataSet; + private shuffledData: number[]; + /** Holds the indexes of the points to be labeled. */ + private labeledPoints: number[] = []; + private highlightedPoints: number[] = []; + private nearestPoint: number; + private pointSize2D: number; + private pointSize3D: number; + /** The buffer attribute that holds the positions of the points. */ + private positionBufferArray: THREE.BufferAttribute; + + // Accessors for rendering and labeling the points. + private xAccessor: (index: number) => number; + private yAccessor: (index: number) => number; + private zAccessor: (index: number) => number; + private labelAccessor: (index: number) => string; + private colorAccessor: (index: number) => string; + private highlightStroke: (i: number) => string; + private favorLabels: (i: number) => boolean; + + // Scaling functions for each axis. + private xScale: d3.scale.Linear<number, number>; + private yScale: d3.scale.Linear<number, number>; + private zScale: d3.scale.Linear<number, number>; + + // Listeners + private onHoverListeners: OnHoverListener[] = []; + private onSelectionListeners: OnSelectionListener[] = []; + private lazySusanAnimation: number; + + // Other variables associated with layout and interaction. + private height: number; + private width: number; + private mode: Mode; + /** Whether the user has turned labels on or off. */ + private labelsAreOn = true; + /** Whether the label canvas has been already cleared. */ + private labelCanvasIsCleared = true; + + private animating: boolean; + private axis3D: THREE.AxisHelper; + private axis2D: THREE.LineSegments; + private dpr: number; // The device pixelratio + private selecting = false; // whether or not we are selecting points. + private mouseIsDown = false; + // Whether the current click sequence contains a drag, so we can determine + // whether to update the selection. + private isDragSequence = false; + private selectionSphere: THREE.Mesh; + private image: HTMLImageElement; + private animationID: number; + /** Color of any point not selected (or NN of selected) */ + private defaultPointColor = POINT_COLOR; + + // HTML elements. + private gc: CanvasRenderingContext2D; + private containerNode: HTMLElement; + private canvas: HTMLCanvasElement; + + /** Get things started up! */ + constructor( + container: d3.Selection<any>, labelAccessor: (index: number) => string) { + this.labelAccessor = labelAccessor; + this.xScale = d3.scale.linear(); + this.yScale = d3.scale.linear(); + this.zScale = d3.scale.linear(); + + // Set up non-THREEjs layout. + this.containerNode = container.node() as HTMLElement; + this.getLayoutValues(); + + // For now, labels are drawn on this transparent canvas with no touch events + // rather than being rendered in webGL. + this.canvas = container.append('canvas').node() as HTMLCanvasElement; + this.gc = this.canvas.getContext('2d'); + d3.select(this.canvas).style({position: 'absolute', left: 0, top: 0}); + this.canvas.style.pointerEvents = 'none'; + + // Set up THREE.js. + this.createSceneAndRenderer(); + this.setDayNightMode(false); + this.createLight(); + this.makeCamera(); + this.resize(false); + // Render now so no black background appears during startup. + this.renderer.render(this.scene, this.perspCamera); + // Add interaction listeners. + this.addInteractionListeners(); + } + + // SET UP + private addInteractionListeners() { + this.containerNode.addEventListener( + 'mousemove', this.onMouseMove.bind(this)); + this.containerNode.addEventListener( + 'mousedown', this.onMouseDown.bind(this)); + this.containerNode.addEventListener('mouseup', this.onMouseUp.bind(this)); + this.containerNode.addEventListener('click', this.onClick.bind(this)); + window.addEventListener('keydown', this.onKeyDown.bind(this), false); + window.addEventListener('keyup', this.onKeyUp.bind(this), false); + } + + /** Updates the positions buffer array to reflect the actual data. */ + private updatePositionsArray() { + for (let i = 0; i < this.dataSet.points.length; i++) { + // Set position based on projected point. + let pp = this.dataSet.points[i].projectedPoint; + this.positionBufferArray.setXYZ(i, pp.x, pp.y, pp.z); + } + if (this.geometry) { + this.positionBufferArray.needsUpdate = true; + } + } + + /** + * Returns an x, y, z value for each item of our data based on the accessor + * methods. + */ + private getPointsCoordinates() { + // Determine max and min of each axis of our data. + let xExtent = d3.extent(this.dataSet.points, (p, i) => this.xAccessor(i)); + let yExtent = d3.extent(this.dataSet.points, (p, i) => this.yAccessor(i)); + this.xScale.domain(xExtent).range([-1, 1]); + this.yScale.domain(yExtent).range([-1, 1]); + if (this.zAccessor) { + let zExtent = d3.extent(this.dataSet.points, (p, i) => this.zAccessor(i)); + this.zScale.domain(zExtent).range([-1, 1]); + } + + // Determine 3d coordinates of each data point. + this.dataSet.points.forEach((d, i) => { + d.projectedPoint.x = this.xScale(this.xAccessor(i)); + d.projectedPoint.y = this.yScale(this.yAccessor(i)); + d.projectedPoint.z = + (this.zAccessor ? this.zScale(this.zAccessor(i)) : 0); + }); + } + + private createLight() { + this.light = new THREE.PointLight(0xFFECBF, 1, 0); + this.scene.add(this.light); + } + + /** General setup of scene and renderer. */ + private createSceneAndRenderer() { + this.scene = new THREE.Scene(); + this.renderer = new THREE.WebGLRenderer(); + // Accouting for retina displays. + this.renderer.setPixelRatio(window.devicePixelRatio || 1); + this.renderer.setSize(this.width, this.height); + this.containerNode.appendChild(this.renderer.domElement); + this.pickingTexture = new THREE.WebGLRenderTarget(this.width, this.height); + this.pickingTexture.texture.minFilter = THREE.LinearFilter; + } + + /** Set up camera and camera's controller. */ + private makeCamera() { + this.perspCamera = + new THREE.PerspectiveCamera(FOV, this.width / this.height, NEAR, FAR); + this.cameraControls = + new (THREE as any) + .OrbitControls(this.perspCamera, this.renderer.domElement); + this.cameraControls.mouseButtons.ORBIT = THREE.MOUSE.LEFT; + this.cameraControls.mouseButtons.PAN = THREE.MOUSE.RIGHT; + // Start is called when the user stars interacting with + // orbit controls. + this.cameraControls.addEventListener('start', () => { + this.cameraControls.autoRotate = false; + cancelAnimationFrame(this.lazySusanAnimation); + }); + // Change is called everytime the user interacts with the + // orbit controls. + this.cameraControls.addEventListener('change', () => { + this.removeAllLabels(); + this.render(); + }); + // End is called when the user stops interacting with the + // orbit controls (e.g. on mouse up, after dragging). + this.cameraControls.addEventListener('end', () => { this.makeLabels(); }); + } + + /** Sets up camera to work in 3D (called after makeCamera()). */ + private makeCamera3D() { + // Set up the camera position at a skewed angle from the xy plane, looking + // toward the origin + this.cameraControls.position0.set(POS_3D.x, POS_3D.y, POS_3D.z); + this.cameraControls.target0.set(TAR_3D.x, TAR_3D.y, TAR_3D.z); + this.cameraControls.enableRotate = true; + let position = new THREE.Vector3(POS_3D.x, POS_3D.y, POS_3D.z); + let target = new THREE.Vector3(TAR_3D.x, TAR_3D.y, TAR_3D.z); + this.animate(position, target, () => { + // Start lazy susan after the animation is done. + this.startLazySusanAnimation(); + }); + } + + /** Sets up camera to work in 2D (called after makeCamera()). */ + private makeCamera2D(animate?: boolean) { + // Set the camera position in the middle of the screen, looking directly + // toward the middle of the xy plane + this.cameraControls.position0.set(POS_2D.x, POS_2D.y, POS_2D.z); + this.cameraControls.target0.set(TAR_2D.x, TAR_2D.y, TAR_2D.z); + let position = new THREE.Vector3(POS_2D.x, POS_2D.y, POS_2D.z); + let target = new THREE.Vector3(TAR_2D.x, TAR_2D.y, TAR_2D.z); + this.animate(position, target); + this.cameraControls.enableRotate = false; + } + + /** + * Set up buffer attributes to be used for the points/images. + */ + private createBufferAttributes() { + // Set up buffer attribute arrays. + let numPoints = this.dataSet.points.length; + let colArr = new Float32Array(numPoints * RGB_NUM_BYTES); + this.uniqueColArr = new Float32Array(numPoints * RGB_NUM_BYTES); + let colors = new THREE.BufferAttribute(this.uniqueColArr, RGB_NUM_BYTES); + // Assign each point a unique color in order to identify when the user + // hovers over a point. + for (let i = 0; i < numPoints; i++) { + let color = new THREE.Color(i); + colors.setXYZ(i, color.r, color.g, color.b); + } + colors.array = colArr; + let hiArr = new Float32Array(numPoints); + + /** Indices cooresponding to highlighted points. */ + let highlights = new THREE.BufferAttribute(hiArr, INDEX_NUM_BYTES); + + // Note that we need two index arrays. + + /** + * The actual indices of the points which we use for sizeAttenuation in + * the shader. + */ + let indicesShader = + new THREE.BufferAttribute(new Float32Array(numPoints), 1); + + for (let i = 0; i < numPoints; i++) { + // Create the array of indices. + indicesShader.setX(i, this.dataSet.points[i].dataSourceIndex); + } + + // Finally, add all attributes to the geometry. + this.geometry.addAttribute('position', this.positionBufferArray); + this.positionBufferArray.needsUpdate = true; + this.geometry.addAttribute('color', colors); + this.geometry.addAttribute('vertexIndex', indicesShader); + this.geometry.addAttribute('isHighlight', highlights); + + // For now, nothing is highlighted. + this.colorSprites(null); + } + + /** + * Generate a texture for the points/images and sets some initial params + */ + private createTexture(image: HTMLImageElement| + HTMLCanvasElement): THREE.Texture { + let tex = new THREE.Texture(image); + tex.needsUpdate = true; + // Used if the texture isn't a power of 2. + tex.minFilter = THREE.LinearFilter; + tex.generateMipmaps = false; + tex.flipY = false; + return tex; + } + + /** + * Create points, set their locations and actually instantiate the + * geometry. + */ + private addSprites() { + // Create geometry. + this.geometry = new THREE.BufferGeometry(); + this.createBufferAttributes(); + let canvas = document.createElement('canvas'); + let image = this.image || canvas; + // TODO(smilkov): Pass sprite dim to the renderer. + let spriteDim = 28.0; + let tex = this.createTexture(image); + let pointSize = (this.zAccessor ? this.pointSize3D : this.pointSize2D); + if (this.image) { + pointSize = IMAGE_SIZE; + } + let uniforms = { + texture: {type: 't', value: tex}, + imageWidth: {type: 'f', value: image.width / spriteDim}, + imageHeight: {type: 'f', value: image.height / spriteDim}, + fogColor: {type: 'c', value: this.fog.color}, + fogNear: {type: 'f', value: this.fog.near}, + fogFar: {type: 'f', value: this.fog.far}, + sizeAttenuation: {type: 'bool', value: !!this.zAccessor}, + isImage: {type: 'bool', value: !!this.image}, + pointSize: {type: 'f', value: pointSize} + }; + this.materialOptions = { + uniforms: uniforms, + vertexShader: VERTEX_SHADER, + fragmentShader: FRAGMENT_SHADER, + transparent: (this.image ? false : true), + // When rendering points with blending, we want depthTest/Write + // turned off. + depthTest: (this.image ? true : false), + depthWrite: (this.image ? true : false), + fog: true, + blending: (this.image ? THREE.NormalBlending : this.blending), + }; + // Give it some material. + let material = new THREE.ShaderMaterial(this.materialOptions); + + // And finally initialize it and add it to the scene. + this.points = new THREE.Points(this.geometry, material); + this.scene.add(this.points); + } + + /** + * Create line traces between connected points and instantiate the geometry. + */ + private addTraces() { + if (!this.dataSet || !this.dataSet.traces) { + return; + } + + this.traces = []; + + for (let i = 0; i < this.dataSet.traces.length; i++) { + let dataTrace = this.dataSet.traces[i]; + + let geometry = new THREE.BufferGeometry(); + let vertices: number[] = []; + let colors: number[] = []; + + for (let j = 0; j < dataTrace.pointIndices.length - 1; j++) { + this.dataSet.points[dataTrace.pointIndices[j]].traceIndex = i; + this.dataSet.points[dataTrace.pointIndices[j + 1]].traceIndex = i; + + let point1 = this.dataSet.points[dataTrace.pointIndices[j]]; + let point2 = this.dataSet.points[dataTrace.pointIndices[j + 1]]; + + vertices.push( + point1.projectedPoint.x, point1.projectedPoint.y, + point1.projectedPoint.z); + vertices.push( + point2.projectedPoint.x, point2.projectedPoint.y, + point2.projectedPoint.z); + + let color1 = + this.getPointInTraceColor(j, dataTrace.pointIndices.length); + let color2 = + this.getPointInTraceColor(j + 1, dataTrace.pointIndices.length); + + colors.push( + color1.r / 255, color1.g / 255, color1.b / 255, color2.r / 255, + color2.g / 255, color2.b / 255); + } + + geometry.addAttribute( + 'position', + new THREE.BufferAttribute(new Float32Array(vertices), XYZ_NUM_BYTES)); + geometry.addAttribute( + 'color', + new THREE.BufferAttribute(new Float32Array(colors), RGB_NUM_BYTES)); + + // We use the same material for every line. + let material = new THREE.LineBasicMaterial({ + linewidth: TRACE_DEFAULT_LINEWIDTH, + opacity: TRACE_DEFAULT_OPACITY, + transparent: true, + vertexColors: THREE.VertexColors + }); + + let trace = new THREE.LineSegments(geometry, material); + this.traces.push(trace); + this.scene.add(trace); + } + } + + /** + * Returns the color of a point along a trace. + */ + private getPointInTraceColor(index: number, totalPoints: number) { + let hue = TRACE_START_HUE + + (TRACE_END_HUE - TRACE_START_HUE) * index / totalPoints; + + return d3.hsl(hue, TRACE_SATURATION, TRACE_LIGHTNESS).rgb(); + } + + /** Clean up any old axes that we may have made previously. */ + private removeOldAxes() { + if (this.axis3D) { + this.scene.remove(this.axis3D); + } + if (this.axis2D) { + this.scene.remove(this.axis2D); + } + } + + /** Add axis. */ + private addAxis3D() { + this.axis3D = new THREE.AxisHelper(); + this.scene.add(this.axis3D); + } + + /** Manually make axis if we're in 2d. */ + private addAxis2D() { + let vertices = new Float32Array([ + 0, + 0, + 0, + this.xScale(1), + 0, + 0, + 0, + 0, + 0, + 0, + this.yScale(1), + 0, + ]); + + let axisColor = new THREE.Color(AXIS_COLOR); + let axisColors = new Float32Array([ + axisColor.r, + axisColor.b, + axisColor.g, + axisColor.r, + axisColor.b, + axisColor.g, + axisColor.r, + axisColor.b, + axisColor.g, + axisColor.r, + axisColor.b, + axisColor.g, + ]); + + // Create line geometry based on above position and color. + let lineGeometry = new THREE.BufferGeometry(); + lineGeometry.addAttribute( + 'position', new THREE.BufferAttribute(vertices, XYZ_NUM_BYTES)); + lineGeometry.addAttribute( + 'color', new THREE.BufferAttribute(axisColors, RGB_NUM_BYTES)); + + // And use it to create the actual object and add this new axis to the + // scene! + let axesMaterial = + new THREE.LineBasicMaterial({vertexColors: THREE.VertexColors}); + this.axis2D = new THREE.LineSegments(lineGeometry, axesMaterial); + this.scene.add(this.axis2D); + } + + // DYNAMIC (post-load) CHANGES + + /** When we stop dragging/zooming, return to normal behavior. */ + private onClick(e?: MouseEvent) { + if (e && this.selecting || !this.points) { + this.resetTraces(); + return; + } + let selection = this.nearestPoint || null; + this.defaultPointColor = (selection ? POINT_COLOR_GRAYED : POINT_COLOR); + // Only call event handlers if the click originated from the scatter plot. + if (e && !this.isDragSequence) { + this.onSelectionListeners.forEach(l => l(selection ? [selection] : [])); + } + this.isDragSequence = false; + this.labeledPoints = + this.highlightedPoints.filter((id, i) => this.favorLabels(i)); + + this.resetTraces(); + if (selection && this.dataSet.points[selection].traceIndex) { + for (let i = 0; i < this.traces.length; i++) { + this.traces[i].material.opacity = TRACE_DESELECTED_OPACITY; + this.traces[i].material.needsUpdate = true; + } + this.traces[this.dataSet.points[selection].traceIndex].material.opacity = + TRACE_SELECTED_OPACITY; + (this.traces[this.dataSet.points[selection].traceIndex].material as + THREE.LineBasicMaterial) + .linewidth = TRACE_SELECTED_LINEWIDTH; + this.traces[this.dataSet.points[selection].traceIndex] + .material.needsUpdate = true; + } + this.render(); + this.makeLabels(); + } + + private resetTraces() { + if (!this.traces) { + return; + } + for (let i = 0; i < this.traces.length; i++) { + this.traces[i].material.opacity = TRACE_DEFAULT_OPACITY; + (this.traces[i].material as THREE.LineBasicMaterial).linewidth = + TRACE_DEFAULT_LINEWIDTH; + this.traces[i].material.needsUpdate = true; + } + } + + /** When dragging, do not redraw labels. */ + private onMouseDown(e: MouseEvent) { + this.animating = false; + this.isDragSequence = false; + this.mouseIsDown = true; + // If we are in selection mode, and we have in fact clicked a valid point, + // create a sphere so we can select things + if (this.selecting) { + this.cameraControls.enabled = false; + this.setNearestPointToMouse(e); + if (this.nearestPoint) { + this.createSelectionSphere(); + } + } else if ( + !e.ctrlKey && + this.cameraControls.mouseButtons.ORBIT == THREE.MOUSE.RIGHT) { + // The user happened to press the ctrl key when the tab was active, + // unpressed the ctrl when the tab was inactive, and now he/she + // is back to the projector tab. + this.cameraControls.mouseButtons.ORBIT = THREE.MOUSE.LEFT; + this.cameraControls.mouseButtons.PAN = THREE.MOUSE.RIGHT; + } else if ( + e.ctrlKey && + this.cameraControls.mouseButtons.ORBIT == THREE.MOUSE.LEFT) { + // Similarly to the situation above. + this.cameraControls.mouseButtons.ORBIT = THREE.MOUSE.RIGHT; + this.cameraControls.mouseButtons.PAN = THREE.MOUSE.LEFT; + } + } + + + /** When we stop dragging/zooming, return to normal behavior. */ + private onMouseUp(e: any) { + if (this.selecting) { + this.cameraControls.enabled = true; + this.scene.remove(this.selectionSphere); + this.selectionSphere = null; + this.render(); + } + this.mouseIsDown = false; + } + + /** + * When the mouse moves, find the nearest point (if any) and send it to the + * hoverlisteners (usually called from embedding.ts) + */ + private onMouseMove(e: MouseEvent) { + if (this.cameraControls.autoRotate) { + // Cancel the lazy susan. + this.cameraControls.autoRotate = false; + cancelAnimationFrame(this.lazySusanAnimation); + this.makeLabels(); + } + + // A quick check to make sure data has come in. + if (!this.points) { + return; + } + this.isDragSequence = this.mouseIsDown; + // Depending if we're selecting or just navigating, handle accordingly. + if (this.selecting && this.mouseIsDown) { + if (this.selectionSphere) { + this.adjustSelectionSphere(e); + } + this.render(); + } else if (!this.mouseIsDown) { + let lastNearestPoint = this.nearestPoint; + this.setNearestPointToMouse(e); + if (lastNearestPoint != this.nearestPoint) { + this.onHoverListeners.forEach(l => l(this.nearestPoint)); + } + } + } + + /** For using ctrl + left click as right click, and for circle select */ + private onKeyDown(e: any) { + // If ctrl is pressed, use left click to orbit + if (e.keyCode === CTRL_KEY) { + this.cameraControls.mouseButtons.ORBIT = THREE.MOUSE.RIGHT; + this.cameraControls.mouseButtons.PAN = THREE.MOUSE.LEFT; + } + + // If shift is pressed, start selecting + if (e.keyCode === SHIFT_KEY) { + this.selecting = true; + this.containerNode.style.cursor = 'crosshair'; + } + } + + /** For using ctrl + left click as right click, and for circle select */ + private onKeyUp(e: any) { + if (e.keyCode === CTRL_KEY) { + this.cameraControls.mouseButtons.ORBIT = THREE.MOUSE.LEFT; + this.cameraControls.mouseButtons.PAN = THREE.MOUSE.RIGHT; + } + + // If shift is released, stop selecting + if (e.keyCode === SHIFT_KEY) { + this.selecting = (this.getMode() === Mode.SELECT); + if (!this.selecting) { + this.containerNode.style.cursor = 'default'; + } + this.scene.remove(this.selectionSphere); + this.selectionSphere = null; + this.render(); + } + } + + private setNearestPointToMouse(e: MouseEvent) { + // Create buffer for reading a single pixel. + let pixelBuffer = new Uint8Array(4); + // No need to account for dpr (device pixel ratio) since the pickingTexture + // has the same coordinates as the mouse (flipped on y). + let x = e.offsetX; + let y = e.offsetY; + + // Read the pixel under the mouse from the texture. + this.renderer.readRenderTargetPixels( + this.pickingTexture, x, this.pickingTexture.height - y, 1, 1, + pixelBuffer); + + // Interpret the pixel as an ID. + let id = (pixelBuffer[0] << 16) | (pixelBuffer[1] << 8) | pixelBuffer[2]; + this.nearestPoint = + id != 0xffffff && id < this.dataSet.points.length ? id : null; + } + + /** Returns the squared distance to the mouse for the i-th point. */ + private getDist2ToMouse(i: number, e: MouseEvent) { + let point = this.getProjectedPointFromIndex(i); + let screenCoords = this.vector3DToScreenCoords(point); + return this.dist2D( + [e.offsetX * this.dpr, e.offsetY * this.dpr], + [screenCoords.x, screenCoords.y]); + } + + private adjustSelectionSphere(e: MouseEvent) { + let dist2 = this.getDist2ToMouse(this.nearestPoint, e) / 100; + this.selectionSphere.scale.set(dist2, dist2, dist2); + this.selectPoints(dist2); + } + + private getProjectedPointFromIndex(i: number): THREE.Vector3 { + return new THREE.Vector3( + this.dataSet.points[i].projectedPoint.x, + this.dataSet.points[i].projectedPoint.y, + this.dataSet.points[i].projectedPoint.z); + } + + private calibratePointSize() { + let numPts = this.dataSet.points.length; + let scaleConstant = 200; + let logBase = 8; + // Scale point size inverse-logarithmically to the number of points. + this.pointSize3D = scaleConstant / Math.log(numPts) / Math.log(logBase); + this.pointSize2D = this.pointSize3D / 1.5; + } + + private setFogDistances() { + let dists = this.getNearFarPoints(); + this.fog.near = dists.shortestDist; + // If there are fewer points we want less fog. We do this + // by making the "far" value (that is, the distance from the camera to the + // far edge of the fog) proportional to the number of points. + let multiplier = 2 - + Math.min(this.dataSet.points.length, NUM_POINTS_FOG_THRESHOLD) / + NUM_POINTS_FOG_THRESHOLD; + this.fog.far = dists.furthestDist * multiplier; + } + + private getNearFarPoints() { + let shortestDist: number = Infinity; + let furthestDist: number = 0; + for (let i = 0; i < this.dataSet.points.length; i++) { + let point = this.getProjectedPointFromIndex(i); + if (!this.isPointWithinCameraView(point)) { + continue; + }; + let distToCam = this.dist3D(point, this.perspCamera.position); + furthestDist = Math.max(furthestDist, distToCam); + shortestDist = Math.min(shortestDist, distToCam); + } + return {shortestDist, furthestDist}; + } + + /** + * Renders the scene and updates the label for the point, which is rendered + * as a div on top of WebGL. + */ + private render() { + if (!this.dataSet) { + return; + } + let lightPos = new THREE.Vector3().copy(this.perspCamera.position); + lightPos.x += 1; + lightPos.y += 1; + this.light.position.set(lightPos.x, lightPos.y, lightPos.z); + + // We want to determine which point the user is hovering over. So, rather + // than linearly iterating through each point to see if it is under the + // mouse, we render another set of the points offscreen, where each point is + // at full opacity and has its id encoded in its color. Then, we see the + // color of the pixel under the mouse, decode the color, and get the id of + // of the point. + let shaderMaterial = this.points.material as THREE.ShaderMaterial; + let colors = this.geometry.getAttribute('color') as THREE.BufferAttribute; + // Make shallow copy of the shader options and modify the necessary values. + let offscreenOptions = + Object.create(this.materialOptions) as THREE.ShaderMaterialParameters; + // Since THREE.js errors if we remove the fog, the workaround is to set the + // near value to very far, so no points have fog. + this.fog.near = 1000; + this.fog.far = 10000; + // Render offscreen as non transparent points (even when we have images). + offscreenOptions.uniforms.isImage.value = false; + offscreenOptions.transparent = false; + offscreenOptions.depthTest = true; + offscreenOptions.depthWrite = true; + shaderMaterial.setValues(offscreenOptions); + // Give each point a unique color. + let origColArr = colors.array; + colors.array = this.uniqueColArr; + colors.needsUpdate = true; + this.renderer.render(this.scene, this.perspCamera, this.pickingTexture); + + // Now render onscreen. + + // Change to original color array. + colors.array = origColArr; + colors.needsUpdate = true; + // Bring back the fog. + if (this.zAccessor && this.geometry) { + this.setFogDistances(); + } + offscreenOptions.uniforms.isImage.value = !!this.image; + // Bring back the standard shader material options. + shaderMaterial.setValues(this.materialOptions); + // Render onscreen. + this.renderer.render(this.scene, this.perspCamera); + } + + /** + * Make sure that the point is in view of the camera (as opposed to behind + * this is a problem because we are projecting to the camera) + */ + private isPointWithinCameraView(point: THREE.Vector3): boolean { + let camToTarget = new THREE.Vector3() + .copy(this.perspCamera.position) + .sub(this.cameraControls.target); + let camToPoint = + new THREE.Vector3().copy(this.perspCamera.position).sub(point); + // If the angle between the camera-target and camera-point vectors is more + // than 90, the point is behind the camera + if (camToPoint.angleTo(camToTarget) > Math.PI / 2) { + return false; + }; + return true; + } + + private vector3DToScreenCoords(v: THREE.Vector3) { + let vector = new THREE.Vector3().copy(v).project(this.perspCamera); + let coords = { + // project() returns the point in perspCamera's coordinates, with the + // origin in the center and a positive upward y. To get it into screen + // coordinates, normalize by adding 1 and dividing by 2. + x: ((vector.x + 1) / 2 * this.width) * this.dpr, + y: -((vector.y - 1) / 2 * this.height) * this.dpr + }; + return coords; + } + + /** Checks if a given label will be within the screen's bounds. */ + private isLabelInBounds(labelWidth: number, coords: {x: number, y: number}) { + let padding = 7; + if ((coords.x < 0) || (coords.y < 0) || + (coords.x > this.width * this.dpr - labelWidth - padding) || + (coords.y > this.height * this.dpr)) { + return false; + }; + return true; + } + + /** Add a specific label to the canvas. */ + private formatLabel( + text: string, point: {x: number, y: number}, opacity: number) { + let ls = new THREE.Color(this.labelStroke); + let lc = new THREE.Color(this.labelColor); + this.gc.strokeStyle = 'rgba(' + ls.r * 255 + ',' + ls.g * 255 + ',' + + ls.b * 255 + ',' + opacity + ')'; + this.gc.fillStyle = 'rgba(' + lc.r * 255 + ',' + lc.g * 255 + ',' + + lc.b * 255 + ',' + opacity + ')'; + this.gc.strokeText(text, point.x + 4, point.y); + this.gc.fillText(text, point.x + 4, point.y); + } + + /** + * Reset the positions of all labels, and check for overlapps using the + * collision grid. + */ + private makeLabels() { + // Don't make labels if they are turned off. + if (!this.labelsAreOn || this.points == null) { + return; + } + // First, remove all old labels. + this.removeAllLabels(); + + this.labelCanvasIsCleared = false; + // If we are passed no points to label (that is, not mousing over any + // points) then want to label ALL the points that we can. + if (!this.labeledPoints.length) { + this.labeledPoints = this.shuffledData; + } + + // We never render more than ~500 labels, so when we get much past that + // point, just break. + let numRenderedLabels: number = 0; + let labelHeight = parseInt(this.gc.font, 10); + + // Bounding box for collision grid. + let boundingBox: BoundingBox = { + loX: 0, + hiX: this.width * this.dpr, + loY: 0, + hiY: this.height * this.dpr + }; + + // Make collision grid with cells proportional to window dimensions. + let grid = + new CollisionGrid(boundingBox, this.width / 25, this.height / 50); + + let dists = this.getNearFarPoints(); + let opacityRange = dists.furthestDist - dists.shortestDist; + + // Setting styles for the labeled font. + this.gc.lineWidth = 6; + this.gc.textBaseline = 'middle'; + this.gc.font = (FONT_SIZE * this.dpr).toString() + 'px roboto'; + + for (let i = 0; + (i < this.labeledPoints.length) && !(numRenderedLabels > SAMPLE_SIZE); + i++) { + let index = this.labeledPoints[i]; + let point = this.getProjectedPointFromIndex(index); + if (!this.isPointWithinCameraView(point)) { + continue; + }; + let screenCoords = this.vector3DToScreenCoords(point); + // Have extra space between neighboring labels. Don't pack too tightly. + let labelMargin = 2; + // Shift the label to the right of the point circle. + let xShift = 3; + let textBoundingBox = { + loX: screenCoords.x + xShift - labelMargin, + // Computing the width of the font is expensive, + // so we assume width of 1 at first. Then, if the label doesn't + // conflict with other labels, we measure the actual width. + hiX: screenCoords.x + xShift + /* labelWidth - 1 */ +1 + labelMargin, + loY: screenCoords.y - labelHeight / 2 - labelMargin, + hiY: screenCoords.y + labelHeight / 2 + labelMargin + }; + + if (grid.insert(textBoundingBox, true)) { + let dataSet = this.dataSet; + let text = this.labelAccessor(index); + let labelWidth = this.gc.measureText(text).width; + + // Now, check with properly computed width. + textBoundingBox.hiX += labelWidth - 1; + if (grid.insert(textBoundingBox) && + this.isLabelInBounds(labelWidth, screenCoords)) { + let lenToCamera = this.dist3D(point, this.perspCamera.position); + // Opacity is scaled between 0.2 and 1, based on how far a label is + // from the camera (Unless we are in 2d mode, in which case opacity is + // just 1!) + let opacity = this.zAccessor ? + 1.2 - (lenToCamera - dists.shortestDist) / opacityRange : + 1; + this.formatLabel(text, screenCoords, opacity); + numRenderedLabels++; + } + } + } + + if (this.highlightedPoints.length > 0) { + // Force-draw the first favored point with increased font size. + let index = this.highlightedPoints[0]; + let point = this.dataSet.points[index]; + this.gc.font = (FONT_SIZE * this.dpr * 1.7).toString() + 'px roboto'; + let coords = new THREE.Vector3( + point.projectedPoint.x, point.projectedPoint.y, + point.projectedPoint.z); + let screenCoords = this.vector3DToScreenCoords(coords); + let text = this.labelAccessor(index); + this.formatLabel(text, screenCoords, 255); + } + } + + /** Returns the distance between two points in 3d space */ + private dist3D(a: Point3D, b: Point3D): number { + let dx = a.x - b.x; + let dy = a.y - b.y; + let dz = a.z - b.z; + return Math.sqrt(dx * dx + dy * dy + dz * dz); + } + + private dist2D(a: [number, number], b: [number, number]): number { + let dX = a[0] - b[0]; + let dY = a[1] - b[1]; + return Math.sqrt(dX * dX + dY * dY); + } + + /** Cancels current animation */ + private cancelAnimation() { + if (this.animationID) { + cancelAnimationFrame(this.animationID); + } + } + + private startLazySusanAnimation() { + this.cameraControls.autoRotate = true; + this.cameraControls.update(); + this.lazySusanAnimation = + requestAnimationFrame(() => this.startLazySusanAnimation()); + } + + /** + * Animates the camera between one location and another. + * If callback is specified, it gets called when the animation is done. + */ + private animate( + pos: THREE.Vector3, target: THREE.Vector3, callback?: () => void) { + this.cameraControls.autoRotate = false; + cancelAnimationFrame(this.lazySusanAnimation); + + let currPos = this.perspCamera.position; + let currTarget = this.cameraControls.target; + let speed = 3; + this.animating = true; + let interp = (a: THREE.Vector3, b: THREE.Vector3) => { + let x = (a.x - b.x) / speed + b.x; + let y = (a.y - b.y) / speed + b.y; + let z = (a.z - b.z) / speed + b.z; + return {x: x, y: y, z: z}; + }; + // If we're still relatively far away from the target, go closer + if (this.dist3D(currPos, pos) > .03) { + let newTar = interp(target, currTarget); + this.cameraControls.target.set(newTar.x, newTar.y, newTar.z); + + let newPos = interp(pos, currPos); + this.perspCamera.position.set(newPos.x, newPos.y, newPos.z); + this.cameraControls.update(); + this.render(); + this.animationID = + requestAnimationFrame(() => this.animate(pos, target, callback)); + } else { + // Once we get close enough, update flags and stop moving + this.animating = false; + this.cameraControls.target.set(target.x, target.y, target.z); + this.cameraControls.update(); + this.makeLabels(); + this.render(); + if (callback) { + callback(); + } + } + } + + /** Removes all points geometry from the scene. */ + private removeAll() { + this.scene.remove(this.points); + this.removeOldAxes(); + this.removeAllLabels(); + this.removeAllTraces(); + } + + /** Removes all traces from the scene. */ + private removeAllTraces() { + if (!this.traces) { + return; + } + + for (let i = 0; i < this.traces.length; i++) { + this.scene.remove(this.traces[i]); + } + this.traces = []; + } + + /** Removes all the labels. */ + private removeAllLabels() { + // If labels are already removed, do not spend compute power to clear the + // canvas. + if (!this.labelCanvasIsCleared) { + this.gc.clearRect(0, 0, this.width * this.dpr, this.height * this.dpr); + this.labelCanvasIsCleared = true; + } + } + + private colorSprites(highlightStroke: ((index: number) => string)) { + // Update attributes to change colors + let colors = this.geometry.getAttribute('color') as THREE.BufferAttribute; + let highlights = + this.geometry.getAttribute('isHighlight') as THREE.BufferAttribute; + for (let i = 0; i < this.dataSet.points.length; i++) { + let unhighlightedColor = this.image ? + new THREE.Color() : + new THREE.Color( + this.colorAccessor ? this.colorAccessor(i) : + (this.defaultPointColor as any)); + colors.setXYZ( + i, unhighlightedColor.r, unhighlightedColor.g, unhighlightedColor.b); + highlights.setX(i, 0.0); + } + if (highlightStroke) { + // Traverse in reverse so that the point we are hovering over + // (highlightedPoints[0]) is painted last. + for (let i = this.highlightedPoints.length - 1; i >= 0; i--) { + let assocPoint = this.highlightedPoints[i]; + let color = new THREE.Color(highlightStroke(i)); + // Fill colors array (single array of numPoints*3 elements, + // triples of which refer to the rgb values of a single vertex). + colors.setXYZ(assocPoint, color.r, color.g, color.b); + highlights.setX(assocPoint, 1.0); + } + } + colors.needsUpdate = true; + highlights.needsUpdate = true; + } + + /** + * This is called when we update the data to make sure we don't have stale + * data lying around. + */ + private cleanVariables() { + this.removeAll(); + if (this.geometry) { + this.geometry.dispose(); + } + this.geometry = null; + this.points = null; + this.labeledPoints = []; + this.highlightedPoints = []; + } + + /** Select the points inside the sphere of radius dist */ + private selectPoints(dist: number) { + let selectedPoints: Array<number> = new Array(); + this.dataSet.points.forEach(point => { + let pt = point.projectedPoint; + let pointVect = new THREE.Vector3(pt.x, pt.y, pt.z); + let distPointToSphereOrigin = new THREE.Vector3() + .copy(this.selectionSphere.position) + .sub(pointVect) + .length(); + if (distPointToSphereOrigin < dist) { + selectedPoints.push(this.dataSet.points.indexOf(point)); + } + }); + this.labeledPoints = selectedPoints; + // Whenever anything is selected, we want to set the corect point color. + this.defaultPointColor = POINT_COLOR_GRAYED; + this.onSelectionListeners.forEach(l => l(selectedPoints)); + } + + private createSelectionSphere() { + let geometry = new THREE.SphereGeometry(1, 300, 100); + let material = new THREE.MeshPhongMaterial({ + color: 0x000000, + specular: (this.zAccessor && 0xffffff), // In 2d, make sphere look flat. + emissive: 0x000000, + shininess: 10, + shading: THREE.SmoothShading, + opacity: 0.125, + transparent: true, + }); + this.selectionSphere = new THREE.Mesh(geometry, material); + this.selectionSphere.scale.set(0, 0, 0); + let pos = this.dataSet.points[this.nearestPoint].projectedPoint; + this.scene.add(this.selectionSphere); + this.selectionSphere.position.set(pos.x, pos.y, pos.z); + } + + private getLayoutValues() { + this.width = this.containerNode.offsetWidth; + this.height = Math.max(1, this.containerNode.offsetHeight); + this.dpr = window.devicePixelRatio; + } + + // PUBLIC API + + /** Sets the data for the scatter plot. */ + setDataSet(dataSet: DataSet, spriteImage: HTMLImageElement) { + this.dataSet = dataSet; + this.calibratePointSize(); + let positions = + new Float32Array(this.dataSet.points.length * XYZ_NUM_BYTES); + this.positionBufferArray = + new THREE.BufferAttribute(positions, XYZ_NUM_BYTES); + this.image = spriteImage; + this.shuffledData = new Array(this.dataSet.points.length); + for (let i = 0; i < this.dataSet.points.length; i++) { + this.shuffledData[i] = i; + } + shuffle(this.shuffledData); + this.cleanVariables(); + } + + setColorAccessor(colorAccessor: (index: number) => string) { + this.colorAccessor = colorAccessor; + // Render only if there is a geometry. + if (this.geometry) { + this.colorSprites(this.highlightStroke); + this.render(); + } + } + + setXAccessor(xAccessor: (index: number) => number) { + this.xAccessor = xAccessor; + } + + setYAccessor(yAccessor: (index: number) => number) { + this.yAccessor = yAccessor; + } + + setZAccessor(zAccessor: (index: number) => number) { + this.zAccessor = zAccessor; + } + + setLabelAccessor(labelAccessor: (index: number) => string) { + this.labelAccessor = labelAccessor; + this.render(); + } + + setMode(mode: Mode) { + this.mode = mode; + if (mode === Mode.SELECT) { + this.selecting = true; + this.containerNode.style.cursor = 'crosshair'; + } else { + this.selecting = false; + this.containerNode.style.cursor = 'default'; + } + } + + getMode(): Mode { return this.mode; } + + resetZoom() { + if (this.animating) { + return; + } + let resetPos = this.cameraControls.position0; + let resetTarget = this.cameraControls.target0; + this.removeAllLabels(); + this.animate(resetPos, resetTarget, () => { + // Start rotating when the animation is done, if we are in 3D mode. + if (this.zAccessor) { + this.startLazySusanAnimation(); + } + }); + } + + /** Zoom by moving the camera toward the target. */ + zoomStep(multiplier: number) { + let additiveZoom = Math.log(multiplier); + if (this.animating) { + return; + } + + // Zoomvect is the vector along which we want to move the camera + // It is the (normalized) vector from the camera to its target + let zoomVect = new THREE.Vector3() + .copy(this.cameraControls.target) + .sub(this.perspCamera.position) + .multiplyScalar(additiveZoom); + let position = + new THREE.Vector3().copy(this.perspCamera.position).add(zoomVect); + + // Make sure that we're not too far zoomed in. If not, zoom! + if ((this.dist3D(position, this.cameraControls.target) > MIN_ZOOM) && + (this.dist3D(position, this.cameraControls.target) < MAX_ZOOM)) { + this.removeAllLabels(); + this.animate(position, this.cameraControls.target); + } + } + + highlightPoints( + pointIndexes: number[], highlightStroke: (i: number) => string, + favorLabels: (i: number) => boolean): void { + this.favorLabels = favorLabels; + this.highlightedPoints = pointIndexes; + this.labeledPoints = pointIndexes; + this.highlightStroke = highlightStroke; + this.colorSprites(highlightStroke); + this.render(); + this.makeLabels(); + } + + getHighlightedPoints(): number[] { return this.highlightedPoints; } + + showLabels(show: boolean) { + this.labelsAreOn = show; + if (this.labelsAreOn) { + this.makeLabels(); + } else { + this.removeAllLabels(); + } + } + + /** + * Toggles between day and night mode (resets corresponding variables for + * color, etc.) + */ + setDayNightMode(isNight: boolean) { + this.isNight = isNight; + this.labelColor = (isNight ? LABEL_COLOR_NIGHT : LABEL_COLOR_DAY); + this.labelStroke = (isNight ? LABEL_STROKE_NIGHT : LABEL_STROKE_DAY); + this.backgroundColor = + (isNight ? BACKGROUND_COLOR_NIGHT : BACKGROUND_COLOR_DAY); + this.blending = (isNight ? BLENDING_NIGHT : BLENDING_DAY); + this.renderer.setClearColor(this.backgroundColor); + } + + showAxes(show: boolean) { + // TODO(ereif): implement + } + + setAxisLabels(xLabel: string, yLabel: string) { + // TODO(ereif): implement + } + /** + * Recreates the scene in its entirety, not only resetting the point + * locations but also demolishing and recreating the THREEjs structures. + */ + recreateScene() { + this.removeAll(); + this.cancelAnimation(); + this.fog = this.zAccessor ? + new THREE.Fog(this.backgroundColor) : + new THREE.Fog(this.backgroundColor, Infinity, Infinity); + this.scene.fog = this.fog; + this.addSprites(); + this.addTraces(); + if (this.zAccessor) { + this.addAxis3D(); + this.makeCamera3D(); + } else { + this.addAxis2D(); + this.makeCamera2D(); + } + this.render(); + } + + /** + * Redraws the data. Should be called anytime the accessor method + * for x and y coordinates changes, which means a new projection + * exists and the scatter plot should repaint the points. + */ + update() { + this.cancelAnimation(); + this.getPointsCoordinates(); + this.updatePositionsArray(); + if (this.geometry) { + this.makeLabels(); + this.render(); + } + } + + resize(render = true) { + this.getLayoutValues(); + this.perspCamera.aspect = this.width / this.height; + this.perspCamera.updateProjectionMatrix(); + d3.select(this.canvas) + .attr('width', this.width * this.dpr) + .attr('height', this.height * this.dpr) + .style({width: this.width + 'px', height: this.height + 'px'}); + this.renderer.setSize(this.width, this.height); + this.pickingTexture = new THREE.WebGLRenderTarget(this.width, this.height); + this.pickingTexture.texture.minFilter = THREE.LinearFilter; + if (render) { + this.render(); + }; + } + + showTickLabels(show: boolean) { + // TODO(ereif): implement + } + + onSelection(listener: OnSelectionListener) { + this.onSelectionListeners.push(listener); + } + + onHover(listener: OnHoverListener) { this.onHoverListeners.push(listener); } + + clickOnPoint(pointIndex: number) { + this.nearestPoint = pointIndex; + this.onClick(); + } +} diff --git a/tensorflow/tensorboard/components/vz-projector/sptree.ts b/tensorflow/tensorboard/components/vz-projector/sptree.ts new file mode 100644 index 0000000000..f9310d2f1e --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/sptree.ts @@ -0,0 +1,189 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** How many elements can be stored in each node of the tree. */ +const NODE_CAPACITY = 4; + +/** N-dimensional point. Usually 2D or 3D. */ +export type Point = number[]; + +export interface BBox { + center: Point; + halfDim: number; +} + +/** A node in a space-partitioning tree. */ +export interface SPNode { + /** The children of this node. */ + children?: SPNode[]; + /** The bounding box of the region this node occupies. */ + box: BBox; + /** One or more points this node has. */ + points?: Point[]; +} + +/** + * A Space-partitioning tree (https://en.wikipedia.org/wiki/Space_partitioning) + * that recursively divides the space into regions of equal sizes. This data + * structure can act both as a Quad tree and an Octree when the data is 2 or + * 3 dimensional respectively. One usage is in t-SNE in order to do Barnes-Hut + * approximation. + */ +export class SPTree { + root: SPNode; + + private masks: number[]; + private capacity: number; + private dim: number; + + /** + * Constructs a new tree with the provided data. + * + * @param data List of n-dimensional data points. + * @param capacity Number of data points to store in a single node. + */ + constructor(data: Point[], capacity = NODE_CAPACITY) { + if (data.length < 1) { + throw new Error('There should be at least 1 data point'); + } + this.capacity = capacity; + // Make a bounding box based on the extent of the data. + this.dim = data[0].length; + // Each node has 2^d children, where d is the dimension of the space. + // Binary masks (e.g. 000, 001, ... 111 in 3D) are used to determine in + // which child (e.g. quadron in 2D) the new point is going to be assigned. + // For more details, see the insert() method and its comments. + this.masks = new Array(Math.pow(2, this.dim)); + for (let d = 0; d < this.masks.length; ++d) { + this.masks[d] = (1 << d); + } + let min: Point = new Array(this.dim); + fillArray(min, Number.POSITIVE_INFINITY); + let max: Point = new Array(this.dim); + fillArray(max, Number.NEGATIVE_INFINITY); + + for (let i = 0; i < data.length; ++i) { + // For each dim get the min and max. + // E.g. For 2-D, get the x_min, x_max, y_min, y_max. + for (let d = 0; d < this.dim; ++d) { + min[d] = Math.min(min[d], data[i][d]); + max[d] = Math.max(max[d], data[i][d]); + } + } + // Create a bounding box with the center of the largest span. + let center: Point = new Array(this.dim); + let halfDim = 0; + for (let d = 0; d < this.dim; ++d) { + let span = max[d] - min[d]; + center[d] = min[d] + span / 2; + halfDim = Math.max(halfDim, span / 2); + } + this.root = {box: {center: center, halfDim: halfDim}}; + for (let i = 0; i < data.length; ++i) { + this.insert(this.root, data[i]); + } + } + + /** + * Visits every node in the tree. Each node can store 1 or more points, + * depending on the node capacity provided in the constructor. + * + * @param accessor Method that takes the currently visited node, and the + * low and high point of the region that this node occupies. E.g. in 2D, + * the low and high points will be the lower-left corner and the upper-right + * corner. + */ + visit( + accessor: (node: SPNode, lowPoint: Point, highPoint: Point) => boolean, + noBox = false) { + this.visitNode(this.root, accessor, noBox); + } + + private visitNode( + node: SPNode, + accessor: (node: SPNode, lowPoint?: Point, highPoint?: Point) => boolean, + noBox: boolean) { + let skipChildren: boolean; + if (noBox) { + skipChildren = accessor(node); + } else { + let lowPoint = new Array(this.dim); + let highPoint = new Array(this.dim); + for (let d = 0; d < this.dim; ++d) { + lowPoint[d] = node.box.center[d] - node.box.halfDim; + highPoint[d] = node.box.center[d] + node.box.halfDim; + } + skipChildren = accessor(node, lowPoint, highPoint); + } + if (!node.children || skipChildren) { + return; + } + for (let i = 0; i < node.children.length; ++i) { + let child = node.children[i]; + if (child) { + this.visitNode(child, accessor, noBox); + } + } + } + + private insert(node: SPNode, p: Point): boolean { + if (node.points == null) { + node.points = []; + } + // If there is space in this node, add the object here. + if (node.points.length < this.capacity) { + node.points.push(p); + return true; + } + // Otherwise, subdivide and then add the point to whichever node will + // accept it. + if (node.children == null) { + node.children = new Array(this.masks.length); + } + + // Decide which child will get the new point by constructing a D-bits binary + // signature (D=3 for 3D) where the k-th bit is 1 if the point's k-th + // coordinate is greater than the node's k-th coordinate, 0 otherwise. + // Then the binary signature in decimal system gives us the index of the + // child where the new point should be. + let index = 0; + for (let d = 0; d < this.dim; ++d) { + if (p[d] > node.box.center[d]) { + index |= this.masks[d]; + } + } + if (node.children[index] == null) { + this.makeChild(node, index); + } + this.insert(node.children[index], p); + return true; + } + + private makeChild(node: SPNode, index: number): void { + let oldC = node.box.center; + let h = node.box.halfDim / 2; + let newC: Point = new Array(this.dim); + for (let d = 0; d < this.dim; ++d) { + newC[d] = (index & (1 << d)) ? oldC[d] + h : oldC[d] - h; + } + node.children[index] = {box: {center: newC, halfDim: h}}; + } +} + +function fillArray<T>(arr: T[], value: T): void { + for (let i = 0; i < arr.length; ++i) { + arr[i] = value; + } +} diff --git a/tensorflow/tensorboard/components/vz-projector/sptree_test.ts b/tensorflow/tensorboard/components/vz-projector/sptree_test.ts new file mode 100644 index 0000000000..e48ae761e1 --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/sptree_test.ts @@ -0,0 +1,106 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {SPTree} from './sptree'; + +const assert = chai.assert; + +it('simple 2D data', () => { + let data = [ + [0, 1], + [1, 0], + [1, 1], + [0, 0], + ]; + let tree = new SPTree(data, 1); + // Check that each point is within the bound. + tree.visit((node, low, high) => { + assert.equal(low.length, 2); + assert.equal(high.length, 2); + node.points.forEach(point => { + assert.equal(point.length, 2); + // Each point should be in the node's bounding box. + assert.equal( + point[0] >= low[0] && point[0] <= high[0] && point[1] >= low[1] && + point[1] <= high[1], + true); + }); + return false; + }); +}); + +it('simple 3D data', () => { + let data = [ + [0, 1, 0], + [1, 0.4, 2], + [1, 1, 3], + [0, 0, 5], + ]; + let tree = new SPTree(data, 1); + // Check that each point is within the bound. + tree.visit((node, low, high) => { + assert.equal(low.length, 3); + assert.equal(high.length, 3); + node.points.forEach(point => { + assert.equal(point.length, 3); + // Each point should be in the node's bounding box. + assert.equal( + point[0] >= low[0] && point[0] <= high[0] && point[1] >= low[1] && + point[1] <= high[1] && point[2] >= low[2] && point[2] <= high[2], + true); + }); + return false; + }); +}); + +it('Only visit root', () => { + let data = [ + [0, 1, 0], + [1, 0.4, 2], + [1, 1, 3], + [0, 0, 5], + ]; + let tree = new SPTree(data, 1); + let numVisits = 0; + tree.visit((node, low, high) => { + numVisits++; + return true; + }); + assert.equal(numVisits, 1); +}); + +it('Search in random data', () => { + let N = 10000; + let data = new Array(N); + for (let i = 0; i < N; i++) { + data[i] = [Math.random(), Math.random()]; + } + let tree = new SPTree(data, 1); + let numVisits = 0; + let query = data[Math.floor(Math.random() * N)]; + let found = false; + tree.visit((node, low, high) => { + numVisits++; + if (node.points.length > 0 && node.points[0] === query) { + found = true; + return true; + } + let outOfBounds = query[0] < low[0] || query[0] > high[0] || + query[1] < low[1] || query[1] > high[1]; + return outOfBounds; + }); + assert.equal(found, true); + assert.isBelow(numVisits, N / 4); +}); diff --git a/tensorflow/tensorboard/components/vz-projector/util.ts b/tensorflow/tensorboard/components/vz-projector/util.ts new file mode 100644 index 0000000000..975f7e122c --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/util.ts @@ -0,0 +1,43 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** Shuffles the array in-place in O(n) time using Fisher-Yates algorithm. */ +export function shuffle<T>(array: T[]): T[] { + let m = array.length; + let t: T; + let i: number; + + // While there remain elements to shuffle. + while (m) { + // Pick a remaining element + i = Math.floor(Math.random() * m--); + // And swap it with the current element. + t = array[m]; + array[m] = array[i]; + array[i] = t; + } + return array; +} + +/** + * Assert that the condition is satisfied; if not, log user-specified message + * to the console. + */ +export function assert(condition: boolean, message?: string) { + if (!condition) { + message = message || 'Assertion failed'; + throw new Error(message); + } +}
\ No newline at end of file diff --git a/tensorflow/tensorboard/components/vz-projector/vector.ts b/tensorflow/tensorboard/components/vz-projector/vector.ts new file mode 100644 index 0000000000..edb6e9bdd0 --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/vector.ts @@ -0,0 +1,270 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {assert} from './util'; + + +/** + * @fileoverview Useful vector utilities. + */ + +export type Vector = number[]; +export type Point2D = [number, number]; + +/** Returns the dot product of two vectors. */ +export function dot(a: Vector, b: Vector): number { + assert(a.length == b.length, 'Vectors a and b must be of same length'); + let result = 0; + for (let i = 0; i < a.length; ++i) { + result += a[i] * b[i]; + } + return result; +} + +/** Sums all the elements in the vector */ +export function sum(a: Vector): number { + let result = 0; + for (let i = 0; i < a.length; ++i) { + result += a[i]; + } + return result; +} + +/** Returns the sum of two vectors, i.e. a + b */ +export function add(a: Vector, b: Vector): Vector { + assert(a.length == b.length, 'Vectors a and b must be of same length'); + let result = new Array(a.length); + for (let i = 0; i < a.length; ++i) { + result[i] = a[i] + b[i]; + } + return result; +} + +/** Subtracts vector b from vector a, i.e. returns a - b */ +export function sub(a: Vector, b: Vector): Vector { + assert(a.length == b.length, 'Vectors a and b must be of same length'); + let result = new Array(a.length); + for (let i = 0; i < a.length; ++i) { + result[i] = a[i] - b[i]; + } + return result; +} + +/** Returns the square norm of the vector */ +export function norm2(a: Vector): number { + let result = 0; + for (let i = 0; i < a.length; ++i) { + result += a[i] * a[i]; + } + return result; +} + +/** Returns the euclidean distance between two vectors. */ +export function dist(a: Vector, b: Vector): number { + return Math.sqrt(dist2(a, b)); +} + +/** Returns the square euclidean distance between two vectors. */ +export function dist2(a: Vector, b: Vector): number { + assert(a.length == b.length, 'Vectors a and b must be of same length'); + let result = 0; + for (let i = 0; i < a.length; ++i) { + let diff = a[i] - b[i]; + result += diff * diff; + } + return result; +} + +/** Returns the square euclidean distance between two 2D points. */ +export function dist2_2D(a: Vector, b: Vector): number { + let dX = a[0] - b[0]; + let dY = a[1] - b[1]; + return dX * dX + dY * dY; +} + +/** Returns the square euclidean distance between two 3D points. */ +export function dist2_3D(a: Vector, b: Vector): number { + let dX = a[0] - b[0]; + let dY = a[1] - b[1]; + let dZ = a[2] - b[2]; + return dX * dX + dY * dY + dZ * dZ; +} + +/** + * Returns the square euclidean distance between two vectors, with an early + * exit (returns -1) if the distance is >= to the provided limit. + */ +export function dist2WithLimit(a: Vector, b: Vector, limit: number): number { + assert(a.length == b.length, 'Vectors a and b must be of same length'); + let result = 0; + for (let i = 0; i < a.length; ++i) { + let diff = a[i] - b[i]; + result += diff * diff; + if (result >= limit) { + return -1; + } + } + return result; +} + +/** Returns the square euclidean distance between two 2D points. */ +export function dist22D(a: Point2D, b: Point2D): number { + let dX = a[0] - b[0]; + let dY = a[1] - b[1]; + return dX * dX + dY * dY; +} + +/** Modifies the vector in-place to have unit norm. */ +export function unit(a: Vector): void { + let norm = Math.sqrt(norm2(a)); + assert(norm >= 0, 'Norm of the vector must be > 0'); + for (let i = 0; i < a.length; ++i) { + a[i] /= norm; + } +} + +/** + * Projects the vectors to a lower dimension + * + * @param vectors Array of vectors to be projected. + * @param newDim The resulting dimension of the vectors. + */ +export function projectRandom(vectors: number[][], newDim: number): number[][] { + let dim = vectors[0].length; + let N = vectors.length; + let newVectors: number[][] = new Array(N); + for (let i = 0; i < N; ++i) { + newVectors[i] = new Array(newDim); + } + // Make nDim projections. + for (let k = 0; k < newDim; ++k) { + let randomVector = rn(dim); + for (let i = 0; i < N; ++i) { + newVectors[i][k] = dot(vectors[i], randomVector); + } + } + return newVectors; +} + +/** + * Projects a vector onto a 2D plane specified by the two direction vectors. + */ +export function project2d(a: Vector, dir1: Vector, dir2: Vector): Point2D { + return [dot(a, dir1), dot(a, dir2)]; +} + +/** Returns a vector filled with zeros */ +export function zeros(length: number): Vector { + let result = new Array(length); + for (let i = 0; i < length; ++i) { + result[i] = 0; + } + return result; +} + +export type Predicate<T> = (a: T) => boolean; + +/** + * Computes the centroid of the data points that pass the specified predicate. + * If the provided data points are not vectors, an accessor function needs + * to be provided. + */ +export function centroid<T>( + dataPoints: T[], predicate: Predicate<T>, + accessor?: (a: T) => Vector): {centroid: Vector, numMatches: number} { + if (accessor == null) { + accessor = (a: T) => <any>a; + } + assert(dataPoints.length >= 0, '`vectors` must be of length >= 1'); + let n = 0; + let centroid = zeros(accessor(dataPoints[0]).length); + for (let i = 0; i < dataPoints.length; ++i) { + let dataPoint = dataPoints[i]; + if (!predicate(dataPoint)) { + continue; + } + ++n; + let vector = accessor(dataPoint); + for (let j = 0; j < centroid.length; ++j) { + centroid[j] += vector[j]; + } + } + if (n == 0) { + return {centroid: null, numMatches: 0}; + } + for (let j = 0; j < centroid.length; ++j) { + centroid[j] /= n; + } + return {centroid: centroid, numMatches: n}; +} + +/** + * Generates a vector of the specified size where each component is drawn from + * a random (0, 1) gaussian distribution. + */ +export function rn(size: number): Vector { + let normal = d3.random.normal(); + let result = new Array(size); + for (let i = 0; i < size; ++i) { + result[i] = normal(); + } + return result; +} + +/** + * Returns the cosine distance ([0, 2]) between two vectors + * that have been normalized to unit norm. + */ +export function cosDistNorm(a: Vector, b: Vector): number { + return 1 - dot(a, b); +} + +/** Returns the cosine similarity ([-1, 1]) between two vectors. */ +export function cosSim(a: Vector, b: Vector): number { + return dot(a, b) / Math.sqrt(norm2(a) * norm2(b)); +} + +/** + * Converts list of vectors (matrix) into a 1-dimensional + * typed array with row-first order. + */ +export function toTypedArray<T>( + dataPoints: T[], accessor: (dataPoint: T) => number[]): Float32Array { + let N = dataPoints.length; + let dim = accessor(dataPoints[0]).length; + let result = new Float32Array(N * dim); + for (let i = 0; i < N; ++i) { + let vector = accessor(dataPoints[i]); + for (let d = 0; d < dim; ++d) { + result[i * dim + d] = vector[d]; + } + } + return result; +} + +/** + * Transposes an RxC matrix represented as a flat typed array + * into a CxR matrix, again represented as a flat typed array. + */ +export function transposeTypedArray( + r: number, c: number, typedArray: Float32Array) { + let result = new Float32Array(r * c); + for (let i = 0; i < r; ++i) { + for (let j = 0; j < c; ++j) { + result[j * r + i] = typedArray[i * c + j]; + } + } + return result; +} diff --git a/tensorflow/tensorboard/components/vz-projector/vz-projector-data-loader.html b/tensorflow/tensorboard/components/vz-projector/vz-projector-data-loader.html new file mode 100644 index 0000000000..e6beaa837e --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/vz-projector-data-loader.html @@ -0,0 +1,144 @@ +<!-- +@license +Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--> + +<link rel="import" href="../polymer/polymer.html"> +<link rel="import" href="../paper-dropdown-menu/paper-dropdown-menu.html"> +<link rel="import" href="../paper-listbox/paper-listbox.html"> +<link rel="import" href="../paper-item/paper-item.html"> + +<dom-module id='vz-projector-data-loader'> +<template> +<style> +:host { +} + +input[type=file] { + display: none; +} + +.file-name { + margin-right: 10px; +} + +.dirs { + display: flex; + flex-direction: column; + margin-right: 10px; + line-height: 20px; +} + +.dir { + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; +} + +paper-item { + --paper-item-disabled: { + border-bottom: 1px solid black; + justify-content: center; + font-size: 12px; + line-height: normal; + min-height: 0px; + }; +} +</style> + +<!-- Server-mode UI --> +<div class="server-controls" style="display:none;"> + <div class="dirs"> + <div class="dir">Checkpoint: <span id="checkpoint-file"></span></div> + <div class="dir">Metadata: <span id="metadata-file"></span></div> + </div> + <!-- List of tensors in checkpoint --> + <paper-dropdown-menu noink no-animations label="[[getNumTensorsLabel(tensorNames)]] found"> + <paper-listbox attr-for-selected="value" class="dropdown-content" selected="{{selectedTensor}}"> + <template is="dom-repeat" items="[[tensorNames]]"> + <paper-item style="justify-content: space-between;" value="[[item.name]]" label="[[item.name]]"> + [[item.name]] + <span style="margin-left: 5px; color:gray; font-size: 12px;">[[item.shape.0]]x[[item.shape.1]]</span> + </paper-item> + </template> + </paper-listbox> + </paper-dropdown-menu> +</div> + +<!-- Standalone-mode UI --> +<div class="standalone-controls" style="display:none;"> + + <!-- Upload buttons --> + <div style="display: flex; justify-content: space-between;"> + <!-- Upload data --> + <div> + <button id="upload" title="Upload a TSV file" class="ink-button">Upload data</button> + <span id="file-name" class="file-name dir"></span> + <input type="file" id="file" name="file"/> + </div> + + <!-- Upload metadata --> + <div> + <button id="upload-metadata" title="Upload a TSV metadata file" class="ink-button">Upload Metadata</button> + <span id="file-metadata-name" class="file-name dir"></span> + <input type="file" id="file-metadata" name="file-metadata"/> + </div> + </div> + + <!-- Demo datasets --> + <paper-dropdown-menu style="width: 100%" noink no-animations label="Select a demo dataset"> + <paper-listbox attr-for-selected="value" class="dropdown-content" selected="{{selectedDemo}}"> + <paper-item value="smartreply_full">SmartReply All</paper-item> + <paper-item value="smartreply_5k">SmartReply 5K</paper-item> + <paper-item value="wiki_5k">Glove Wiki 5K</paper-item> + <paper-item value="wiki_10k">Glove Wiki 10K</paper-item> + <paper-item value="wiki_40k">Glove Wiki 40K</paper-item> + <paper-item value="mnist_10k">MNIST 10K</paper-item> + <paper-item value="iris">Iris</paper-item> + </paper-listbox> + </paper-dropdown-menu> + +</div> + +<!-- Label by --> +<template is="dom-if" if="[[labelOptions.length]]"> + <paper-dropdown-menu style="width: 100%" noink no-animations label="Label by"> + <paper-listbox attr-for-selected="value" class="dropdown-content" selected="{{labelOption}}"> + <template is="dom-repeat" items="[[labelOptions]]"> + <paper-item style="justify-content: space-between;" value="[[item]]" label="[[item]]"> + [[item]] + </paper-item> + </template> + </paper-listbox> + </paper-dropdown-menu> +</template> + +<!-- Color by --> +<template is="dom-if" if="[[colorOptions.length]]"> + <paper-dropdown-menu id="colorby" style="width: 100%" noink no-animations label="Color by"> + <paper-listbox attr-for-selected="value" class="dropdown-content" selected="{{colorOption}}"> + <template is="dom-repeat" items="[[colorOptions]]"> + <paper-item style="justify-content: space-between;" class$="[[getSeparatorClass(item.isSeparator)]]" value="[[item]]" label="[[item.name]]" disabled="[[item.isSeparator]]"> + [[item.name]] + <span style="margin-left: 5px; color:gray; font-size: 12px;">[[item.desc]]</span> + </paper-item> + </template> + </paper-listbox> + </paper-dropdown-menu> +</template> + +<!-- Closing global template --> +</template> +</dom-module> diff --git a/tensorflow/tensorboard/components/vz-projector/vz-projector-data-loader.ts b/tensorflow/tensorboard/components/vz-projector/vz-projector-data-loader.ts new file mode 100644 index 0000000000..f44bb6bbe9 --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/vz-projector-data-loader.ts @@ -0,0 +1,500 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {runAsyncTask, updateMessage} from './async'; +import {DataPoint, DataSet, DatasetMetadata, DataSource} from './data'; +import {PolymerElement} from './vz-projector-util'; + + +/** Prefix added to the http requests when asking the server for data. */ +const DATA_URL = 'data'; + +type DemoDataset = { + fpath: string; metadata_path?: string; metadata?: DatasetMetadata; +}; + +type Metadata = { + [key: string]: (number|string); +}; + +/** List of compiled demo datasets for showing the capabilities of the tool. */ +const DEMO_DATASETS: {[name: string]: DemoDataset} = { + 'wiki_5k': { + fpath: 'wiki_5000_50d_tensors.ssv', + metadata_path: 'wiki_5000_50d_labels.ssv' + }, + 'wiki_10k': { + fpath: 'wiki_10000_100d_tensors.ssv', + metadata_path: 'wiki_10000_100d_labels.ssv' + }, + 'wiki_40k': { + fpath: 'wiki_40000_100d_tensors.ssv', + metadata_path: 'wiki_40000_100d_labels.ssv' + }, + 'smartreply_5k': { + fpath: 'smartreply_5000_256d_tensors.tsv', + metadata_path: 'smartreply_5000_256d_labels.tsv' + }, + 'smartreply_full': { + fpath: 'smartreply_full_256d_tensors.tsv', + metadata_path: 'smartreply_full_256d_labels.tsv' + }, + 'mnist_10k': { + fpath: 'mnist_10k_784d_tensors.tsv', + metadata_path: 'mnist_10k_784d_labels.tsv', + metadata: { + image: + {sprite_fpath: 'mnist_10k_sprite.png', single_image_dim: [28, 28]} + }, + }, + 'iris': {fpath: 'iris_tensors.tsv', metadata_path: 'iris_labels.tsv'} +}; + +/** Maximum number of colors supported in the color map. */ +const NUM_COLORS_COLOR_MAP = 20; + +interface ServerInfo { + tensors: {[name: string]: [number, number]}; + tensors_file: string; + checkpoint_file: string; + checkpoint_dir: string; + metadata_file: string; +} + +let DataLoaderPolymer = PolymerElement({ + is: 'vz-projector-data-loader', + properties: { + dataSource: { + type: Object, // DataSource + notify: true + }, + selectedDemo: {type: String, value: 'wiki_5k', notify: true}, + selectedTensor: {type: String, notify: true}, + labelOption: {type: String, notify: true}, + colorOption: {type: Object, notify: true}, + // Private. + tensorNames: Array + } +}); + +export type ColorOption = { + name: string; desc?: string; map?: (value: string | number) => string; + isSeparator?: boolean; +}; + +class DataLoader extends DataLoaderPolymer { + dataSource: DataSource; + selectedDemo: string; + labelOption: string; + labelOptions: string[]; + colorOption: ColorOption; + colorOptions: ColorOption[]; + selectedTensor: string; + tensorNames: {name: string, shape: number[]}[]; + + private dom: d3.Selection<any>; + + ready() { + this.dom = d3.select(this); + if (this.dataSource) { + // There is data already. + return; + } + // Check to see if there is a server. + d3.json(`${DATA_URL}/info`, (err, serverInfo) => { + if (err) { + // No server was found, thus operate in standalone mode. + this.setupStandaloneMode(); + return; + } + // Server was found, thus show the checkpoint dir and the tensors. + this.setupServerMode(serverInfo); + }); + } + + getSeparatorClass(isSeparator: boolean): string { + return isSeparator ? 'separator' : null; + } + + private setupServerMode(info: ServerInfo) { + // Display the server-mode controls. + this.dom.select('.server-controls').style('display', null); + this.dom.select('#checkpoint-file') + .text(info.checkpoint_file) + .attr('title', info.checkpoint_file); + this.dom.select('#metadata-file') + .text(info.metadata_file) + .attr('title', info.metadata_file); + + // Handle the list of checkpoint tensors. + this.dom.on('selected-tensor-changed', () => { + this.selectedTensorChanged(this.selectedTensor); + }); + let names = Object.keys(info.tensors) + .filter(name => { + let shape = info.tensors[name]; + return shape.length == 2 && shape[0] > 1 && shape[1] > 1; + }) + .sort((a, b) => info.tensors[b][0] - info.tensors[a][0]); + this.tensorNames = + names.map(name => { return {name, shape: info.tensors[name]}; }); + } + + private updateMetadataUI(columnStats: ColumnStats[]) { + // Label by options. + let labelIndex = -1; + this.labelOptions = columnStats.length > 1 ? columnStats.map((stats, i) => { + // Make the default label by the first non-numeric column. + if (!stats.isNumeric && labelIndex == -1) { + labelIndex = i; + } + return stats.name; + }) : + ['label']; + this.labelOption = this.labelOptions[Math.max(0, labelIndex)]; + + // Color by options. + let standardColorOption: ColorOption[] = [ + {name: 'No color map'}, + // TODO(smilkov): Implement this. + //{name: 'Distance of neighbors', + // desc: 'How far is each point from its neighbors'} + ]; + let metadataColorOption: ColorOption[] = + columnStats + .filter(stats => { + return !stats.tooManyUniqueValues || stats.isNumeric; + }) + .map(stats => { + let map: (v: string|number) => string; + if (!stats.tooManyUniqueValues) { + let scale = d3.scale.category20(); + let range = scale.range(); + // Re-order the range. + let newRange = range.map((color, i) => { + let index = (i * 2) % (range.length - 1); + if (index == 0) { + index = range.length - 1; + } + return range[index]; + }); + scale.range(newRange).domain(stats.uniqueValues); + map = scale; + } else { + map = d3.scale.linear<string>() + .domain([stats.min, stats.max]) + .range(['white', 'black']); + } + let desc = stats.tooManyUniqueValues ? + 'gradient' : + stats.uniqueValues.length + ' colors'; + return {name: stats.name, desc: desc, map: map}; + }); + if (metadataColorOption.length > 0) { + // Add a separator line between built-in color maps + // and those based on metadata columns. + standardColorOption.push({name: 'Metadata', isSeparator: true}); + } + this.colorOptions = standardColorOption.concat(metadataColorOption); + this.colorOption = this.colorOptions[0]; + } + + private setupStandaloneMode() { + // Display the standalone UI controls. + this.dom.select('.standalone-controls').style('display', null); + + // Demo dataset dropdown + let demoDatasetChanged = (demoDataSet: DemoDataset) => { + if (demoDataSet == null) { + return; + } + + this.dom.selectAll('.file-name').style('display', 'none'); + let separator = demoDataSet.fpath.substr(-3) == 'tsv' ? '\t' : ' '; + fetchDemoData(`${DATA_URL}/${demoDataSet.fpath}`, separator) + .then(points => { + + let p1 = demoDataSet.metadata_path ? + new Promise<ColumnStats[]>((resolve, reject) => { + updateMessage('Fetching metadata...'); + d3.text( + `${DATA_URL}/${demoDataSet.metadata_path}`, + (err: Error, rawMetadata: string) => { + if (err) { + console.error(err); + reject(err); + return; + } + resolve(parseAndMergeMetadata(rawMetadata, points)); + }); + }) : + null; + + let p2 = demoDataSet.metadata && demoDataSet.metadata.image ? + fetchImage( + `${DATA_URL}/${demoDataSet.metadata.image.sprite_fpath}`) : + null; + + Promise.all([p1, p2]).then(values => { + this.updateMetadataUI(values[0]); + let dataSource = new DataSource(); + dataSource.originalDataSet = new DataSet(points); + dataSource.spriteImage = values[1]; + dataSource.metadata = demoDataSet.metadata; + this.dataSource = dataSource; + }); + }); + }; + + this.dom.on('selected-demo-changed', () => { + demoDatasetChanged(DEMO_DATASETS[this.selectedDemo]); + }); + demoDatasetChanged(DEMO_DATASETS[this.selectedDemo]); + + // Show and setup the upload button. + let fileInput = this.dom.select('#file'); + fileInput.on('change', () => { + let file: File = (<any>d3.event).target.files[0]; + this.dom.select('#file-name') + .style('display', null) + .text(file.name) + .attr('title', file.name); + // Clear out the value of the file chooser. This ensures that if the user + // selects the same file, we'll re-read it. + (<any>d3.event).target.value = ''; + // Clear the value of the datasets dropdown. + this.selectedDemo = null; + let fileReader = new FileReader(); + fileReader.onload = evt => { + let str: string = (evt.target as any).result; + parseTensors(str).then(data => { + let dataSource = new DataSource(); + dataSource.originalDataSet = new DataSet(data); + this.dataSource = dataSource; + }); + }; + fileReader.readAsText(file); + }); + + let uploadButton = this.dom.select('#upload'); + uploadButton.on( + 'click', () => { (<HTMLInputElement>fileInput.node()).click(); }); + + // Show and setup the upload metadata button. + let fileMetadataInput = this.dom.select('#file-metadata'); + fileMetadataInput.on('change', () => { + let file: File = (<any>d3.event).target.files[0]; + this.dom.select('#file-metadata-name') + .style('display', null) + .text(file.name) + .attr('title', file.name); + // Clear out the value of the file chooser. This ensures that if the user + // selects the same file, we'll re-read it. + (<any>d3.event).target.value = ''; + // Clear the value of the datasets dropdown. + this.selectedDemo = null; + let fileReader = new FileReader(); + fileReader.onload = evt => { + let str: string = (evt.target as any).result; + parseAndMergeMetadata(str, this.dataSource.originalDataSet.points) + .then(columnStats => { + this.updateMetadataUI(columnStats); + // Must make a shallow copy, otherwise polymer will not + // fire the 'data-changed' event, even if we explicitly + // call this.fire(). + this.dataSource = this.dataSource.makeShallowCopy(); + }); + }; + fileReader.readAsText(file); + }); + + let uploadMetadataButton = this.dom.select('#upload-metadata'); + uploadMetadataButton.on('click', () => { + (<HTMLInputElement>fileMetadataInput.node()).click(); + }); + } + + private selectedTensorChanged(name: string) { + // Get the tensor. + updateMessage('Fetching tensor values...'); + d3.text(`${DATA_URL}/tensor?name=${name}`, (err: Error, tsv: string) => { + if (err) { + console.error(err); + return; + } + parseTensors(tsv).then(dataPoints => { + updateMessage('Fetching metadata...'); + d3.text(`${DATA_URL}/metadata`, (err: Error, rawMetadata: string) => { + if (err) { + console.error(err); + return; + } + parseAndMergeMetadata(rawMetadata, dataPoints).then(columnStats => { + this.updateMetadataUI(columnStats); + let dataSource = new DataSource(); + dataSource.originalDataSet = new DataSet(dataPoints); + this.dataSource = dataSource; + }); + }); + }); + }); + } + + private getNumTensorsLabel(tensorNames: string[]) { + return tensorNames.length === 1 ? '1 tensor' : + tensorNames.length + ' tensors'; + } +} + +function fetchImage(url: string): Promise<HTMLImageElement> { + return new Promise<HTMLImageElement>((resolve, reject) => { + let image = new Image(); + image.onload = () => resolve(image); + image.onerror = (err) => reject(err); + image.src = url; + }); +} + +/** Makes a network request for a delimited text file. */ +function fetchDemoData(url: string, separator: string): Promise<DataPoint[]> { + return new Promise<DataPoint[]>((resolve, reject) => { + updateMessage('Fetching tensors...'); + d3.text(url, (error: Error, dataString: string) => { + if (error) { + console.error(error); + updateMessage('Error loading data.'); + reject(error); + } else { + parseTensors(dataString, separator).then(data => resolve(data)); + } + }); + }); +} + +/** Parses a tsv text file. */ +function parseTensors(content: string, delim = '\t'): Promise<DataPoint[]> { + let data: DataPoint[] = []; + let numDim: number; + return runAsyncTask('Parsing tensors...', () => { + let lines = content.split('\n'); + lines.forEach(line => { + line = line.trim(); + if (line == '') { + return; + } + let row = line.split(delim); + let dataPoint: DataPoint = { + metadata: {}, + vector: null, + dataSourceIndex: data.length, + projections: null, + projectedPoint: null + }; + // If the first label is not a number, take it as the label. + if (isNaN(row[0] as any) || numDim == row.length - 1) { + dataPoint.metadata['label'] = row[0]; + dataPoint.vector = row.slice(1).map(Number); + } else { + dataPoint.vector = row.map(Number); + } + data.push(dataPoint); + if (numDim == null) { + numDim = dataPoint.vector.length; + } + if (numDim != dataPoint.vector.length) { + updateMessage('Parsing failed. Vector dimensions do not match'); + throw Error('Parsing failed'); + } + if (numDim <= 1) { + updateMessage( + 'Parsing failed. Found a vector with only one dimension?'); + throw Error('Parsing failed'); + } + }); + return data; + }); +} + +/** Statistics for a metadata column. */ +type ColumnStats = { + name: string; isNumeric: boolean; tooManyUniqueValues: boolean; + uniqueValues?: string[]; + min: number; + max: number; +}; + +function parseAndMergeMetadata( + content: string, data: DataPoint[]): Promise<ColumnStats[]> { + return runAsyncTask('Parsing metadata...', () => { + let lines = content.split('\n').filter(line => line.trim().length > 0); + let hasHeader = (lines.length - 1 == data.length); + + // Dimension mismatch. + if (lines.length != data.length && !hasHeader) { + throw Error('Dimensions do not match'); + } + + // If the first row doesn't contain metadata keys, we assume that the values + // are labels. + let columnNames: string[] = ['label']; + if (hasHeader) { + columnNames = lines[0].split('\t'); + lines = lines.slice(1); + } + let columnStats: ColumnStats[] = columnNames.map(name => { + return { + name: name, + isNumeric: true, + tooManyUniqueValues: false, + min: Number.POSITIVE_INFINITY, + max: Number.NEGATIVE_INFINITY + }; + }); + let setOfValues = columnNames.map(() => d3.set()); + lines.forEach((line: string, i: number) => { + let rowValues = line.split('\t'); + data[i].metadata = {}; + columnNames.forEach((name: string, colIndex: number) => { + let value = rowValues[colIndex]; + let set = setOfValues[colIndex]; + let stats = columnStats[colIndex]; + data[i].metadata[name] = value; + + // Update stats. + if (!stats.tooManyUniqueValues) { + set.add(value); + if (set.size() > NUM_COLORS_COLOR_MAP) { + stats.tooManyUniqueValues = true; + } + } + if (isNaN(value as any)) { + stats.isNumeric = false; + } else { + stats.min = Math.min(stats.min, +value); + stats.max = Math.max(stats.max, +value); + } + }); + }); + columnStats.forEach((stats, colIndex) => { + let set = setOfValues[colIndex]; + if (!stats.tooManyUniqueValues) { + stats.uniqueValues = set.values(); + } + }); + return columnStats; + }); +} + +document.registerElement(DataLoader.prototype.is, DataLoader); diff --git a/tensorflow/tensorboard/components/vz-projector/vz-projector-util.ts b/tensorflow/tensorboard/components/vz-projector/vz-projector-util.ts new file mode 100644 index 0000000000..cca1ea3988 --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/vz-projector-util.ts @@ -0,0 +1,33 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +export type Spec = { + is: string; properties: { + [key: string]: + (Function | + { + type: Function, value?: any; + readonly?: boolean; + notify?: boolean; + observer?: string; + }) + }; +}; + +export function PolymerElement(spec: Spec) { + return Polymer.Class(spec as any) as{new (): PolymerHTMLElement}; +} + +export interface PolymerHTMLElement extends HTMLElement, polymer.Base {} diff --git a/tensorflow/tensorboard/components/vz-projector/vz-projector.html b/tensorflow/tensorboard/components/vz-projector/vz-projector.html new file mode 100644 index 0000000000..25a90bb3d8 --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/vz-projector.html @@ -0,0 +1,628 @@ +<!-- +@license +Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--> + +<link rel="import" href="../polymer/polymer.html"> +<link rel="import" href="../paper-toggle-button/paper-toggle-button.html"> +<link rel="import" href="../paper-listbox/paper-listbox.html"> +<link rel="import" href="../paper-item/paper-item.html"> +<link rel="import" href="../paper-checkbox/paper-checkbox.html"> +<link rel="import" href="vz-projector-data-loader.html"> + +<dom-module id='vz-projector'> +<template> +<style> + +:host { + display: flex; + width: 100%; + height: 100%; +} + +#container { + display: flex; + width: 100%; + height: 100%; +} + +.hidden { + display: none !important; +} + +/* Main */ + +#main { + position: relative; + flex-grow: 2; +} + +#main .stage { + position: relative; + flex-grow: 2; +} + +#scatter { + position: absolute; + top: 0; + left: 0; + right: 0; + bottom: 0; +} + +#left-pane { + display: flex; + flex-direction: column; + min-width: 312px; + width: 312px; + border-right: 1px solid rgba(0, 0, 0, 0.1); + background: #fafafa; +} + +#right-pane { + min-width: 300px; + width: 300px; + border-left: 1px solid rgba(0, 0, 0, 0.1); + background: #fafafa; +} + +.file-name { + margin-right: 5px; +} + +.control label { + font-size: 12px; + color: rgba(0, 0, 0, 0.7); + margin-top: 10px; + font-weight: 500; +} + +.control .info { + display: block; + font-size: 12px; + color: rgba(0, 0, 0, 0.2); + margin-bottom: 18px; + white-space: nowrap; +} + +.control input[type=text] { + font-weight: 300; + font-size: 16px; + display: block; + padding: 8px 0; + margin: 0 0 8px 0; + width: 100%; + box-sizing: border-box; + border: none; + border-bottom: 1px solid rgba(0, 0, 0, 0.2); + background: none; +} + +.slider { + display: flex; + align-items: center; + margin-bottom: 10px; + justify-content: space-between; +} + +.slider span { + width: 35px; + text-align: right; +} + +.control input[type=text]:focus { + outline: none; + border-bottom: 1px solid rgba(0, 0, 0, 1); +} + +.control { + display: inline-block; + width: 45%; + vertical-align: top; + margin-right: 10px; + overflow-x: hidden; +} + +.control.last { + margin-right: 0; +} + +#wrapper-notify-msg { + z-index: 1; + position: fixed; + top: 10px; + width: 100%; + display: flex; + justify-content: center; +} + +#notify-msg { + display: none; + font-weight: 500; + color: black; + background-color: #FFF9C4; + padding: 5px; + border: 1px solid #FBC02D; +} + +.brush .extent { + stroke: #fff; + fill-opacity: .125; + shape-rendering: crispEdges; +} + +.ink-panel-content { + display: none; +} + +.ink-panel-content.active { + display: block; +} + +.nn-list .neighbor { + font-size: 12px; + margin-bottom: 6px; +} + +.nn-list .neighbor .value { + float: right; + color: #666; + font-weight: 300; +} + +.nn-list .neighbor .bar { + position: relative; + border-top: 1px solid rgba(0, 0, 0, 0.15); + margin: 2px 0; +} + +.nn-list .neighbor .bar .fill { + position: absolute; + top: -1px; + border-top: 1px solid white; +} + +.nn-list .neighbor .tick { + position: absolute; + top: 0px; + height: 3px; + border-left: 1px solid rgba(0, 0, 0, 0.15); +} + +.nn-list .neighbor-link:hover { + cursor: pointer; +} + +.origin text { + font-size: 12px; + font-weight: 500; +} + +.origin line { + stroke: black; + stroke-opacity: 0.2; +} + +/* Ink Framework */ + +/* - Buttons */ +.ink-button, ::shadow .ink-button { + border: none; + border-radius: 2px; + font-size: 13px; + padding: 10px; + min-width: 100px; + flex-shrink: 0; + background: #e3e3e3; +} + +/* - Tabs */ + +.ink-tab-group { + display: flex; + justify-content: space-around; + box-sizing: border-box; + height: 100%; + margin: 0 12px; +} + +.ink-tab-group .ink-tab { + font-weight: 300; + color: rgba(0, 0, 0, 0.5); + text-align: center; + text-transform: uppercase; + line-height: 60px; + cursor: pointer; + padding: 0 12px; +} + +.ink-tab-group .ink-tab:hover { + color: black; +} + +.ink-tab-group .ink-tab.active { + font-weight: 500; + color: black; + border-bottom: 2px solid black; + +} + +/* - Panel */ + +.ink-panel { + display: flex; + flex-direction: column; + font-size: 14px; + line-height: 1.45em; +} + +.ink-panel h4 { + font-size: 14px; + font-weight: 500; + margin: 0; + border-bottom: 1px solid #ddd; + padding-bottom: 5px; + margin-bottom: 10px; +} + +.ink-panel-header { + height: 60px; + border-bottom: 1px solid rgba(0, 0, 0, 0.1); +} + +.ink-panel-metadata-container span { + font-size: 16px; +} + +.ink-panel-metadata { + border-bottom: 1px solid #ccc; + display: table; + padding: 10px 0; + width: 100%; +} + +.ink-panel-metadata-row { + display: table-row; +} + +.ink-panel-metadata-key { + font-weight: bold; +} + +.ink-panel-metadata-key, +.ink-panel-metadata-value { + display: table-cell; + padding-right: 10px; +} + +.ink-panel-buttons { + margin-bottom: 10px; +} + +.ink-panel-content { + padding: 24px; + overflow-y: auto; +} + +.ink-panel-content .distance a { + text-decoration: none; + color: black; +} + +.ink-panel-content .distance a.selected { + color: black; + border-bottom: 2px solid black; +} + +.ink-panel-footer { + display: flex; + align-items: center; + border-top: solid 1px #eee; + height: 60px; + padding: 0 24px; + color: rgba(0, 0, 0, 0.5); +} + +.ink-panel-content h3 { + font-weight: 500; + font-size: 14px; + text-transform: uppercase; + margin-top: 20px; + margin-bottom: 5px; +} + +.ink-panel-header h3 { + margin: 0; + font-weight: 500; + font-size: 14px; + line-height: 60px; + text-transform: uppercase; + padding: 0 24px; +} + +/* - Menubar */ + +.ink-panel-menubar { + position: relative; + height: 60px; + border-bottom: solid 1px #eee; + padding: 0 24px; +} + +.ink-panel-menubar .material-icons { + color: black; +} + +.ink-panel-menubar .menu-button { + margin-right: 12px; + cursor: pointer; + line-height: 60px; + border: none; + background: none; + font-size: 13px; + font-weight: 200; + padding: 0; + margin: 0 20px 0 0; + outline: none; + color: #666; +} + +.ink-panel-menubar button .material-icons { + position: relative; + top: 7px; + margin-right: 8px; + color: #999; +} + +.ink-panel-menubar button.selected, +.ink-panel-menubar button.selected .material-icons { + color: #880E4F; +} + +.ink-panel-menubar .ink-fabs { + position: absolute; + right: 24px; + top: 60px; + z-index: 1; +} + +.ink-panel-menubar .ink-fabs .ink-fab { + position: relative; + top: -20px; + width: 40px; + height: 40px; + border: 1px solid rgba(0, 0, 0, 0.02); + border-radius: 50%; + display: inline-block; + background: white; + margin-left: 8px; + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.3); + cursor: pointer; +} + +.ink-panel-menubar .ink-fabs .ink-fab .material-icons { + margin: 0; + display: block; + line-height: 24px; + position: absolute; + top: calc(50% - 12px); + left: calc(50% - 12px); +} + +.ink-panel-menubar .search-box { + transition: width .2s; + margin-left: -65px; + width: 0; + margin-right: 65px; + background: white; +} + +.two-way-toggle { + display: flex; + flex-direction: row; +} + +.two-way-toggle span { + padding-right: 7px; +} + +paper-listbox .pca-item { + cursor: pointer; + min-height: 17px; + font-size: 12px; + line-height: 17px; +} + +.has-border { + border: 1px solid rgba(0, 0, 0, 0.1); +} + +</style> +<link href="https://fonts.googleapis.com/css?family=Roboto:300,400,500|Material+Icons" rel="stylesheet" type="text/css"> +<div id="wrapper-notify-msg"> + <div id="notify-msg">Loading...</div> +</div> +<div id="container"> + <div id="left-pane" class="ink-panel"> + <div class="ink-panel-header"> + <div class="ink-tab-group"> + <div data-tab="tsne" class="ink-tab" title="t-distributed stochastic neighbor embedding">t-SNE</div> + <div data-tab="pca" class="ink-tab" title="Principal component analysis">PCA</div> + <div data-tab="custom" class="ink-tab" title="Linear projection of two custom vectors">Custom</div> + </div> + </div> + <!-- TSNE Controls --> + <div data-panel="tsne" class="ink-panel-content"> + <p><a href="https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding">t-distributed stochastic neighbor embedding</a> is a dimensionality reduction technique</p> + <p style="color: #880E4F; font-weight: bold;">For fast results, your data will be sampled down to 10,000 points.</p> + <div class="slider"><label>Dimension</label><div class="two-way-toggle"><span>2D</span><paper-toggle-button id="tsne-toggle" noink checked>3D</paper-toggle-button></div></div> + <div class="slider tsne-perplexity"> + <label>Perplexity</label> + <input type="range" min="2" max="100"></input> + <span></span> + </div> + <div class="slider tsne-learning-rate"> + <label>Learning rate</label> + <input type="range" min="-3" max="2"></input> + <span></span> + </div> + <p>The most appropriate perplexity value depends on the density of your data. Loosely speaking, one could say that a larger / denser dataset requires a larger perplexity. Typical values for the perplexity range between 5 and 50.</p> + <p>The most appropriate learning rate depends on the size of your data, with smaller datasets requiring smaller learning rates.</p> + <p> + <button class="run-tsne ink-button">Run</button> + <button class="stop-tsne ink-button">Stop</button> + </p> + <p>Iteration: <span class="run-tsne-iter">0</span></p> + </div> + <!-- PCA Controls --> + <div data-panel="pca" class="ink-panel-content"> + <p><a href="https://en.wikipedia.org/wiki/Principal_component_analysis">Principal component analysis</a> is a dimensionality reduction technique</p> + <label>X</label> + <paper-listbox class="has-border" selected="{{pcaX}}"> + <template is="dom-repeat" items="[[pcaComponents]]"> + <paper-item class="pca-item">Component #[[item]]</paper-item> + </template> + </paper-listbox> + <br/> + <label>Y</label> + <paper-listbox class="has-border" selected="{{pcaY}}"> + <template is="dom-repeat" items="[[pcaComponents]]"> + <paper-item class="pca-item">Component #[[item]]</paper-item> + </template> + </paper-listbox> + <br/> + <paper-checkbox noink id="z-checkbox" checked="{{hasPcaZ}}">Z</paper-checkbox> + <paper-listbox class="has-border" disabled="true" selected="{{pcaZ}}"> + <template is="dom-repeat" items="[[pcaComponents]]"> + <paper-item disabled="[[!hasPcaZ]]" class="pca-item">Component #[[item]]</paper-item> + </template> + </paper-listbox> + </div> + <!-- Custom Controls --> + <div data-panel="custom" class="ink-panel-content"> + <p>Search for two vectors upon which to project all points. Use <code>/regex/</code> to signal a regular expression, otherwise does an exact match.<p> + <h3>Horizontal</h3> + <div class="control xLeft"> + <label>Left</label> + <input type="text" value="/\./"></input> + <span class="info"></span> + </div> + <div class="control xRight last"> + <label>Right</label> + <input type="text" value="/!/"></input> + <span class="info"> </span> + </div> + <h3>Vertical</h3> + <div class="control yUp"> + <label>Up</label> + <input type="text" value="/./"></input> + <span class="info"> </span> + </div> + <div class="control yDown last"> + <label>Down</label> + <input type="text" value="/\?/"></input> + <span class="info"> </span> + </div> + </div> + </div> + <div id="main" class="ink-panel"> + <div class="ink-panel-menubar"> + <button class="menu-button search" title="Search"> + <i class="material-icons">search</i> + <span class="button-label">Search</span> + </button> + <div class="control search-box"> + <input type="text" value=""> + <span class="info"></span> + </div> + <button class="menu-button selectMode" title="Bounding box selection"> + <i class="material-icons">photo_size_select_small</i> + Select + </button> + <button class="menu-button show-labels selected" title="Show/hide labels"> + <i class="material-icons">text_fields</i> + Labels + </button> + <button class="menu-button nightDayMode" title="Toggle between night and day mode"> + <i class="material-icons">brightness_2</i> + Night Mode + </button> + <div class="ink-fabs"> + <div class="ink-fab reset-zoom" title="Zoom to fit all"> + <i class="material-icons resetZoom">home</i> + </div> + <div class="ink-fab zoom-in" title="Zoom in"> + <i class="material-icons">add</i> + </div> + <div class="ink-fab zoom-out" title="Zoom out"> + <i class="material-icons">remove</i> + </div> + </div> + </div> + <div class="stage"> + <div id="scatter"></div> + </div> + <div id="info-panel" class="ink-panel-footer"> + <div> + Number of data points: <span class="numDataPoints"></span>, dimension of embedding: <span class="dim"></span> + | <span id="hoverInfo"></span> + </div> + </div> + </div> + <div id="right-pane" class="ink-panel"> + <div class="ink-panel-header"> + <div class="ink-tab-group"> + <div data-tab="data" class="active ink-tab" title="Setup data">Data</div> + <div data-tab="inspector" class="ink-tab" title="Inspect data">Inspector</div> + </div> + </div> + + <!-- Inspector UI controls --> + <div data-panel="inspector" class="ink-panel-content"> + <div class="ink-panel-metadata-container" style="display: none"> + <span>Metadata</span> + <div class="ink-panel-metadata"></div> + </div> + <div class="ink-panel-buttons"> + <div style="margin-bottom: 10px"> + <button style="display: none;" class="ink-button reset-filter">Show All Data</button> + </div> + <button style="display: none;" class="ink-button set-filter">Isolate selection</button> + <button class="ink-button clear-selection" style="display: none;">Clear selection</button> + </div> + <div class="slider num-nn"> + <label>Number of neighbors</label> + <input type="range" min="5" max="1000"></input> + <span></span> + </div> + <div class="distance"> + Distance: + <div style="float:right"> + <a class="selected cosine" href="javascript:void(0);">cosine</a> | + <a class="euclidean" href="javascript:void(0);">euclidean</a> + </div> + </div> + <p>Nearest points to <b id="nn-title"></b></p> + <div class="nn-list"></div> + </div> + + <!-- Data UI controls --> + <div data-panel="data" class="active ink-panel-content"> + <vz-projector-data-loader data-source="{{dataSource}}" label-option="{{labelOption}}" color-option="{{colorOption}}"></vz-projector-data-loader> + </div> + </div> +</div> +</template> +</dom-module> diff --git a/tensorflow/tensorboard/components/vz-projector/vz-projector.ts b/tensorflow/tensorboard/components/vz-projector/vz-projector.ts new file mode 100644 index 0000000000..e8b1a882ec --- /dev/null +++ b/tensorflow/tensorboard/components/vz-projector/vz-projector.ts @@ -0,0 +1,762 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import {DataPoint, DataSet, DataSource} from './data'; +import * as knn from './knn'; +import {Mode, Scatter} from './scatter'; +import {ScatterWebGL} from './scatterWebGL'; +import * as vector from './vector'; +import {ColorOption} from './vz-projector-data-loader'; +import {PolymerElement} from './vz-projector-util'; + + +/** T-SNE perplexity. Roughly how many neighbors each point influences. */ +let perplexity: number = 30; +/** T-SNE learning rate. */ +let learningRate: number = 10; +/** Number of dimensions for the scatter plot. */ +let dimension = 3; +/** Number of nearest neighbors to highlight around the selected point. */ +let numNN = 100; + +/** Highlight stroke color for the nearest neighbors. */ +const NN_HIGHLIGHT_COLOR = '#6666FA'; + +/** Highlight stroke color for the selected point */ +const POINT_HIGHLIGHT_COLOR_DAY = 'black'; +const POINT_HIGHLIGHT_COLOR_NIGHT = new THREE.Color(0xFFE11F).getStyle(); + +/** Color scale for nearest neighbors. */ +const NN_COLOR_SCALE = + d3.scale.linear<string>() + .domain([1, 0.7, 0.4]) + .range(['hsl(285, 80%, 40%)', 'hsl(0, 80%, 65%)', 'hsl(40, 70%, 60%)']) + .clamp(true); + +/** Text color used for error/important messages. */ +const CALLOUT_COLOR = '#880E4F'; + +type Centroids = { + [key: string]: number[]; xLeft: number[]; xRight: number[]; yUp: number[]; + yDown: number[]; +}; + +let ProjectorPolymer = PolymerElement({ + is: 'vz-projector', + properties: { + // A data source. + dataSource: { + type: Object, // DataSource + observer: 'dataSourceChanged' + }, + + // Private. + pcaComponents: {type: Array, value: d3.range(1, 11)}, + pcaX: { + type: Number, + value: 0, + notify: true, + }, + pcaY: { + type: Number, + value: 1, + notify: true, + }, + pcaZ: { + type: Number, + value: 2, + notify: true, + }, + hasPcaZ: {type: Boolean, value: true, notify: true}, + labelOption: {type: String, observer: 'labelOptionChanged'}, + colorOption: {type: Object, observer: 'colorOptionChanged'}, + } +}); + +class Projector extends ProjectorPolymer { + // Public API. + dataSource: DataSource; + + private dom: d3.Selection<any>; + private pcaX: number; + private pcaY: number; + private pcaZ: number; + private hasPcaZ: boolean; + // The working subset of the data source's original data set. + private currentDataSet: DataSet; + private scatter: Scatter; + private dim: number; + private selectedDistance: (a: number[], b: number[]) => number; + private highlightedPoints: {index: number, color: string}[]; + private selectedPoints: number[]; + private centroidValues: any; + private centroids: Centroids; + /** The centroid across all points. */ + private allCentroid: number[]; + private labelOption: string; + private colorOption: ColorOption; + + ready() { + this.hasPcaZ = true; + this.selectedDistance = vector.cosDistNorm; + this.highlightedPoints = []; + this.selectedPoints = []; + this.centroidValues = {xLeft: null, xRight: null, yUp: null, yDown: null}; + this.centroids = {xLeft: null, xRight: null, yUp: null, yDown: null}; + // Dynamically creating elements inside .nn-list. + this.scopeSubtree(this.$$('.nn-list'), true); + this.dom = d3.select(this); + // Sets up all the UI. + this.setupUIControls(); + if (this.dataSource) { + this.dataSourceChanged(); + } + } + + labelOptionChanged() { + let labelAccessor = (i: number): string => { + return this.points[i].metadata[this.labelOption] as string; + }; + this.scatter.setLabelAccessor(labelAccessor); + } + + colorOptionChanged() { + let colorMap = this.colorOption.map; + if (colorMap == null) { + this.scatter.setColorAccessor(null); + return; + }; + let colors = (i: number) => { + return colorMap(this.points[i].metadata[this.colorOption.name]); + }; + this.scatter.setColorAccessor(colors); + } + + dataSourceChanged() { + if (this.scatter == null || this.dataSource == null) { + // We are not ready yet. + return; + } + this.initFromSource(this.dataSource); + // Set the container to a fixed height, otherwise in Colab the + // height can grow indefinitely. + let container = this.dom.select('#container'); + container.style('height', container.property('clientHeight') + 'px'); + } + + /** + * Normalizes the distance so it can be visually encoded with color. + * The normalization depends on the distance metric (cosine vs euclidean). + */ + private normalizeDist(d: number, minDist: number): number { + return this.selectedDistance === vector.cosDistNorm ? 1 - d : minDist / d; + } + + /** Normalizes and encodes the provided distance with color. */ + private dist2color(d: number, minDist: number): string { + return NN_COLOR_SCALE(this.normalizeDist(d, minDist)); + } + + private initFromSource(source: DataSource) { + this.dataSource = source; + this.setDataSet(this.dataSource.getDataSet()); + this.dom.select('.reset-filter').style('display', 'none'); + // Regexp inputs. + this.setupInput('xLeft'); + this.setupInput('xRight'); + this.setupInput('yUp'); + this.setupInput('yDown'); + } + + private setDataSet(ds: DataSet) { + this.currentDataSet = ds; + this.scatter.setDataSet(this.currentDataSet, this.dataSource.spriteImage); + this.updateMenuButtons(); + this.dim = this.currentDataSet.dim[1]; + this.dom.select('span.numDataPoints').text(this.currentDataSet.dim[0]); + this.dom.select('span.dim').text(this.currentDataSet.dim[1]); + this.showTab('pca'); + } + + private setupInput(name: string) { + let control = this.dom.select('.control.' + name); + let info = control.select('.info'); + + let updateInput = (value: string) => { + if (value.trim() === '') { + info.style('color', CALLOUT_COLOR).text('Enter a regex.'); + return; + } + let result = this.getCentroid(value); + if (result.error) { + info.style('color', CALLOUT_COLOR) + .text('Invalid regex. Using a random vector.'); + result.centroid = vector.rn(this.dim); + } else if (result.numMatches === 0) { + info.style('color', CALLOUT_COLOR) + .text('0 matches. Using a random vector.'); + result.centroid = vector.rn(this.dim); + } else { + info.style('color', null).text(`${result.numMatches} matches.`); + } + this.centroids[name] = result.centroid; + this.centroidValues[name] = value; + }; + let self = this; + + let input = control.select('input').on('input', function() { + updateInput(this.value); + self.showCustom(); + }); + this.allCentroid = null; + // Init the control with the current input. + updateInput((input.node() as HTMLInputElement).value); + } + + private setupUIControls() { + let self = this; + // Global tabs + d3.selectAll('.ink-tab').on('click', function() { + let id = this.getAttribute('data-tab'); + self.showTab(id); + }); + + // Unknown why, but the polymer toggle button stops working + // as soon as you do d3.select() on it. + let tsneToggle = this.querySelector('#tsne-toggle') as HTMLInputElement; + let zCheckbox = this.querySelector('#z-checkbox') as HTMLInputElement; + + // PCA controls. + zCheckbox.addEventListener('change', () => { + // Make sure tsne stays in the same dimension as PCA. + dimension = this.hasPcaZ ? 3 : 2; + tsneToggle.checked = this.hasPcaZ; + this.showPCA(() => { this.scatter.recreateScene(); }); + }); + this.dom.on('pca-x-changed', () => this.showPCA()); + this.dom.on('pca-y-changed', () => this.showPCA()); + this.dom.on('pca-z-changed', () => this.showPCA()); + + // TSNE controls. + + tsneToggle.addEventListener('change', () => { + // Make sure PCA stays in the same dimension as tsne. + this.hasPcaZ = tsneToggle.checked; + dimension = tsneToggle.checked ? 3 : 2; + if (this.scatter) { + this.showTSNE(); + this.scatter.recreateScene(); + } + }); + + this.dom.select('.run-tsne').on('click', () => this.runTSNE()); + this.dom.select('.stop-tsne').on('click', () => { + this.currentDataSet.stopTSNE(); + }); + + let updatePerplexity = () => { + perplexity = +perplexityInput.property('value'); + this.dom.select('.tsne-perplexity span').text(perplexity); + }; + let perplexityInput = this.dom.select('.tsne-perplexity input') + .property('value', perplexity) + .on('input', updatePerplexity); + updatePerplexity(); + + let updateLearningRate = () => { + let val = +learningRateInput.property('value'); + learningRate = Math.pow(10, val); + this.dom.select('.tsne-learning-rate span').text(learningRate); + }; + let learningRateInput = this.dom.select('.tsne-learning-rate input') + .property('value', 1) + .on('input', updateLearningRate); + updateLearningRate(); + + // Nearest neighbors controls. + let updateNumNN = () => { + numNN = +numNNInput.property('value'); + this.dom.select('.num-nn span').text(numNN); + }; + let numNNInput = this.dom.select('.num-nn input') + .property('value', numNN) + .on('input', updateNumNN); + updateNumNN(); + + // View controls + this.dom.select('.reset-zoom').on('click', () => { + this.scatter.resetZoom(); + }); + this.dom.select('.zoom-in').on('click', () => { + this.scatter.zoomStep(2); + }); + this.dom.select('.zoom-out').on('click', () => { + this.scatter.zoomStep(0.5); + }); + + // Toolbar controls + let searchBox = this.dom.select('.control.search-box'); + let searchBoxInfo = searchBox.select('.info'); + + let searchByRegEx = + (pattern: string): {error?: Error, indices: number[]} => { + let regEx: RegExp; + try { + regEx = new RegExp(pattern, 'i'); + } catch (e) { + return {error: e.message, indices: null}; + } + let indices: number[] = []; + for (let id = 0; id < this.points.length; ++id) { + if (regEx.test('' + this.points[id].metadata['label'])) { + indices.push(id); + } + } + return {indices: indices}; + }; + + // Called whenever the search text input changes. + let searchInputChanged = (value: string) => { + if (value.trim() === '') { + searchBoxInfo.style('color', CALLOUT_COLOR).text('Enter a regex.'); + if (this.scatter != null) { + this.selectedPoints = []; + this.selectionWasUpdated(); + } + return; + } + let result = searchByRegEx(value); + let indices = result.indices; + if (result.error) { + searchBoxInfo.style('color', CALLOUT_COLOR).text('Invalid regex.'); + } + if (indices) { + if (indices.length === 0) { + searchBoxInfo.style('color', CALLOUT_COLOR).text(`0 matches.`); + } else { + searchBoxInfo.style('color', null).text(`${indices.length} matches.`); + this.showTab('inspector'); + let neighbors = this.findNeighbors(indices[0]); + if (indices.length === 1) { + this.scatter.clickOnPoint(indices[0]); + } + this.selectedPoints = indices; + this.updateNNList(neighbors); + } + this.selectionWasUpdated(); + } + }; + + searchBox.select('input').on( + 'input', function() { searchInputChanged(this.value); }); + let searchButton = this.dom.select('.search'); + + searchButton.on('click', () => { + let mode = this.scatter.getMode(); + this.scatter.setMode(mode === Mode.SEARCH ? Mode.HOVER : Mode.SEARCH); + if (this.scatter.getMode() == Mode.HOVER) { + this.selectedPoints = []; + this.selectionWasUpdated(); + } else { + searchInputChanged(searchBox.select('input').property('value')); + } + this.updateMenuButtons(); + }); + // Init the control with an empty input. + searchInputChanged(''); + + this.dom.select('.distance a.euclidean').on('click', function() { + d3.selectAll('.distance a').classed('selected', false); + d3.select(this).classed('selected', true); + self.selectedDistance = vector.dist; + if (self.selectedPoints.length > 0) { + let neighbors = self.findNeighbors(self.selectedPoints[0]); + self.updateNNList(neighbors); + } + }); + + this.dom.select('.distance a.cosine').on('click', function() { + d3.selectAll('.distance a').classed('selected', false); + d3.select(this).classed('selected', true); + self.selectedDistance = vector.cosDistNorm; + if (self.selectedPoints.length > 0) { + let neighbors = self.findNeighbors(self.selectedPoints[0]); + self.updateNNList(neighbors); + } + }); + + let selectModeButton = this.dom.select('.selectMode'); + + selectModeButton.on('click', () => { + let mode = this.scatter.getMode(); + this.scatter.setMode(mode === Mode.SELECT ? Mode.HOVER : Mode.SELECT); + this.updateMenuButtons(); + }); + + let showLabels = true; + let showLabelsButton = this.dom.select('.show-labels'); + showLabelsButton.on('click', () => { + showLabels = !showLabels; + this.scatter.showLabels(showLabels); + showLabelsButton.classed('selected', showLabels); + }); + + let dayNightModeButton = this.dom.select('.nightDayMode'); + let modeIsNight = dayNightModeButton.classed('selected'); + dayNightModeButton.on('click', () => { + modeIsNight = !modeIsNight; + this.scatter.setDayNightMode(modeIsNight); + this.scatter.update(); + dayNightModeButton.classed('selected', modeIsNight); + }); + + // Resize + window.addEventListener('resize', () => { this.scatter.resize(); }); + + // Canvas + this.scatter = new ScatterWebGL( + this.dom.select('#scatter'), + i => '' + this.points[i].metadata['label']); + this.scatter.onHover(hoveredIndex => { + if (hoveredIndex == null) { + this.highlightedPoints = []; + } else { + let point = this.points[hoveredIndex]; + this.dom.select('#hoverInfo').text(point.metadata['label']); + let neighbors = this.findNeighbors(hoveredIndex); + let minDist = neighbors[0].dist; + let pointIndices = [hoveredIndex].concat(neighbors.map(d => d.index)); + let pointHighlightColor = modeIsNight ? POINT_HIGHLIGHT_COLOR_NIGHT : + POINT_HIGHLIGHT_COLOR_DAY; + this.highlightedPoints = pointIndices.map((index, i) => { + let color = i == 0 ? pointHighlightColor : + this.dist2color(neighbors[i - 1].dist, minDist); + return {index: index, color: color}; + }); + } + this.selectionWasUpdated(); + }); + + this.scatter.onSelection( + selectedPoints => this.updateSelection(selectedPoints)); + + // Selection controls + this.dom.select('.set-filter').on('click', () => { + let highlighted = this.selectedPoints; + let highlightedOrig: number[] = + highlighted.map(d => { return this.points[d].dataSourceIndex; }); + let subset = this.dataSource.getDataSet(highlightedOrig); + this.setDataSet(subset); + this.dom.select('.reset-filter').style('display', null); + this.selectedPoints = []; + this.scatter.recreateScene(); + this.selectionWasUpdated(); + this.updateIsolateButton(); + }); + + this.dom.select('.reset-filter').on('click', () => { + let subset = this.dataSource.getDataSet(); + this.setDataSet(subset); + this.dom.select('.reset-filter').style('display', 'none'); + }); + + this.dom.select('.clear-selection').on('click', () => { + this.selectedPoints = []; + this.scatter.setMode(Mode.HOVER); + this.scatter.clickOnPoint(null); + this.updateMenuButtons(); + this.selectionWasUpdated(); + }); + } + + private updateSelection(selectedPoints: number[]) { + // If no points are selected, unselect everything. + if (!selectedPoints.length) { + this.selectedPoints = []; + this.updateNNList([]); + } + // If only one point is selected, we want to get its nearest neighbors + // and change the UI accordingly. + else if (selectedPoints.length === 1) { + let selectedPoint = selectedPoints[0]; + this.showTab('inspector'); + let neighbors = this.findNeighbors(selectedPoint); + this.selectedPoints = [selectedPoint].concat(neighbors.map(n => n.index)); + this.updateNNList(neighbors); + } + // Otherwise, select all points and hide nearest neighbors list. + else { + this.selectedPoints = selectedPoints as number[]; + this.highlightedPoints = []; + this.updateNNList([]); + } + this.updateMetadata(); + this.selectionWasUpdated(); + } + + private showPCA(callback?: () => void) { + this.currentDataSet.projectPCA().then(() => { + this.scatter.showTickLabels(false); + let x = this.pcaX; + let y = this.pcaY; + let z = this.pcaZ; + let hasZ = dimension == 3; + this.scatter.setXAccessor(i => this.points[i].projections['pca-' + x]); + this.scatter.setYAccessor(i => this.points[i].projections['pca-' + y]); + this.scatter.setZAccessor( + hasZ ? (i => this.points[i].projections['pca-' + z]) : null); + this.scatter.setAxisLabels('pca-' + x, 'pca-' + y); + this.scatter.update(); + if (callback) { + callback(); + } + }); + } + + private showTab(id: string) { + let tab = this.dom.select('.ink-tab[data-tab="' + id + '"]'); + let pane = + d3.select((tab.node() as HTMLElement).parentNode.parentNode.parentNode); + pane.selectAll('.ink-tab').classed('active', false); + tab.classed('active', true); + pane.selectAll('.ink-panel-content').classed('active', false); + pane.select('.ink-panel-content[data-panel="' + id + '"]') + .classed('active', true); + if (id === 'pca') { + this.showPCA(() => this.scatter.recreateScene()); + } else if (id === 'tsne') { + this.showTSNE(); + } else if (id === 'custom') { + this.showCustom(); + } + } + + private showCustom() { + this.scatter.showTickLabels(true); + let xDir = vector.sub(this.centroids.xRight, this.centroids.xLeft); + this.currentDataSet.projectLinear(xDir, 'linear-x'); + this.scatter.setXAccessor(i => this.points[i].projections['linear-x']); + + let yDir = vector.sub(this.centroids.yUp, this.centroids.yDown); + this.currentDataSet.projectLinear(yDir, 'linear-y'); + this.scatter.setYAccessor(i => this.points[i].projections['linear-y']); + + // Scatter is only in 2D in projection mode. + this.scatter.setZAccessor(null); + + let xLabel = this.centroidValues.xLeft + ' → ' + this.centroidValues.xRight; + let yLabel = this.centroidValues.yUp + ' → ' + this.centroidValues.yDown; + this.scatter.setAxisLabels(xLabel, yLabel); + this.scatter.update(); + this.scatter.recreateScene(); + } + + private get points() { return this.currentDataSet.points; } + + private showTSNE() { + this.scatter.showTickLabels(false); + this.scatter.setXAccessor(i => this.points[i].projections['tsne-0']); + this.scatter.setYAccessor(i => this.points[i].projections['tsne-1']); + this.scatter.setZAccessor( + dimension === 3 ? (i => this.points[i].projections['tsne-2']) : null); + this.scatter.setAxisLabels('tsne-0', 'tsne-1'); + } + + private runTSNE() { + this.currentDataSet.projectTSNE( + perplexity, learningRate, dimension, (iteration: number) => { + if (iteration != null) { + this.dom.select('.run-tsne-iter').text(iteration); + this.scatter.update(); + } + }); + } + + // Updates the displayed metadata for the selected point. + private updateMetadata() { + let metadataContainerElement = this.dom.select('.ink-panel-metadata'); + metadataContainerElement.selectAll('*').remove(); + + let display = false; + if (this.selectedPoints.length >= 1) { + let selectedPoint = this.points[this.selectedPoints[0]]; + + for (let metadataKey in selectedPoint.metadata) { + let rowElement = document.createElement('div'); + rowElement.className = 'ink-panel-metadata-row vz-projector'; + + let keyElement = document.createElement('div'); + keyElement.className = 'ink-panel-metadata-key vz-projector'; + keyElement.textContent = metadataKey; + + let valueElement = document.createElement('div'); + valueElement.className = 'ink-panel-metadata-value vz-projector'; + valueElement.textContent = '' + selectedPoint.metadata[metadataKey]; + + rowElement.appendChild(keyElement); + rowElement.appendChild(valueElement); + + metadataContainerElement.append(function() { + return this.appendChild(rowElement); + }); + } + + display = true; + } + + this.dom.select('.ink-panel-metadata-container') + .style('display', display ? '' : 'none'); + } + + private selectionWasUpdated() { + this.dom.select('#hoverInfo') + .text(`Selected ${this.selectedPoints.length} points`); + let allPoints = + this.highlightedPoints.map(x => x.index).concat(this.selectedPoints); + let stroke = (i: number) => { + return i < this.highlightedPoints.length ? + this.highlightedPoints[i].color : + NN_HIGHLIGHT_COLOR; + }; + let favor = (i: number) => { + return i == 0 || (i < this.highlightedPoints.length ? false : true); + }; + this.scatter.highlightPoints(allPoints, stroke, favor); + this.updateIsolateButton(); + } + + private updateMenuButtons() { + let searchBox = this.dom.select('.control.search-box'); + this.dom.select('.search').classed( + 'selected', this.scatter.getMode() === Mode.SEARCH); + let searchMode = this.scatter.getMode() === Mode.SEARCH; + this.dom.select('.control.search-box') + .style('width', searchMode ? '110px' : null) + .style('margin-right', searchMode ? '10px' : null); + (searchBox.select('input').node() as HTMLInputElement).focus(); + this.dom.select('.selectMode') + .classed('selected', this.scatter.getMode() === Mode.SELECT); + } + + /** + * Finds the nearest neighbors of the currently selected point using the + * currently selected distance method. + */ + private findNeighbors(pointIndex: number): knn.NearestEntry[] { + // Find the nearest neighbors of a particular point. + let neighbors = knn.findKNNofPoint( + this.points, pointIndex, numNN, (d => d.vector), this.selectedDistance); + let result = neighbors.slice(0, numNN); + return result; + } + + /** Updates the nearest neighbors list in the inspector. */ + private updateNNList(neighbors: knn.NearestEntry[]) { + let nnlist = this.dom.select('.nn-list'); + nnlist.html(''); + + if (neighbors.length == 0) { + this.dom.select('#nn-title').text(''); + return; + } + + let selectedPoint = this.points[this.selectedPoints[0]]; + this.dom.select('#nn-title') + .text(selectedPoint != null ? selectedPoint.metadata['label'] : ''); + + let minDist = neighbors.length > 0 ? neighbors[0].dist : 0; + let n = nnlist.selectAll('.neighbor') + .data(neighbors) + .enter() + .append('div') + .attr('class', 'neighbor') + .append('a') + .attr('class', 'neighbor-link'); + + n.append('span') + .attr('class', 'label') + .style('color', d => this.dist2color(d.dist, minDist)) + .text(d => this.points[d.index].metadata['label']); + + n.append('span').attr('class', 'value').text(d => d.dist.toFixed(2)); + + let bar = n.append('div').attr('class', 'bar'); + + bar.append('div') + .attr('class', 'fill') + .style('border-top-color', d => this.dist2color(d.dist, minDist)) + .style('width', d => this.normalizeDist(d.dist, minDist) * 100 + '%'); + + bar.selectAll('.tick') + .data(d3.range(1, 4)) + .enter() + .append('div') + .attr('class', 'tick') + .style('left', d => d * 100 / 4 + '%'); + + n.on('click', d => { this.updateSelection([d.index]); }); + } + + private updateIsolateButton() { + let numPoints = this.selectedPoints.length; + let isolateButton = this.dom.select('.set-filter'); + let clearButton = this.dom.select('button.clear-selection'); + if (numPoints > 1) { + isolateButton.text(`Isolate ${numPoints} points`).style('display', null); + clearButton.style('display', null); + } else { + isolateButton.style('display', 'none'); + clearButton.style('display', 'none'); + } + } + + private getCentroid(pattern: string): CentroidResult { + let accessor = (a: DataPoint) => a.vector; + if (pattern == null) { + return {numMatches: 0}; + } + if (pattern == '') { + if (this.allCentroid == null) { + this.allCentroid = + vector.centroid(this.points, () => true, accessor).centroid; + } + return {centroid: this.allCentroid, numMatches: this.points.length}; + } + + let regExp: RegExp; + let predicate: (a: DataPoint) => boolean; + // Check for a regex. + if (pattern.charAt(0) == '/' && pattern.charAt(pattern.length - 1) == '/') { + pattern = pattern.slice(1, pattern.length - 1); + try { + regExp = new RegExp(pattern, 'i'); + } catch (e) { + return {error: e.message}; + } + predicate = + (a: DataPoint) => { return regExp.test('' + a.metadata['label']); }; + // else does an exact match + } else { + predicate = (a: DataPoint) => { return a.metadata['label'] == pattern; }; + } + return vector.centroid(this.points, predicate, accessor); + } +} + +type CentroidResult = { + centroid?: number[]; numMatches?: number; error?: string +}; + +document.registerElement(Projector.prototype.is, Projector); diff --git a/tensorflow/tensorboard/gulp_tasks/compile.js b/tensorflow/tensorboard/gulp_tasks/compile.js index 4877468ab3..93d3e50c3f 100644 --- a/tensorflow/tensorboard/gulp_tasks/compile.js +++ b/tensorflow/tensorboard/gulp_tasks/compile.js @@ -19,20 +19,67 @@ var typescript = require('typescript'); var gutil = require('gulp-util'); var filter = require('gulp-filter'); var merge = require('merge2'); +var browserify = require('browserify'); +var tsify = require('tsify'); +var source = require('vinyl-source-stream'); +var glob = require('glob').sync; +var concat = require('gulp-concat'); var tsProject = ts.createProject('./tsconfig.json', { typescript: typescript, noExternalResolve: true, // opt-in for faster compilation! }); +/** List of components (and their external deps) that are using es6 modules. */ +var ES6_COMPONENTS = [{ + name: 'vz-projector', + deps: [ + 'd3/d3.min.js', 'weblas/dist/weblas.js', 'three.js/build/three.min.js', + 'three.js/examples/js/controls/OrbitControls.js', + 'numericjs/lib/numeric-1.2.6.js' + ] +}]; module.exports = function() { + // Compile all components that are using ES6 modules into a bundle.js + // using browserify. + var entries = ['typings/index.d.ts']; + var deps = {}; + ES6_COMPONENTS.forEach(function(component) { + // Collect all the typescript files across the components. + entries = entries.concat(glob( + 'components/' + component.name + '/**/*.ts', + // Do not include tests. + {ignore: 'components/' + component.name + '/**/*_test.ts'})); + // Collect the unique external deps across all components using es6 modules. + component.deps.forEach(function(dep) { deps['components/' + dep] = true; }); + }); + deps = Object.keys(deps); + + // Compile, bundle all the typescript files and prepend their deps. + browserify(entries) + .plugin(tsify) + .bundle() + .on('error', function(error) { console.error(error.toString()); }) + .pipe(source('app.js')) + .pipe(gulp.dest('components')) + .on('end', function() { + // Typescript was compiled and bundled. Now we need to prepend + // the external dependencies. + gulp.src(deps.concat(['components/app.js'])) + .pipe(concat('bundle.js')) + .pipe(gulp.dest('components')); + }); + + // Compile components that are using global namespaces producing 1 js file + // for each ts file. var isComponent = filter([ - 'components/tf-*/**/*.ts', - 'components/vz-*/**/*.ts', - 'typings/**/*.ts', + 'components/tf-*/**/*.ts', 'components/vz-*/**/*.ts', 'typings/**/*.ts', 'components/plottable/plottable.d.ts' - ]); + // Ignore components that use es6 modules. + ].concat(ES6_COMPONENTS.map(function(component) { + return '!components/' + component.name + '/**/*.ts'; + }))); return tsProject.src() .pipe(isComponent) diff --git a/tensorflow/tensorboard/package.json b/tensorflow/tensorboard/package.json index ce104cb2f8..8c8f4ac5dc 100644 --- a/tensorflow/tensorboard/package.json +++ b/tensorflow/tensorboard/package.json @@ -13,9 +13,14 @@ "author": "Google", "license": "Apache-2.0", "devDependencies": { + "browserify": "^13.1.0", "gulp": "~3.9.0", + "gulp-bower": "0.0.13", "gulp-cli": "^1.1.0", + "gulp-concat": "^2.6.0", "gulp-filter": "~3.0.1", + "gulp-header": "~1.7.1", + "gulp-rename": "~1.2.2", "gulp-replace": "~0.5.4", "gulp-server-livereload": "~1.5.4", "gulp-tslint": "~4.2.2", @@ -24,13 +29,12 @@ "gulp-vulcanize": "~6.1.0", "merge2": "~0.3.6", "minimist": "~1.2.0", + "tsify": "^0.15.6", "tslint": "^3.2.1", - "typescript": "1.8.0", + "typescript": "^2.0.0", + "typings": "~1.0.4", + "vinyl-source-stream": "^1.1.0", "vulcanize": "^1.14.0", - "web-component-tester": "4.2.2", - "gulp-header": "~1.7.1", - "gulp-rename": "~1.2.2", - "gulp-bower": "0.0.13", - "typings": "~1.0.4" + "web-component-tester": "4.2.2" } } diff --git a/tensorflow/tensorboard/tsconfig.json b/tensorflow/tensorboard/tsconfig.json index 1ecec4f922..e51e70f848 100644 --- a/tensorflow/tensorboard/tsconfig.json +++ b/tensorflow/tensorboard/tsconfig.json @@ -2,7 +2,8 @@ "compilerOptions": { "noImplicitAny": false, "noEmitOnError": true, - "target": "ES5" + "target": "ES5", + "module": "commonjs" }, "compileOnSave": false, "exclude": [ diff --git a/tensorflow/tensorboard/typings.json b/tensorflow/tensorboard/typings.json index 7e7e7f9e88..4679f7acfe 100644 --- a/tensorflow/tensorboard/typings.json +++ b/tensorflow/tensorboard/typings.json @@ -9,6 +9,7 @@ "mocha": "registry:dt/mocha#2.2.5+20160317120654", "polymer": "registry:dt/polymer#1.1.6+20160317120654", "sinon": "registry:dt/sinon#1.16.0+20160517064723", + "three": "registry:dt/three#0.0.0+20160802154944", "webcomponents.js": "registry:dt/webcomponents.js#0.6.0+20160317120654" } } |