diff options
-rw-r--r-- | tensorflow/python/saved_model/builder.py | 11 | ||||
-rw-r--r-- | tensorflow/tensorboard/components/vz-projector/data-loader.ts | 38 | ||||
-rw-r--r-- | tensorflow/tensorboard/components/vz-projector/data.ts | 6 | ||||
-rw-r--r-- | tensorflow/tensorboard/components/vz-projector/knn.ts | 5 | ||||
-rw-r--r-- | tensorflow/tensorboard/components/vz-projector/logging.ts (renamed from tensorflow/tensorboard/components/vz-projector/async.ts) | 62 | ||||
-rw-r--r-- | tensorflow/tensorboard/components/vz-projector/util.ts | 41 | ||||
-rw-r--r-- | tensorflow/tensorboard/components/vz-projector/vz-projector.ts | 2 |
7 files changed, 96 insertions, 69 deletions
diff --git a/tensorflow/python/saved_model/builder.py b/tensorflow/python/saved_model/builder.py index 4c3d04ac61..4f3e8e138e 100644 --- a/tensorflow/python/saved_model/builder.py +++ b/tensorflow/python/saved_model/builder.py @@ -28,6 +28,7 @@ from google.protobuf.any_pb2 import Any from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import saved_model_pb2 +from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.lib.io import file_io @@ -252,7 +253,10 @@ class SavedModelBuilder(object): # Save asset files, if any. self._save_assets(assets_collection) - saver = tf_saver.Saver(variables.all_variables(), sharded=True) + saver = tf_saver.Saver( + variables.all_variables(), + sharded=True, + write_version=saver_pb2.SaverDef.V1) meta_graph_def = saver.export_meta_graph() # Tag the meta graph def and add it to the SavedModel. @@ -298,7 +302,10 @@ class SavedModelBuilder(object): compat.as_text(constants.VARIABLES_FILENAME)) # Save the variables and export meta graph def. - saver = tf_saver.Saver(variables.all_variables(), sharded=True) + saver = tf_saver.Saver( + variables.all_variables(), + sharded=True, + write_version=saver_pb2.SaverDef.V1) saver.save(sess, variables_path, write_meta_graph=False) meta_graph_def = saver.export_meta_graph() diff --git a/tensorflow/tensorboard/components/vz-projector/data-loader.ts b/tensorflow/tensorboard/components/vz-projector/data-loader.ts index 0db06c5a0b..ed1589c574 100644 --- a/tensorflow/tensorboard/components/vz-projector/data-loader.ts +++ b/tensorflow/tensorboard/components/vz-projector/data-loader.ts @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import {runAsyncTask, updateMessage} from './async'; +import {runAsyncTask} from './util'; +import * as logging from './logging'; import {ColumnStats, DataPoint, DataSet, DatasetMetadata, MetadataInfo, PointMetadata, State, DataProto} from './data'; @@ -87,9 +88,9 @@ class ServerDataProvider implements DataProvider { } retrieveRuns(callback: (runs: string[]) => void): void { - let msgId = updateMessage('Fetching runs...'); + let msgId = logging.setModalMessage('Fetching runs...'); d3.json(`${this.routePrefix}/runs`, (err, runs) => { - updateMessage(null, msgId); + logging.setModalMessage(null, msgId); callback(runs); }); } @@ -101,9 +102,9 @@ class ServerDataProvider implements DataProvider { return; } - let msgId = updateMessage('Fetching checkpoint info...'); + let msgId = logging.setModalMessage('Fetching checkpoint info...'); d3.json(`${this.routePrefix}/info?run=${run}`, (err, checkpointInfo) => { - updateMessage(null, msgId); + logging.setModalMessage(null, msgId); this.runCheckpointInfoCache[run] = checkpointInfo; callback(checkpointInfo); }); @@ -111,7 +112,7 @@ class ServerDataProvider implements DataProvider { retrieveTensor(run: string, tensorName: string, callback: (ds: DataSet) => void) { // Get the tensor. - updateMessage('Fetching tensor values...', TENSORS_MSG_ID); + logging.setModalMessage('Fetching tensor values...', TENSORS_MSG_ID); d3.text( `${this.routePrefix}/tensor?run=${run}&name=${tensorName}`, (err: Error, tsv: string) => { @@ -127,7 +128,7 @@ class ServerDataProvider implements DataProvider { retrieveMetadata(run: string, tensorName: string, callback: (r: MetadataInfo) => void) { - updateMessage('Fetching metadata...', METADATA_MSG_ID); + logging.setModalMessage('Fetching metadata...', METADATA_MSG_ID); d3.text( `${this.routePrefix}/metadata?run=${run}&name=${tensorName}`, (err: Error, rawMetadata: string) => { @@ -156,11 +157,11 @@ class ServerDataProvider implements DataProvider { getBookmarks( run: string, tensorName: string, callback: (r: State[]) => void) { - let msgId = updateMessage('Fetching bookmarks...'); + let msgId = logging.setModalMessage('Fetching bookmarks...'); d3.json( `${this.routePrefix}/bookmarks?run=${run}&name=${tensorName}`, (err, bookmarks) => { - updateMessage(null, msgId); + logging.setModalMessage(null, msgId); if (!err) { callback(bookmarks as State[]); } @@ -320,18 +321,19 @@ function parseTensors(content: string, delim = '\t'): Promise<DataPoint[]> { numDim = dataPoint.vector.length; } if (numDim !== dataPoint.vector.length) { - updateMessage('Parsing failed. Vector dimensions do not match'); + logging.setModalMessage( + 'Parsing failed. Vector dimensions do not match'); throw Error('Parsing failed'); } if (numDim <= 1) { - updateMessage( + logging.setModalMessage( 'Parsing failed. Found a vector with only one dimension?'); throw Error('Parsing failed'); } }); return data; }, TENSORS_MSG_ID).then(dataPoints => { - updateMessage(null, TENSORS_MSG_ID); + logging.setModalMessage(null, TENSORS_MSG_ID); return dataPoints; }); } @@ -417,7 +419,7 @@ function parseMetadata(content: string): Promise<MetadataInfo> { pointsInfo: pointsMetadata } as MetadataInfo; }, METADATA_MSG_ID).then(metadata => { - updateMessage(null, METADATA_MSG_ID); + logging.setModalMessage(null, METADATA_MSG_ID); return metadata; }); } @@ -518,11 +520,11 @@ class DemoDataProvider implements DataProvider { let demoDataSet = DemoDataProvider.DEMO_DATASETS[tensorName]; let separator = demoDataSet.fpath.substr(-3) === 'tsv' ? '\t' : ' '; let url = `${DemoDataProvider.DEMO_FOLDER}/${demoDataSet.fpath}`; - updateMessage('Fetching tensors...', TENSORS_MSG_ID); + logging.setModalMessage('Fetching tensors...', TENSORS_MSG_ID); d3.text(url, (error: Error, dataString: string) => { if (error) { console.error(error); - updateMessage('Error loading data.'); + logging.setModalMessage('Error loading data.'); return; } parseTensors(dataString, separator).then(points => { @@ -537,7 +539,7 @@ class DemoDataProvider implements DataProvider { let dataSetPromise: Promise<MetadataInfo> = null; if (demoDataSet.metadata_path) { dataSetPromise = new Promise<MetadataInfo>((resolve, reject) => { - updateMessage('Fetching metadata...', METADATA_MSG_ID); + logging.setModalMessage('Fetching metadata...', METADATA_MSG_ID); d3.text( `${DemoDataProvider.DEMO_FOLDER}/${demoDataSet.metadata_path}`, (err: Error, rawMetadata: string) => { @@ -554,7 +556,7 @@ class DemoDataProvider implements DataProvider { let spritesPromise: Promise<HTMLImageElement> = null; if (demoDataSet.metadata && demoDataSet.metadata.image) { let spriteFilePath = demoDataSet.metadata.image.sprite_fpath; - spriteMsgId = updateMessage('Fetching sprite image...'); + spriteMsgId = logging.setModalMessage('Fetching sprite image...'); spritesPromise = fetchImage(`${DemoDataProvider.DEMO_FOLDER}/${spriteFilePath}`); } @@ -562,7 +564,7 @@ class DemoDataProvider implements DataProvider { // Fetch the metadata and the image in parallel. Promise.all([dataSetPromise, spritesPromise]).then(values => { if (spriteMsgId) { - updateMessage(null, spriteMsgId); + logging.setModalMessage(null, spriteMsgId); } let [metadata, spriteImage] = values; metadata.spriteImage = spriteImage; diff --git a/tensorflow/tensorboard/components/vz-projector/data.ts b/tensorflow/tensorboard/components/vz-projector/data.ts index 09b0c9cd08..a55bff00dd 100644 --- a/tensorflow/tensorboard/components/vz-projector/data.ts +++ b/tensorflow/tensorboard/components/vz-projector/data.ts @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import {runAsyncTask, updateWarningMessage} from './async'; import {TSNE} from './bh_tsne'; import * as knn from './knn'; import * as scatterPlot from './scatterPlot'; -import {shuffle, getSearchPredicate} from './util'; +import {shuffle, getSearchPredicate, runAsyncTask} from './util'; +import * as logging from './logging'; import * as vector from './vector'; export type DistanceFunction = (a: number[], b: number[]) => number; @@ -327,7 +327,7 @@ export class DataSet implements scatterPlot.DataSet { mergeMetadata(metadata: MetadataInfo) { if (metadata.pointsInfo.length !== this.points.length) { - updateWarningMessage( + logging.setWarningMessage( `Number of tensors (${this.points.length}) do not match` + ` the number of lines in metadata (${metadata.pointsInfo.length}).`); } diff --git a/tensorflow/tensorboard/components/vz-projector/knn.ts b/tensorflow/tensorboard/components/vz-projector/knn.ts index a234bda3f2..3a47dd07b5 100644 --- a/tensorflow/tensorboard/components/vz-projector/knn.ts +++ b/tensorflow/tensorboard/components/vz-projector/knn.ts @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import {runAsyncTask, updateMessage} from './async'; +import {runAsyncTask} from './util'; +import * as logging from './logging'; import {KMin} from './heap'; import * as vector from './vector'; @@ -109,7 +110,7 @@ export function findKNNGPUCosine<T>( if (piece < numPieces) { step(resolve); } else { - updateMessage(null, KNN_GPU_MSG_ID); + logging.setModalMessage(null, KNN_GPU_MSG_ID); bigMatrix.delete(); resolve(nearest); } diff --git a/tensorflow/tensorboard/components/vz-projector/async.ts b/tensorflow/tensorboard/components/vz-projector/logging.ts index 88791e2927..2b01cabc11 100644 --- a/tensorflow/tensorboard/components/vz-projector/async.ts +++ b/tensorflow/tensorboard/components/vz-projector/logging.ts @@ -13,9 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -/** Delay for running async tasks, in milliseconds. */ -const ASYNC_DELAY_MS = 25; - /** Duration in ms for showing warning messages to the user */ const WARNING_DURATION_MS = 5000; @@ -25,43 +22,14 @@ const WARNING_DURATION_MS = 5000; */ const MSG_ANIMATION_DURATION_MSEC = 300 + 20; - -/** - * Runs an expensive task asynchronously with some delay - * so that it doesn't block the UI thread immediately. - * - * @param message The message to display to the user. - * @param task The expensive task to run. - * @param msgId Optional. ID of an existing message. If provided, will overwrite - * an existing message and won't automatically clear the message when the - * task is done. - * @return The value returned by the task. - */ -export function runAsyncTask<T>(message: string, task: () => T, - msgId: string = null): Promise<T> { - let autoClear = (msgId == null); - msgId = updateMessage(message, msgId); - return new Promise<T>((resolve, reject) => { - d3.timer(() => { - try { - let result = task(); - // Clearing the old message. - if (autoClear) { - updateMessage(null, msgId); - } - resolve(result); - } catch (ex) { - updateMessage('Error: ' + ex.message); - reject(ex); - } - return true; - }, ASYNC_DELAY_MS); - }); -} - +let dom: HTMLElement = null; let msgId = 0; let numActiveMessages = 0; +export function setDomContainer(domElement: HTMLElement) { + dom = domElement; +} + /** * Updates the user message with the provided id. * @@ -70,16 +38,21 @@ let numActiveMessages = 0; * is assigned. * @return The id of the message. */ -export function updateMessage(msg: string, id: string = null): string { - let dialog = d3.select('#wrapper-notify-msg').node() as any; +export function setModalMessage(msg: string, id: string = null): string { + if (dom == null) { + console.warn('Can\'t show modal message before the dom is initialized'); + return; + } if (id == null) { id = (msgId++).toString(); } + let dialog = dom.querySelector('#wrapper-notify-msg') as any; + let msgsContainer = dom.querySelector('#notify-msgs') as HTMLElement; let divId = `notify-msg-${id}`; - let msgDiv = d3.select('#' + divId); + let msgDiv = d3.select(dom.querySelector('#' + divId)); let exists = msgDiv.size() > 0; if (!exists) { - msgDiv = d3.select('#notify-msgs').insert('div', ':first-child') + msgDiv = d3.select(msgsContainer).insert('div', ':first-child') .attr('class', 'notify-msg') .attr('id', divId); numActiveMessages++; @@ -102,12 +75,13 @@ export function updateMessage(msg: string, id: string = null): string { /** * Shows a warning message to the user for a certain amount of time. */ -export function updateWarningMessage(msg: string): void { - let warningDiv = d3.select('#warning-msg'); +export function setWarningMessage(msg: string): void { + let warningMsg = dom.querySelector('#warning-msg') as HTMLElement; + let warningDiv = d3.select(warningMsg); warningDiv.style('display', 'block').text('Warning: ' + msg); // Hide the warning message after a certain timeout. setTimeout(() => { warningDiv.style('display', 'none'); }, WARNING_DURATION_MS); -} +}
\ No newline at end of file diff --git a/tensorflow/tensorboard/components/vz-projector/util.ts b/tensorflow/tensorboard/components/vz-projector/util.ts index f729ee16c2..f1712ffcf0 100644 --- a/tensorflow/tensorboard/components/vz-projector/util.ts +++ b/tensorflow/tensorboard/components/vz-projector/util.ts @@ -16,6 +16,14 @@ limitations under the License. import {DataSet} from './scatterPlot'; import {Point2D} from './vector'; import {DataPoint} from './data'; +import * as logging from './logging'; + +/** + * Delay for running expensive tasks, in milliseconds. + * The duration was empirically found so that it leaves enough time for the + * browser to update its UI state before starting an expensive UI-blocking task. + */ +const TASK_DELAY_MS = 25; /** Shuffles the array in-place in O(n) time using Fisher-Yates algorithm. */ export function shuffle<T>(array: T[]): T[] { @@ -122,4 +130,37 @@ export function getSearchPredicate(query: string, inRegexMode: boolean, }; } return predicate; +} + +/** + * Runs an expensive task asynchronously with some delay + * so that it doesn't block the UI thread immediately. + * + * @param message The message to display to the user. + * @param task The expensive task to run. + * @param msgId Optional. ID of an existing message. If provided, will overwrite + * an existing message and won't automatically clear the message when the + * task is done. + * @return The value returned by the task. + */ +export function runAsyncTask<T>(message: string, task: () => T, + msgId: string = null): Promise<T> { + let autoClear = (msgId == null); + msgId = logging.setModalMessage(message, msgId); + return new Promise<T>((resolve, reject) => { + d3.timer(() => { + try { + let result = task(); + // Clearing the old message. + if (autoClear) { + logging.setModalMessage(null, msgId); + } + resolve(result); + } catch (ex) { + logging.setModalMessage('Error: ' + ex.message); + reject(ex); + } + return true; + }, TASK_DELAY_MS); + }); }
\ No newline at end of file diff --git a/tensorflow/tensorboard/components/vz-projector/vz-projector.ts b/tensorflow/tensorboard/components/vz-projector/vz-projector.ts index 957886d360..2556108a2e 100644 --- a/tensorflow/tensorboard/components/vz-projector/vz-projector.ts +++ b/tensorflow/tensorboard/components/vz-projector/vz-projector.ts @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ import {ColorOption, DataProto, DataSet, MetadataInfo, Projection, State} from './data'; +import * as logging from './logging'; import {DataProvider, getDataProvider, ServingMode, TensorInfo} from './data-loader'; import {HoverContext, HoverListener} from './hoverContext'; import * as knn from './knn'; @@ -108,6 +109,7 @@ export class Projector extends ProjectorPolymer implements SelectionContext, this.selectedPointIndices = []; this.neighborsOfFirstPoint = []; this.dom = d3.select(this); + logging.setDomContainer(this); this.dataPanel = this.$['data-panel'] as DataPanel; this.inspectorPanel = this.$['inspector-panel'] as InspectorPanel; this.inspectorPanel.initialize(this); |