aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tensorboard/tensorboard_handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/tensorboard/tensorboard_handler.py')
-rw-r--r--tensorflow/tensorboard/tensorboard_handler.py379
1 files changed, 379 insertions, 0 deletions
diff --git a/tensorflow/tensorboard/tensorboard_handler.py b/tensorflow/tensorboard/tensorboard_handler.py
new file mode 100644
index 0000000000..cd50f43069
--- /dev/null
+++ b/tensorflow/tensorboard/tensorboard_handler.py
@@ -0,0 +1,379 @@
+"""TensorBoard server handler logic.
+
+TensorboardHandler contains all the logic for serving static files off of disk
+and for handling the API calls to endpoints like /tags that require information
+about loaded events.
+"""
+
+import BaseHTTPServer
+import csv
+import gzip
+import imghdr
+import json
+import mimetypes
+import os
+import StringIO
+import urllib
+import urlparse
+
+from google.protobuf import text_format
+import tensorflow.python.platform
+
+from tensorflow.python.platform import logging
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.summary import event_accumulator
+from tensorflow.tensorboard import float_wrapper
+
+RUNS_ROUTE = '/runs'
+SCALARS_ROUTE = '/' + event_accumulator.SCALARS
+IMAGES_ROUTE = '/' + event_accumulator.IMAGES
+HISTOGRAMS_ROUTE = '/' + event_accumulator.HISTOGRAMS
+COMPRESSED_HISTOGRAMS_ROUTE = '/' + event_accumulator.COMPRESSED_HISTOGRAMS
+INDIVIDUAL_IMAGE_ROUTE = '/individualImage'
+GRAPH_ROUTE = '/' + event_accumulator.GRAPH
+
+_IMGHDR_TO_MIMETYPE = {
+ 'bmp': 'image/bmp',
+ 'gif': 'image/gif',
+ 'jpeg': 'image/jpeg',
+ 'png': 'image/png'
+}
+_DEFAULT_IMAGE_MIMETYPE = 'application/octet-stream'
+
+
+def _content_type_for_image(encoded_image_string):
+ image_type = imghdr.what(None, encoded_image_string)
+ return _IMGHDR_TO_MIMETYPE.get(image_type, _DEFAULT_IMAGE_MIMETYPE)
+
+
+class _OutputFormat(object):
+ """An enum used to list the valid output formats for API calls.
+
+ Not all API calls support all formats (for example, only scalars and
+ compressed histograms support CSV).
+ """
+ JSON = 'json'
+ CSV = 'csv'
+
+
+class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
+ """Handler class for use with BaseHTTPServer.HTTPServer.
+
+ This is essentially a thin wrapper around calls to an EventMultiplexer object
+ as well as serving files off disk.
+ """
+
+ def __init__(self, multiplexer, *args):
+ self._multiplexer = multiplexer
+ BaseHTTPServer.BaseHTTPRequestHandler.__init__(self, *args)
+
+ # We use underscore_names for consistency with inherited methods.
+
+ def _image_response_for_run(self, run_images, run, tag):
+ """Builds a JSON-serializable object with information about run_images.
+
+ Args:
+ run_images: A list of event_accumulator.ImageValueEvent objects.
+ run: The name of the run.
+ tag: The name of the tag the images all belong to.
+
+ Returns:
+ A list of dictionaries containing the wall time, step, URL, width, and
+ height for each image.
+ """
+ response = []
+ for index, run_image in enumerate(run_images):
+ response.append({
+ 'wall_time': run_image.wall_time,
+ 'step': run_image.step,
+ # We include the size so that the frontend can add that to the <img>
+ # tag so that the page layout doesn't change when the image loads.
+ 'width': run_image.width,
+ 'height': run_image.height,
+ 'query': self._query_for_individual_image(run, tag, index)
+ })
+ return response
+
+ def _path_is_safe(self, path):
+ """Check path is safe (stays within current directory).
+
+ This is for preventing directory-traversal attacks.
+
+ Args:
+ path: The path to check for safety.
+
+ Returns:
+ True if the given path stays within the current directory, and false
+ if it would escape to a higher directory. E.g. _path_is_safe('index.html')
+ returns true, but _path_is_safe('../../../etc/password') returns false.
+ """
+ base = os.path.abspath(os.curdir)
+ absolute_path = os.path.abspath(path)
+ prefix = os.path.commonprefix([base, absolute_path])
+ return prefix == base
+
+ def _send_gzip_response(self, content, content_type, code=200):
+ """Writes the given content as gzip response using the given content type.
+
+ Args:
+ content: The content to respond with.
+ content_type: The mime type of the content.
+ code: The numeric HTTP status code to use.
+ """
+ out = StringIO.StringIO()
+ f = gzip.GzipFile(fileobj=out, mode='w')
+ f.write(content)
+ f.close()
+ gzip_content = out.getvalue()
+ self.send_response(code)
+ self.send_header('Content-Type', content_type)
+ self.send_header('Content-Length', len(gzip_content))
+ self.send_header('Content-Encoding', 'gzip')
+ self.end_headers()
+ self.wfile.write(gzip_content)
+
+ def _send_json_response(self, obj, code=200):
+ """Writes out the given object as JSON using the given HTTP status code.
+
+ This also replaces special float values with stringified versions.
+
+ Args:
+ obj: The object to respond with.
+ code: The numeric HTTP status code to use.
+ """
+
+ output = json.dumps(float_wrapper.WrapSpecialFloats(obj))
+
+ self.send_response(code)
+ self.send_header('Content-Type', 'application/json')
+ self.send_header('Content-Length', len(output))
+ self.end_headers()
+ self.wfile.write(output)
+
+ def _send_csv_response(self, serialized_csv, code=200):
+ """Writes out the given string, which represents CSV data.
+
+ Unlike _send_json_response, this does *not* perform the CSV serialization
+ for you. It only sets the proper headers.
+
+ Args:
+ serialized_csv: A string containing some CSV data.
+ code: The numeric HTTP status code to use.
+ """
+
+ self.send_response(code)
+ self.send_header('Content-Type', 'text/csv')
+ self.send_header('Content-Length', len(serialized_csv))
+ self.end_headers()
+ self.wfile.write(serialized_csv)
+
+ def _serve_scalars(self, query_params):
+ """Given a tag and single run, return array of ScalarEvents."""
+ # TODO(cassandrax): return HTTP status code for malformed requests
+ tag = query_params.get('tag')
+ run = query_params.get('run')
+ values = self._multiplexer.Scalars(run, tag)
+
+ if query_params.get('format') == _OutputFormat.CSV:
+ string_io = StringIO.StringIO()
+ writer = csv.writer(string_io)
+ writer.writerow(['Wall time', 'Step', 'Value'])
+ writer.writerows(values)
+ self._send_csv_response(string_io.getvalue())
+ else:
+ self._send_json_response(values)
+
+ def _serve_graph(self, query_params):
+ """Given a single run, return the graph definition in json format."""
+ run = query_params.get('run', None)
+ if run is None:
+ self.send_error(400, 'query parameter "run" is required')
+ return
+
+ try:
+ graph = self._multiplexer.Graph(run)
+ except ValueError:
+ self.send_response(404)
+ return
+
+ # Serialize the graph to pbtxt format.
+ graph_pbtxt = text_format.MessageToString(graph)
+ # Gzip it and send it to the user.
+ self._send_gzip_response(graph_pbtxt, 'text/plain')
+
+ def _serve_histograms(self, query_params):
+ """Given a tag and single run, return an array of histogram values."""
+ tag = query_params.get('tag')
+ run = query_params.get('run')
+ values = self._multiplexer.Histograms(run, tag)
+ self._send_json_response(values)
+
+ def _serve_compressed_histograms(self, query_params):
+ """Given a tag and single run, return an array of compressed histograms."""
+ tag = query_params.get('tag')
+ run = query_params.get('run')
+ compressed_histograms = self._multiplexer.CompressedHistograms(run, tag)
+ if query_params.get('format') == _OutputFormat.CSV:
+ string_io = StringIO.StringIO()
+ writer = csv.writer(string_io)
+
+ # Build the headers; we have two columns for timing and two columns for
+ # each compressed histogram bucket.
+ headers = ['Wall time', 'Step']
+ if compressed_histograms:
+ bucket_count = len(compressed_histograms[0].compressed_histogram_values)
+ for i in xrange(bucket_count):
+ headers += ['Edge %d basis points' % i, 'Edge %d value' % i]
+ writer.writerow(headers)
+
+ for compressed_histogram in compressed_histograms:
+ row = [compressed_histogram.wall_time, compressed_histogram.step]
+ for value in compressed_histogram.compressed_histogram_values:
+ row += [value.rank_in_bps, value.value]
+ writer.writerow(row)
+ self._send_csv_response(string_io.getvalue())
+ else:
+ self._send_json_response(compressed_histograms)
+
+ def _serve_images(self, query_params):
+ """Given a tag and list of runs, serve a list of images.
+
+ Note that the images themselves are not sent; instead, we respond with URLs
+ to the images. The frontend should treat these URLs as opaque and should not
+ try to parse information about them or generate them itself, as the format
+ may change.
+
+ Args:
+ query_params: The query parameters as a dict.
+ """
+ tag = query_params.get('tag')
+ run = query_params.get('run')
+
+ images = self._multiplexer.Images(run, tag)
+ response = self._image_response_for_run(images, run, tag)
+ self._send_json_response(response)
+
+ def _serve_image(self, query_params):
+ """Serves an individual image."""
+ tag = query_params.get('tag')
+ run = query_params.get('run')
+ index = int(query_params.get('index'))
+ image = self._multiplexer.Images(run, tag)[index]
+ encoded_image_string = image.encoded_image_string
+ content_type = _content_type_for_image(encoded_image_string)
+
+ self.send_response(200)
+ self.send_header('Content-Type', content_type)
+ self.send_header('Content-Length', len(encoded_image_string))
+ self.end_headers()
+ self.wfile.write(encoded_image_string)
+
+ def _query_for_individual_image(self, run, tag, index):
+ """Builds a URL for accessing the specified image.
+
+ This should be kept in sync with _serve_image. Note that the URL is *not*
+ guaranteed to always return the same image, since images may be unloaded
+ from the reservoir as new images come in.
+
+ Args:
+ run: The name of the run.
+ tag: The tag.
+ index: The index of the image. Negative values are OK.
+
+ Returns:
+ A string representation of a URL that will load the index-th
+ sampled image in the given run with the given tag.
+ """
+ query_string = urllib.urlencode({
+ 'run': run,
+ 'tag': tag,
+ 'index': index
+ })
+ return query_string
+
+ def _serve_runs(self, unused_query_params):
+ """Return a JSON object about runs and tags.
+
+ Returns a mapping from runs to tagType to list of tags for that run.
+
+ Returns:
+ {runName: {images: [tag1, tag2, tag3],
+ scalars: [tagA, tagB, tagC],
+ histograms: [tagX, tagY, tagZ]}}
+ """
+ self._send_json_response(self._multiplexer.Runs())
+
+ def _serve_index(self, unused_query_params):
+ """Serves the index page (i.e., the tensorboard app itself)."""
+ self._serve_static_file('/dist/index.html')
+
+ def _serve_static_file(self, path):
+ """Serves the static file located at the given path.
+
+ Args:
+ path: The path of the static file, relative to the tensorboard/ directory.
+ """
+ # Strip off the leading forward slash.
+ path = path.lstrip('/')
+ if not self._path_is_safe(path):
+ logging.info('path %s not safe, sending 404' % path)
+ # Traversal attack, so 404.
+ self.send_error(404)
+ return
+
+ if path.startswith('external'):
+ path = os.path.join('../', path)
+ else:
+ path = os.path.join('tensorboard', path)
+ # Open the file and read it.
+ try:
+ contents = resource_loader.load_resource(path)
+ except IOError:
+ logging.info('path %s not found, sending 404' % path)
+ self.send_error(404)
+ return
+
+ self.send_response(200)
+
+ mimetype = mimetypes.guess_type(path)[0] or 'application/octet-stream'
+ self.send_header('Content-Type', mimetype)
+ self.end_headers()
+ self.wfile.write(contents)
+
+ def do_GET(self): # pylint: disable=invalid-name
+ """Handler for all get requests."""
+ parsed_url = urlparse.urlparse(self.path)
+
+ # Remove a trailing slash, if present.
+ clean_path = parsed_url.path
+ if clean_path.endswith('/'):
+ clean_path = clean_path[:-1]
+
+ handlers = {
+ SCALARS_ROUTE: self._serve_scalars,
+ GRAPH_ROUTE: self._serve_graph,
+ HISTOGRAMS_ROUTE: self._serve_histograms,
+ COMPRESSED_HISTOGRAMS_ROUTE: self._serve_compressed_histograms,
+ IMAGES_ROUTE: self._serve_images,
+ INDIVIDUAL_IMAGE_ROUTE: self._serve_image,
+ RUNS_ROUTE: self._serve_runs,
+ '': self._serve_index
+ }
+
+ if clean_path in handlers:
+ query_params = urlparse.parse_qs(parsed_url.query)
+ # parse_qs returns a list of values for each key; we're only interested in
+ # the first.
+ for key in query_params:
+ value_count = len(query_params[key])
+ if value_count != 1:
+ self.send_error(
+ 400,
+ 'query parameter %s should have exactly one value, had %d' %
+ (key, value_count))
+ return
+
+ query_params[key] = query_params[key][0]
+ handlers[clean_path](query_params)
+ else:
+ self._serve_static_file(clean_path)