aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@andyselle.com>2017-11-14 11:39:36 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-14 11:46:10 -0800
commit7be5ab5ddbfd7d81ffd7e2022633908a14a52ff1 (patch)
treee4ec8ef99ca705020d07582770706cd7a247967c
parent1b0a9ff5e863da4349f100bead90b0032805eafc (diff)
Add tflite documentation
Merged commit includes the following changes: 175703479 by yifeif: Internal change. PiperOrigin-RevId: 175703479 (This is 1 of the 3 commits in from staging c674e27bfd68a6c990e694b6afd901bfeeaa006d)
-rw-r--r--tensorflow/contrib/lite/README.md200
-rw-r--r--tensorflow/contrib/lite/g3doc/TFLite-Architecture.jpgbin0 -> 48710 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/apis.md359
-rw-r--r--tensorflow/contrib/lite/g3doc/custom_operators.md91
-rw-r--r--tensorflow/contrib/lite/g3doc/ios.md67
-rw-r--r--tensorflow/contrib/lite/g3doc/models.md22
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md417
-rw-r--r--tensorflow/contrib/lite/java/demo/README.md36
-rw-r--r--tensorflow/contrib/lite/models/smartreply/g3doc/README.md146
-rw-r--r--tensorflow/contrib/lite/models/testdata/g3doc/README.md102
-rw-r--r--tensorflow/contrib/lite/models/testdata/g3doc/asr_am.svg4
-rwxr-xr-xtensorflow/contrib/lite/models/testdata/g3doc/hotword.svg4
-rwxr-xr-xtensorflow/contrib/lite/models/testdata/g3doc/speakerid.svg4
-rwxr-xr-xtensorflow/contrib/lite/models/testdata/g3doc/tts.svg4
-rw-r--r--tensorflow/contrib/lite/models/testdata/smartreply_samples.tsv50
-rw-r--r--tensorflow/contrib/lite/nnapi/README.md15
-rw-r--r--tensorflow/contrib/lite/toco/README.md26
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md509
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md238
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/python_api.md62
20 files changed, 2356 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md
new file mode 100644
index 0000000000..b173936f5b
--- /dev/null
+++ b/tensorflow/contrib/lite/README.md
@@ -0,0 +1,200 @@
+# TensorFlow Lite
+TensorFlow Lite is TensorFlow’s lightweight solution for mobile and embedded devices. It enables low-latency inference of on-device machine learning models with a small binary size and fast performance supporting hardware acceleration.
+
+TensorFlow Lite uses many techniques for achieving low latency like optimizing the kernels for specific mobile apps, pre-fused activations, quantized kernels that allow smaller and faster (fixed-point math) models, and in the future, leverage specialized machine learning hardware to get the best possible performance for a particular model on a particular device.
+
+![image](g3doc/TFLite-Architecture.jpg)
+# Getting Started with a Demo App
+
+This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using a quantized Mobilenet model. A device running Android 5.0 ( API 21) or higher is required to run the demo.
+
+There are 3 ways to get the demo app to your device
+ - Download the prebuilt binary or
+ - Use Android Studio to build the application or
+ - Download the source code for TensorFlow Lite and the demo and build it using bazel
+
+## Description
+In the demo app, inference is done using the TensorFlow Lite Java API. The demo app classifies frames in real-time, displaying the top most probable classifications. It also displays the time taken to detect the object.
+
+## Downloading the pre-built binary
+The fastest path to trying the demo, is to download the pre-built binary
+[TfLiteCameraDemo.apk](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk)
+
+Once the apk is installed, click the app icon to start the app. The first-time the app is opened, the app asks for runtime permissions to access the device camera. The demo app opens the back-camera of the device and recognizes the objects in the camera’s field of view. At the bottom of the image (or at the left of the image if the device is in landscape mode), it shows the latency of classification and the top three objects classified.
+
+## Building in Android Studio using TensorFlow Lite AAR from JCenter
+The simplest way to compile the demo app, and try out changes to the project code is to use AndroidStudio.
+
+ - Install the latest version of Android Studio 3 as specified [here](https://developer.android.com/studio/index.html).
+ - Make sure the Android SDK version is greater than 26 and NDK version is greater than 14 (in the Android Studio Settings).
+ - Import the tensorflow/contrib/lite/java/demo directory as a new Android Studio project.
+ - Click through installing all the Gradle extensions it requests.
+ - Download the quantized Mobilenet TensorFlow Lite model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip)
+ - unzip and copy mobilenet_quant_v1_224.tflite to the assets directory:
+ tensorflow/contrib/lite/java/demo/app/src/main/assets/
+ - Build and run the demo app
+
+## Building TensorFlow Lite and the demo app from source
+
+### Clone the TensorFlow repo
+- git clone
+ [https://github.com/tensorflow/tensorflow](https://github.com/tensorflow/tensorflow)
+
+### Install Bazel
+If bazel is not installed on your system, install it now by following [these directions](https://bazel.build/versions/master/docs/install.html)
+
+NOTE: Bazel does not currently support building for Android on Windows. Full support for gradle/cmake builds is coming soon, but in the meantime Windows users should download the [prebuilt binary](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/demo/TfLiteCameraDemo.apk) instead.
+
+### Install Android NDK and SDK
+Bazel is the primary build system for TensorFlow. Bazel and the Android NDK and SDK must be installed on your system.
+ - Install the latest version of Bazel as per the instructions on the [Bazel website](https://bazel.build/versions/master/docs/install.html)
+ - The Android NDK is required to build the native (C/C++) TensorFlow code. The current recommended version is 14b, which may be found [here](https://developer.android.com/tools/revisions/build-tools.html).
+ - The Android SDK and build tools may be obtained [here](https://developer.android.com/tools/revisions/build-tools.html), or alternatively as part of [Android Studio](https://developer.android.com/studio/index.html). Build tools API >= 23 is required to build the TensorFlow Android demo (though it will run on API >= 21 devices).
+
+ - The Android NDK is required to build the native (C/C++) TensorFlow Lite code. The current recommended version is 14b, which can be found [here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads).
+
+ - The Android SDK and build tools may be obtained [here](https://developer.android.com/tools/revisions/build-tools.html), or alternatively as part of [Android Studio](https://developer.android.com/studio/index.html). Build tools API >= 23 is required to build the TF Android demo (though it will run on API >= 21 devices).
+ - In the root of the TensorFlow repository update the `WORKSPACE` file with the `api_level` and location of the SDK and NDK. If you installed it with AndroidStudio the SDK path can be found in the SDK manager, and the default NDK path is:`{SDK path}/ndk-bundle.`
+
+```
+ Android_sdk_repository (
+ name = "androidsdk",
+ api_level = 23,
+ build_tools_version = "23.0.2",
+ path = "/home/xxxx/android-sdk-linux/", )
+
+android_ndk_repository(
+ name="androidndk",
+ path="/home/xxxx/android-ndk-r10e/",
+ api_level=19)
+
+```
+Additional details on building with Android can be found [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/README.md)
+
+### Build the source code
+Run bazel with the following command to build the demo.
+
+Build the demo app:
+bazel build --cxxopt='--std=c++11' //tensorflow/contrib/lite/java/demo/app/src/main:TfLiteCameraDemo
+
+### More about the demo
+The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used. The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch 224 * 224 is the width and height of the image 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. The Mobilenet model has 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The Mobilenet quantized model is bundled within the assets directory of the app.
+
+# TensorFlow Lite Quick Start
+
+## Step 1. Decide which GraphDef to use
+ Depending on the use case, the developer may choose to use one of the popular
+ open-sourced models such as InceptionV3 or MobileNets, re-train these models
+ with their own custom data set or even build their own custom model.
+
+### Using a pre-trained model
+
+[MobileNets](https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html) is a family of mobile-first computer vision models for [TensorFlow](https://www.tensorflow.org/) designed to effectively maximize accuracy while being mindful of the restricted resources for an on-device or embedded application. MobileNets are small, low-latency, low-power models parameterized to meet the resource constraints of a variety of use cases. They can be built upon for classification, detection, embeddings and segmentation similar to how other popular large scale models, such as [Inception](https://arxiv.org/pdf/1602.07261.pdf), are used. Google provides 16 pre-trained [ImageNet](http://www.image-net.org/challenges/LSVRC/) classification checkpoints for MobileNets for use in mobile projects of all sizes.
+
+[Inception-v3](https://arxiv.org/abs/1512.00567) is an image recognition model which achieves fairly high accuracy in recognizing general objects with 1000 classes, like "Zebra", "Dalmatian", and "Dishwasher". The model extracts general features from input images using a convolutional neural network and classifies them based on those features with fully-connected and softmax layers.
+
+[On Device Smart Reply](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html) is an on-device model which provides one-touch replies for an incoming text message by suggesting contextually relevant messages. The model is built specifically for memory constrained devices such as watches & phones and it has been successfully used to surface [Smart Replies on Android Wear](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html). Note that this model only works on Android as of now.
+
+These pre-trained models can be downloaded from [here](models.md).
+
+### Retrain Inception-V3 or MobileNet for a custom data set
+The above pre-trained models have been trained on the ImageNet data set, which consists of 1000 predefined classes. A model will need to be re-trained if these classes are not relevant or useful for a given use case. This technique is called transfer learning, which starts with a model that has been already trained on a problem and will then be retrained on a similar problem. Deep learning from scratch can take days, but transfer learning can be done fairly quickly. In order to do this, a developer will need to generate their custom data set labeled with the relevant classes.
+
+The [TensorFlow for Poets](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/) codelab walks through this process step-by-step. The retraining code supports retraining for both floating point and quantized inference.
+
+
+### Train a custom model
+A developer may choose to train a custom model using Tensorflow. TensorFlow documentation has [several tutorials](https://www.tensorflow.org/tutorials/) for building and training models. If the user has written a model using TensorFlow’s Slim Framework the first step is to export this to a GraphDef file. This is necessary because Slim does not store the model structure outside the code, so to communicate with other parts of the framework it needs to be exported. Documentation for the export can be found [here](https://github.com/tensorflow/models/tree/master/research/slim#Export). The output of this step will be a .pb file for the custom model.
+
+TensorFlow Lite currently supports a subset of TensorFlow operators. Please refer to [this document](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for details of supported operators and their usage. This
+set will continue to expand in future releases of Tensorflow Lite.
+
+
+## Step 2. Model format conversion
+
+The model generated in Step 1 is a standard Tensorflow model. After the completion of Step 1 a user should have a standard .pb or .pbtxt GraphDef file. If the application developer is using a pre-trained model (as defined in Step 1 above), they can download a ready to use, already converted model for use from [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/models.md). Models generated using retraining (aka transfer learning) or custom models will need to be converted using the steps mentioned below.
+
+A prerequisite to converting the model to the Tensorflow Lite format is to freeze the graph.
+
+Since we employ several formats, the following definitions may be useful:
+ - GraphDef (.pb) - a protobuf that represents the TensorFlow training and or computation graph. This contains operators, tensors, and variables definitions.
+
+ - CheckPoint (.ckpt) - Serialized variables from a TensorFlow graph. Note, this does not contain the graph structure, so alone it cannot typically be interpreted.
+
+ - FrozenGraphDef - a subclass of GraphDef that contains no variables. A GraphDef can be converted to a frozen graphdef by taking a checkpoint and a graphdef and converting every variable into a constant with the value looked up in the checkpoint.
+
+ - SavedModel - A collection of GraphDef and CheckPoint together with a signature that labels input and output arguments to a model. A GraphDef and Checkpoint can be extracted from a saved model.
+
+ - TensorFlow lite model (.lite) - a serialized flatbuffer, containing TensorFlow lite operators and Tensors for the TensorFlow lite interpreter. This is most analogous to TensorFlow frozen GraphDefs.
+
+### Freeze Graph
+To use this .pb GraphDef file within TensorFlow Lite, the application developer will need checkpoints containing trained weight parameters. The .pb contains only the structure of the graph. The process of merging the checkpoint values with the graph structure is known as “freezing” the graph.
+
+The developer should know where the checkpoints folder is present or checkpoints can also be downloaded for a pre-trained model (Example: Here is a link to the [MobileNets](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md)
+
+Graph freezing can be done using the command below (and modifying the arguments appropriately)
+
+```
+bazel build tensorflow/python/tools:freeze_graph
+
+bazel-bin/tensorflow/python/tools/freeze_graph\
+ --input_graph=/tmp/mobilenet_v1_224.pb \
+ --input_checkpoint=/tmp/checkpoints/mobilenet-10202.ckpt \
+ --input_binary=true --output_graph=/tmp/frozen_mobilenet_v1_224.pb \
+ --output_node_names=MobileNet/Predictions/Reshape_1
+```
+
+The user has to first build the freeze_graph script using bazel and then run the script. The input_binary flag has to be enabled to ensure that the protobuf is read and written in binary format. The user has to input the .pb and the .ckpt files to freeze the graph The output_node_names may not be obvious outside of the code that built the model. The easiest way to find them is to visualize the graph, either with
+graphviz, or [in tensorboard](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2/#3).
+
+This frozen Graphdef is now ready to be converted to flatbuffer format (.lite) for use on Android or iOS. On Android users have the flexibility to use either the float or quantized versions of the frozen graphdef, if available, using the Tensorflow Optimizing Converter tool.
+
+Here is a sample command line to convert the frozen Graphdef to '.lite' format for The Tensorflow Optimizing Converter supports both float and quantized models, however, different configuration parameters are needed depending on whether a FLOAT or QUANTIZED mode is being used.
+
+```
+bazel build tensorflow/contrib/lite/toco:toco
+
+bazel run --config=opt tensorflow/contrib/lite/toco:toco -- \
+ --input_file=(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
+ --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \
+ --output_file=/tmp/mobilenet_v1_1.0_224.lite --inference_type=FLOAT \
+ --input_type=FLOAT --input_arrays=input \
+ --output_arrays=MobilenetV1/Predictions/Reshape_1 --input_shapes=1,224,224,3
+```
+
+- The input_file argument should point to the frozen GraphDef file that holds the model architecture.
+- The output_file argument should point to where the TensorFlow Lite model file should be generated.
+- The input_type and inference_type arguments should be set to FLOAT, unless converted a [quantized](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/) model.
+- Setting the input_array, output_array and input_shape arguments are a bit trickier. The easiest way to find these values is to explore the graph in tensorboard . The user should reuse the arguments that were used for specifying the output nodes for inference in the `freeze_graph`step.
+
+Note, it is also possible to use the Tensorflow Optimizing Converter through protos either from Python or from the command line see the
+documentation [here](https://github.com/tensorflow/tensorflow/tree/mastertensorflow/contrib/lite/python:toco_from_protos target) A developer can then integrate the conversion step into their model design workflow to ensure that a model will be easily convertible to a mobile inference graph. For example,
+
+```
+import tensorflow as tf
+
+img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
+val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
+out = tf.identity(val, name="out")
+with tf.Session() as sess:
+ tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
+ open("converteds_model.tflite", "wb").write(tflite_model)
+
+```
+For detailed instructions on how to use the Tensorflow Optimizing Converter, please see [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md).
+
+You may refer to the [Ops compatibility guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tf_ops_compatibility.md) for troubleshooting help. If that doesn’t help, please file an [issue](https://github.com/tensorflow/tensorflow/issues).
+
+## Step 3. Use the TensorFlow Lite model for inference in a mobile app
+
+After completion of Step 2 the developer should have a .lite model.
+
+### For Android
+Because Android apps need to be written in Java, and core TensorFlow is in C++, a JNI library is provided to interface between the two. Its interface is aimed only at inference, so it provides the ability to load a graph, set up inputs, and run the model to calculate particular outputs. The full documentation for the set of methods can be seen [here](https://github.com/TensorFlow/TensorFlow/blob/master/TensorFlow/contrib/lite/g3doc/). The demo app is also open sourced on [github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app).
+
+The [demo app] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app) uses this interface, so it’s a good place to look for example usage. You can also download the prebuilt binary [here](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk).
+
+Note that you’d need to follow instructions for installing TensorFlow on Android, setting up bazel and Android Studio outlined [here](https://www.tensorflow.org/mobile/android_build).
+
+### For iOS
+Follow the documentation [here](https://github.com/TensorFlow/TensorFlow/blob/master/TensorFlow/contrib/lite/g3doc/ios.md) to get integrate a TFLite model into your app.
diff --git a/tensorflow/contrib/lite/g3doc/TFLite-Architecture.jpg b/tensorflow/contrib/lite/g3doc/TFLite-Architecture.jpg
new file mode 100644
index 0000000000..bc83946647
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/TFLite-Architecture.jpg
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md
new file mode 100644
index 0000000000..311fc69696
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/apis.md
@@ -0,0 +1,359 @@
+# TensorFlow Lite APIs
+
+TensorFlow Lite provides programming APIs in C++ and Java, and in both cases
+the API design reflects a preference for performance over ease of use.
+TensorFlow Lite is designed for fast inference on small devices so it should be
+no surprise that the APIs try to avoid unnecessary copies at the expense of
+convenience. Similarly, consistency with TensorFlow APIs was not an explicit
+goal and some variance is to be expected.
+
+## C++
+
+In order to run the inference model in TensorFlow Lite, one has to load the
+model into a `FlatBufferModel` object which then can be executed by an
+`Interpreter`. The `FlatBufferModel` needs to remain valid for the whole
+lifetime of the `Interpreter`, and a single `FlatBufferModel` can be
+simultaneously used by more than one `Interpreter`. In concrete terms, the
+`FlatBufferModel` object must be created before any `Interpreter` objects that
+use it, and must be kept around until they have all been destroyed.
+
+The simplest usage of TensorFlow Lite will look like this:
+
+```c++
+tflite::FlatBufferModel model(path_to_model);
+tflite::ops::builtin::BuiltinOpResolver resolver;
+std::unique_ptr<tflite::Interpreter> interpreter;
+tflite::InterpreterBuilder(*model, resolver)(&interpreter);
+// Resize input tensors, if desired.
+interpreter->AllocateTensors();
+float* input = interpreter->typed_input_tensor<float>(0);
+// Fill `input`.
+interpreter->Invoke();
+float* output = interpreter->type_output_tensor<float>(0);
+```
+### Data Alignment
+
+TensorFlow Lite data is usually aligned to 32-bit boundaries. It is recommended
+that all data provided to TensorFlow Lite be aligned that way.
+
+### Error Reporting
+
+In many places TensorFlow Lite returns status information through
+`TfLiteStatus` objects:
+
+```c++
+typedef enum {
+ kTfLiteOk = 0,
+ kTfLiteError = 1
+} TfLiteStatus;
+
+```
+
+Failures can be easily verified with:
+```c++
+if (status != kTfLiteOk) {
+ // ... error handling here ...
+}
+```
+
+In order to obtain detailed error information an ErrorReporter must be
+provided:
+
+```c++
+class ErrorReporter {
+ virtual int Report(const char* format, va_list args) = 0;
+};
+```
+
+The `DefaultErrorReporter` takes care of reporting to `stderr`.
+
+### Loading a Model
+
+The `FlatBufferModel` class encapsulates a model and can be built in a couple of
+slightly different ways depending on where the model is stored:
+
+```c++
+class FlatBufferModel {
+  // Build a model based on a file. Return a nullptr in case of failure.
+  static std::unique_ptr<FlatBufferModel> BuildFromFile(
+      const char* filename,
+      ErrorReporter* error_reporter);
+
+  // Build a model based on a pre-loaded flatbuffer. The caller retains
+  // ownership of the buffer and should keep it alive until the returned object
+  // is destroyed. Return a nullptr in case of failure.
+  static std::unique_ptr<FlatBufferModel> BuildFromBuffer(
+      const char* buffer,
+      size_t buffer_size,
+      ErrorReporter* error_reporter);
+};
+```
+
+Note that if TensorFlow Lite detects the presence of Android's NNAPI it will
+automatically try to use shared memory to store the FlatBufferModel.
+
+### Running a Model
+
+Running a model involves a few simple steps:
+
+ * Build an `Interpreter` based on an existing `FlatBufferModel`
+ * Optionally resize input tensors if the predefined sizes are not desired.
+ * Set input tensor values
+ * Invoke inference
+ * Read output tensor values
+
+The important parts of public interface of the `Interpreter` are provided
+below. It should be noted that:
+
+ * Tensors are represented by integers, in order to avoid string comparisons
+ (and any fixed dependency on string libraries).
+ * An interpreter must not be accessed from concurrent threads
+ * Memory allocation for input and output tensors must be triggered
+ by calling AllocateTensors() right after resizing tensors.
+
+```c++
+class Interpreter {
+ Interpreter(ErrorReporter* error_reporter);
+
+ // Read only access to list of inputs.
+ const std::vector<int>& inputs() const;
+
+ // Read only access to list of outputs.
+ const std::vector<int>& outputs() const;
+
+ // Change the dimensionality of a given tensor.
+ TfLiteStatus ResizeInputTensor(int tensor_index,
+ const std::vector<int>& dims);
+
+ // Returns status of success or failure.
+ TfLiteStatus AllocateTensors();
+
+ // Return a pointer into the data of a given input tensor.
+ template <class T>
+ T* typed_input_tensor(int index) {
+ return typed_tensor<T>(inputs_[index]);
+ }
+
+ // Return a pointer into the data of a given output tensor.
+ template <class T>
+ T* typed_output_tensor(int index) {
+ return typed_tensor<T>(outputs_[index]);
+ }
+
+ // Execute the model, populating output tensors.
+ TfLiteStatus Invoke();
+};
+```
+
+### Writing Custom Operators
+
+All TensorFlow Lite operators (both custom and builtin) are defined using a
+simple pure-C interface that consists of four functions:
+
+```c++
+typedef struct {
+ void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
+ void (*free)(TfLiteContext* context, void* buffer);
+ TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
+ TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
+} TfLiteRegistration;
+```
+
+Refer to `context.h` for details on `TfLiteContext` and `TfLiteNode`. The
+former provides error reporting facilities and access to global objects,
+including all the tensors. The latter allows implementations to access their
+inputs and outputs.
+
+When the interpreter loads a model, it calls init() once for each node in the
+graph. A given `init()` will be called more than once if the op is used
+multiple times in the graph. For custom ops a configuration buffer will be
+provided, containing a flexbuffer that maps parameter names to their values.
+The buffer is empty for builtin ops because the interpreter has already parsed
+the op parameters. Kernel implementation that require state should initialize
+it here and transfer ownership to the caller. For each `init()` call, there
+will be a corresponding call to `free()`, allowing implementations to dispose
+of the buffer they might have allocated in `init()`.
+
+Whenever the input tensors are resized the interpreter will go through the
+graph notifying implementations of the change. This gives them the chance to
+resize their internal buffer, check validity of input shapes and types, and
+recalculate output shapes. This is all done through `prepare()` and
+implementation can access their state using `node->user_data`.
+
+Finally, each time inference runs the interpreter traverses the graph calling
+`invoke()`, and here too the state is available as `node->user_data`.
+
+Custom ops can be implemented in exactly the same way as builtin ops, by
+defined those four functions and a global registration function that usually
+looks like this:
+
+```c++
+namespace tflite {
+namespace ops {
+namespace custom {
+ TfLiteRegistration* Register_MY_CUSTOM_OP() {
+ static TfLiteRegistration r = {my_custom_op::Init,
+ my_custom_op::Free,
+ my_custom_op::Prepare,
+ my_custom_op::Eval};
+ return &r;
+ }
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+```
+
+Note that registration is not automatic and an explicit call to
+`Register_MY_CUSTOM_OP` should be made somewhere. While the standard
+`:builtin_ops` takes care of the registration of builtins, custom ops will have
+to be collected in separated custom libraries.
+
+### Customizing the kernel library
+
+Behind the scenes the interpreter will load a library of kernels which will be
+assigned to execute each of the operators in the model. While the default
+library only contains builtin kernels, it is possible to replace it with a
+custom library.
+
+The interpreter uses an `OpResolver` to translate operator codes and names into
+actual code:
+
+```c++
+class OpResolver {
+ virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0;
+ virtual TfLiteRegistration* FindOp(const char* op) const = 0;
+ virtual void AddOp(tflite::BuiltinOperator op, TfLiteRegistration* registration) = 0;
+ virtual void AddOp(const char* op, TfLiteRegistration* registration) = 0;
+};
+```
+
+The regular usage will require the developer to use the `BuiltinOpResolver` and
+write:
+
+```c++
+tflite::ops::builtin::BuiltinOpResolver resolver;
+```
+
+They can then optionally register custom ops:
+
+```c++
+resolver.AddOp("MY_CUSTOM_OP", Register_MY_CUSTOM_OP());
+```
+
+before the resolver is passed to the `InterpreterBuilder`.
+
+If the set of builtin ops is deemed to be too large, a new `OpResolver` could
+be code-generated based on a given subset of ops, possibly only the ones
+contained in a given model. This is the equivalent of TensorFlow's selective
+registration (and a simple version of it is available in the `tools`
+directory).
+
+## Java
+
+TensorFlow Lite's Java API supports on-device inference and is provided as an
+Android Studio Library that allows loading models, feeding inputs, and
+retrieving inference outputs.
+
+The simplest usage of Tensorflow Lite Java API looks like this:
+
+```java
+try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
+ interpreter.run(input, output);
+}
+```
+
+### Loading a Model
+
+The `Interpreter.java` class drives model inference with TensorFlow Lite. In
+most of the cases, this is the only class an app developer will need.
+
+#### Initializing an `Interpreter` Mith a Model Mile
+
+The `Interpreter` can be initialized with a model file using the constructor:
+
+```java
+public Interpreter(@NotNull File modelFile);
+```
+
+or with a `MappedByteBuffer`:
+
+```java
+public Interpreter(@NotNull MappedByteBuffer mappedByteBuffer);
+```
+
+In both cases a valid TensorFlow Lite must be provided or an
+`IllegalArgumentException` with be thrown. If a `MappedByteBuffer` is used to
+initialize an Interpreter, it should remain unchanged for the whole lifetime of
+the `Interpreter`.
+
+### Running a Model
+
+#### Supported Data Types
+
+To use TensorFlow Lite, the data types of the input and output tensors must be
+one of the following primitive types:
+
+* `float`
+* `int`
+* `long`
+* `byte`
+
+If other data types, including boxed types like `Integer` and `Float`, are used,
+an `IllegalArgumentException` will be thrown.
+
+#### Inputs
+
+Each input should be an array, a multi-dimensional array, or a `ByteBuffer` of
+the supported primitive types.
+
+The use of `ByteBuffer` is preferred since it allows the `Interpreter` to avoid
+unnecessary copies. Each `ByteBuffer` needs to be a direct byte buffer, and its
+order must be `ByteOrder.nativeOrder()`. After it is used for a model inference,
+it must remain unchanged until the model inference is finished.
+
+#### Outputs
+
+Each output should be an array, or a multi-dimensional array of the supported
+primitive types.
+
+#### Running Model Inference
+
+If a model takes only one input and returns only one output, the following will
+trigger an inference run:
+
+```java
+interpreter.run(input, output);
+```
+
+For models with multiple inputs, or multiple outputs, use:
+
+```java
+interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
+```
+
+where each entry in `inputs` corresponds to an input tensor and
+`map_of_indices_to_outputs` maps indices of output tensors to the
+corresponding output data. In both cases the tensor indices should correspond to
+the values given to the `TensorFlow Lite Optimized Converter` when the model was
+created. Be aware that the order of tensors in `input` must match the order
+given to the `TensorFlow Lite Optimized Converter`.
+
+The Java API also provides convenient functions for app developers to get the
+index of any model input or output using a tensor name:
+
+```java
+public int getInputIndex(String tensorName);
+public int getOutputIndex(String tensorName);
+```
+
+If tensorName is not a valid name in model, an `IllegalArgumentException` will
+be thrown.
+
+### Releasing Resources After Use
+
+An `Interpreter` owns resources. To avoid memory leak, the resources must be
+released after use by:
+
+```java
+interpreter.close();
+```
diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md
new file mode 100644
index 0000000000..204a489a93
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/custom_operators.md
@@ -0,0 +1,91 @@
+# How to use custom operators
+
+TensorFlow Lite currently supports a subset of TensorFlow operators. However, it
+does support the use of user-provided implementations (as known as custom
+implementations) if the model contains an operator that is not supported.
+
+Let’s walk through this via an example. Assume we are using the `Sin` operator
+and that we are building a very simple model for a function `y = sin(x +
+offset)`, where `offset` is trainable.
+
+The code to train the TensorFlow model will be something like:
+
+```python
+offset = tf.get_variable("offset", [1,], tf.float32)
+x = tf.placeholder(tf.float32, shape=(None,))
+y = tf.sin(x + offset)
+y_ = tf.placeholder(tf.float32, shape=(None,))
+loss = tf.reduce_sum(tf.square(y - y_))
+optimizer = tf.train.GradientDescentOptimizer(0.001)
+train = optimizer.minimize(loss)
+```
+
+If you convert this model to Tensorflow Lite format using the TensorFlow Lite
+Optimizing Converter with `--allow_custom_ops` argument, and run it with the
+default interpreter, the interpreter will raise the following error messages:
+
+```
+Didn't find custom op for name 'Sin'
+Registration failed.
+```
+
+All we need to do to use the op in TensorFlow Lite is define two functions
+(`Prepare` and `Eval`), and construct a `TfLiteRegistration`. This code would
+look something like this:
+
+```cpp
+TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) {
+ using namespace tflite;
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+
+ int num_dims = NumDimensions(input);
+
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(num_dims);
+ for (int i=0; i<num_dims; ++i) {
+ output_size->data[i] = input->dims->data[i];
+ }
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
+ using namespace tflite;
+ TfLiteTensor* input = GetInput(context, node,0);
+ TfLiteTensor* output = GetOutput(context, node,0);
+
+ float* input_data = input->data.f;
+ float* output_data = output->data.f;
+
+ size_t count = 1;
+ int num_dims = NumDimensions(input);
+ for (int i = 0; i < num_dims; ++i) {
+ count *= input->dims->data[i];
+ }
+
+ for (size_t i=0; i<count; ++i) {
+ output_data[i] = sin(input_data[i]);
+ }
+ return kTfLiteOk;
+}
+
+TfLiteRegistration* Register_SIN() {
+ static TfLiteRegistration r = {nullptr, nullptr, SinResize, SinEval};
+ return &r;
+}
+```
+
+When initializing the OpResolver, add the custom op into the resolver, this will
+register the operator with Tensorflow Lite so that TensorFlow Lite can use the
+new implementation.
+
+```cpp
+tflite::ops::builtin::BuiltinOpResolver builtins;
+builtins.AddCustom("Sin", Register_SIN());
+```
+
+Note that a similar process as above can be followed for supporting for a set of
+operations instead of a single operator.
diff --git a/tensorflow/contrib/lite/g3doc/ios.md b/tensorflow/contrib/lite/g3doc/ios.md
new file mode 100644
index 0000000000..ce8b37fbf9
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/ios.md
@@ -0,0 +1,67 @@
+# TensorFlow Lite for iOS
+
+## Building
+
+To create a universal iOS library for TensorFlow Lite, you need to build it
+using Xcode's command line tools on a MacOS machine. If you have not already,
+you will need to install Xcode 8 or later and the tools using `xcode-select`:
+
+```bash
+xcode-select --install
+```
+
+If this is a new install, you will need to run XCode once to agree to the
+license before continuing.
+
+(You will also need to have [Homebrew](http://brew.sh/) installed.)
+
+Then install
+[automake](https://en.wikipedia.org/wiki/Automake)/[libtool](https://en.wikipedia.org/wiki/GNU_Libtool):
+
+```bash
+brew install automake
+brew install libtool
+```
+
+Then you need to run a shell script to download the dependencies you need:
+
+```bash
+tensorflow/contrib/lite/download_dependencies.sh
+```
+
+This will fetch copies of libraries and data from the web and install them in
+`tensorflow/contrib/lite/downloads`.
+
+With all of the dependencies set up, you can now build the library for all five
+supported architectures on iOS:
+
+```bash
+tensorflow/contrib/lite/build_ios_universal_lib.sh
+```
+
+Under the hood this uses a makefile in `tensorflow/contrib/lite` to build the
+different versions of the library, followed by a call to `lipo` to bundle them
+into a universal file containing armv7, armv7s, arm64, i386, and x86_64
+architectures. The resulting library is in
+`tensorflow/contrib/lite/gen/lib/libtensorflow-lite.a`.
+
+## Using in your own application
+
+You'll need to update various settings in your app to link against TensorFlow
+Lite. You can view them in the example project at
+`tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj` but here's a full
+rundown:
+
+- You'll need to add the library at
+ `tensorflow/contrib/lite/gen/lib/libtensorflow-lite.a` to your linking build
+ stage, and in Search Paths add `tensorflow/contrib/lite/gen/lib` to the
+ Library Search Paths setting.
+
+- The _Header Search_ paths needs to contain:
+
+ - the root folder of tensorflow,
+ - `tensorflow/contrib/lite/downloads`
+ - `tensorflow/contrib/lite/downloads/flatbuffers/include`
+
+- C++11 support (or later) should be enabled by setting `C++ Language Dialect`
+ to `GNU++11` (or `GNU++14`), and `C++ Standard Library` to `libc++`.
diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md
new file mode 100644
index 0000000000..0508c160c6
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/models.md
@@ -0,0 +1,22 @@
+#List of Hosted Models
+
+* [Inception V3 2015](https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2015_2017_11_10.zip)
+* [Inception V3 Slim 2016](https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_slim_2016_android_2017_11_10.zip)
+* [Mobilenet 0.25 128 Float](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.25_128_float_2017_11_08.zip)
+* [Mobilenet 0.25 160 Float](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.25_160_float_2017_11_08.zip)
+* [Mobilenet 0.25 192 Float](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.25_192_float_2017_11_08.zip)
+* [Mobilenet 0.25 224 Float](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.25_224_float_2017_11_08.zip)
+* [Mobilenet 0.50 128 Float](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_float_2017_11_08.zip)
+* [Mobilenet 0.50 160 Float](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_160_float_2017_11_08.zip)
+* [Mobilenet 0.50 192 Float](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_192_float_2017_11_08.zip)
+* [Mobilenet 0.50 224 Float](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_224_float_2017_11_08.zip)
+* [Mobilenet 0.75 128 Float](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.75_128_float_2017_11_08.zip)
+* [Mobilenet 0.75 160 Float](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.75_160_float_2017_11_08.zip)
+* [Mobilenet 0.75 192 Float](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.75_192_float_2017_11_08.zip)
+* [Mobilenet 0.75 224 Float](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.75_224_float_2017_11_08.zip)
+* [Mobilenet 1.0 128 Float](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_128_float_2017_11_08.zip)
+* [Mobilenet 1.0 160 Float](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_160_float_2017_11_08.zip)
+* [Mobilenet 1.0 192 Float](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_192_float_2017_11_08.zip)
+* [Mobilenet 1.0 224 Float](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_float_2017_11_08.zip)
+* [Mobilenet 1.0 224 Quant](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_224_android_quant_2017_11_08.zip)
+* [Smart Reply 1.0 Android ](https://storage.googleapis.com/download.tensorflow.org/models/smartreply_1.0_2017_11_01.zip)
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
new file mode 100644
index 0000000000..121c4c2c95
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -0,0 +1,417 @@
+# TensorFlow Compatibility Guide
+
+TensorFlow Lite supports a number of TensorFlow operations used in common
+inference models. As they are processed by the TensorFlow Lite Optimizing
+Converter, those operations may be elided or fused, before the supported
+operations are mapped to their TensorFlow Lite counterparts.
+
+Since the set of TensorFlow Lite operations is smaller than TensorFlow's, not
+every model is convertible. Even for supported operations, very specific usage
+patterns are sometimes expected, for performance reasons. We expect to expand
+the set of supported operations in future TensorFlow Lite releases.
+
+The best way to understand how to build a TensorFlow model that can be used with
+TensorFlow Lite is to carefully consider how operations are converted and
+optimized, along with the limitations imposed by this process.
+
+## Supported Types
+
+Most TensorFlow Lite operations target both floating-point (float32) and
+quantized (uint8) inference, but usually there is little or no support for other
+types like tf.float16 and strings.
+
+Apart from using different version of the operations, the other difference
+between floating-point and quantized models lies in the way they are converted.
+Quantized conversion expect the models to be annotated with "fake quantization"
+nodes that record the dynamic range of the tensors. Without that information TF
+Lite is not able to accurately quantize a model, which means that proper
+quantized training is necessary before conversion.
+
+## Data Format and Broadcasting
+
+At the moment TensorFlow Lite supports only TensorFlow's "NHWC" format, and
+broadcasting in operations like tf.add and tf.mul is generally not supported.
+
+## Compatible Operations
+
+The following TensorFlow operations are usually mapped to their TensorFlow Lite
+counterparts:
+
+* [tf.matmul](https://www.tensorflow.org/api_docs/python/tf/matmul) - *as long
+ as the second argument is constant and transposition is not used*
+* [tf.nn.avg_pool](https://www.tensorflow.org/api_docs/python/tf/nn/avg_pool)
+* [tf.nn.conv2d](https://www.tensorflow.org/api_docs/python/tf/nn/conv2d) -
+ *as long as the filter is constant*
+* [tf.nn.depthwise_conv2d](https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d) -
+ *as long as the filter is constant and rate is [1,1]*
+* [tf.nn.l2_normalize](https://www.tensorflow.org/api_docs/python/tf/nn/l2_normalize) -
+ *as long as normalization is done along the last dimension*
+* [tf.nn.local_response_normalization](https://www.tensorflow.org/api_docs/python/tf/nn/local_response_normalization)
+* [tf.nn.max_pool](https://www.tensorflow.org/api_docs/python/tf/nn/max_pool)
+* [tf.nn.softmax](https://www.tensorflow.org/api_docs/python/tf/nn/softmax) -
+ *as long as tensors are 2D and axis is the last dimension*
+* [tf.reshape](https://www.tensorflow.org/api_docs/python/tf/reshape)
+* [tf.sigmoid](https://www.tensorflow.org/api_docs/python/tf/sigmoid)
+* [tf.space_to_depth](https://www.tensorflow.org/api_docs/python/tf/space_to_depth)
+
+## Straighforward Conversions, Constant-Folding and Fusing
+
+A number of TensorFlow operations can be processed by TensorFlow Lite even
+though they have no direct equivalent. This is the case for operations that can
+be simply removed from the graph (tf.identity), replaced by tensors
+(tf.placeholder), or fused into more complex operations (tf.nn.bias_add). Even
+some supported operations may sometimes be removed through one of these
+processes.
+
+Here is a list of TensorFlow operations that are usually removed from the graph:
+
+* [tf.add](https://www.tensorflow.org/api_docs/python/tf/add)
+* [tf.check_numerics](https://www.tensorflow.org/api_docs/python/tf/check_numerics)
+* [tf.constant](https://www.tensorflow.org/api_docs/python/tf/constant)
+* [tf.div](https://www.tensorflow.org/api_docs/python/tf/div)
+* [tf.divide](https://www.tensorflow.org/api_docs/python/tf/divide)
+* [tf.fake_quant_with_min_max_args](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_args)
+* [tf.fake_quant_with_min_max_vars](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_vars)
+* [tf.greater](https://www.tensorflow.org/api_docs/python/tf/greater)
+* [tf.greater_equal](https://www.tensorflow.org/api_docs/python/tf/greater_equal)
+* [tf.identity](https://www.tensorflow.org/api_docs/python/tf/identity)
+* [tf.less](https://www.tensorflow.org/api_docs/python/tf/less)
+* [tf.less_equal](https://www.tensorflow.org/api_docs/python/tf/less_equal)
+* [tf.maximum](https://www.tensorflow.org/api_docs/python/tf/maximum)
+* [tf.minimum](https://www.tensorflow.org/api_docs/python/tf/minimum)
+* [tf.multiply](https://www.tensorflow.org/api_docs/python/tf/multiply)
+* [tf.no_op](https://www.tensorflow.org/api_docs/python/tf/no_op)
+* [tf.placeholder](https://www.tensorflow.org/api_docs/python/tf/placeholder)
+* [tf.placeholder_with_default](https://www.tensorflow.org/api_docs/python/tf/placeholder_with_default)
+* [tf.realdiv](https://www.tensorflow.org/api_docs/python/tf/realdiv)
+* [tf.reduce_max](https://www.tensorflow.org/api_docs/python/tf/reduce_max)
+* [tf.reduce_min](https://www.tensorflow.org/api_docs/python/tf/reduce_min)
+* [tf.reduce_sum](https://www.tensorflow.org/api_docs/python/tf/reduce_sum)
+* [tf.rsqrt](https://www.tensorflow.org/api_docs/python/tf/rsqrt)
+* [tf.shape](https://www.tensorflow.org/api_docs/python/tf/shape)
+* [tf.sqrt](https://www.tensorflow.org/api_docs/python/tf/sqrt)
+* [tf.square](https://www.tensorflow.org/api_docs/python/tf/square)
+* [tf.squeeze](https://www.tensorflow.org/api_docs/python/tf/squeeze)
+* [tf.subtract](https://www.tensorflow.org/api_docs/python/tf/subtract)
+* [tf.tile](https://www.tensorflow.org/api_docs/python/tf/tile)
+* [tf.nn.batch_norm_with_global_normalization](https://www.tensorflow.org/api_docs/python/tf/nn/batch_norm_with_global_normalization)
+* [tf.nn.bias_add](https://www.tensorflow.org/api_docs/python/tf/nn/bias_add)
+* [tf.nn.fused_batch_norm](https://www.tensorflow.org/api_docs/python/tf/nn/fused_batch_norm)
+* [tf.nn.relu](https://www.tensorflow.org/api_docs/python/tf/nn/relu)
+* [tf.nn.relu6](https://www.tensorflow.org/api_docs/python/tf/nn/relu6)
+
+Note that many of those operations don't have TensorFlow Lite equivalents and
+the corresponding model will not be convertible if they can't be elided or
+fused.
+
+## Unsupported Operations
+
+TensorFlow operation not listed above are likely unsupported. Notably, the
+following common ops are not supported at the moment:
+
+* [tf.batch_to_space_nd](https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd)
+* [tf.depth_to_space](https://www.tensorflow.org/api_docs/python/tf/depth_to_space)
+* [tf.floor](https://www.tensorflow.org/api_docs/python/tf/floor)
+* [tf.gather](https://www.tensorflow.org/api_docs/python/tf/gather)
+* [tf.image.resize_bilinear](https://www.tensorflow.org/api_docs/python/tf/image/resize_bilinear)
+* [tf.pad](https://www.tensorflow.org/api_docs/python/tf/pad)
+* [tf.reduce_mean](https://www.tensorflow.org/api_docs/python/tf/reduce_mean)
+* [tf.slice](https://www.tensorflow.org/api_docs/python/tf/slice)
+* [tf.space_to_batch_nd](https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd)
+* [tf.split](https://www.tensorflow.org/api_docs/python/tf/split)
+* [tf.strided_slice](https://www.tensorflow.org/api_docs/python/tf/strided_slice)
+* [tf.tanh](https://www.tensorflow.org/api_docs/python/tf/tanh)
+
+## TensorFlow Lite Operations
+
+The following TensorFlow Lite operations are fully supported and used in place
+of the TensorFlow operations listed above:
+
+**ADD**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: elementwise sum of the input tensors
+}
+Options {
+ fused_activation_function: NONE|RELU|RELU6
+}
+```
+
+**AVERAGE_POOL_2D**
+
+```
+Inputs {
+ 0: a tensor
+}
+Outputs {
+ 0: a tensor where each entry is the mean of the input values in the
+ corresponding window.
+}
+Options {
+ fused_activation_function: NONE|RELU|RELU6
+ padding: SAME|VALID
+ stride_w,stride_h: stride of the sliding window
+ filter_width,filter_height: size of the sliding window
+}
+```
+
+**CONCATENATION**
+
+```
+Inputs {
+ 0-N: any number of tensors
+}
+Outputs {
+ 0: concatenation of the input tensors along the given axis.
+}
+Options {
+ fused_activation_function: NONE|RELU|RELU6
+ axis: dimension along which the concatenation is performed
+}
+```
+
+**CONV_2D**
+
+```
+Inputs {
+ 0: 4D tensor
+ 1: filter
+ 2: bias (optional)
+}
+Outputs {
+ 0: result of 2D convolution of the input tensor
+}
+Options {
+ fused_activation_function: NONE|RELU|RELU6
+ padding: SAME|VALID
+ stride_w,stride_h: stride of the filter window
+}
+```
+
+**DEPTHWISE_CONV_2D**
+
+```
+Inputs {
+ 0: 4D tensor
+ 1: filter
+ 2: bias (optional)
+}
+Outputs {
+ 0: result of a depthwise-2D convolution of the input tensor
+}
+Options {
+ fused_activation_function: NONE|RELU|RELU6
+ padding: SAME|VALID
+ stride_w,stride_h: stride of the filter window
+ depth_multiplier: relation between the last dimension of the input and output
+ tensors
+}
+```
+
+**FULLY_CONNECTED**
+
+```
+Inputs {
+ 0: 4D tensor
+ 1: filter
+ 2: bias (optional)
+}
+Outputs {
+ 0: output of a fully (densely) connected layer, which connects all
+ elements in the input tensor with each element in this tensor.
+}
+Options {
+ fused_activation_function: NONE|RELU|RELU6
+}
+```
+
+**L2_NORMALIZATION**
+
+```
+Inputs {
+ 0: input tensor
+}
+Outputs {
+ 0: normalized tensor (along the last dimension)
+}
+Options {
+ fused_activation_function: NONE|RELU|RELU6
+}
+```
+
+**L2_POOL_2D**
+
+```
+Inputs {
+ 0: a tensor
+}
+Outputs {
+ 0: a tensor equivalent to tf.sqrt(tf.nn.ave_pool(tf.square(input))
+}
+Options {
+ fused_activation_function: NONE|RELU|RELU6
+ padding: SAME|VALID
+ stride_w,stride_h: stride of the sliding window
+ filter_width,filter_height: size of the sliding window
+}
+```
+
+**LOCAL_RESPONSE_NORMALIZATION**
+
+```
+Inputs {
+ 0: a tensor
+}
+Outputs {
+ 0: a tensor equivalent to tf.nn.local_response_normalization
+}
+Options {
+ radius
+ bias
+ alpha
+ beta
+}
+```
+
+**LOGISTIC**
+
+```
+Inputs {
+ 0: a tensor
+}
+Outputs {
+ 0: a tensor equivalent to 1 / (1 + exp(-input))
+}
+```
+
+**MAX_POOL_2D**
+
+```
+Inputs {
+ 0: a tensor
+}
+Outputs {
+ 0: a tensor where each entry is the maximum of the input values in the
+ corresponding window.
+}
+Options {
+ fused_activation_function: NONE|RELU|RELU6
+ padding: SAME|VALID
+ stride_w,stride_h: stride of the sliding window
+ filter_width,filter_height: size of the sliding window
+}
+```
+
+**MUL**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: elementwise multiplication of the input tensors
+}
+Options {
+ fused_activation_function: NONE|RELU|RELU6
+}
+```
+
+**RELU**
+
+```
+Inputs {
+ 0: a tensor
+}
+Outputs {
+ 0: a tensor equivalent to max(0, min(input, 1)
+}
+```
+
+**RELU1**
+
+```
+Inputs {
+ 0: a tensor
+}
+Outputs {
+ 0: a tensor equivalent to max(-1, min(input, 6)
+}
+```
+
+**RELU6**
+
+```
+Inputs {
+ 0: a tensor
+}
+Outputs {
+ 0: a tensor equivalent to max(0, min(input, 6)
+}
+```
+
+**RESHAPE**
+
+```
+Inputs {
+ 0: a tensor
+ 1: ignored
+}
+Outputs {
+ 0: a tensor with the same elements as the input but with the new shape
+}
+Options {
+ new_shape
+}
+```
+
+**SOFTMAX**
+
+```
+Inputs {
+ 0: a tensor
+}
+Outputs {
+ 0: a tensor equivalent to exp(input) / tf.reduce_sum(exp(input * beta), dim),
+ where dim is always the last dimension of the input tensor.
+}
+Options {
+ beta
+}
+```
+
+**SPACE_TO_DEPTH**
+
+```
+Inputs {
+ 0: a 4D tensor
+}
+Outputs {
+ 0: a tensor rearranged using block_size. See tf.space_to_depth for details.
+}
+Options {
+ block_size
+}
+```
+
+And these are TensorFlow Lite operations that are present but not ready for
+custom models yet:
+
+* CALL
+* CONCAT_EMBEDDINGS
+* CUSTOM
+* EMBEDDING_LOOKUP
+* EMBEDDING_LOOKUP_SPARSE
+* HASHTABLE_LOOKUP
+* LSH_PROJECTION
+* LSTM
+* RESIZE_BILINEAR
+* RNN
+* SKIP_GRAM
+* SVDF
+* TANH
diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md
new file mode 100644
index 0000000000..71b633c577
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/README.md
@@ -0,0 +1,36 @@
+# TF Lite Android App
+
+## Building from Source with Bazel
+
+1. Follow the [Bazel steps for the TF Demo App](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#bazel):
+
+ 1. [Install Bazel and Android Prerequisites](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites).
+ It's easiest with Android Studio.
+
+ - You'll need at least SDK version 23.
+ - Bazel requires Android Build Tools `26.0.1` or higher.
+ - You also need to install the Android Support Repository, available
+ through Android Studio under `Android SDK Manager -> SDK Tools ->
+ Android Support Repository`.
+
+ 2. [Edit your `WORKSPACE`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#edit-workspace)
+ to add SDK and NDK targets.
+
+ - Make sure the `api_level` in `WORKSPACE` is set to an SDK version that
+ you have installed.
+ - By default, Android Studio will install the SDK to `~/Android/Sdk` and
+ the NDK to `~/Android/Sdk/ndk-bundle`.
+
+2. Build the app with Bazel. The demo needs C++11:
+
+ ```shell
+ bazel build -c opt --cxxopt='--std=c++11' \
+ //tensorflow/contrib/lite/java/demo/app/src/main:TfLiteCameraDemo
+ ```
+
+3. Install the demo on a
+ [debug-enabled device](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install):
+
+ ```shell
+ adb install bazel-bin/tensorflow/contrib/lite/java/demo/app/src/main/TfLiteCameraDemo.apk
+ ```
diff --git a/tensorflow/contrib/lite/models/smartreply/g3doc/README.md b/tensorflow/contrib/lite/models/smartreply/g3doc/README.md
new file mode 100644
index 0000000000..cab5dcca43
--- /dev/null
+++ b/tensorflow/contrib/lite/models/smartreply/g3doc/README.md
@@ -0,0 +1,146 @@
+# Smart Reply Model
+
+## What is On-Device Smart Reply Model?
+
+Smart Replies are contextually relevant, one-touch responses that help the user
+to reply to an incoming text message (or email) efficiently and effortlessly.
+Smart Replies have been highly successful across several Google products
+including
+[Gmail](https://www.blog.google/products/gmail/save-time-with-smart-reply-in-gmail/),
+[Inbox](https://www.blog.google/products/gmail/computer-respond-to-this-email/)
+and
+[Allo](https://blog.google/products/allo/google-allo-smarter-messaging-app/).
+
+The On-device Smart Reply model is targeted towards text chat use cases. It has
+a completely different architecture from its cloud-based counterparts, and is
+built specifically for memory constraints devices such as phones & watches. It
+has been successfully used to provide [Smart Replies on Android
+Wear](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html)
+to all first- & third-party apps.
+
+The on-device model comes with several benefits. It is:
+
+* **Faster**: The model resides on the device and does not require internet
+ connectivity. Thus, the inference is very fast and has an average latency of
+ only a few milliseconds.
+* **Resource efficient**: The model has a small memory footprint on
+ the device.
+* **Privacy-friendly**: The user data never leaves the device and this
+ eliminates any privacy restrictions.
+
+A caveat, though, is that the on-device model has lower triggering rate than its
+cloud counterparts (triggering rate is the percentage of times the model
+suggests a response for an incoming message).
+
+## When to use this Model?
+
+The On-Device Smart Reply model is aimed towards improving the messaging
+experience for day-to-day conversational chat messages. We recommend using this
+model for similar use cases. Some sample messages on which the model does well
+are provided in this [tsv
+file](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/testdata/smartreply_samples.tsv)
+for reference. The file format is:
+
+```
+ {incoming_message smart_reply1 [smart_reply2] [smart_reply3]}
+```
+
+For the current model, we see a triggering rate of about 30-40% for messages
+which are similar to those provided in the tsv file above.
+
+In case the model does not trigger any response, the system falls back to
+suggesting replies from a fixed back-off set that was compiled from popular
+response intents observed in chat conversations. Some of the fallback responses
+are `Ok, Yes, No, 👍, ☺`.
+
+The model can only be used for inference at this time (i.e. it cannot be custom
+trained). If you are interested to know how the model was trained, please refer
+to this [blog
+post](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html)
+and [research paper](https://arxiv.org/pdf/1708.00630).
+
+## How to use this Model?
+
+We have provided a pre-built demo APK that you can download, install and test on
+your phone ([demo APK
+here](http://download.tensorflow.org/deps/tflite/SmartReplyDemo.apk)).
+
+The On-Device Smart Reply demo App works in the following way:
+
+1. Android app links to the JNI binary with a predictor library.
+
+2. In the predictor library, `GetSegmentPredictions` is called with a list of input
+ strings.
+
+ 2.1 The input string can be 1-3 most recent messages of the conversations in
+ form of string vector. The model will run on these input sentences and
+ provide Smart Replies corresponding to them.
+
+ 2.2 The function performs some preprocessing on input data which includes:
+
+ * Sentence splitting: The input message will be split into sentences if
+ message has more than one sentence. Eg: a message like “How are you?
+ Want to grab lunch?” will be broken down into 2 different sentences.
+ * Normalization: The individual sentences will be normalized by converting
+ them into lower cases, removing unnecessary punctuations, etc. Eg: “how
+ are you????” will be converted to “how are you?” (refer for NORMALIZE op
+ for more details).
+
+ The input string content will be converted to tensors.
+
+ 2.3 The function then runs the prediction model on the input tensors.
+
+ 2.4 The function also performs some post-processing which includes
+ aggregating the model predictions for the input sentences from 2.2 and
+ returning the appropriate responses.
+
+3. Finally, it gets response(s) from `std::vector<PredictorResponse>`, and
+ returns back to Android app. Responses are sorted in descending order of
+ confidence score.
+
+## Ops and Functionality Supported
+
+Following are the ops supported for using On-Device Smart Reply model:
+
+* **NORMALIZE**
+
+ This is a custom op which normalizes the sentences by:
+
+ * Converting all sentences into lower case.
+ * Removing unnecessary punctuations (eg: “how are you????” → “how are
+ you?”).
+ * Expanding sentences wherever necessary (eg: “ I’m home” → “I am home”).
+
+* **SKIP_GRAM**
+
+ This is an op inside TensorFlow Lite that converts sentences into a list of
+ skip grams. The configurable parameters are `ngram_size` and
+ `max_skip_size`. For the model provided, the values for these parameters are
+ set to 3 & 2 respectively.
+
+* **EXTRACT_FEATURES**
+
+ This is a custom op that hashes skip grams to features represented as
+ integers. Longer skip-grams are allocated higher weights.
+
+* **LSH_PROJECTION**
+
+ This is an op inside TensorFlow Lite that projects input features to a
+ corresponding bit vector space using Locality Sensitive Hashing (LSH).
+
+* **PREDICT**
+
+ This is a custom op that runs the input features through the projection
+ model (details [here](https://arxiv.org/pdf/1708.00630.pdf)), computes the
+ appropriate response labels along with weights for the projected features,
+ and aggregates the response labels and weights together.
+
+* **HASHTABLE_LOOKUP**
+
+ This is a custom op that uses label id from predict op and looks up the
+ response text from the given label id.
+
+## Further Information
+
+* Open source code
+ [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/smartreply/).
diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/README.md b/tensorflow/contrib/lite/models/testdata/g3doc/README.md
new file mode 100644
index 0000000000..d0c21d2833
--- /dev/null
+++ b/tensorflow/contrib/lite/models/testdata/g3doc/README.md
@@ -0,0 +1,102 @@
+## Speech Model Tests
+
+Sample test data has been provided for speech related models in Tensorflow Lite
+to help users working with speech models to verify and test their models.
+
+For the hotword, speaker-id and automatic speech recognition sample models, the
+architecture assumes that the models receive their input from a speech
+pre-processing module. The speech pre-processing module receives the audio
+signal and produces features for the encoder neural network and uses some
+typical signal processing algorithms, like FFT and spectral subtraction, and
+ultimately produces a log-mel filterbank (the log of the triangular mel filters
+applied to the power spectra). The text-to-speech model assumes that the inputs
+are linguistic features describing characteristics of phonemes, syllables,
+words, phrases, and sentence. The outputs are acoustic features including
+mel-cepstral coefficients, log fundamental frequency, and band aperiodicity.
+The pre-processing modules for these models are not provided in the open source
+version of TensorFlow Lite.
+
+The following sections describe the architecture of the sample models at a high
+level:
+
+### Hotword Model
+
+The hotword model is the neural network model we use for keyphrase/hotword
+spotting (i.e. "okgoogle" detection). It is the entry point for voice
+interaction (e.g. Google search app on Android devices or Google Home, etc.).
+The speech hotword model block diagram is shown in Figure below. It has an input
+size of 40 (float), an output size of 7 (float), one Svdf layer, and four fully
+connected layers with the corresponding parameters as shown in figure below.
+
+![hotword_model](hotword.svg "Hotword model")
+
+### Speaker-id Model
+
+The speaker-id model is the neural network model we use for speaker
+verification. It runs after the hotword triggers. The speech speaker-id model
+block diagram is shown in Figure below. It has an input size of 80 (float), an
+output size of 64 (float), three Lstm layers, and one fully connected layers
+with the corresponding parameters as shown in figure below.
+
+![speakerid_model](speakerid.svg "Speaker-id model")
+
+### Text-to-speech (TTS) Model
+
+The text-to-speech model is the neural network model used to generate speech
+from text. The speech text-to-speech model’s block diagram is shown
+in Figure below. It has and input size of 334 (float), an output size of 196
+(float), two fully connected layers, three Lstm layers, and one recurrent layer
+with the corresponding parameters as shown in the figure.
+
+![tts_model](tts.svg "TTS model")
+
+### Automatic Speech Recognizer (ASR) Acoustic Model (AM)
+
+The acoustic model for automatic speech recognition is the neural network model
+for matching phonemes to the input autio features. It generates posterior
+probabilities of phonemes from speech frontend features (log-mel filterbanks).
+It has an input size of 320 (float), an output size of 42 (float), five LSTM
+layers and one fully connected layers with a Softmax activation function, with
+the corresponding parameters as shown in the figure.
+
+![asr_am_model](asr_am.svg "ASR AM model")
+
+## Speech models test input/output generation
+
+As mentioned above the input to models are generated from a pre-processing
+module (output of a log-mel filterbank, or linguistic features), and the outputs
+are generated by running the equivalent TensorFlow model by feeding them the
+same input.
+
+## Link to the open source code
+
+### Models:
+
+[Speech hotword model (Svdf rank=1)] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/testdata/speech_hotword_model_rank1.tflite)
+
+[Speech hotword model (Svdf rank=2)] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/testdata/speech_hotword_model_rank2.tflite)
+
+[Speaker-id model] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/testdata/speech_speakerid_model.tflite)
+
+[TTS model] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/testdata/speech_tts_model.tflite)
+
+[ASR AM model] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/testdata/speech_terse_am_model.tflite)
+
+### Test benches
+
+[Speech hotword model test] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_hotword_model_test.cc)
+
+[Speaker-id model test] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc)
+
+[TTS model test] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_tts_model_test.cc)
+
+[ASR AM model test] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc)
+
+## Android Support
+The models have been tested on Android phones, using the following tests:
+
+[Hotword] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/android/BUILD?rcl=172930882&l=25)
+
+[Speaker-id] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/android/BUILD?rcl=172930882&l=36)
+
+
diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/asr_am.svg b/tensorflow/contrib/lite/models/testdata/g3doc/asr_am.svg
new file mode 100644
index 0000000000..ca96556422
--- /dev/null
+++ b/tensorflow/contrib/lite/models/testdata/g3doc/asr_am.svg
@@ -0,0 +1,4 @@
+<?xml version="1.0" standalone="yes"?>
+
+<svg version="1.1" viewBox="0.0 0.0 960.0 720.0" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"><clipPath id="p.0"><path d="m0 0l960.0 0l0 720.0l-960.0 0l0 -720.0z" clip-rule="nonzero"></path></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l960.0 0l0 720.0l-960.0 0z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m392.0 30.700842l166.01575 0l0 42.110237l-166.01575 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m392.0 30.700842l166.01575 0l0 42.110237l-166.01575 0z" fill-rule="evenodd"></path><path fill="#000000" d="m404.43954 57.620842l0 -13.59375l1.8125 0l0 13.59375l-1.8125 0zm4.6676636 0l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm10.375732 3.78125l0 -13.640625l1.53125 0l0 1.28125q0.53125 -0.75 1.203125 -1.125q0.6875 -0.375 1.640625 -0.375q1.265625 0 2.234375 0.65625q0.96875 0.640625 1.453125 1.828125q0.5 1.1875 0.5 2.59375q0 1.515625 -0.546875 2.734375q-0.546875 1.203125 -1.578125 1.84375q-1.03125 0.640625 -2.171875 0.640625q-0.84375 0 -1.515625 -0.34375q-0.65625 -0.359375 -1.078125 -0.890625l0 4.796875l-1.671875 0zm1.515625 -8.65625q0 1.90625 0.765625 2.8125q0.78125 0.90625 1.875 0.90625q1.109375 0 1.890625 -0.9375q0.796875 -0.9375 0.796875 -2.921875q0 -1.875 -0.78125 -2.8125q-0.765625 -0.9375 -1.84375 -0.9375q-1.0625 0 -1.890625 1.0q-0.8125 1.0 -0.8125 2.890625zm15.313202 4.875l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm7.5788574 -1.5l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm9.897858 5.5q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm2.353302 -6.9375l1.65625 -0.265625q0.140625 1.0 0.765625 1.53125q0.640625 0.515625 1.78125 0.515625q1.15625 0 1.703125 -0.46875q0.5625 -0.46875 0.5625 -1.09375q0 -0.5625 -0.484375 -0.890625q-0.34375 -0.21875 -1.703125 -0.5625q-1.84375 -0.46875 -2.5625 -0.796875q-0.703125 -0.34375 -1.078125 -0.9375q-0.359375 -0.609375 -0.359375 -1.328125q0 -0.65625 0.296875 -1.21875q0.3125 -0.5625 0.828125 -0.9375q0.390625 -0.28125 1.0625 -0.484375q0.671875 -0.203125 1.4375 -0.203125q1.171875 0 2.046875 0.34375q0.875 0.328125 1.28125 0.90625q0.421875 0.5625 0.578125 1.515625l-1.625 0.21875q-0.109375 -0.75 -0.65625 -1.171875q-0.53125 -0.4375 -1.5 -0.4375q-1.15625 0 -1.640625 0.390625q-0.484375 0.375 -0.484375 0.875q0 0.328125 0.203125 0.59375q0.203125 0.265625 0.640625 0.4375q0.25 0.09375 1.46875 0.4375q1.765625 0.46875 2.46875 0.765625q0.703125 0.296875 1.09375 0.875q0.40625 0.578125 0.40625 1.4375q0 0.828125 -0.484375 1.578125q-0.484375 0.734375 -1.40625 1.140625q-0.921875 0.390625 -2.078125 0.390625q-1.921875 0 -2.9375 -0.796875q-1.0 -0.796875 -1.28125 -2.359375zm10.015625 -8.75l0 -1.90625l1.671875 0l0 1.90625l-1.671875 0zm0 11.6875l0 -9.859375l1.671875 0l0 9.859375l-1.671875 0zm3.254181 0l0 -1.359375l6.265625 -7.1875q-1.0625 0.046875 -1.875 0.046875l-4.015625 0l0 -1.359375l8.046875 0l0 1.109375l-5.34375 6.25l-1.015625 1.140625q1.109375 -0.078125 2.09375 -0.078125l4.5625 0l0 1.4375l-8.71875 0zm16.953125 -3.171875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm23.074646 -2.125l-8.96875 0l0 -1.5625l8.96875 0l0 1.5625zm0 4.125l-8.96875 0l0 -1.546875l8.96875 0l0 1.546875zm6.8439026 0.28125l1.671875 -0.21875q0.28125 1.421875 0.96875 2.046875q0.703125 0.625 1.6875 0.625q1.1875 0 2.0 -0.8125q0.8125 -0.828125 0.8125 -2.03125q0 -1.140625 -0.765625 -1.890625q-0.75 -0.75 -1.90625 -0.75q-0.46875 0 -1.171875 0.1875l0.1875 -1.46875q0.15625 0.015625 0.265625 0.015625q1.0625 0 1.90625 -0.546875q0.859375 -0.5625 0.859375 -1.71875q0 -0.921875 -0.625 -1.515625q-0.609375 -0.609375 -1.59375 -0.609375q-0.96875 0 -1.625 0.609375q-0.640625 0.609375 -0.828125 1.84375l-1.671875 -0.296875q0.296875 -1.6875 1.375 -2.609375q1.09375 -0.921875 2.71875 -0.921875q1.109375 0 2.046875 0.484375q0.9375 0.46875 1.421875 1.296875q0.5 0.828125 0.5 1.75q0 0.890625 -0.46875 1.609375q-0.46875 0.71875 -1.40625 1.15625q1.21875 0.265625 1.875 1.15625q0.671875 0.875 0.671875 2.1875q0 1.78125 -1.296875 3.015625q-1.296875 1.234375 -3.28125 1.234375q-1.796875 0 -2.984375 -1.0625q-1.171875 -1.0625 -1.34375 -2.765625zm19.141296 1.984375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm1.5944824 -5.09375q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm10.219482 10.703125l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m359.0 102.02362l232.18896 0l0 58.992126l-232.18896 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m359.0 102.02362l232.18896 0l0 58.992126l-232.18896 0z" fill-rule="evenodd"></path><path fill="#000000" d="m401.82367 128.94362l0 -13.59375l1.796875 0l0 11.984375l6.703125 0l0 1.609375l-8.5 0zm9.844482 -4.375l1.6875 -0.140625q0.125 1.015625 0.5625 1.671875q0.4375 0.65625 1.359375 1.0625q0.9375 0.40625 2.09375 0.40625q1.03125 0 1.8125 -0.3125q0.796875 -0.3125 1.1875 -0.84375q0.390625 -0.53125 0.390625 -1.15625q0 -0.640625 -0.375 -1.109375q-0.375 -0.484375 -1.234375 -0.8125q-0.546875 -0.21875 -2.421875 -0.65625q-1.875 -0.453125 -2.625 -0.859375q-0.96875 -0.515625 -1.453125 -1.265625q-0.46875 -0.75 -0.46875 -1.6875q0 -1.03125 0.578125 -1.921875q0.59375 -0.90625 1.703125 -1.359375q1.125 -0.46875 2.5 -0.46875q1.515625 0 2.671875 0.484375q1.15625 0.484375 1.765625 1.4375q0.625 0.9375 0.671875 2.140625l-1.71875 0.125q-0.140625 -1.28125 -0.953125 -1.9375q-0.796875 -0.671875 -2.359375 -0.671875q-1.625 0 -2.375 0.609375q-0.75 0.59375 -0.75 1.4375q0 0.734375 0.53125 1.203125q0.515625 0.46875 2.703125 0.96875q2.203125 0.5 3.015625 0.875q1.1875 0.546875 1.75 1.390625q0.578125 0.828125 0.578125 1.921875q0 1.09375 -0.625 2.0625q-0.625 0.953125 -1.796875 1.484375q-1.15625 0.53125 -2.609375 0.53125q-1.84375 0 -3.09375 -0.53125q-1.25 -0.546875 -1.96875 -1.625q-0.703125 -1.078125 -0.734375 -2.453125zm16.506073 4.375l0 -12.0l-4.46875 0l0 -1.59375l10.765625 0l0 1.59375l-4.5 0l0 12.0l-1.796875 0zm7.8803406 0l0 -13.59375l2.71875 0l3.21875 9.625q0.4375 1.34375 0.640625 2.015625q0.234375 -0.75 0.734375 -2.1875l3.25 -9.453125l2.421875 0l0 13.59375l-1.734375 0l0 -11.390625l-3.953125 11.390625l-1.625 0l-3.9375 -11.578125l0 11.578125l-1.734375 0zm21.212677 0l-1.671875 0l0 -10.640625q-0.59375 0.578125 -1.578125 1.15625q-0.984375 0.5625 -1.765625 0.859375l0 -1.625q1.40625 -0.65625 2.453125 -1.59375q1.046875 -0.9375 1.484375 -1.8125l1.078125 0l0 13.65625zm12.918396 4.0q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm2.572052 -7.59375l1.671875 -0.21875q0.28125 1.421875 0.96875 2.046875q0.703125 0.625 1.6875 0.625q1.1875 0 2.0 -0.8125q0.8125 -0.828125 0.8125 -2.03125q0 -1.140625 -0.765625 -1.890625q-0.75 -0.75 -1.90625 -0.75q-0.46875 0 -1.171875 0.1875l0.1875 -1.46875q0.15625 0.015625 0.265625 0.015625q1.0625 0 1.90625 -0.546875q0.859375 -0.5625 0.859375 -1.71875q0 -0.921875 -0.625 -1.515625q-0.609375 -0.609375 -1.59375 -0.609375q-0.96875 0 -1.625 0.609375q-0.640625 0.609375 -0.828125 1.84375l-1.671875 -0.296875q0.296875 -1.6875 1.375 -2.609375q1.09375 -0.921875 2.71875 -0.921875q1.109375 0 2.046875 0.484375q0.9375 0.46875 1.421875 1.296875q0.5 0.828125 0.5 1.75q0 0.890625 -0.46875 1.609375q-0.46875 0.71875 -1.40625 1.15625q1.21875 0.265625 1.875 1.15625q0.671875 0.875 0.671875 2.1875q0 1.78125 -1.296875 3.015625q-1.296875 1.234375 -3.28125 1.234375q-1.796875 0 -2.984375 -1.0625q-1.171875 -1.0625 -1.34375 -2.765625zm19.141357 1.984375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm1.5944519 -5.09375q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.016357 6.703125l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.9844055 0l-3.3906555 4.640625l3.6562805 5.21875l-2.0469055 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm9.9687805 -3.5625l1.765625 -0.15625q0.1875 1.28125 0.890625 1.9375q0.71875 0.640625 1.71875 0.640625q1.203125 0 2.03125 -0.90625q0.84375 -0.90625 0.84375 -2.421875q0 -1.421875 -0.8125 -2.25q-0.796875 -0.828125 -2.09375 -0.828125q-0.796875 0 -1.453125 0.375q-0.640625 0.359375 -1.015625 0.953125l-1.578125 -0.203125l1.328125 -7.0l6.765625 0l0 1.609375l-5.4375 0l-0.734375 3.640625q1.234375 -0.84375 2.578125 -0.84375q1.78125 0 3.0 1.234375q1.234375 1.234375 1.234375 3.171875q0 1.84375 -1.078125 3.1875q-1.3125 1.65625 -3.578125 1.65625q-1.859375 0 -3.03125 -1.03125q-1.171875 -1.046875 -1.34375 -2.765625zm10.375671 -3.140625q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.656982 0q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm10.219421 10.703125l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" d="m395.9714 154.72487l0 -13.640625l1.53125 0l0 1.28125q0.53125 -0.75 1.203125 -1.125q0.6875 -0.375 1.640625 -0.375q1.265625 0 2.234375 0.65625q0.96875 0.640625 1.453125 1.828125q0.5 1.1875 0.5 2.59375q0 1.515625 -0.546875 2.734375q-0.546875 1.203125 -1.578125 1.84375q-1.03125 0.640625 -2.171875 0.640625q-0.84375 0 -1.515625 -0.34375q-0.65625 -0.359375 -1.078125 -0.890625l0 4.796875l-1.671875 0zm1.515625 -8.65625q0 1.90625 0.765625 2.8125q0.78125 0.90625 1.875 0.90625q1.109375 0 1.890625 -0.9375q0.796875 -0.9375 0.796875 -2.921875q0 -1.875 -0.78125 -2.8125q-0.765625 -0.9375 -1.84375 -0.9375q-1.0625 0 -1.890625 1.0q-0.8125 1.0 -0.8125 2.890625zm8.844452 4.875l0 -9.859375l1.5 0l0 1.5q0.578125 -1.046875 1.0625 -1.375q0.484375 -0.34375 1.078125 -0.34375q0.84375 0 1.71875 0.546875l-0.578125 1.546875q-0.609375 -0.359375 -1.234375 -0.359375q-0.546875 0 -0.984375 0.328125q-0.421875 0.328125 -0.609375 0.90625q-0.28125 0.890625 -0.28125 1.953125l0 5.15625l-1.671875 0zm5.603302 -4.921875q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.281982 -6.734375l0 -1.9375l1.65625 0l0 1.9375l-1.65625 0zm-2.125 15.484375l0.3125 -1.421875q0.5 0.125 0.796875 0.125q0.515625 0 0.765625 -0.34375q0.25 -0.328125 0.25 -1.6875l0 -10.359375l1.65625 0l0 10.390625q0 1.828125 -0.46875 2.546875q-0.59375 0.921875 -2.0 0.921875q-0.671875 0 -1.3125 -0.171875zm13.019806 -7.0l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.547607 2.265625l1.640625 0.21875q-0.265625 1.6875 -1.375 2.65625q-1.109375 0.953125 -2.734375 0.953125q-2.015625 0 -3.25 -1.3125q-1.21875 -1.328125 -1.21875 -3.796875q0 -1.59375 0.515625 -2.78125q0.53125 -1.203125 1.609375 -1.796875q1.09375 -0.609375 2.359375 -0.609375q1.609375 0 2.625 0.8125q1.015625 0.8125 1.3125 2.3125l-1.625 0.25q-0.234375 -1.0 -0.828125 -1.5q-0.59375 -0.5 -1.421875 -0.5q-1.265625 0 -2.0625 0.90625q-0.78125 0.90625 -0.78125 2.859375q0 1.984375 0.765625 2.890625q0.765625 0.890625 1.984375 0.890625q0.984375 0 1.640625 -0.59375q0.65625 -0.609375 0.84375 -1.859375zm6.546875 2.109375l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm1.5426636 -10.1875l0 -1.90625l1.671875 0l0 1.90625l-1.671875 0zm0 11.6875l0 -9.859375l1.671875 0l0 9.859375l-1.671875 0zm3.5042114 -4.921875q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.281952 4.921875l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm14.887146 -2.9375l1.65625 -0.265625q0.140625 1.0 0.765625 1.53125q0.640625 0.515625 1.78125 0.515625q1.15625 0 1.703125 -0.46875q0.5625 -0.46875 0.5625 -1.09375q0 -0.5625 -0.484375 -0.890625q-0.34375 -0.21875 -1.703125 -0.5625q-1.84375 -0.46875 -2.5625 -0.796875q-0.703125 -0.34375 -1.078125 -0.9375q-0.359375 -0.609375 -0.359375 -1.328125q0 -0.65625 0.296875 -1.21875q0.3125 -0.5625 0.828125 -0.9375q0.390625 -0.28125 1.0625 -0.484375q0.671875 -0.203125 1.4375 -0.203125q1.171875 0 2.046875 0.34375q0.875 0.328125 1.28125 0.90625q0.421875 0.5625 0.578125 1.515625l-1.625 0.21875q-0.109375 -0.75 -0.65625 -1.171875q-0.53125 -0.4375 -1.5 -0.4375q-1.15625 0 -1.640625 0.390625q-0.484375 0.375 -0.484375 0.875q0 0.328125 0.203125 0.59375q0.203125 0.265625 0.640625 0.4375q0.25 0.09375 1.46875 0.4375q1.765625 0.46875 2.46875 0.765625q0.703125 0.296875 1.09375 0.875q0.40625 0.578125 0.40625 1.4375q0 0.828125 -0.484375 1.578125q-0.484375 0.734375 -1.40625 1.140625q-0.921875 0.390625 -2.078125 0.390625q-1.921875 0 -2.9375 -0.796875q-1.0 -0.796875 -1.28125 -2.359375zm10.015625 -8.75l0 -1.90625l1.671875 0l0 1.90625l-1.671875 0zm0 11.6875l0 -9.859375l1.671875 0l0 9.859375l-1.671875 0zm3.2542114 0l0 -1.359375l6.265625 -7.1875q-1.0625 0.046875 -1.875 0.046875l-4.015625 0l0 -1.359375l8.046875 0l0 1.109375l-5.34375 6.25l-1.015625 1.140625q1.109375 -0.078125 2.09375 -0.078125l4.5625 0l0 1.4375l-8.71875 0zm16.953125 -3.171875l1.7187805 0.21875q-0.40625 1.5 -1.5156555 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.2344055 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.3437805 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm9.578827 -2.078125l0 -1.90625l1.90625 0l0 1.90625l-1.90625 0zm0 7.953125l0 -1.90625l1.90625 0l0 1.90625l-1.90625 0zm18.210388 -1.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm1.5944824 -5.09375q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.656921 0q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m385.80054 657.01575l180.0 0l0 42.11023l-180.0 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m385.80054 657.01575l180.0 0l0 42.11023l-180.0 0z" fill-rule="evenodd"></path><path fill="#000000" d="m402.3206 677.3107q0 -3.390625 1.8125 -5.296875q1.828125 -1.921875 4.703125 -1.921875q1.875 0 3.390625 0.90625q1.515625 0.890625 2.296875 2.5q0.796875 1.609375 0.796875 3.65625q0 2.0625 -0.84375 3.703125q-0.828125 1.625 -2.359375 2.46875q-1.53125 0.84375 -3.296875 0.84375q-1.921875 0 -3.4375 -0.921875q-1.5 -0.9375 -2.28125 -2.53125q-0.78125 -1.609375 -0.78125 -3.40625zm1.859375 0.03125q0 2.453125 1.3125 3.875q1.328125 1.40625 3.3125 1.40625q2.03125 0 3.34375 -1.421875q1.3125 -1.4375 1.3125 -4.0625q0 -1.65625 -0.5625 -2.890625q-0.546875 -1.234375 -1.640625 -1.921875q-1.078125 -0.6875 -2.421875 -0.6875q-1.90625 0 -3.28125 1.3125q-1.375 1.3125 -1.375 4.390625zm19.433289 6.59375l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm7.5788574 -1.5l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm1.5270386 5.28125l0 -13.640625l1.53125 0l0 1.28125q0.53125 -0.75 1.203125 -1.125q0.6875 -0.375 1.640625 -0.375q1.265625 0 2.234375 0.65625q0.96875 0.640625 1.453125 1.828125q0.5 1.1875 0.5 2.59375q0 1.515625 -0.546875 2.734375q-0.546875 1.203125 -1.578125 1.84375q-1.03125 0.640625 -2.171875 0.640625q-0.84375 0 -1.515625 -0.34375q-0.65625 -0.359375 -1.078125 -0.890625l0 4.796875l-1.671875 0zm1.515625 -8.65625q0 1.90625 0.765625 2.8125q0.78125 0.90625 1.875 0.90625q1.109375 0 1.890625 -0.9375q0.796875 -0.9375 0.796875 -2.921875q0 -1.875 -0.78125 -2.8125q-0.765625 -0.9375 -1.84375 -0.9375q-1.0625 0 -1.890625 1.0q-0.8125 1.0 -0.8125 2.890625zm15.313232 4.875l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm7.578827 -1.5l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm9.897858 5.5q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm2.353302 -6.9375l1.65625 -0.265625q0.140625 1.0 0.765625 1.53125q0.640625 0.515625 1.78125 0.515625q1.15625 0 1.703125 -0.46875q0.5625 -0.46875 0.5625 -1.09375q0 -0.5625 -0.484375 -0.890625q-0.34375 -0.21875 -1.703125 -0.5625q-1.84375 -0.46875 -2.5625 -0.796875q-0.703125 -0.34375 -1.078125 -0.9375q-0.359375 -0.609375 -0.359375 -1.328125q0 -0.65625 0.296875 -1.21875q0.3125 -0.5625 0.828125 -0.9375q0.390625 -0.28125 1.0625 -0.484375q0.671875 -0.203125 1.4375 -0.203125q1.171875 0 2.046875 0.34375q0.875 0.328125 1.28125 0.90625q0.421875 0.5625 0.578125 1.515625l-1.625 0.21875q-0.109375 -0.75 -0.65625 -1.171875q-0.53125 -0.4375 -1.5 -0.4375q-1.15625 0 -1.640625 0.390625q-0.484375 0.375 -0.484375 0.875q0 0.328125 0.203125 0.59375q0.203125 0.265625 0.640625 0.4375q0.25 0.09375 1.46875 0.4375q1.765625 0.46875 2.46875 0.765625q0.703125 0.296875 1.09375 0.875q0.40625 0.578125 0.40625 1.4375q0 0.828125 -0.484375 1.578125q-0.484375 0.734375 -1.40625 1.140625q-0.921875 0.390625 -2.078125 0.390625q-1.921875 0 -2.9375 -0.796875q-1.0 -0.796875 -1.28125 -2.359375zm10.015625 -8.75l0 -1.90625l1.671875 0l0 1.90625l-1.671875 0zm0 11.6875l0 -9.859375l1.671875 0l0 9.859375l-1.671875 0zm3.2542114 0l0 -1.359375l6.265625 -7.1875q-1.0625 0.046875 -1.875 0.046875l-4.015625 0l0 -1.359375l8.046875 0l0 1.109375l-5.34375 6.25l-1.015625 1.140625q1.109375 -0.078125 2.09375 -0.078125l4.5625 0l0 1.4375l-8.71875 0zm16.953125 -3.171875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm23.074646 -2.125l-8.96875 0l0 -1.5625l8.96875 0l0 1.5625zm0 4.125l-8.96875 0l0 -1.546875l8.96875 0l0 1.546875zm12.187622 3.875l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm13.797607 3.171875l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm3.1569824 5.609375l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m475.09448 161.01575l0 24.724411" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m475.09448 161.01575l0 18.724411" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m473.44275 179.74016l1.6517334 4.538101l1.6517334 -4.538101z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m475.09448 244.72906l0 25.29132" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m475.09448 244.72906l0 19.291351" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m473.44275 264.02042l1.6517334 4.538086l1.6517334 -4.538086z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m475.00787 72.81108l0.09448242 29.196846" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m475.00787 72.81108l0.07510376 23.196877" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m473.4312 96.013306l1.6664124 4.5327225l1.6370544 -4.543419z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m359.0 526.4199l232.18896 0l0 42.11029l-232.18896 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m359.0 526.4199l232.18896 0l0 42.11029l-232.18896 0z" fill-rule="evenodd"></path><path fill="#000000" d="m372.43524 553.33997l0 -13.59375l9.171875 0l0 1.59375l-7.375 0l0 4.21875l6.375 0l0 1.609375l-6.375 0l0 6.171875l-1.796875 0zm17.53659 0l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm3.8913574 0l0 -13.59375l1.671875 0l0 13.59375l-1.671875 0zm4.144806 0l0 -13.59375l1.671875 0l0 13.59375l-1.671875 0zm4.0979614 3.796875l-0.171875 -1.5625q0.546875 0.140625 0.953125 0.140625q0.546875 0 0.875 -0.1875q0.34375 -0.1875 0.5625 -0.515625q0.15625 -0.25 0.5 -1.25q0.046875 -0.140625 0.15625 -0.40625l-3.734375 -9.875l1.796875 0l2.046875 5.71875q0.40625 1.078125 0.71875 2.28125q0.28125 -1.15625 0.6875 -2.25l2.09375 -5.75l1.671875 0l-3.75 10.03125q-0.59375 1.625 -0.9375 2.234375q-0.4375 0.828125 -1.015625 1.203125q-0.578125 0.390625 -1.375 0.390625q-0.484375 0 -1.078125 -0.203125zm19.328125 -8.5625l1.796875 0.453125q-0.5625 2.21875 -2.03125 3.390625q-1.46875 1.15625 -3.59375 1.15625q-2.203125 0 -3.578125 -0.890625q-1.375 -0.90625 -2.09375 -2.59375q-0.71875 -1.703125 -0.71875 -3.65625q0 -2.125 0.796875 -3.703125q0.8125 -1.578125 2.3125 -2.390625q1.5 -0.828125 3.296875 -0.828125q2.046875 0 3.4375 1.046875q1.390625 1.03125 1.9375 2.90625l-1.765625 0.421875q-0.46875 -1.484375 -1.375 -2.15625q-0.90625 -0.6875 -2.265625 -0.6875q-1.5625 0 -2.625 0.75q-1.046875 0.75 -1.484375 2.03125q-0.421875 1.265625 -0.421875 2.609375q0 1.734375 0.5 3.03125q0.515625 1.28125 1.578125 1.921875q1.078125 0.640625 2.3125 0.640625q1.515625 0 2.5625 -0.859375q1.046875 -0.875 1.421875 -2.59375zm2.926056 -0.15625q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.281982 4.921875l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm10.375702 0l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm17.125732 -3.171875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.547577 2.265625l1.640625 0.21875q-0.265625 1.6875 -1.375 2.65625q-1.109375 0.953125 -2.734375 0.953125q-2.015625 0 -3.25 -1.3125q-1.21875 -1.328125 -1.21875 -3.796875q0 -1.59375 0.515625 -2.78125q0.53125 -1.203125 1.609375 -1.796875q1.09375 -0.609375 2.359375 -0.609375q1.609375 0 2.625 0.8125q1.015625 0.8125 1.3125 2.3125l-1.625 0.25q-0.234375 -1.0 -0.828125 -1.5q-0.59375 -0.5 -1.421875 -0.5q-1.265625 0 -2.0625 0.90625q-0.78125 0.90625 -0.78125 2.859375q0 1.984375 0.765625 2.890625q0.765625 0.890625 1.984375 0.890625q0.984375 0 1.640625 -0.59375q0.65625 -0.609375 0.84375 -1.859375zm6.546875 2.109375l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm8.277069 -1.671875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.500702 5.875l0 -1.25q-0.9375 1.46875 -2.75 1.46875q-1.171875 0 -2.171875 -0.640625q-0.984375 -0.65625 -1.53125 -1.8125q-0.53125 -1.171875 -0.53125 -2.6875q0 -1.46875 0.484375 -2.671875q0.5 -1.203125 1.46875 -1.84375q0.984375 -0.640625 2.203125 -0.640625q0.890625 0 1.578125 0.375q0.703125 0.375 1.140625 0.984375l0 -4.875l1.65625 0l0 13.59375l-1.546875 0zm-5.28125 -4.921875q0 1.890625 0.796875 2.828125q0.8125 0.9375 1.890625 0.9375q1.09375 0 1.859375 -0.890625q0.765625 -0.890625 0.765625 -2.734375q0 -2.015625 -0.78125 -2.953125q-0.78125 -0.953125 -1.921875 -0.953125q-1.109375 0 -1.859375 0.90625q-0.75 0.90625 -0.75 2.859375zm17.637146 8.921875q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm2.556427 -7.5625l1.765625 -0.15625q0.1875 1.28125 0.890625 1.9375q0.71875 0.640625 1.71875 0.640625q1.203125 0 2.03125 -0.90625q0.84375 -0.90625 0.84375 -2.421875q0 -1.421875 -0.8125 -2.25q-0.796875 -0.828125 -2.09375 -0.828125q-0.796875 0 -1.453125 0.375q-0.640625 0.359375 -1.015625 0.953125l-1.578125 -0.203125l1.328125 -7.0l6.765625 0l0 1.609375l-5.4375 0l-0.734375 3.640625q1.234375 -0.84375 2.578125 -0.84375q1.78125 0 3.0 1.234375q1.234375 1.234375 1.234375 3.171875q0 1.84375 -1.078125 3.1875q-1.3125 1.65625 -3.578125 1.65625q-1.859375 0 -3.03125 -1.03125q-1.171875 -1.046875 -1.34375 -2.765625zm10.375732 -3.140625q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.656982 0q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.016357 6.703125l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.984375 0l-3.390625 4.640625l3.65625 5.21875l-2.046875 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm15.328125 0l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm13.797546 3.171875l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm3.1569824 5.609375l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m475.09448 413.32974l0 24.125977" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m475.09448 413.3297l0 18.126007" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m473.44275 431.45572l1.6517334 4.538086l1.6517334 -4.538086z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m475.09448 329.01575l0 25.322845" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m475.09448 329.01575l0 19.322845" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m473.44275 348.3386l1.6517334 4.538086l1.6517334 -4.538086z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m475.09448 496.44235l0 29.984283" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m475.09448 496.44238l0 23.984253" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m473.44275 520.42664l1.6517334 4.538086l1.6517334 -4.538086z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m359.0 185.73694l232.18896 0l0 58.992126l-232.18896 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m359.0 185.73694l232.18896 0l0 58.992126l-232.18896 0z" fill-rule="evenodd"></path><path fill="#000000" d="m401.82367 212.65694l0 -13.59375l1.796875 0l0 11.984375l6.703125 0l0 1.609375l-8.5 0zm9.844482 -4.375l1.6875 -0.140625q0.125 1.015625 0.5625 1.671875q0.4375 0.65625 1.359375 1.0625q0.9375 0.40625 2.09375 0.40625q1.03125 0 1.8125 -0.3125q0.796875 -0.3125 1.1875 -0.84375q0.390625 -0.53125 0.390625 -1.15625q0 -0.640625 -0.375 -1.109375q-0.375 -0.484375 -1.234375 -0.8125q-0.546875 -0.21875 -2.421875 -0.65625q-1.875 -0.453125 -2.625 -0.859375q-0.96875 -0.515625 -1.453125 -1.265625q-0.46875 -0.75 -0.46875 -1.6875q0 -1.03125 0.578125 -1.921875q0.59375 -0.90625 1.703125 -1.359375q1.125 -0.46875 2.5 -0.46875q1.515625 0 2.671875 0.484375q1.15625 0.484375 1.765625 1.4375q0.625 0.9375 0.671875 2.140625l-1.71875 0.125q-0.140625 -1.28125 -0.953125 -1.9375q-0.796875 -0.671875 -2.359375 -0.671875q-1.625 0 -2.375 0.609375q-0.75 0.59375 -0.75 1.4375q0 0.734375 0.53125 1.203125q0.515625 0.46875 2.703125 0.96875q2.203125 0.5 3.015625 0.875q1.1875 0.546875 1.75 1.390625q0.578125 0.828125 0.578125 1.921875q0 1.09375 -0.625 2.0625q-0.625 0.953125 -1.796875 1.484375q-1.15625 0.53125 -2.609375 0.53125q-1.84375 0 -3.09375 -0.53125q-1.25 -0.546875 -1.96875 -1.625q-0.703125 -1.078125 -0.734375 -2.453125zm16.506073 4.375l0 -12.0l-4.46875 0l0 -1.59375l10.765625 0l0 1.59375l-4.5 0l0 12.0l-1.796875 0zm7.8803406 0l0 -13.59375l2.71875 0l3.21875 9.625q0.4375 1.34375 0.640625 2.015625q0.234375 -0.75 0.734375 -2.1875l3.25 -9.453125l2.421875 0l0 13.59375l-1.734375 0l0 -11.390625l-3.953125 11.390625l-1.625 0l-3.9375 -11.578125l0 11.578125l-1.734375 0zm23.697052 -1.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm10.434021 5.609375q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm2.556427 -7.5625l1.765625 -0.15625q0.1875 1.28125 0.890625 1.9375q0.71875 0.640625 1.71875 0.640625q1.203125 0 2.03125 -0.90625q0.84375 -0.90625 0.84375 -2.421875q0 -1.421875 -0.8125 -2.25q-0.796875 -0.828125 -2.09375 -0.828125q-0.796875 0 -1.453125 0.375q-0.640625 0.359375 -1.015625 0.953125l-1.578125 -0.203125l1.328125 -7.0l6.765625 0l0 1.609375l-5.4375 0l-0.734375 3.640625q1.234375 -0.84375 2.578125 -0.84375q1.78125 0 3.0 1.234375q1.234375 1.234375 1.234375 3.171875q0 1.84375 -1.078125 3.1875q-1.3125 1.65625 -3.578125 1.65625q-1.859375 0 -3.03125 -1.03125q-1.171875 -1.046875 -1.34375 -2.765625zm10.375732 -3.140625q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.656952 0q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.016357 6.703125l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.9844055 0l-3.3906555 4.640625l3.6562805 5.21875l-2.0469055 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm9.9687805 -3.5625l1.765625 -0.15625q0.1875 1.28125 0.890625 1.9375q0.71875 0.640625 1.71875 0.640625q1.203125 0 2.03125 -0.90625q0.84375 -0.90625 0.84375 -2.421875q0 -1.421875 -0.8125 -2.25q-0.796875 -0.828125 -2.09375 -0.828125q-0.796875 0 -1.453125 0.375q-0.640625 0.359375 -1.015625 0.953125l-1.578125 -0.203125l1.328125 -7.0l6.765625 0l0 1.609375l-5.4375 0l-0.734375 3.640625q1.234375 -0.84375 2.578125 -0.84375q1.78125 0 3.0 1.234375q1.234375 1.234375 1.234375 3.171875q0 1.84375 -1.078125 3.1875q-1.3125 1.65625 -3.578125 1.65625q-1.859375 0 -3.03125 -1.03125q-1.171875 -1.046875 -1.34375 -2.765625zm10.375671 -3.140625q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.656982 0q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm10.219421 10.703125l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" d="m395.9714 238.43819l0 -13.640625l1.53125 0l0 1.28125q0.53125 -0.75 1.203125 -1.125q0.6875 -0.375 1.640625 -0.375q1.265625 0 2.234375 0.65625q0.96875 0.640625 1.453125 1.828125q0.5 1.1875 0.5 2.59375q0 1.515625 -0.546875 2.734375q-0.546875 1.203125 -1.578125 1.84375q-1.03125 0.640625 -2.171875 0.640625q-0.84375 0 -1.515625 -0.34375q-0.65625 -0.359375 -1.078125 -0.890625l0 4.796875l-1.671875 0zm1.515625 -8.65625q0 1.90625 0.765625 2.8125q0.78125 0.90625 1.875 0.90625q1.109375 0 1.890625 -0.9375q0.796875 -0.9375 0.796875 -2.921875q0 -1.875 -0.78125 -2.8125q-0.765625 -0.9375 -1.84375 -0.9375q-1.0625 0 -1.890625 1.0q-0.8125 1.0 -0.8125 2.890625zm8.844452 4.875l0 -9.859375l1.5 0l0 1.5q0.578125 -1.046875 1.0625 -1.375q0.484375 -0.34375 1.078125 -0.34375q0.84375 0 1.71875 0.546875l-0.578125 1.546875q-0.609375 -0.359375 -1.234375 -0.359375q-0.546875 0 -0.984375 0.328125q-0.421875 0.328125 -0.609375 0.90625q-0.28125 0.890625 -0.28125 1.953125l0 5.15625l-1.671875 0zm5.603302 -4.921875q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.281982 -6.734375l0 -1.9375l1.65625 0l0 1.9375l-1.65625 0zm-2.125 15.484375l0.3125 -1.421875q0.5 0.125 0.796875 0.125q0.515625 0 0.765625 -0.34375q0.25 -0.328125 0.25 -1.6875l0 -10.359375l1.65625 0l0 10.390625q0 1.828125 -0.46875 2.546875q-0.59375 0.921875 -2.0 0.921875q-0.671875 0 -1.3125 -0.171875zm13.019806 -7.0l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.547607 2.265625l1.640625 0.21875q-0.265625 1.6875 -1.375 2.65625q-1.109375 0.953125 -2.734375 0.953125q-2.015625 0 -3.25 -1.3125q-1.21875 -1.328125 -1.21875 -3.796875q0 -1.59375 0.515625 -2.78125q0.53125 -1.203125 1.609375 -1.796875q1.09375 -0.609375 2.359375 -0.609375q1.609375 0 2.625 0.8125q1.015625 0.8125 1.3125 2.3125l-1.625 0.25q-0.234375 -1.0 -0.828125 -1.5q-0.59375 -0.5 -1.421875 -0.5q-1.265625 0 -2.0625 0.90625q-0.78125 0.90625 -0.78125 2.859375q0 1.984375 0.765625 2.890625q0.765625 0.890625 1.984375 0.890625q0.984375 0 1.640625 -0.59375q0.65625 -0.609375 0.84375 -1.859375zm6.546875 2.109375l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm1.5426636 -10.1875l0 -1.90625l1.671875 0l0 1.90625l-1.671875 0zm0 11.6875l0 -9.859375l1.671875 0l0 9.859375l-1.671875 0zm3.5042114 -4.921875q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.281952 4.921875l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm14.887146 -2.9375l1.65625 -0.265625q0.140625 1.0 0.765625 1.53125q0.640625 0.515625 1.78125 0.515625q1.15625 0 1.703125 -0.46875q0.5625 -0.46875 0.5625 -1.09375q0 -0.5625 -0.484375 -0.890625q-0.34375 -0.21875 -1.703125 -0.5625q-1.84375 -0.46875 -2.5625 -0.796875q-0.703125 -0.34375 -1.078125 -0.9375q-0.359375 -0.609375 -0.359375 -1.328125q0 -0.65625 0.296875 -1.21875q0.3125 -0.5625 0.828125 -0.9375q0.390625 -0.28125 1.0625 -0.484375q0.671875 -0.203125 1.4375 -0.203125q1.171875 0 2.046875 0.34375q0.875 0.328125 1.28125 0.90625q0.421875 0.5625 0.578125 1.515625l-1.625 0.21875q-0.109375 -0.75 -0.65625 -1.171875q-0.53125 -0.4375 -1.5 -0.4375q-1.15625 0 -1.640625 0.390625q-0.484375 0.375 -0.484375 0.875q0 0.328125 0.203125 0.59375q0.203125 0.265625 0.640625 0.4375q0.25 0.09375 1.46875 0.4375q1.765625 0.46875 2.46875 0.765625q0.703125 0.296875 1.09375 0.875q0.40625 0.578125 0.40625 1.4375q0 0.828125 -0.484375 1.578125q-0.484375 0.734375 -1.40625 1.140625q-0.921875 0.390625 -2.078125 0.390625q-1.921875 0 -2.9375 -0.796875q-1.0 -0.796875 -1.28125 -2.359375zm10.015625 -8.75l0 -1.90625l1.671875 0l0 1.90625l-1.671875 0zm0 11.6875l0 -9.859375l1.671875 0l0 9.859375l-1.671875 0zm3.2542114 0l0 -1.359375l6.265625 -7.1875q-1.0625 0.046875 -1.875 0.046875l-4.015625 0l0 -1.359375l8.046875 0l0 1.109375l-5.34375 6.25l-1.015625 1.140625q1.109375 -0.078125 2.09375 -0.078125l4.5625 0l0 1.4375l-8.71875 0zm16.953125 -3.171875l1.7187805 0.21875q-0.40625 1.5 -1.5156555 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.2344055 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.3437805 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm9.578827 -2.078125l0 -1.90625l1.90625 0l0 1.90625l-1.90625 0zm0 7.953125l0 -1.90625l1.90625 0l0 1.90625l-1.90625 0zm18.210388 -1.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm1.5944824 -5.09375q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.656921 0q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m359.0 270.02362l232.18896 0l0 58.992126l-232.18896 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m359.0 270.02362l232.18896 0l0 58.992126l-232.18896 0z" fill-rule="evenodd"></path><path fill="#000000" d="m401.82367 296.94363l0 -13.59375l1.796875 0l0 11.984375l6.703125 0l0 1.609375l-8.5 0zm9.844482 -4.375l1.6875 -0.140625q0.125 1.015625 0.5625 1.671875q0.4375 0.65625 1.359375 1.0625q0.9375 0.40625 2.09375 0.40625q1.03125 0 1.8125 -0.3125q0.796875 -0.3125 1.1875 -0.84375q0.390625 -0.53125 0.390625 -1.15625q0 -0.640625 -0.375 -1.109375q-0.375 -0.484375 -1.234375 -0.8125q-0.546875 -0.21875 -2.421875 -0.65625q-1.875 -0.453125 -2.625 -0.859375q-0.96875 -0.515625 -1.453125 -1.265625q-0.46875 -0.75 -0.46875 -1.6875q0 -1.03125 0.578125 -1.921875q0.59375 -0.90625 1.703125 -1.359375q1.125 -0.46875 2.5 -0.46875q1.515625 0 2.671875 0.484375q1.15625 0.484375 1.765625 1.4375q0.625 0.9375 0.671875 2.140625l-1.71875 0.125q-0.140625 -1.28125 -0.953125 -1.9375q-0.796875 -0.671875 -2.359375 -0.671875q-1.625 0 -2.375 0.609375q-0.75 0.59375 -0.75 1.4375q0 0.734375 0.53125 1.203125q0.515625 0.46875 2.703125 0.96875q2.203125 0.5 3.015625 0.875q1.1875 0.546875 1.75 1.390625q0.578125 0.828125 0.578125 1.921875q0 1.09375 -0.625 2.0625q-0.625 0.953125 -1.796875 1.484375q-1.15625 0.53125 -2.609375 0.53125q-1.84375 0 -3.09375 -0.53125q-1.25 -0.546875 -1.96875 -1.625q-0.703125 -1.078125 -0.734375 -2.453125zm16.506073 4.375l0 -12.0l-4.46875 0l0 -1.59375l10.765625 0l0 1.59375l-4.5 0l0 12.0l-1.796875 0zm7.8803406 0l0 -13.59375l2.71875 0l3.21875 9.625q0.4375 1.34375 0.640625 2.015625q0.234375 -0.75 0.734375 -2.1875l3.25 -9.453125l2.421875 0l0 13.59375l-1.734375 0l0 -11.390625l-3.953125 11.390625l-1.625 0l-3.9375 -11.578125l0 11.578125l-1.734375 0zm23.697052 -1.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm10.434021 5.609375q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm2.556427 -7.5625l1.765625 -0.15625q0.1875 1.28125 0.890625 1.9375q0.71875 0.640625 1.71875 0.640625q1.203125 0 2.03125 -0.90625q0.84375 -0.90625 0.84375 -2.421875q0 -1.421875 -0.8125 -2.25q-0.796875 -0.828125 -2.09375 -0.828125q-0.796875 0 -1.453125 0.375q-0.640625 0.359375 -1.015625 0.953125l-1.578125 -0.203125l1.328125 -7.0l6.765625 0l0 1.609375l-5.4375 0l-0.734375 3.640625q1.234375 -0.84375 2.578125 -0.84375q1.78125 0 3.0 1.234375q1.234375 1.234375 1.234375 3.171875q0 1.84375 -1.078125 3.1875q-1.3125 1.65625 -3.578125 1.65625q-1.859375 0 -3.03125 -1.03125q-1.171875 -1.046875 -1.34375 -2.765625zm10.375732 -3.140625q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.656952 0q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.016357 6.703125l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.9844055 0l-3.3906555 4.640625l3.6562805 5.21875l-2.0469055 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm9.9687805 -3.5625l1.765625 -0.15625q0.1875 1.28125 0.890625 1.9375q0.71875 0.640625 1.71875 0.640625q1.203125 0 2.03125 -0.90625q0.84375 -0.90625 0.84375 -2.421875q0 -1.421875 -0.8125 -2.25q-0.796875 -0.828125 -2.09375 -0.828125q-0.796875 0 -1.453125 0.375q-0.640625 0.359375 -1.015625 0.953125l-1.578125 -0.203125l1.328125 -7.0l6.765625 0l0 1.609375l-5.4375 0l-0.734375 3.640625q1.234375 -0.84375 2.578125 -0.84375q1.78125 0 3.0 1.234375q1.234375 1.234375 1.234375 3.171875q0 1.84375 -1.078125 3.1875q-1.3125 1.65625 -3.578125 1.65625q-1.859375 0 -3.03125 -1.03125q-1.171875 -1.046875 -1.34375 -2.765625zm10.375671 -3.140625q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.656982 0q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm10.219421 10.703125l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" d="m395.9714 322.72488l0 -13.640625l1.53125 0l0 1.28125q0.53125 -0.75 1.203125 -1.125q0.6875 -0.375 1.640625 -0.375q1.265625 0 2.234375 0.65625q0.96875 0.640625 1.453125 1.828125q0.5 1.1875 0.5 2.59375q0 1.515625 -0.546875 2.734375q-0.546875 1.203125 -1.578125 1.84375q-1.03125 0.640625 -2.171875 0.640625q-0.84375 0 -1.515625 -0.34375q-0.65625 -0.359375 -1.078125 -0.890625l0 4.796875l-1.671875 0zm1.515625 -8.65625q0 1.90625 0.765625 2.8125q0.78125 0.90625 1.875 0.90625q1.109375 0 1.890625 -0.9375q0.796875 -0.9375 0.796875 -2.921875q0 -1.875 -0.78125 -2.8125q-0.765625 -0.9375 -1.84375 -0.9375q-1.0625 0 -1.890625 1.0q-0.8125 1.0 -0.8125 2.890625zm8.844452 4.875l0 -9.859375l1.5 0l0 1.5q0.578125 -1.046875 1.0625 -1.375q0.484375 -0.34375 1.078125 -0.34375q0.84375 0 1.71875 0.546875l-0.578125 1.546875q-0.609375 -0.359375 -1.234375 -0.359375q-0.546875 0 -0.984375 0.328125q-0.421875 0.328125 -0.609375 0.90625q-0.28125 0.890625 -0.28125 1.953125l0 5.15625l-1.671875 0zm5.603302 -4.921875q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.281982 -6.734375l0 -1.9375l1.65625 0l0 1.9375l-1.65625 0zm-2.125 15.484375l0.3125 -1.421875q0.5 0.125 0.796875 0.125q0.515625 0 0.765625 -0.34375q0.25 -0.328125 0.25 -1.6875l0 -10.359375l1.65625 0l0 10.390625q0 1.828125 -0.46875 2.546875q-0.59375 0.921875 -2.0 0.921875q-0.671875 0 -1.3125 -0.171875zm13.019806 -7.0l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.547607 2.265625l1.640625 0.21875q-0.265625 1.6875 -1.375 2.65625q-1.109375 0.953125 -2.734375 0.953125q-2.015625 0 -3.25 -1.3125q-1.21875 -1.328125 -1.21875 -3.796875q0 -1.59375 0.515625 -2.78125q0.53125 -1.203125 1.609375 -1.796875q1.09375 -0.609375 2.359375 -0.609375q1.609375 0 2.625 0.8125q1.015625 0.8125 1.3125 2.3125l-1.625 0.25q-0.234375 -1.0 -0.828125 -1.5q-0.59375 -0.5 -1.421875 -0.5q-1.265625 0 -2.0625 0.90625q-0.78125 0.90625 -0.78125 2.859375q0 1.984375 0.765625 2.890625q0.765625 0.890625 1.984375 0.890625q0.984375 0 1.640625 -0.59375q0.65625 -0.609375 0.84375 -1.859375zm6.546875 2.109375l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm1.5426636 -10.1875l0 -1.90625l1.671875 0l0 1.90625l-1.671875 0zm0 11.6875l0 -9.859375l1.671875 0l0 9.859375l-1.671875 0zm3.5042114 -4.921875q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.281952 4.921875l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm14.887146 -2.9375l1.65625 -0.265625q0.140625 1.0 0.765625 1.53125q0.640625 0.515625 1.78125 0.515625q1.15625 0 1.703125 -0.46875q0.5625 -0.46875 0.5625 -1.09375q0 -0.5625 -0.484375 -0.890625q-0.34375 -0.21875 -1.703125 -0.5625q-1.84375 -0.46875 -2.5625 -0.796875q-0.703125 -0.34375 -1.078125 -0.9375q-0.359375 -0.609375 -0.359375 -1.328125q0 -0.65625 0.296875 -1.21875q0.3125 -0.5625 0.828125 -0.9375q0.390625 -0.28125 1.0625 -0.484375q0.671875 -0.203125 1.4375 -0.203125q1.171875 0 2.046875 0.34375q0.875 0.328125 1.28125 0.90625q0.421875 0.5625 0.578125 1.515625l-1.625 0.21875q-0.109375 -0.75 -0.65625 -1.171875q-0.53125 -0.4375 -1.5 -0.4375q-1.15625 0 -1.640625 0.390625q-0.484375 0.375 -0.484375 0.875q0 0.328125 0.203125 0.59375q0.203125 0.265625 0.640625 0.4375q0.25 0.09375 1.46875 0.4375q1.765625 0.46875 2.46875 0.765625q0.703125 0.296875 1.09375 0.875q0.40625 0.578125 0.40625 1.4375q0 0.828125 -0.484375 1.578125q-0.484375 0.734375 -1.40625 1.140625q-0.921875 0.390625 -2.078125 0.390625q-1.921875 0 -2.9375 -0.796875q-1.0 -0.796875 -1.28125 -2.359375zm10.015625 -8.75l0 -1.90625l1.671875 0l0 1.90625l-1.671875 0zm0 11.6875l0 -9.859375l1.671875 0l0 9.859375l-1.671875 0zm3.2542114 0l0 -1.359375l6.265625 -7.1875q-1.0625 0.046875 -1.875 0.046875l-4.015625 0l0 -1.359375l8.046875 0l0 1.109375l-5.34375 6.25l-1.015625 1.140625q1.109375 -0.078125 2.09375 -0.078125l4.5625 0l0 1.4375l-8.71875 0zm16.953125 -3.171875l1.7187805 0.21875q-0.40625 1.5 -1.5156555 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.2344055 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.3437805 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm9.578827 -2.078125l0 -1.90625l1.90625 0l0 1.90625l-1.90625 0zm0 7.953125l0 -1.90625l1.90625 0l0 1.90625l-1.90625 0zm18.210388 -1.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm1.5944824 -5.09375q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.656921 0q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m359.0 354.33762l232.18896 0l0 58.992126l-232.18896 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m359.0 354.33762l232.18896 0l0 58.992126l-232.18896 0z" fill-rule="evenodd"></path><path fill="#000000" d="m401.82367 381.2576l0 -13.59375l1.796875 0l0 11.984375l6.703125 0l0 1.609375l-8.5 0zm9.844482 -4.375l1.6875 -0.140625q0.125 1.015625 0.5625 1.671875q0.4375 0.65625 1.359375 1.0625q0.9375 0.40625 2.09375 0.40625q1.03125 0 1.8125 -0.3125q0.796875 -0.3125 1.1875 -0.84375q0.390625 -0.53125 0.390625 -1.15625q0 -0.640625 -0.375 -1.109375q-0.375 -0.484375 -1.234375 -0.8125q-0.546875 -0.21875 -2.421875 -0.65625q-1.875 -0.453125 -2.625 -0.859375q-0.96875 -0.515625 -1.453125 -1.265625q-0.46875 -0.75 -0.46875 -1.6875q0 -1.03125 0.578125 -1.921875q0.59375 -0.90625 1.703125 -1.359375q1.125 -0.46875 2.5 -0.46875q1.515625 0 2.671875 0.484375q1.15625 0.484375 1.765625 1.4375q0.625 0.9375 0.671875 2.140625l-1.71875 0.125q-0.140625 -1.28125 -0.953125 -1.9375q-0.796875 -0.671875 -2.359375 -0.671875q-1.625 0 -2.375 0.609375q-0.75 0.59375 -0.75 1.4375q0 0.734375 0.53125 1.203125q0.515625 0.46875 2.703125 0.96875q2.203125 0.5 3.015625 0.875q1.1875 0.546875 1.75 1.390625q0.578125 0.828125 0.578125 1.921875q0 1.09375 -0.625 2.0625q-0.625 0.953125 -1.796875 1.484375q-1.15625 0.53125 -2.609375 0.53125q-1.84375 0 -3.09375 -0.53125q-1.25 -0.546875 -1.96875 -1.625q-0.703125 -1.078125 -0.734375 -2.453125zm16.506073 4.375l0 -12.0l-4.46875 0l0 -1.59375l10.765625 0l0 1.59375l-4.5 0l0 12.0l-1.796875 0zm7.8803406 0l0 -13.59375l2.71875 0l3.21875 9.625q0.4375 1.34375 0.640625 2.015625q0.234375 -0.75 0.734375 -2.1875l3.25 -9.453125l2.421875 0l0 13.59375l-1.734375 0l0 -11.390625l-3.953125 11.390625l-1.625 0l-3.9375 -11.578125l0 11.578125l-1.734375 0zm23.697052 -1.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm10.434021 5.609375q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm2.556427 -7.5625l1.765625 -0.15625q0.1875 1.28125 0.890625 1.9375q0.71875 0.640625 1.71875 0.640625q1.203125 0 2.03125 -0.90625q0.84375 -0.90625 0.84375 -2.421875q0 -1.421875 -0.8125 -2.25q-0.796875 -0.828125 -2.09375 -0.828125q-0.796875 0 -1.453125 0.375q-0.640625 0.359375 -1.015625 0.953125l-1.578125 -0.203125l1.328125 -7.0l6.765625 0l0 1.609375l-5.4375 0l-0.734375 3.640625q1.234375 -0.84375 2.578125 -0.84375q1.78125 0 3.0 1.234375q1.234375 1.234375 1.234375 3.171875q0 1.84375 -1.078125 3.1875q-1.3125 1.65625 -3.578125 1.65625q-1.859375 0 -3.03125 -1.03125q-1.171875 -1.046875 -1.34375 -2.765625zm10.375732 -3.140625q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.656952 0q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.016357 6.703125l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.9844055 0l-3.3906555 4.640625l3.6562805 5.21875l-2.0469055 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm9.9687805 -3.5625l1.765625 -0.15625q0.1875 1.28125 0.890625 1.9375q0.71875 0.640625 1.71875 0.640625q1.203125 0 2.03125 -0.90625q0.84375 -0.90625 0.84375 -2.421875q0 -1.421875 -0.8125 -2.25q-0.796875 -0.828125 -2.09375 -0.828125q-0.796875 0 -1.453125 0.375q-0.640625 0.359375 -1.015625 0.953125l-1.578125 -0.203125l1.328125 -7.0l6.765625 0l0 1.609375l-5.4375 0l-0.734375 3.640625q1.234375 -0.84375 2.578125 -0.84375q1.78125 0 3.0 1.234375q1.234375 1.234375 1.234375 3.171875q0 1.84375 -1.078125 3.1875q-1.3125 1.65625 -3.578125 1.65625q-1.859375 0 -3.03125 -1.03125q-1.171875 -1.046875 -1.34375 -2.765625zm10.375671 -3.140625q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.656982 0q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm10.219421 10.703125l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" d="m395.9714 407.03885l0 -13.640625l1.53125 0l0 1.28125q0.53125 -0.75 1.203125 -1.125q0.6875 -0.375 1.640625 -0.375q1.265625 0 2.234375 0.65625q0.96875 0.640625 1.453125 1.828125q0.5 1.1875 0.5 2.59375q0 1.515625 -0.546875 2.734375q-0.546875 1.203125 -1.578125 1.84375q-1.03125 0.640625 -2.171875 0.640625q-0.84375 0 -1.515625 -0.34375q-0.65625 -0.359375 -1.078125 -0.890625l0 4.796875l-1.671875 0zm1.515625 -8.65625q0 1.90625 0.765625 2.8125q0.78125 0.90625 1.875 0.90625q1.109375 0 1.890625 -0.9375q0.796875 -0.9375 0.796875 -2.921875q0 -1.875 -0.78125 -2.8125q-0.765625 -0.9375 -1.84375 -0.9375q-1.0625 0 -1.890625 1.0q-0.8125 1.0 -0.8125 2.890625zm8.844452 4.875l0 -9.859375l1.5 0l0 1.5q0.578125 -1.046875 1.0625 -1.375q0.484375 -0.34375 1.078125 -0.34375q0.84375 0 1.71875 0.546875l-0.578125 1.546875q-0.609375 -0.359375 -1.234375 -0.359375q-0.546875 0 -0.984375 0.328125q-0.421875 0.328125 -0.609375 0.90625q-0.28125 0.890625 -0.28125 1.953125l0 5.15625l-1.671875 0zm5.603302 -4.921875q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.281982 -6.734375l0 -1.9375l1.65625 0l0 1.9375l-1.65625 0zm-2.125 15.484375l0.3125 -1.421875q0.5 0.125 0.796875 0.125q0.515625 0 0.765625 -0.34375q0.25 -0.328125 0.25 -1.6875l0 -10.359375l1.65625 0l0 10.390625q0 1.828125 -0.46875 2.546875q-0.59375 0.921875 -2.0 0.921875q-0.671875 0 -1.3125 -0.171875zm13.019806 -7.0l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.547607 2.265625l1.640625 0.21875q-0.265625 1.6875 -1.375 2.65625q-1.109375 0.953125 -2.734375 0.953125q-2.015625 0 -3.25 -1.3125q-1.21875 -1.328125 -1.21875 -3.796875q0 -1.59375 0.515625 -2.78125q0.53125 -1.203125 1.609375 -1.796875q1.09375 -0.609375 2.359375 -0.609375q1.609375 0 2.625 0.8125q1.015625 0.8125 1.3125 2.3125l-1.625 0.25q-0.234375 -1.0 -0.828125 -1.5q-0.59375 -0.5 -1.421875 -0.5q-1.265625 0 -2.0625 0.90625q-0.78125 0.90625 -0.78125 2.859375q0 1.984375 0.765625 2.890625q0.765625 0.890625 1.984375 0.890625q0.984375 0 1.640625 -0.59375q0.65625 -0.609375 0.84375 -1.859375zm6.546875 2.109375l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm1.5426636 -10.1875l0 -1.90625l1.671875 0l0 1.90625l-1.671875 0zm0 11.6875l0 -9.859375l1.671875 0l0 9.859375l-1.671875 0zm3.5042114 -4.921875q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.281952 4.921875l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm14.887146 -2.9375l1.65625 -0.265625q0.140625 1.0 0.765625 1.53125q0.640625 0.515625 1.78125 0.515625q1.15625 0 1.703125 -0.46875q0.5625 -0.46875 0.5625 -1.09375q0 -0.5625 -0.484375 -0.890625q-0.34375 -0.21875 -1.703125 -0.5625q-1.84375 -0.46875 -2.5625 -0.796875q-0.703125 -0.34375 -1.078125 -0.9375q-0.359375 -0.609375 -0.359375 -1.328125q0 -0.65625 0.296875 -1.21875q0.3125 -0.5625 0.828125 -0.9375q0.390625 -0.28125 1.0625 -0.484375q0.671875 -0.203125 1.4375 -0.203125q1.171875 0 2.046875 0.34375q0.875 0.328125 1.28125 0.90625q0.421875 0.5625 0.578125 1.515625l-1.625 0.21875q-0.109375 -0.75 -0.65625 -1.171875q-0.53125 -0.4375 -1.5 -0.4375q-1.15625 0 -1.640625 0.390625q-0.484375 0.375 -0.484375 0.875q0 0.328125 0.203125 0.59375q0.203125 0.265625 0.640625 0.4375q0.25 0.09375 1.46875 0.4375q1.765625 0.46875 2.46875 0.765625q0.703125 0.296875 1.09375 0.875q0.40625 0.578125 0.40625 1.4375q0 0.828125 -0.484375 1.578125q-0.484375 0.734375 -1.40625 1.140625q-0.921875 0.390625 -2.078125 0.390625q-1.921875 0 -2.9375 -0.796875q-1.0 -0.796875 -1.28125 -2.359375zm10.015625 -8.75l0 -1.90625l1.671875 0l0 1.90625l-1.671875 0zm0 11.6875l0 -9.859375l1.671875 0l0 9.859375l-1.671875 0zm3.2542114 0l0 -1.359375l6.265625 -7.1875q-1.0625 0.046875 -1.875 0.046875l-4.015625 0l0 -1.359375l8.046875 0l0 1.109375l-5.34375 6.25l-1.015625 1.140625q1.109375 -0.078125 2.09375 -0.078125l4.5625 0l0 1.4375l-8.71875 0zm16.953125 -3.171875l1.7187805 0.21875q-0.40625 1.5 -1.5156555 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.2344055 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.3437805 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm9.578827 -2.078125l0 -1.90625l1.90625 0l0 1.90625l-1.90625 0zm0 7.953125l0 -1.90625l1.90625 0l0 1.90625l-1.90625 0zm18.210388 -1.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm1.5944824 -5.09375q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.656921 0q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m359.0 437.45026l232.18896 0l0 58.992096l-232.18896 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m359.0 437.45026l232.18896 0l0 58.992096l-232.18896 0z" fill-rule="evenodd"></path><path fill="#000000" d="m401.82367 464.37024l0 -13.59375l1.796875 0l0 11.984375l6.703125 0l0 1.609375l-8.5 0zm9.844482 -4.375l1.6875 -0.140625q0.125 1.015625 0.5625 1.671875q0.4375 0.65625 1.359375 1.0625q0.9375 0.40625 2.09375 0.40625q1.03125 0 1.8125 -0.3125q0.796875 -0.3125 1.1875 -0.84375q0.390625 -0.53125 0.390625 -1.15625q0 -0.640625 -0.375 -1.109375q-0.375 -0.484375 -1.234375 -0.8125q-0.546875 -0.21875 -2.421875 -0.65625q-1.875 -0.453125 -2.625 -0.859375q-0.96875 -0.515625 -1.453125 -1.265625q-0.46875 -0.75 -0.46875 -1.6875q0 -1.03125 0.578125 -1.921875q0.59375 -0.90625 1.703125 -1.359375q1.125 -0.46875 2.5 -0.46875q1.515625 0 2.671875 0.484375q1.15625 0.484375 1.765625 1.4375q0.625 0.9375 0.671875 2.140625l-1.71875 0.125q-0.140625 -1.28125 -0.953125 -1.9375q-0.796875 -0.671875 -2.359375 -0.671875q-1.625 0 -2.375 0.609375q-0.75 0.59375 -0.75 1.4375q0 0.734375 0.53125 1.203125q0.515625 0.46875 2.703125 0.96875q2.203125 0.5 3.015625 0.875q1.1875 0.546875 1.75 1.390625q0.578125 0.828125 0.578125 1.921875q0 1.09375 -0.625 2.0625q-0.625 0.953125 -1.796875 1.484375q-1.15625 0.53125 -2.609375 0.53125q-1.84375 0 -3.09375 -0.53125q-1.25 -0.546875 -1.96875 -1.625q-0.703125 -1.078125 -0.734375 -2.453125zm16.506073 4.375l0 -12.0l-4.46875 0l0 -1.59375l10.765625 0l0 1.59375l-4.5 0l0 12.0l-1.796875 0zm7.8803406 0l0 -13.59375l2.71875 0l3.21875 9.625q0.4375 1.34375 0.640625 2.015625q0.234375 -0.75 0.734375 -2.1875l3.25 -9.453125l2.421875 0l0 13.59375l-1.734375 0l0 -11.390625l-3.953125 11.390625l-1.625 0l-3.9375 -11.578125l0 11.578125l-1.734375 0zm23.697052 -1.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm10.434021 5.609375q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm2.556427 -7.5625l1.765625 -0.15625q0.1875 1.28125 0.890625 1.9375q0.71875 0.640625 1.71875 0.640625q1.203125 0 2.03125 -0.90625q0.84375 -0.90625 0.84375 -2.421875q0 -1.421875 -0.8125 -2.25q-0.796875 -0.828125 -2.09375 -0.828125q-0.796875 0 -1.453125 0.375q-0.640625 0.359375 -1.015625 0.953125l-1.578125 -0.203125l1.328125 -7.0l6.765625 0l0 1.609375l-5.4375 0l-0.734375 3.640625q1.234375 -0.84375 2.578125 -0.84375q1.78125 0 3.0 1.234375q1.234375 1.234375 1.234375 3.171875q0 1.84375 -1.078125 3.1875q-1.3125 1.65625 -3.578125 1.65625q-1.859375 0 -3.03125 -1.03125q-1.171875 -1.046875 -1.34375 -2.765625zm10.375732 -3.140625q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.656952 0q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.016357 6.703125l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.9844055 0l-3.3906555 4.640625l3.6562805 5.21875l-2.0469055 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm9.9687805 -3.5625l1.765625 -0.15625q0.1875 1.28125 0.890625 1.9375q0.71875 0.640625 1.71875 0.640625q1.203125 0 2.03125 -0.90625q0.84375 -0.90625 0.84375 -2.421875q0 -1.421875 -0.8125 -2.25q-0.796875 -0.828125 -2.09375 -0.828125q-0.796875 0 -1.453125 0.375q-0.640625 0.359375 -1.015625 0.953125l-1.578125 -0.203125l1.328125 -7.0l6.765625 0l0 1.609375l-5.4375 0l-0.734375 3.640625q1.234375 -0.84375 2.578125 -0.84375q1.78125 0 3.0 1.234375q1.234375 1.234375 1.234375 3.171875q0 1.84375 -1.078125 3.1875q-1.3125 1.65625 -3.578125 1.65625q-1.859375 0 -3.03125 -1.03125q-1.171875 -1.046875 -1.34375 -2.765625zm10.375671 -3.140625q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.656982 0q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm10.219421 10.703125l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" d="m395.9714 490.1515l0 -13.640625l1.53125 0l0 1.28125q0.53125 -0.75 1.203125 -1.125q0.6875 -0.375 1.640625 -0.375q1.265625 0 2.234375 0.65625q0.96875 0.640625 1.453125 1.828125q0.5 1.1875 0.5 2.59375q0 1.515625 -0.546875 2.734375q-0.546875 1.203125 -1.578125 1.84375q-1.03125 0.640625 -2.171875 0.640625q-0.84375 0 -1.515625 -0.34375q-0.65625 -0.359375 -1.078125 -0.890625l0 4.796875l-1.671875 0zm1.515625 -8.65625q0 1.90625 0.765625 2.8125q0.78125 0.90625 1.875 0.90625q1.109375 0 1.890625 -0.9375q0.796875 -0.9375 0.796875 -2.921875q0 -1.875 -0.78125 -2.8125q-0.765625 -0.9375 -1.84375 -0.9375q-1.0625 0 -1.890625 1.0q-0.8125 1.0 -0.8125 2.890625zm8.844452 4.875l0 -9.859375l1.5 0l0 1.5q0.578125 -1.046875 1.0625 -1.375q0.484375 -0.34375 1.078125 -0.34375q0.84375 0 1.71875 0.546875l-0.578125 1.546875q-0.609375 -0.359375 -1.234375 -0.359375q-0.546875 0 -0.984375 0.328125q-0.421875 0.328125 -0.609375 0.90625q-0.28125 0.890625 -0.28125 1.953125l0 5.15625l-1.671875 0zm5.603302 -4.921875q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.281982 -6.734375l0 -1.9375l1.65625 0l0 1.9375l-1.65625 0zm-2.125 15.484375l0.3125 -1.421875q0.5 0.125 0.796875 0.125q0.515625 0 0.765625 -0.34375q0.25 -0.328125 0.25 -1.6875l0 -10.359375l1.65625 0l0 10.390625q0 1.828125 -0.46875 2.546875q-0.59375 0.921875 -2.0 0.921875q-0.671875 0 -1.3125 -0.171875zm13.019806 -7.0l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.547607 2.265625l1.640625 0.21875q-0.265625 1.6875 -1.375 2.65625q-1.109375 0.953125 -2.734375 0.953125q-2.015625 0 -3.25 -1.3125q-1.21875 -1.328125 -1.21875 -3.796875q0 -1.59375 0.515625 -2.78125q0.53125 -1.203125 1.609375 -1.796875q1.09375 -0.609375 2.359375 -0.609375q1.609375 0 2.625 0.8125q1.015625 0.8125 1.3125 2.3125l-1.625 0.25q-0.234375 -1.0 -0.828125 -1.5q-0.59375 -0.5 -1.421875 -0.5q-1.265625 0 -2.0625 0.90625q-0.78125 0.90625 -0.78125 2.859375q0 1.984375 0.765625 2.890625q0.765625 0.890625 1.984375 0.890625q0.984375 0 1.640625 -0.59375q0.65625 -0.609375 0.84375 -1.859375zm6.546875 2.109375l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm1.5426636 -10.1875l0 -1.90625l1.671875 0l0 1.90625l-1.671875 0zm0 11.6875l0 -9.859375l1.671875 0l0 9.859375l-1.671875 0zm3.5042114 -4.921875q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.281952 4.921875l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm14.887146 -2.9375l1.65625 -0.265625q0.140625 1.0 0.765625 1.53125q0.640625 0.515625 1.78125 0.515625q1.15625 0 1.703125 -0.46875q0.5625 -0.46875 0.5625 -1.09375q0 -0.5625 -0.484375 -0.890625q-0.34375 -0.21875 -1.703125 -0.5625q-1.84375 -0.46875 -2.5625 -0.796875q-0.703125 -0.34375 -1.078125 -0.9375q-0.359375 -0.609375 -0.359375 -1.328125q0 -0.65625 0.296875 -1.21875q0.3125 -0.5625 0.828125 -0.9375q0.390625 -0.28125 1.0625 -0.484375q0.671875 -0.203125 1.4375 -0.203125q1.171875 0 2.046875 0.34375q0.875 0.328125 1.28125 0.90625q0.421875 0.5625 0.578125 1.515625l-1.625 0.21875q-0.109375 -0.75 -0.65625 -1.171875q-0.53125 -0.4375 -1.5 -0.4375q-1.15625 0 -1.640625 0.390625q-0.484375 0.375 -0.484375 0.875q0 0.328125 0.203125 0.59375q0.203125 0.265625 0.640625 0.4375q0.25 0.09375 1.46875 0.4375q1.765625 0.46875 2.46875 0.765625q0.703125 0.296875 1.09375 0.875q0.40625 0.578125 0.40625 1.4375q0 0.828125 -0.484375 1.578125q-0.484375 0.734375 -1.40625 1.140625q-0.921875 0.390625 -2.078125 0.390625q-1.921875 0 -2.9375 -0.796875q-1.0 -0.796875 -1.28125 -2.359375zm10.015625 -8.75l0 -1.90625l1.671875 0l0 1.90625l-1.671875 0zm0 11.6875l0 -9.859375l1.671875 0l0 9.859375l-1.671875 0zm3.2542114 0l0 -1.359375l6.265625 -7.1875q-1.0625 0.046875 -1.875 0.046875l-4.015625 0l0 -1.359375l8.046875 0l0 1.109375l-5.34375 6.25l-1.015625 1.140625q1.109375 -0.078125 2.09375 -0.078125l4.5625 0l0 1.4375l-8.71875 0zm16.953125 -3.171875l1.7187805 0.21875q-0.40625 1.5 -1.5156555 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.2344055 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.3437805 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm9.578827 -2.078125l0 -1.90625l1.90625 0l0 1.90625l-1.90625 0zm0 7.953125l0 -1.90625l1.90625 0l0 1.90625l-1.90625 0zm18.210388 -1.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm1.5944824 -5.09375q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.656921 0q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m405.46194 594.54596l140.06302 0l0 42.11023l-140.06302 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m405.46194 594.54596l140.06302 0l0 42.11023l-140.06302 0z" fill-rule="evenodd"></path><path fill="#000000" d="m442.13754 617.09094l1.6875 -0.140625q0.125 1.015625 0.5625 1.671875q0.4375 0.65625 1.359375 1.0625q0.9375 0.40625 2.09375 0.40625q1.03125 0 1.8125 -0.3125q0.796875 -0.3125 1.1875 -0.84375q0.390625 -0.53125 0.390625 -1.15625q0 -0.640625 -0.375 -1.109375q-0.375 -0.484375 -1.234375 -0.8125q-0.546875 -0.21875 -2.421875 -0.65625q-1.875 -0.453125 -2.625 -0.859375q-0.96875 -0.515625 -1.453125 -1.265625q-0.46875 -0.75 -0.46875 -1.6875q0 -1.03125 0.578125 -1.921875q0.59375 -0.90625 1.703125 -1.359375q1.125 -0.46875 2.5 -0.46875q1.515625 0 2.671875 0.484375q1.15625 0.484375 1.765625 1.4375q0.625 0.9375 0.671875 2.140625l-1.71875 0.125q-0.140625 -1.28125 -0.953125 -1.9375q-0.796875 -0.671875 -2.359375 -0.671875q-1.625 0 -2.375 0.609375q-0.75 0.59375 -0.75 1.4375q0 0.734375 0.53125 1.203125q0.515625 0.46875 2.703125 0.96875q2.203125 0.5 3.015625 0.875q1.1875 0.546875 1.75 1.390625q0.578125 0.828125 0.578125 1.921875q0 1.09375 -0.625 2.0625q-0.625 0.953125 -1.796875 1.484375q-1.15625 0.53125 -2.609375 0.53125q-1.84375 0 -3.09375 -0.53125q-1.25 -0.546875 -1.96875 -1.625q-0.703125 -1.078125 -0.734375 -2.453125zm12.209198 -0.546875q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.688232 4.921875l0 -8.546875l-1.484375 0l0 -1.3125l1.484375 0l0 -1.046875q0 -0.984375 0.171875 -1.46875q0.234375 -0.65625 0.84375 -1.046875q0.609375 -0.40625 1.703125 -0.40625q0.703125 0 1.5624695 0.15625l-0.25 1.46875q-0.5155945 -0.09375 -0.9843445 -0.09375q-0.765625 0 -1.078125 0.328125q-0.3125 0.3125 -0.3125 1.203125l0 0.90625l1.921875 0l0 1.3125l-1.921875 0l0 8.546875l-1.65625 0zm8.433289 -1.5l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm1.5270386 1.5l0 -9.859375l1.5 0l0 1.390625q0.453125 -0.71875 1.21875 -1.15625q0.78125 -0.453125 1.765625 -0.453125q1.09375 0 1.796875 0.453125q0.703125 0.453125 0.984375 1.28125q1.171875 -1.734375 3.046875 -1.734375q1.46875 0 2.25 0.8125q0.796875 0.8125 0.796875 2.5l0 6.765625l-1.671875 0l0 -6.203125q0 -1.0 -0.15625 -1.4375q-0.15625 -0.453125 -0.59375 -0.71875q-0.421875 -0.265625 -1.0 -0.265625q-1.03125 0 -1.71875 0.6875q-0.6875 0.6875 -0.6875 2.21875l0 5.71875l-1.671875 0l0 -6.40625q0 -1.109375 -0.40625 -1.65625q-0.40625 -0.5625 -1.34375 -0.5625q-0.703125 0 -1.3125 0.375q-0.59375 0.359375 -0.859375 1.078125q-0.265625 0.71875 -0.265625 2.0625l0 5.109375l-1.671875 0zm21.978302 -1.21875q-0.9375 0.796875 -1.796875 1.125q-0.859375 0.3125 -1.84375 0.3125q-1.609375 0 -2.484375 -0.78125q-0.875 -0.796875 -0.875 -2.03125q0 -0.734375 0.328125 -1.328125q0.328125 -0.59375 0.859375 -0.953125q0.53125 -0.359375 1.203125 -0.546875q0.5 -0.140625 1.484375 -0.25q2.03125 -0.25 2.984375 -0.578125q0 -0.34375 0 -0.4375q0 -1.015625 -0.46875 -1.4375q-0.640625 -0.5625 -1.90625 -0.5625q-1.171875 0 -1.734375 0.40625q-0.5625 0.40625 -0.828125 1.46875l-1.640625 -0.234375q0.234375 -1.046875 0.734375 -1.6875q0.515625 -0.640625 1.46875 -0.984375q0.96875 -0.359375 2.25 -0.359375q1.265625 0 2.046875 0.296875q0.78125 0.296875 1.15625 0.75q0.375 0.453125 0.515625 1.140625q0.09375 0.421875 0.09375 1.53125l0 2.234375q0 2.328125 0.09375 2.953125q0.109375 0.609375 0.4375 1.171875l-1.75 0q-0.265625 -0.515625 -0.328125 -1.21875zm-0.140625 -3.71875q-0.90625 0.359375 -2.734375 0.625q-1.03125 0.140625 -1.453125 0.328125q-0.421875 0.1875 -0.65625 0.546875q-0.234375 0.359375 -0.234375 0.796875q0 0.671875 0.5 1.125q0.515625 0.4375 1.484375 0.4375q0.96875 0 1.71875 -0.421875q0.75 -0.4375 1.109375 -1.15625q0.265625 -0.578125 0.265625 -1.671875l0 -0.609375zm2.9694824 4.9375l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.984375 0l-3.390625 4.640625l3.65625 5.21875l-2.046875 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m475.09448 568.5302l0.40945435 26.015747" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m475.09448 568.5302l0.31506348 20.01648" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m473.758 588.5727l1.7229309 4.5115356l1.5801086 -4.5635376z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m475.49344 636.6562l0.31497192 20.346436" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m475.49344 636.6562l0.22210693 14.347168" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m474.064 651.02893l1.7217712 4.511963l1.5812988 -4.5631104z" fill-rule="evenodd"></path></g></svg>
+
diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/hotword.svg b/tensorflow/contrib/lite/models/testdata/g3doc/hotword.svg
new file mode 100755
index 0000000000..36187aa321
--- /dev/null
+++ b/tensorflow/contrib/lite/models/testdata/g3doc/hotword.svg
@@ -0,0 +1,4 @@
+<?xml version="1.0" standalone="yes"?>
+
+<svg version="1.1" viewBox="0.0 0.0 720.0 540.0" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"><clipPath id="p.0"><path d="m0 0l720.0 0l0 540.0l-720.0 0l0 -540.0z" clip-rule="nonzero"></path></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l720.0 0l0 540.0l-720.0 0z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m286.0 5.0l166.01575 0l0 41.984253l-166.01575 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m286.0 5.0l166.01575 0l0 41.984253l-166.01575 0z" fill-rule="evenodd"></path><path fill="#000000" d="m303.62738 31.919998l0 -13.59375l1.8125 0l0 13.59375l-1.8125 0zm4.667694 0l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm10.375702 3.78125l0 -13.640625l1.53125 0l0 1.28125q0.53125 -0.75 1.203125 -1.125q0.6875 -0.375 1.640625 -0.375q1.265625 0 2.234375 0.65625q0.96875 0.640625 1.453125 1.828125q0.5 1.1875 0.5 2.59375q0 1.515625 -0.546875 2.734375q-0.546875 1.203125 -1.578125 1.84375q-1.03125 0.640625 -2.171875 0.640625q-0.84375 0 -1.515625 -0.34375q-0.65625 -0.359375 -1.078125 -0.890625l0 4.796875l-1.671875 0zm1.515625 -8.65625q0 1.90625 0.765625 2.8125q0.78125 0.90625 1.875 0.90625q1.109375 0 1.890625 -0.9375q0.796875 -0.9375 0.796875 -2.921875q0 -1.875 -0.78125 -2.8125q-0.765625 -0.9375 -1.84375 -0.9375q-1.0625 0 -1.890625 1.0q-0.8125 1.0 -0.8125 2.890625zm15.313232 4.875l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm7.5788574 -1.5l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm9.897827 5.5q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm2.3533325 -6.9375l1.65625 -0.265625q0.140625 1.0 0.765625 1.53125q0.640625 0.515625 1.78125 0.515625q1.15625 0 1.703125 -0.46875q0.5625 -0.46875 0.5625 -1.09375q0 -0.5625 -0.484375 -0.890625q-0.34375 -0.21875 -1.703125 -0.5625q-1.84375 -0.46875 -2.5625 -0.796875q-0.703125 -0.34375 -1.078125 -0.9375q-0.359375 -0.609375 -0.359375 -1.328125q0 -0.65625 0.296875 -1.21875q0.3125 -0.5625 0.828125 -0.9375q0.390625 -0.28125 1.0625 -0.484375q0.671875 -0.203125 1.4375 -0.203125q1.171875 0 2.046875 0.34375q0.875 0.328125 1.28125 0.90625q0.421875 0.5625 0.578125 1.515625l-1.625 0.21875q-0.109375 -0.75 -0.65625 -1.171875q-0.53125 -0.4375 -1.5 -0.4375q-1.15625 0 -1.640625 0.390625q-0.484375 0.375 -0.484375 0.875q0 0.328125 0.203125 0.59375q0.203125 0.265625 0.640625 0.4375q0.25 0.09375 1.46875 0.4375q1.765625 0.46875 2.46875 0.765625q0.703125 0.296875 1.09375 0.875q0.40625 0.578125 0.40625 1.4375q0 0.828125 -0.484375 1.578125q-0.484375 0.734375 -1.40625 1.140625q-0.921875 0.390625 -2.078125 0.390625q-1.921875 0 -2.9375 -0.796875q-1.0 -0.796875 -1.28125 -2.359375zm10.015625 -8.75l0 -1.90625l1.671875 0l0 1.90625l-1.671875 0zm0 11.6875l0 -9.859375l1.671875 0l0 9.859375l-1.671875 0zm3.254181 0l0 -1.359375l6.265625 -7.1875q-1.0625 0.046875 -1.875 0.046875l-4.015625 0l0 -1.359375l8.046875 0l0 1.109375l-5.34375 6.25l-1.015625 1.140625q1.109375 -0.078125 2.09375 -0.078125l4.5625 0l0 1.4375l-8.71875 0zm16.953125 -3.171875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm23.074646 -2.125l-8.96875 0l0 -1.5625l8.96875 0l0 1.5625zm0 4.125l-8.96875 0l0 -1.546875l8.96875 0l0 1.546875zm12.187653 3.875l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm5.016327 -1.921875q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm10.219452 10.703125l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m273.00787 70.23491l192.0 0l0 92.7874l-192.0 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m273.00787 70.23491l192.0 0l0 92.7874l-192.0 0z" fill-rule="evenodd"></path><path fill="#000000" d="m344.98923 92.77991l1.6875 -0.140625q0.125 1.015625 0.5625 1.671875q0.4375 0.65625 1.359375 1.0625q0.9375 0.40625 2.09375 0.40625q1.03125 0 1.8125 -0.3125q0.796875 -0.3125 1.1875 -0.84375q0.390625 -0.53125 0.390625 -1.15625q0 -0.640625 -0.375 -1.109375q-0.375 -0.484375 -1.234375 -0.8125q-0.546875 -0.21875 -2.421875 -0.65625q-1.875 -0.453125 -2.625 -0.859375q-0.96875 -0.515625 -1.453125 -1.265625q-0.46875 -0.75 -0.46875 -1.6875q0 -1.03125 0.578125 -1.921875q0.59375 -0.90625 1.703125 -1.359375q1.125 -0.46875 2.5 -0.46875q1.515625 0 2.671875 0.484375q1.15625 0.484375 1.765625 1.4375q0.625 0.9375 0.671875 2.140625l-1.71875 0.125q-0.140625 -1.28125 -0.953125 -1.9375q-0.796875 -0.671875 -2.359375 -0.671875q-1.625 0 -2.375 0.609375q-0.75 0.59375 -0.75 1.4375q0 0.734375 0.53125 1.203125q0.515625 0.46875 2.703125 0.96875q2.203125 0.5 3.015625 0.875q1.1875 0.546875 1.75 1.390625q0.578125 0.828125 0.578125 1.921875q0 1.09375 -0.625 2.0625q-0.625 0.953125 -1.796875 1.484375q-1.15625 0.53125 -2.609375 0.53125q-1.84375 0 -3.09375 -0.53125q-1.25 -0.546875 -1.96875 -1.625q-0.703125 -1.078125 -0.734375 -2.453125zm16.943573 4.375l-5.28125 -13.59375l1.953125 0l3.53125 9.875q0.4375 1.1875 0.71875 2.21875q0.3125 -1.109375 0.734375 -2.21875l3.671875 -9.875l1.84375 0l-5.328125 13.59375l-1.84375 0zm8.552948 0l0 -13.59375l4.6875 0q1.578125 0 2.421875 0.1875q1.15625 0.265625 1.984375 0.96875q1.078125 0.921875 1.609375 2.34375q0.53125 1.40625 0.53125 3.21875q0 1.546875 -0.359375 2.75q-0.359375 1.1875 -0.921875 1.984375q-0.5625 0.78125 -1.234375 1.234375q-0.671875 0.4375 -1.625 0.671875q-0.953125 0.234375 -2.1875 0.234375l-4.90625 0zm1.796875 -1.609375l2.90625 0q1.34375 0 2.109375 -0.25q0.765625 -0.25 1.21875 -0.703125q0.640625 -0.640625 1.0 -1.71875q0.359375 -1.078125 0.359375 -2.625q0 -2.125 -0.703125 -3.265625q-0.703125 -1.15625 -1.703125 -1.546875q-0.71875 -0.28125 -2.328125 -0.28125l-2.859375 0l0 10.390625zm11.769806 1.609375l0 -13.59375l9.171875 0l0 1.59375l-7.375 0l0 4.21875l6.375 0l0 1.609375l-6.375 0l0 6.171875l-1.796875 0z" fill-rule="nonzero"></path><path fill="#000000" d="m296.54065 119.15491l0 -13.59375l2.71875 0l3.21875 9.625q0.4375 1.34375 0.640625 2.015625q0.234375 -0.75 0.734375 -2.1875l3.25 -9.453125l2.421875 0l0 13.59375l-1.734375 0l0 -11.390625l-3.953125 11.390625l-1.625 0l-3.9375 -11.578125l0 11.578125l-1.734375 0zm22.134552 -3.171875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm9.110107 5.875l0 -9.859375l1.5 0l0 1.390625q0.453125 -0.71875 1.21875 -1.15625q0.78125 -0.453125 1.765625 -0.453125q1.09375 0 1.796875 0.453125q0.703125 0.453125 0.984375 1.28125q1.171875 -1.734375 3.046875 -1.734375q1.46875 0 2.25 0.8125q0.796875 0.8125 0.796875 2.5l0 6.765625l-1.671875 0l0 -6.203125q0 -1.0 -0.15625 -1.4375q-0.15625 -0.453125 -0.59375 -0.71875q-0.421875 -0.265625 -1.0 -0.265625q-1.03125 0 -1.71875 0.6875q-0.6875 0.6875 -0.6875 2.21875l0 5.71875l-1.671875 0l0 -6.40625q0 -1.109375 -0.40625 -1.65625q-0.40625 -0.5625 -1.34375 -0.5625q-0.703125 0 -1.3125 0.375q-0.59375 0.359375 -0.859375 1.078125q-0.265625 0.71875 -0.265625 2.0625l0 5.109375l-1.671875 0zm14.915802 -4.921875q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.266327 4.921875l0 -9.859375l1.5 0l0 1.5q0.578125 -1.046875 1.0625 -1.375q0.484375 -0.34375 1.078125 -0.34375q0.84375 0 1.71875 0.546875l-0.578125 1.546875q-0.609375 -0.359375 -1.234375 -0.359375q-0.546875 0 -0.984375 0.328125q-0.421875 0.328125 -0.609375 0.90625q-0.28125 0.890625 -0.28125 1.953125l0 5.15625l-1.671875 0zm6.150177 3.796875l-0.171875 -1.5625q0.546875 0.140625 0.953125 0.140625q0.546875 0 0.875 -0.1875q0.34375 -0.1875 0.5625 -0.515625q0.15625 -0.25 0.5 -1.25q0.046875 -0.140625 0.15625 -0.40625l-3.734375 -9.875l1.796875 0l2.046875 5.71875q0.40625 1.078125 0.71875 2.28125q0.28125 -1.15625 0.6875 -2.25l2.09375 -5.75l1.671875 0l-3.75 10.03125q-0.59375 1.625 -0.9375 2.234375q-0.4375 0.828125 -1.015625 1.203125q-0.578125 0.390625 -1.375 0.390625q-0.484375 0 -1.078125 -0.203125zm13.917694 -6.734375l1.65625 -0.265625q0.140625 1.0 0.765625 1.53125q0.640625 0.515625 1.78125 0.515625q1.15625 0 1.703125 -0.46875q0.5625 -0.46875 0.5625 -1.09375q0 -0.5625 -0.484375 -0.890625q-0.34375 -0.21875 -1.703125 -0.5625q-1.84375 -0.46875 -2.5625 -0.796875q-0.703125 -0.34375 -1.078125 -0.9375q-0.359375 -0.609375 -0.359375 -1.328125q0 -0.65625 0.296875 -1.21875q0.3125 -0.5625 0.828125 -0.9375q0.390625 -0.28125 1.0625 -0.484375q0.671875 -0.203125 1.4375 -0.203125q1.171875 0 2.046875 0.34375q0.875 0.328125 1.28125 0.90625q0.421875 0.5625 0.578125 1.515625l-1.625 0.21875q-0.109375 -0.75 -0.65625 -1.171875q-0.53125 -0.4375 -1.5 -0.4375q-1.15625 0 -1.640625 0.390625q-0.484375 0.375 -0.484375 0.875q0 0.328125 0.203125 0.59375q0.203125 0.265625 0.640625 0.4375q0.25 0.09375 1.46875 0.4375q1.765625 0.46875 2.46875 0.765625q0.703125 0.296875 1.09375 0.875q0.40625 0.578125 0.40625 1.4375q0 0.828125 -0.484375 1.578125q-0.484375 0.734375 -1.40625 1.140625q-0.921875 0.390625 -2.078125 0.390625q-1.921875 0 -2.9375 -0.796875q-1.0 -0.796875 -1.28125 -2.359375zm10.015625 -8.75l0 -1.90625l1.671875 0l0 1.90625l-1.671875 0zm0 11.6875l0 -9.859375l1.671875 0l0 9.859375l-1.671875 0zm3.254181 0l0 -1.359375l6.265625 -7.1875q-1.0625 0.046875 -1.875 0.046875l-4.015625 0l0 -1.359375l8.046875 0l0 1.109375l-5.34375 6.25l-1.015625 1.140625q1.109375 -0.078125 2.09375 -0.078125l4.5625 0l0 1.4375l-8.71875 0zm16.953125 -3.171875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm23.074646 -2.125l-8.96875 0l0 -1.5625l8.96875 0l0 1.5625zm0 4.125l-8.96875 0l0 -1.546875l8.96875 0l0 1.546875zm12.187653 3.875l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm5.016327 -1.921875q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375z" fill-rule="nonzero"></path><path fill="#000000" d="m326.25818 145.1549q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm7.915802 -4.0l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm5.0163574 -1.921875q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm13.199646 6.703125l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.984375 0l-3.390625 4.640625l3.65625 5.21875l-2.046875 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm21.448914 0l-1.671875 0l0 -10.640625q-0.59375 0.578125 -1.578125 1.15625q-0.984375 0.5625 -1.765625 0.859375l0 -1.625q1.40625 -0.65625 2.453125 -1.59375q1.046875 -0.9375 1.484375 -1.8125l1.078125 0l0 13.65625zm12.860107 -1.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm7.891327 1.609375l-1.671875 0l0 -10.640625q-0.59375 0.578125 -1.578125 1.15625q-0.984375 0.5625 -1.765625 0.859375l0 -1.625q1.40625 -0.65625 2.453125 -1.59375q1.046875 -0.9375 1.484375 -1.8125l1.078125 0l0 13.65625zm12.750732 -10.265625l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm10.078827 8.40625l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m294.0 411.00525l150.01575 0l0 58.992126l-150.01575 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m294.0 411.00525l150.01575 0l0 58.992126l-150.01575 0z" fill-rule="evenodd"></path><path fill="#000000" d="m305.7563 437.92526l0 -13.59375l9.171875 0l0 1.59375l-7.375 0l0 4.21875l6.375 0l0 1.609375l-6.375 0l0 6.171875l-1.796875 0zm17.536621 0l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm3.8913574 0l0 -13.59375l1.671875 0l0 13.59375l-1.671875 0zm4.144806 0l0 -13.59375l1.671875 0l0 13.59375l-1.671875 0zm4.097931 3.796875l-0.171875 -1.5625q0.546875 0.140625 0.953125 0.140625q0.546875 0 0.875 -0.1875q0.34375 -0.1875 0.5625 -0.515625q0.15625 -0.25 0.5 -1.25q0.046875 -0.140625 0.15625 -0.40625l-3.734375 -9.875l1.796875 0l2.046875 5.71875q0.40625 1.078125 0.71875 2.28125q0.28125 -1.15625 0.6875 -2.25l2.09375 -5.75l1.671875 0l-3.75 10.03125q-0.59375 1.625 -0.9375 2.234375q-0.4375 0.828125 -1.015625 1.203125q-0.578125 0.390625 -1.375 0.390625q-0.484375 0 -1.078125 -0.203125zm19.328125 -8.5625l1.796875 0.453125q-0.5625 2.21875 -2.03125 3.390625q-1.46875 1.15625 -3.59375 1.15625q-2.203125 0 -3.578125 -0.890625q-1.375 -0.90625 -2.09375 -2.59375q-0.71875 -1.703125 -0.71875 -3.65625q0 -2.125 0.796875 -3.703125q0.8125 -1.578125 2.3125 -2.390625q1.5 -0.828125 3.296875 -0.828125q2.046875 0 3.4375 1.046875q1.390625 1.03125 1.9375 2.90625l-1.765625 0.421875q-0.46875 -1.484375 -1.375 -2.15625q-0.90625 -0.6875 -2.265625 -0.6875q-1.5625 0 -2.625 0.75q-1.046875 0.75 -1.484375 2.03125q-0.421875 1.265625 -0.421875 2.609375q0 1.734375 0.5 3.03125q0.515625 1.28125 1.578125 1.921875q1.078125 0.640625 2.3125 0.640625q1.515625 0 2.5625 -0.859375q1.046875 -0.875 1.421875 -2.59375zm2.9260864 -0.15625q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.281982 4.921875l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm10.375702 0l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm17.125702 -3.171875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.547607 2.265625l1.640625 0.21875q-0.265625 1.6875 -1.375 2.65625q-1.109375 0.953125 -2.734375 0.953125q-2.015625 0 -3.25 -1.3125q-1.21875 -1.328125 -1.21875 -3.796875q0 -1.59375 0.515625 -2.78125q0.53125 -1.203125 1.609375 -1.796875q1.09375 -0.609375 2.359375 -0.609375q1.609375 0 2.625 0.8125q1.015625 0.8125 1.3125 2.3125l-1.625 0.25q-0.234375 -1.0 -0.828125 -1.5q-0.59375 -0.5 -1.421875 -0.5q-1.265625 0 -2.0625 0.90625q-0.78125 0.90625 -0.78125 2.859375q0 1.984375 0.765625 2.890625q0.765625 0.890625 1.984375 0.890625q0.984375 0 1.640625 -0.59375q0.65625 -0.609375 0.84375 -1.859375zm6.546875 2.109375l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm8.277039 -1.671875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.500732 5.875l0 -1.25q-0.9375 1.46875 -2.75 1.46875q-1.171875 0 -2.171875 -0.640625q-0.984375 -0.65625 -1.53125 -1.8125q-0.53125 -1.171875 -0.53125 -2.6875q0 -1.46875 0.484375 -2.671875q0.5 -1.203125 1.46875 -1.84375q0.984375 -0.640625 2.203125 -0.640625q0.890625 0 1.578125 0.375q0.703125 0.375 1.140625 0.984375l0 -4.875l1.65625 0l0 13.59375l-1.546875 0zm-5.28125 -4.921875q0 1.890625 0.796875 2.828125q0.8125 0.9375 1.890625 0.9375q1.09375 0 1.859375 -0.890625q0.765625 -0.890625 0.765625 -2.734375q0 -2.015625 -0.78125 -2.953125q-0.78125 -0.953125 -1.921875 -0.953125q-1.109375 0 -1.859375 0.90625q-0.75 0.90625 -0.75 2.859375z" fill-rule="nonzero"></path><path fill="#000000" d="m336.6339 463.92526q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm11.3376465 -5.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm1.5944824 -1.953125l1.765625 -0.15625q0.1875 1.28125 0.890625 1.9375q0.71875 0.640625 1.71875 0.640625q1.203125 0 2.03125 -0.90625q0.84375 -0.90625 0.84375 -2.421875q0 -1.421875 -0.8125 -2.25q-0.796875 -0.828125 -2.09375 -0.828125q-0.796875 0 -1.453125 0.375q-0.640625 0.359375 -1.015625 0.953125l-1.578125 -0.203125l1.328125 -7.0l6.765625 0l0 1.609375l-5.4375 0l-0.734375 3.640625q1.234375 -0.84375 2.578125 -0.84375q1.78125 0 3.0 1.234375q1.234375 1.234375 1.234375 3.171875q0 1.84375 -1.078125 3.1875q-1.3125 1.65625 -3.578125 1.65625q-1.859375 0 -3.03125 -1.03125q-1.171875 -1.046875 -1.34375 -2.765625zm19.047607 -6.703125l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm13.059021 4.40625l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.984375 0l-3.390625 4.640625l3.65625 5.21875l-2.046875 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm15.277039 -11.8125l0 -1.609375l8.796875 0l0 1.296875q-1.296875 1.375 -2.578125 3.671875q-1.265625 2.296875 -1.96875 4.71875q-0.5 1.703125 -0.640625 3.734375l-1.71875 0q0.03125 -1.609375 0.625 -3.875q0.609375 -2.28125 1.734375 -4.390625q1.140625 -2.109375 2.40625 -3.546875l-6.65625 0zm11.813232 15.8125l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m286.0 485.50656l166.01575 0l0 41.984222l-166.01575 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m286.0 485.50656l166.01575 0l0 41.984222l-166.01575 0z" fill-rule="evenodd"></path><path fill="#000000" d="m300.7158 505.80157q0 -3.390625 1.8125 -5.296875q1.828125 -1.921875 4.703125 -1.921875q1.875 0 3.390625 0.90625q1.515625 0.890625 2.296875 2.5q0.796875 1.609375 0.796875 3.65625q0 2.0625 -0.84375 3.703125q-0.828125 1.625 -2.359375 2.46875q-1.53125 0.84375 -3.296875 0.84375q-1.921875 0 -3.4375 -0.921875q-1.5 -0.9375 -2.28125 -2.53125q-0.78125 -1.609375 -0.78125 -3.40625zm1.859375 0.03125q0 2.453125 1.3125 3.875q1.328125 1.40625 3.3125 1.40625q2.03125 0 3.34375 -1.421875q1.3125 -1.4375 1.3125 -4.0625q0 -1.65625 -0.5625 -2.890625q-0.546875 -1.234375 -1.640625 -1.921875q-1.078125 -0.6875 -2.421875 -0.6875q-1.90625 0 -3.28125 1.3125q-1.375 1.3125 -1.375 4.390625zm19.43332 6.59375l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm7.578827 -1.5l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm1.5270691 5.28125l0 -13.640625l1.53125 0l0 1.28125q0.53125 -0.75 1.203125 -1.125q0.6875 -0.375 1.640625 -0.375q1.265625 0 2.234375 0.65625q0.96875 0.640625 1.453125 1.828125q0.5 1.1875 0.5 2.59375q0 1.515625 -0.546875 2.734375q-0.546875 1.203125 -1.578125 1.84375q-1.03125 0.640625 -2.171875 0.640625q-0.84375 0 -1.515625 -0.34375q-0.65625 -0.359375 -1.078125 -0.890625l0 4.796875l-1.671875 0zm1.515625 -8.65625q0 1.90625 0.765625 2.8125q0.78125 0.90625 1.875 0.90625q1.109375 0 1.890625 -0.9375q0.796875 -0.9375 0.796875 -2.921875q0 -1.875 -0.78125 -2.8125q-0.765625 -0.9375 -1.84375 -0.9375q-1.0625 0 -1.890625 1.0q-0.8125 1.0 -0.8125 2.890625zm15.313202 4.875l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm7.578827 -1.5l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm9.897888 5.5q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm2.3532715 -6.9375l1.65625 -0.265625q0.140625 1.0 0.765625 1.53125q0.640625 0.515625 1.78125 0.515625q1.15625 0 1.703125 -0.46875q0.5625 -0.46875 0.5625 -1.09375q0 -0.5625 -0.484375 -0.890625q-0.34375 -0.21875 -1.703125 -0.5625q-1.84375 -0.46875 -2.5625 -0.796875q-0.703125 -0.34375 -1.078125 -0.9375q-0.359375 -0.609375 -0.359375 -1.328125q0 -0.65625 0.296875 -1.21875q0.3125 -0.5625 0.828125 -0.9375q0.390625 -0.28125 1.0625 -0.484375q0.671875 -0.203125 1.4375 -0.203125q1.171875 0 2.046875 0.34375q0.875 0.328125 1.28125 0.90625q0.421875 0.5625 0.578125 1.515625l-1.625 0.21875q-0.109375 -0.75 -0.65625 -1.171875q-0.53125 -0.4375 -1.5 -0.4375q-1.15625 0 -1.640625 0.390625q-0.484375 0.375 -0.484375 0.875q0 0.328125 0.203125 0.59375q0.203125 0.265625 0.640625 0.4375q0.25 0.09375 1.46875 0.4375q1.765625 0.46875 2.46875 0.765625q0.703125 0.296875 1.09375 0.875q0.40625 0.578125 0.40625 1.4375q0 0.828125 -0.484375 1.578125q-0.484375 0.734375 -1.40625 1.140625q-0.921875 0.390625 -2.078125 0.390625q-1.921875 0 -2.9375 -0.796875q-1.0 -0.796875 -1.28125 -2.359375zm10.015625 -8.75l0 -1.90625l1.671875 0l0 1.90625l-1.671875 0zm0 11.6875l0 -9.859375l1.671875 0l0 9.859375l-1.671875 0zm3.2542114 0l0 -1.359375l6.265625 -7.1875q-1.0625 0.046875 -1.875 0.046875l-4.015625 0l0 -1.359375l8.046875 0l0 1.109375l-5.34375 6.25l-1.015625 1.140625q1.109375 -0.078125 2.09375 -0.078125l4.5625 0l0 1.4375l-8.71875 0zm16.953125 -3.171875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm23.074646 -2.125l-8.96875 0l0 -1.5625l8.96875 0l0 1.5625zm0 4.125l-8.96875 0l0 -1.546875l8.96875 0l0 1.546875zm6.953247 -7.9375l0 -1.609375l8.796875 0l0 1.296875q-1.296875 1.375 -2.578125 3.671875q-1.265625 2.296875 -1.96875 4.71875q-0.5 1.703125 -0.640625 3.734375l-1.71875 0q0.03125 -1.609375 0.625 -3.875q0.609375 -2.28125 1.734375 -4.390625q1.140625 -2.109375 2.40625 -3.546875l-6.65625 0zm11.813232 15.8125l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m294.0 187.5l150.01575 0l0 58.992126l-150.01575 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m294.0 187.5l150.01575 0l0 58.992126l-150.01575 0z" fill-rule="evenodd"></path><path fill="#000000" d="m305.7563 214.42l0 -13.59375l9.171875 0l0 1.59375l-7.375 0l0 4.21875l6.375 0l0 1.609375l-6.375 0l0 6.171875l-1.796875 0zm17.536621 0l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm3.8913574 0l0 -13.59375l1.671875 0l0 13.59375l-1.671875 0zm4.144806 0l0 -13.59375l1.671875 0l0 13.59375l-1.671875 0zm4.097931 3.796875l-0.171875 -1.5625q0.546875 0.140625 0.953125 0.140625q0.546875 0 0.875 -0.1875q0.34375 -0.1875 0.5625 -0.515625q0.15625 -0.25 0.5 -1.25q0.046875 -0.140625 0.15625 -0.40625l-3.734375 -9.875l1.796875 0l2.046875 5.71875q0.40625 1.078125 0.71875 2.28125q0.28125 -1.15625 0.6875 -2.25l2.09375 -5.75l1.671875 0l-3.75 10.03125q-0.59375 1.625 -0.9375 2.234375q-0.4375 0.828125 -1.015625 1.203125q-0.578125 0.390625 -1.375 0.390625q-0.484375 0 -1.078125 -0.203125zm19.328125 -8.5625l1.796875 0.453125q-0.5625 2.21875 -2.03125 3.390625q-1.46875 1.15625 -3.59375 1.15625q-2.203125 0 -3.578125 -0.890625q-1.375 -0.90625 -2.09375 -2.59375q-0.71875 -1.703125 -0.71875 -3.65625q0 -2.125 0.796875 -3.703125q0.8125 -1.578125 2.3125 -2.390625q1.5 -0.828125 3.296875 -0.828125q2.046875 0 3.4375 1.046875q1.390625 1.03125 1.9375 2.90625l-1.765625 0.421875q-0.46875 -1.484375 -1.375 -2.15625q-0.90625 -0.6875 -2.265625 -0.6875q-1.5625 0 -2.625 0.75q-1.046875 0.75 -1.484375 2.03125q-0.421875 1.265625 -0.421875 2.609375q0 1.734375 0.5 3.03125q0.515625 1.28125 1.578125 1.921875q1.078125 0.640625 2.3125 0.640625q1.515625 0 2.5625 -0.859375q1.046875 -0.875 1.421875 -2.59375zm2.9260864 -0.15625q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.281982 4.921875l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm10.375702 0l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm17.125702 -3.171875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.547607 2.265625l1.640625 0.21875q-0.265625 1.6875 -1.375 2.65625q-1.109375 0.953125 -2.734375 0.953125q-2.015625 0 -3.25 -1.3125q-1.21875 -1.328125 -1.21875 -3.796875q0 -1.59375 0.515625 -2.78125q0.53125 -1.203125 1.609375 -1.796875q1.09375 -0.609375 2.359375 -0.609375q1.609375 0 2.625 0.8125q1.015625 0.8125 1.3125 2.3125l-1.625 0.25q-0.234375 -1.0 -0.828125 -1.5q-0.59375 -0.5 -1.421875 -0.5q-1.265625 0 -2.0625 0.90625q-0.78125 0.90625 -0.78125 2.859375q0 1.984375 0.765625 2.890625q0.765625 0.890625 1.984375 0.890625q0.984375 0 1.640625 -0.59375q0.65625 -0.609375 0.84375 -1.859375zm6.546875 2.109375l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm8.277039 -1.671875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.500732 5.875l0 -1.25q-0.9375 1.46875 -2.75 1.46875q-1.171875 0 -2.171875 -0.640625q-0.984375 -0.65625 -1.53125 -1.8125q-0.53125 -1.171875 -0.53125 -2.6875q0 -1.46875 0.484375 -2.671875q0.5 -1.203125 1.46875 -1.84375q0.984375 -0.640625 2.203125 -0.640625q0.890625 0 1.578125 0.375q0.703125 0.375 1.140625 0.984375l0 -4.875l1.65625 0l0 13.59375l-1.546875 0zm-5.28125 -4.921875q0 1.890625 0.796875 2.828125q0.8125 0.9375 1.890625 0.9375q1.09375 0 1.859375 -0.890625q0.765625 -0.890625 0.765625 -2.734375q0 -2.015625 -0.78125 -2.953125q-0.78125 -0.953125 -1.921875 -0.953125q-1.109375 0 -1.859375 0.90625q-0.75 0.90625 -0.75 2.859375z" fill-rule="nonzero"></path><path fill="#000000" d="m321.0703 240.42q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm8.853302 -4.0l-1.671875 0l0 -10.640625q-0.59375 0.578125 -1.578125 1.15625q-0.984375 0.5625 -1.765625 0.859375l0 -1.625q1.40625 -0.65625 2.453125 -1.59375q1.046875 -0.9375 1.484375 -1.8125l1.078125 0l0 13.65625zm12.860107 -1.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm7.891327 1.609375l-1.671875 0l0 -10.640625q-0.59375 0.578125 -1.578125 1.15625q-0.984375 0.5625 -1.765625 0.859375l0 -1.625q1.40625 -0.65625 2.453125 -1.59375q1.046875 -0.9375 1.484375 -1.8125l1.078125 0l0 13.65625zm12.750732 -10.265625l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm13.059021 4.40625l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.984375 0l-3.390625 4.640625l3.65625 5.21875l-2.046875 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm15.152039 -3.5625l1.765625 -0.15625q0.1875 1.28125 0.890625 1.9375q0.71875 0.640625 1.71875 0.640625q1.203125 0 2.03125 -0.90625q0.84375 -0.90625 0.84375 -2.421875q0 -1.421875 -0.8125 -2.25q-0.796875 -0.828125 -2.09375 -0.828125q-0.796875 0 -1.453125 0.375q-0.640625 0.359375 -1.015625 0.953125l-1.578125 -0.203125l1.328125 -7.0l6.765625 0l0 1.609375l-5.4375 0l-0.734375 3.640625q1.234375 -0.84375 2.578125 -0.84375q1.78125 0 3.0 1.234375q1.234375 1.234375 1.234375 3.171875q0 1.84375 -1.078125 3.1875q-1.3125 1.65625 -3.578125 1.65625q-1.859375 0 -3.03125 -1.03125q-1.171875 -1.046875 -1.34375 -2.765625zm10.500732 -8.25l0 -1.609375l8.796875 0l0 1.296875q-1.296875 1.375 -2.578125 3.671875q-1.265625 2.296875 -1.96875 4.71875q-0.5 1.703125 -0.640625 3.734375l-1.71875 0q0.03125 -1.609375 0.625 -3.875q0.609375 -2.28125 1.734375 -4.390625q1.140625 -2.109375 2.40625 -3.546875l-6.65625 0zm12.828827 4.4375q-1.046875 -0.375 -1.546875 -1.078125q-0.5 -0.71875 -0.5 -1.703125q0 -1.484375 1.0625 -2.484375q1.078125 -1.015625 2.84375 -1.015625q1.78125 0 2.859375 1.03125q1.09375 1.03125 1.09375 2.515625q0 0.953125 -0.5 1.65625q-0.484375 0.703125 -1.5 1.078125q1.25 0.40625 1.90625 1.3125q0.65625 0.90625 0.65625 2.171875q0 1.75 -1.234375 2.9375q-1.234375 1.1875 -3.25 1.1875q-2.015625 0 -3.25 -1.1875q-1.234375 -1.203125 -1.234375 -2.984375q0 -1.328125 0.671875 -2.21875q0.671875 -0.890625 1.921875 -1.21875zm-0.328125 -2.828125q0 0.96875 0.609375 1.578125q0.625 0.609375 1.625 0.609375q0.953125 0 1.5625 -0.609375q0.625 -0.609375 0.625 -1.484375q0 -0.921875 -0.640625 -1.546875q-0.625 -0.625 -1.578125 -0.625q-0.953125 0 -1.578125 0.609375q-0.625 0.609375 -0.625 1.46875zm-0.546875 6.28125q0 0.71875 0.328125 1.390625q0.34375 0.65625 1.015625 1.03125q0.671875 0.359375 1.4375 0.359375q1.203125 0 1.984375 -0.765625q0.78125 -0.78125 0.78125 -1.96875q0 -1.203125 -0.8125 -1.984375q-0.796875 -0.796875 -2.0 -0.796875q-1.1875 0 -1.96875 0.78125q-0.765625 0.78125 -0.765625 1.953125zm10.235107 7.921875l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m294.0 262.00262l150.01575 0l0 58.992126l-150.01575 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m294.0 262.00262l150.01575 0l0 58.992126l-150.01575 0z" fill-rule="evenodd"></path><path fill="#000000" d="m305.7563 288.92264l0 -13.59375l9.171875 0l0 1.59375l-7.375 0l0 4.21875l6.375 0l0 1.609375l-6.375 0l0 6.171875l-1.796875 0zm17.536621 0l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm3.8913574 0l0 -13.59375l1.671875 0l0 13.59375l-1.671875 0zm4.144806 0l0 -13.59375l1.671875 0l0 13.59375l-1.671875 0zm4.097931 3.796875l-0.171875 -1.5625q0.546875 0.140625 0.953125 0.140625q0.546875 0 0.875 -0.1875q0.34375 -0.1875 0.5625 -0.515625q0.15625 -0.25 0.5 -1.25q0.046875 -0.140625 0.15625 -0.40625l-3.734375 -9.875l1.796875 0l2.046875 5.71875q0.40625 1.078125 0.71875 2.28125q0.28125 -1.15625 0.6875 -2.25l2.09375 -5.75l1.671875 0l-3.75 10.03125q-0.59375 1.625 -0.9375 2.234375q-0.4375 0.828125 -1.015625 1.203125q-0.578125 0.390625 -1.375 0.390625q-0.484375 0 -1.078125 -0.203125zm19.328125 -8.5625l1.796875 0.453125q-0.5625 2.21875 -2.03125 3.390625q-1.46875 1.15625 -3.59375 1.15625q-2.203125 0 -3.578125 -0.890625q-1.375 -0.90625 -2.09375 -2.59375q-0.71875 -1.703125 -0.71875 -3.65625q0 -2.125 0.796875 -3.703125q0.8125 -1.578125 2.3125 -2.390625q1.5 -0.828125 3.296875 -0.828125q2.046875 0 3.4375 1.046875q1.390625 1.03125 1.9375 2.90625l-1.765625 0.421875q-0.46875 -1.484375 -1.375 -2.15625q-0.90625 -0.6875 -2.265625 -0.6875q-1.5625 0 -2.625 0.75q-1.046875 0.75 -1.484375 2.03125q-0.421875 1.265625 -0.421875 2.609375q0 1.734375 0.5 3.03125q0.515625 1.28125 1.578125 1.921875q1.078125 0.640625 2.3125 0.640625q1.515625 0 2.5625 -0.859375q1.046875 -0.875 1.421875 -2.59375zm2.9260864 -0.15625q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.281982 4.921875l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm10.375702 0l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm17.125702 -3.171875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.547607 2.265625l1.640625 0.21875q-0.265625 1.6875 -1.375 2.65625q-1.109375 0.953125 -2.734375 0.953125q-2.015625 0 -3.25 -1.3125q-1.21875 -1.328125 -1.21875 -3.796875q0 -1.59375 0.515625 -2.78125q0.53125 -1.203125 1.609375 -1.796875q1.09375 -0.609375 2.359375 -0.609375q1.609375 0 2.625 0.8125q1.015625 0.8125 1.3125 2.3125l-1.625 0.25q-0.234375 -1.0 -0.828125 -1.5q-0.59375 -0.5 -1.421875 -0.5q-1.265625 0 -2.0625 0.90625q-0.78125 0.90625 -0.78125 2.859375q0 1.984375 0.765625 2.890625q0.765625 0.890625 1.984375 0.890625q0.984375 0 1.640625 -0.59375q0.65625 -0.609375 0.84375 -1.859375zm6.546875 2.109375l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm8.277039 -1.671875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.500732 5.875l0 -1.25q-0.9375 1.46875 -2.75 1.46875q-1.171875 0 -2.171875 -0.640625q-0.984375 -0.65625 -1.53125 -1.8125q-0.53125 -1.171875 -0.53125 -2.6875q0 -1.46875 0.484375 -2.671875q0.5 -1.203125 1.46875 -1.84375q0.984375 -0.640625 2.203125 -0.640625q0.890625 0 1.578125 0.375q0.703125 0.375 1.140625 0.984375l0 -4.875l1.65625 0l0 13.59375l-1.546875 0zm-5.28125 -4.921875q0 1.890625 0.796875 2.828125q0.8125 0.9375 1.890625 0.9375q1.09375 0 1.859375 -0.890625q0.765625 -0.890625 0.765625 -2.734375q0 -2.015625 -0.78125 -2.953125q-0.78125 -0.953125 -1.921875 -0.953125q-1.109375 0 -1.859375 0.90625q-0.75 0.90625 -0.75 2.859375z" fill-rule="nonzero"></path><path fill="#000000" d="m326.25818 314.92264q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm2.556427 -7.5625l1.765625 -0.15625q0.1875 1.28125 0.890625 1.9375q0.71875 0.640625 1.71875 0.640625q1.203125 0 2.03125 -0.90625q0.84375 -0.90625 0.84375 -2.421875q0 -1.421875 -0.8125 -2.25q-0.796875 -0.828125 -2.09375 -0.828125q-0.796875 0 -1.453125 0.375q-0.640625 0.359375 -1.015625 0.953125l-1.578125 -0.203125l1.328125 -7.0l6.765625 0l0 1.609375l-5.4375 0l-0.734375 3.640625q1.234375 -0.84375 2.578125 -0.84375q1.78125 0 3.0 1.234375q1.234375 1.234375 1.234375 3.171875q0 1.84375 -1.078125 3.1875q-1.3125 1.65625 -3.578125 1.65625q-1.859375 0 -3.03125 -1.03125q-1.171875 -1.046875 -1.34375 -2.765625zm10.500702 -8.25l0 -1.609375l8.796875 0l0 1.296875q-1.296875 1.375 -2.578125 3.671875q-1.265625 2.296875 -1.96875 4.71875q-0.5 1.703125 -0.640625 3.734375l-1.71875 0q0.03125 -1.609375 0.625 -3.875q0.609375 -2.28125 1.734375 -4.390625q1.140625 -2.109375 2.40625 -3.546875l-6.65625 0zm12.828857 4.4375q-1.046875 -0.375 -1.546875 -1.078125q-0.5 -0.71875 -0.5 -1.703125q0 -1.484375 1.0625 -2.484375q1.078125 -1.015625 2.84375 -1.015625q1.78125 0 2.859375 1.03125q1.09375 1.03125 1.09375 2.515625q0 0.953125 -0.5 1.65625q-0.484375 0.703125 -1.5 1.078125q1.25 0.40625 1.90625 1.3125q0.65625 0.90625 0.65625 2.171875q0 1.75 -1.234375 2.9375q-1.234375 1.1875 -3.25 1.1875q-2.015625 0 -3.25 -1.1875q-1.234375 -1.203125 -1.234375 -2.984375q0 -1.328125 0.671875 -2.21875q0.671875 -0.890625 1.921875 -1.21875zm-0.328125 -2.828125q0 0.96875 0.609375 1.578125q0.625 0.609375 1.625 0.609375q0.953125 0 1.5625 -0.609375q0.625 -0.609375 0.625 -1.484375q0 -0.921875 -0.640625 -1.546875q-0.625 -0.625 -1.578125 -0.625q-0.953125 0 -1.578125 0.609375q-0.625 0.609375 -0.625 1.46875zm-0.546875 6.28125q0 0.71875 0.328125 1.390625q0.34375 0.65625 1.015625 1.03125q0.671875 0.359375 1.4375 0.359375q1.203125 0 1.984375 -0.765625q0.78125 -0.78125 0.78125 -1.96875q0 -1.203125 -0.8125 -1.984375q-0.796875 -0.796875 -2.0 -0.796875q-1.1875 0 -1.96875 0.78125q-0.765625 0.78125 -0.765625 1.953125zm13.215271 3.921875l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.984375 0l-3.390625 4.640625l3.65625 5.21875l-2.046875 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm23.933289 -1.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm1.5944824 -1.953125l1.765625 -0.15625q0.1875 1.28125 0.890625 1.9375q0.71875 0.640625 1.71875 0.640625q1.203125 0 2.03125 -0.90625q0.84375 -0.90625 0.84375 -2.421875q0 -1.421875 -0.8125 -2.25q-0.796875 -0.828125 -2.09375 -0.828125q-0.796875 0 -1.453125 0.375q-0.640625 0.359375 -1.015625 0.953125l-1.578125 -0.203125l1.328125 -7.0l6.765625 0l0 1.609375l-5.4375 0l-0.734375 3.640625q1.234375 -0.84375 2.578125 -0.84375q1.78125 0 3.0 1.234375q1.234375 1.234375 1.234375 3.171875q0 1.84375 -1.078125 3.1875q-1.3125 1.65625 -3.578125 1.65625q-1.859375 0 -3.03125 -1.03125q-1.171875 -1.046875 -1.34375 -2.765625zm19.047607 -6.703125l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm10.078827 8.40625l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m294.0 336.50394l150.01575 0l0 58.992126l-150.01575 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m294.0 336.50394l150.01575 0l0 58.992126l-150.01575 0z" fill-rule="evenodd"></path><path fill="#000000" d="m305.7563 363.42395l0 -13.59375l9.171875 0l0 1.59375l-7.375 0l0 4.21875l6.375 0l0 1.609375l-6.375 0l0 6.171875l-1.796875 0zm17.536621 0l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm3.8913574 0l0 -13.59375l1.671875 0l0 13.59375l-1.671875 0zm4.144806 0l0 -13.59375l1.671875 0l0 13.59375l-1.671875 0zm4.097931 3.796875l-0.171875 -1.5625q0.546875 0.140625 0.953125 0.140625q0.546875 0 0.875 -0.1875q0.34375 -0.1875 0.5625 -0.515625q0.15625 -0.25 0.5 -1.25q0.046875 -0.140625 0.15625 -0.40625l-3.734375 -9.875l1.796875 0l2.046875 5.71875q0.40625 1.078125 0.71875 2.28125q0.28125 -1.15625 0.6875 -2.25l2.09375 -5.75l1.671875 0l-3.75 10.03125q-0.59375 1.625 -0.9375 2.234375q-0.4375 0.828125 -1.015625 1.203125q-0.578125 0.390625 -1.375 0.390625q-0.484375 0 -1.078125 -0.203125zm19.328125 -8.5625l1.796875 0.453125q-0.5625 2.21875 -2.03125 3.390625q-1.46875 1.15625 -3.59375 1.15625q-2.203125 0 -3.578125 -0.890625q-1.375 -0.90625 -2.09375 -2.59375q-0.71875 -1.703125 -0.71875 -3.65625q0 -2.125 0.796875 -3.703125q0.8125 -1.578125 2.3125 -2.390625q1.5 -0.828125 3.296875 -0.828125q2.046875 0 3.4375 1.046875q1.390625 1.03125 1.9375 2.90625l-1.765625 0.421875q-0.46875 -1.484375 -1.375 -2.15625q-0.90625 -0.6875 -2.265625 -0.6875q-1.5625 0 -2.625 0.75q-1.046875 0.75 -1.484375 2.03125q-0.421875 1.265625 -0.421875 2.609375q0 1.734375 0.5 3.03125q0.515625 1.28125 1.578125 1.921875q1.078125 0.640625 2.3125 0.640625q1.515625 0 2.5625 -0.859375q1.046875 -0.875 1.421875 -2.59375zm2.9260864 -0.15625q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.281982 4.921875l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm10.375702 0l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm17.125702 -3.171875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.547607 2.265625l1.640625 0.21875q-0.265625 1.6875 -1.375 2.65625q-1.109375 0.953125 -2.734375 0.953125q-2.015625 0 -3.25 -1.3125q-1.21875 -1.328125 -1.21875 -3.796875q0 -1.59375 0.515625 -2.78125q0.53125 -1.203125 1.609375 -1.796875q1.09375 -0.609375 2.359375 -0.609375q1.609375 0 2.625 0.8125q1.015625 0.8125 1.3125 2.3125l-1.625 0.25q-0.234375 -1.0 -0.828125 -1.5q-0.59375 -0.5 -1.421875 -0.5q-1.265625 0 -2.0625 0.90625q-0.78125 0.90625 -0.78125 2.859375q0 1.984375 0.765625 2.890625q0.765625 0.890625 1.984375 0.890625q0.984375 0 1.640625 -0.59375q0.65625 -0.609375 0.84375 -1.859375zm6.546875 2.109375l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm8.277039 -1.671875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.500732 5.875l0 -1.25q-0.9375 1.46875 -2.75 1.46875q-1.171875 0 -2.171875 -0.640625q-0.984375 -0.65625 -1.53125 -1.8125q-0.53125 -1.171875 -0.53125 -2.6875q0 -1.46875 0.484375 -2.671875q0.5 -1.203125 1.46875 -1.84375q0.984375 -0.640625 2.203125 -0.640625q0.890625 0 1.578125 0.375q0.703125 0.375 1.140625 0.984375l0 -4.875l1.65625 0l0 13.59375l-1.546875 0zm-5.28125 -4.921875q0 1.890625 0.796875 2.828125q0.8125 0.9375 1.890625 0.9375q1.09375 0 1.859375 -0.890625q0.765625 -0.890625 0.765625 -2.734375q0 -2.015625 -0.78125 -2.953125q-0.78125 -0.953125 -1.921875 -0.953125q-1.109375 0 -1.859375 0.90625q-0.75 0.90625 -0.75 2.859375z" fill-rule="nonzero"></path><path fill="#000000" d="m326.25818 389.42395q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm11.337677 -5.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm1.5944519 -1.953125l1.765625 -0.15625q0.1875 1.28125 0.890625 1.9375q0.71875 0.640625 1.71875 0.640625q1.203125 0 2.03125 -0.90625q0.84375 -0.90625 0.84375 -2.421875q0 -1.421875 -0.8125 -2.25q-0.796875 -0.828125 -2.09375 -0.828125q-0.796875 0 -1.453125 0.375q-0.640625 0.359375 -1.015625 0.953125l-1.578125 -0.203125l1.328125 -7.0l6.765625 0l0 1.609375l-5.4375 0l-0.734375 3.640625q1.234375 -0.84375 2.578125 -0.84375q1.78125 0 3.0 1.234375q1.234375 1.234375 1.234375 3.171875q0 1.84375 -1.078125 3.1875q-1.3125 1.65625 -3.578125 1.65625q-1.859375 0 -3.03125 -1.03125q-1.171875 -1.046875 -1.34375 -2.765625zm19.047607 -6.703125l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm13.059021 4.40625l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.984375 0l-3.390625 4.640625l3.65625 5.21875l-2.046875 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm23.933289 -1.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm1.5944824 -1.953125l1.765625 -0.15625q0.1875 1.28125 0.890625 1.9375q0.71875 0.640625 1.71875 0.640625q1.203125 0 2.03125 -0.90625q0.84375 -0.90625 0.84375 -2.421875q0 -1.421875 -0.8125 -2.25q-0.796875 -0.828125 -2.09375 -0.828125q-0.796875 0 -1.453125 0.375q-0.640625 0.359375 -1.015625 0.953125l-1.578125 -0.203125l1.328125 -7.0l6.765625 0l0 1.609375l-5.4375 0l-0.734375 3.640625q1.234375 -0.84375 2.578125 -0.84375q1.78125 0 3.0 1.234375q1.234375 1.234375 1.234375 3.171875q0 1.84375 -1.078125 3.1875q-1.3125 1.65625 -3.578125 1.65625q-1.859375 0 -3.03125 -1.03125q-1.171875 -1.046875 -1.34375 -2.765625zm19.047607 -6.703125l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm10.078827 8.40625l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m369.00787 46.984253l0 23.244095" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m369.00787 46.984253l0 17.244095" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m367.35614 64.22835l1.6517334 4.5380936l1.6517334 -4.5380936z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m369.00787 163.02231l0 24.472443" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m369.00787 163.02231l0 18.472443" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m367.35614 181.49475l1.6517334 4.538101l1.6517334 -4.538101z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m369.00787 246.50656l0 15.496063" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m369.00787 246.50656l0 9.496063" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m367.35614 256.00262l1.6517334 4.538086l1.6517334 -4.538086z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m369.00787 320.99475l0 15.496063" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m369.00787 320.99475l0 9.496063" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m367.35614 330.4908l1.6517334 4.538086l1.6517334 -4.538086z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m369.00787 395.49606l0 15.496063" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m369.00787 395.49606l0 9.496063" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m367.35614 404.99213l1.6517334 4.538086l1.6517334 -4.538086z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m369.00787 470.0105l0 15.496063" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m369.00787 470.0105l0 9.496063" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m367.35614 479.50656l1.6517334 4.538086l1.6517334 -4.538086z" fill-rule="evenodd"></path></g></svg>
+
diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/speakerid.svg b/tensorflow/contrib/lite/models/testdata/g3doc/speakerid.svg
new file mode 100755
index 0000000000..dbe4312c46
--- /dev/null
+++ b/tensorflow/contrib/lite/models/testdata/g3doc/speakerid.svg
@@ -0,0 +1,4 @@
+<?xml version="1.0" standalone="yes"?>
+
+<svg version="1.1" viewBox="0.0 0.0 720.0 540.0" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"><clipPath id="p.0"><path d="m0 0l720.0 0l0 540.0l-720.0 0l0 -540.0z" clip-rule="nonzero"></path></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l720.0 0l0 540.0l-720.0 0z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m287.0 39.0l166.01575 0l0 41.984253l-166.01575 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m287.0 39.0l166.01575 0l0 41.984253l-166.01575 0z" fill-rule="evenodd"></path><path fill="#000000" d="m304.62738 65.92l0 -13.59375l1.8125 0l0 13.59375l-1.8125 0zm4.667694 0l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm10.375702 3.78125l0 -13.640625l1.53125 0l0 1.28125q0.53125 -0.75 1.203125 -1.125q0.6875 -0.375 1.640625 -0.375q1.265625 0 2.234375 0.65625q0.96875 0.640625 1.453125 1.828125q0.5 1.1875 0.5 2.59375q0 1.515625 -0.546875 2.734375q-0.546875 1.203125 -1.578125 1.84375q-1.03125 0.640625 -2.171875 0.640625q-0.84375 0 -1.515625 -0.34375q-0.65625 -0.359375 -1.078125 -0.890625l0 4.796875l-1.671875 0zm1.515625 -8.65625q0 1.90625 0.765625 2.8125q0.78125 0.90625 1.875 0.90625q1.109375 0 1.890625 -0.9375q0.796875 -0.9375 0.796875 -2.921875q0 -1.875 -0.78125 -2.8125q-0.765625 -0.9375 -1.84375 -0.9375q-1.0625 0 -1.890625 1.0q-0.8125 1.0 -0.8125 2.890625zm15.313232 4.875l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm7.5788574 -1.5l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm9.897827 5.5q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm2.3533325 -6.9375l1.65625 -0.265625q0.140625 1.0 0.765625 1.53125q0.640625 0.515625 1.78125 0.515625q1.15625 0 1.703125 -0.46875q0.5625 -0.46875 0.5625 -1.09375q0 -0.5625 -0.484375 -0.890625q-0.34375 -0.21875 -1.703125 -0.5625q-1.84375 -0.46875 -2.5625 -0.796875q-0.703125 -0.34375 -1.078125 -0.9375q-0.359375 -0.609375 -0.359375 -1.328125q0 -0.65625 0.296875 -1.21875q0.3125 -0.5625 0.828125 -0.9375q0.390625 -0.28125 1.0625 -0.484375q0.671875 -0.203125 1.4375 -0.203125q1.171875 0 2.046875 0.34375q0.875 0.328125 1.28125 0.90625q0.421875 0.5625 0.578125 1.515625l-1.625 0.21875q-0.109375 -0.75 -0.65625 -1.171875q-0.53125 -0.4375 -1.5 -0.4375q-1.15625 0 -1.640625 0.390625q-0.484375 0.375 -0.484375 0.875q0 0.328125 0.203125 0.59375q0.203125 0.265625 0.640625 0.4375q0.25 0.09375 1.46875 0.4375q1.765625 0.46875 2.46875 0.765625q0.703125 0.296875 1.09375 0.875q0.40625 0.578125 0.40625 1.4375q0 0.828125 -0.484375 1.578125q-0.484375 0.734375 -1.40625 1.140625q-0.921875 0.390625 -2.078125 0.390625q-1.921875 0 -2.9375 -0.796875q-1.0 -0.796875 -1.28125 -2.359375zm10.015625 -8.75l0 -1.90625l1.671875 0l0 1.90625l-1.671875 0zm0 11.6875l0 -9.859375l1.671875 0l0 9.859375l-1.671875 0zm3.254181 0l0 -1.359375l6.265625 -7.1875q-1.0625 0.046875 -1.875 0.046875l-4.015625 0l0 -1.359375l8.046875 0l0 1.109375l-5.34375 6.25l-1.015625 1.140625q1.109375 -0.078125 2.09375 -0.078125l4.5625 0l0 1.4375l-8.71875 0zm16.953125 -3.171875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm23.074646 -2.125l-8.96875 0l0 -1.5625l8.96875 0l0 1.5625zm0 4.125l-8.96875 0l0 -1.546875l8.96875 0l0 1.546875zm9.406403 -3.5q-1.046875 -0.375 -1.546875 -1.078125q-0.5 -0.71875 -0.5 -1.703125q0 -1.484375 1.0625 -2.484375q1.078125 -1.015625 2.84375 -1.015625q1.78125 0 2.859375 1.03125q1.09375 1.03125 1.09375 2.515625q0 0.953125 -0.5 1.65625q-0.484375 0.703125 -1.5 1.078125q1.25 0.40625 1.90625 1.3125q0.65625 0.90625 0.65625 2.171875q0 1.75 -1.234375 2.9375q-1.234375 1.1875 -3.25 1.1875q-2.015625 0 -3.25 -1.1875q-1.234375 -1.203125 -1.234375 -2.984375q0 -1.328125 0.671875 -2.21875q0.671875 -0.890625 1.921875 -1.21875zm-0.328125 -2.828125q0 0.96875 0.609375 1.578125q0.625 0.609375 1.625 0.609375q0.953125 0 1.5625 -0.609375q0.625 -0.609375 0.625 -1.484375q0 -0.921875 -0.640625 -1.546875q-0.625 -0.625 -1.578125 -0.625q-0.953125 0 -1.578125 0.609375q-0.625 0.609375 -0.625 1.46875zm-0.546875 6.28125q0 0.71875 0.328125 1.390625q0.34375 0.65625 1.015625 1.03125q0.671875 0.359375 1.4375 0.359375q1.203125 0 1.984375 -0.765625q0.78125 -0.78125 0.78125 -1.96875q0 -1.203125 -0.8125 -1.984375q-0.796875 -0.796875 -2.0 -0.796875q-1.1875 0 -1.96875 0.78125q-0.765625 0.78125 -0.765625 1.953125zm8.672577 -2.78125q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm10.219452 10.703125l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m295.0 111.0l150.01575 0l0 41.984253l-150.01575 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m295.0 111.0l150.01575 0l0 41.984253l-150.01575 0z" fill-rule="evenodd"></path><path fill="#000000" d="m307.1128 137.92l0 -13.59375l1.796875 0l0 11.984375l6.703125 0l0 1.609375l-8.5 0zm9.844452 -4.375l1.6875 -0.140625q0.125 1.015625 0.5625 1.671875q0.4375 0.65625 1.359375 1.0625q0.9375 0.40625 2.09375 0.40625q1.03125 0 1.8125 -0.3125q0.796875 -0.3125 1.1875 -0.84375q0.390625 -0.53125 0.390625 -1.15625q0 -0.640625 -0.375 -1.109375q-0.375 -0.484375 -1.234375 -0.8125q-0.546875 -0.21875 -2.421875 -0.65625q-1.875 -0.453125 -2.625 -0.859375q-0.96875 -0.515625 -1.453125 -1.265625q-0.46875 -0.75 -0.46875 -1.6875q0 -1.03125 0.578125 -1.921875q0.59375 -0.90625 1.703125 -1.359375q1.125 -0.46875 2.5 -0.46875q1.515625 0 2.671875 0.484375q1.15625 0.484375 1.765625 1.4375q0.625 0.9375 0.671875 2.140625l-1.71875 0.125q-0.140625 -1.28125 -0.953125 -1.9375q-0.796875 -0.671875 -2.359375 -0.671875q-1.625 0 -2.375 0.609375q-0.75 0.59375 -0.75 1.4375q0 0.734375 0.53125 1.203125q0.515625 0.46875 2.703125 0.96875q2.203125 0.5 3.015625 0.875q1.1875 0.546875 1.75 1.390625q0.578125 0.828125 0.578125 1.921875q0 1.09375 -0.625 2.0625q-0.625 0.953125 -1.796875 1.484375q-1.15625 0.53125 -2.609375 0.53125q-1.84375 0 -3.09375 -0.53125q-1.25 -0.546875 -1.96875 -1.625q-0.703125 -1.078125 -0.734375 -2.453125zm16.506073 4.375l0 -12.0l-4.46875 0l0 -1.59375l10.765625 0l0 1.59375l-4.5 0l0 12.0l-1.796875 0zm7.880371 0l0 -13.59375l2.71875 0l3.21875 9.625q0.4375 1.34375 0.640625 2.015625q0.234375 -0.75 0.734375 -2.1875l3.25 -9.453125l2.421875 0l0 13.59375l-1.734375 0l0 -11.390625l-3.953125 11.390625l-1.625 0l-3.9375 -11.578125l0 11.578125l-1.734375 0zm21.212646 0l-1.671875 0l0 -10.640625q-0.59375 0.578125 -1.578125 1.15625q-0.984375 0.5625 -1.765625 0.859375l0 -1.625q1.40625 -0.65625 2.453125 -1.59375q1.046875 -0.9375 1.484375 -1.8125l1.078125 0l0 13.65625zm12.918396 4.0q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm5.1345825 -11.375q-1.046875 -0.375 -1.546875 -1.078125q-0.5 -0.71875 -0.5 -1.703125q0 -1.484375 1.0625 -2.484375q1.078125 -1.015625 2.84375 -1.015625q1.78125 0 2.859375 1.03125q1.09375 1.03125 1.09375 2.515625q0 0.953125 -0.5 1.65625q-0.484375 0.703125 -1.5 1.078125q1.25 0.40625 1.90625 1.3125q0.65625 0.90625 0.65625 2.171875q0 1.75 -1.234375 2.9375q-1.234375 1.1875 -3.25 1.1875q-2.015625 0 -3.25 -1.1875q-1.234375 -1.203125 -1.234375 -2.984375q0 -1.328125 0.671875 -2.21875q0.671875 -0.890625 1.921875 -1.21875zm-0.328125 -2.828125q0 0.96875 0.609375 1.578125q0.625 0.609375 1.625 0.609375q0.953125 0 1.5625 -0.609375q0.625 -0.609375 0.625 -1.484375q0 -0.921875 -0.640625 -1.546875q-0.625 -0.625 -1.578125 -0.625q-0.953125 0 -1.578125 0.609375q-0.625 0.609375 -0.625 1.46875zm-0.546875 6.28125q0 0.71875 0.328125 1.390625q0.34375 0.65625 1.015625 1.03125q0.671875 0.359375 1.4375 0.359375q1.203125 0 1.984375 -0.765625q0.78125 -0.78125 0.78125 -1.96875q0 -1.203125 -0.8125 -1.984375q-0.796875 -0.796875 -2.0 -0.796875q-1.1875 0 -1.96875 0.78125q-0.765625 0.78125 -0.765625 1.953125zm8.672577 -2.78125q0 -2.421875 0.5 -3.890625q0.5 -1.46875 1.46875 -2.265625q0.984375 -0.796875 2.46875 -0.796875q1.09375 0 1.921875 0.4375q0.828125 0.4375 1.359375 1.28125q0.546875 0.828125 0.84375 2.015625q0.3125 1.1875 0.3125 3.21875q0 2.390625 -0.5 3.859375q-0.484375 1.46875 -1.46875 2.28125q-0.96875 0.796875 -2.46875 0.796875q-1.96875 0 -3.078125 -1.40625q-1.359375 -1.703125 -1.359375 -5.53125zm1.71875 0q0 3.34375 0.78125 4.453125q0.796875 1.109375 1.9375 1.109375q1.15625 0 1.9375 -1.109375q0.78125 -1.125 0.78125 -4.453125q0 -3.359375 -0.78125 -4.46875q-0.78125 -1.109375 -1.953125 -1.109375q-1.15625 0 -1.828125 0.984375q-0.875 1.234375 -0.875 4.59375zm8.016327 6.703125l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.984375 0l-3.390625 4.640625l3.65625 5.21875l-2.046875 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm18.640625 -10.265625l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm13.875732 4.40625l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm6.5788574 8.78125l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m295.0 183.0l150.01575 0l0 41.984253l-150.01575 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m295.0 183.0l150.01575 0l0 41.984253l-150.01575 0z" fill-rule="evenodd"></path><path fill="#000000" d="m307.1128 209.92l0 -13.59375l1.796875 0l0 11.984375l6.703125 0l0 1.609375l-8.5 0zm9.844452 -4.375l1.6875 -0.140625q0.125 1.015625 0.5625 1.671875q0.4375 0.65625 1.359375 1.0625q0.9375 0.40625 2.09375 0.40625q1.03125 0 1.8125 -0.3125q0.796875 -0.3125 1.1875 -0.84375q0.390625 -0.53125 0.390625 -1.15625q0 -0.640625 -0.375 -1.109375q-0.375 -0.484375 -1.234375 -0.8125q-0.546875 -0.21875 -2.421875 -0.65625q-1.875 -0.453125 -2.625 -0.859375q-0.96875 -0.515625 -1.453125 -1.265625q-0.46875 -0.75 -0.46875 -1.6875q0 -1.03125 0.578125 -1.921875q0.59375 -0.90625 1.703125 -1.359375q1.125 -0.46875 2.5 -0.46875q1.515625 0 2.671875 0.484375q1.15625 0.484375 1.765625 1.4375q0.625 0.9375 0.671875 2.140625l-1.71875 0.125q-0.140625 -1.28125 -0.953125 -1.9375q-0.796875 -0.671875 -2.359375 -0.671875q-1.625 0 -2.375 0.609375q-0.75 0.59375 -0.75 1.4375q0 0.734375 0.53125 1.203125q0.515625 0.46875 2.703125 0.96875q2.203125 0.5 3.015625 0.875q1.1875 0.546875 1.75 1.390625q0.578125 0.828125 0.578125 1.921875q0 1.09375 -0.625 2.0625q-0.625 0.953125 -1.796875 1.484375q-1.15625 0.53125 -2.609375 0.53125q-1.84375 0 -3.09375 -0.53125q-1.25 -0.546875 -1.96875 -1.625q-0.703125 -1.078125 -0.734375 -2.453125zm16.506073 4.375l0 -12.0l-4.46875 0l0 -1.59375l10.765625 0l0 1.59375l-4.5 0l0 12.0l-1.796875 0zm7.880371 0l0 -13.59375l2.71875 0l3.21875 9.625q0.4375 1.34375 0.640625 2.015625q0.234375 -0.75 0.734375 -2.1875l3.25 -9.453125l2.421875 0l0 13.59375l-1.734375 0l0 -11.390625l-3.953125 11.390625l-1.625 0l-3.9375 -11.578125l0 11.578125l-1.734375 0zm23.697021 -1.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm10.434021 5.609375q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm11.2283325 -14.265625l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm13.875702 4.40625l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm4.375702 4.78125l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.984375 0l-3.390625 4.640625l3.65625 5.21875l-2.046875 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm18.640625 -10.265625l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm13.875732 4.40625l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm6.5788574 8.78125l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m295.0 255.0l150.01575 0l0 41.984253l-150.01575 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m295.0 255.0l150.01575 0l0 41.984253l-150.01575 0z" fill-rule="evenodd"></path><path fill="#000000" d="m307.1128 281.91998l0 -13.59375l1.796875 0l0 11.984375l6.703125 0l0 1.609375l-8.5 0zm9.844452 -4.375l1.6875 -0.140625q0.125 1.015625 0.5625 1.671875q0.4375 0.65625 1.359375 1.0625q0.9375 0.40625 2.09375 0.40625q1.03125 0 1.8125 -0.3125q0.796875 -0.3125 1.1875 -0.84375q0.390625 -0.53125 0.390625 -1.15625q0 -0.640625 -0.375 -1.109375q-0.375 -0.484375 -1.234375 -0.8125q-0.546875 -0.21875 -2.421875 -0.65625q-1.875 -0.453125 -2.625 -0.859375q-0.96875 -0.515625 -1.453125 -1.265625q-0.46875 -0.75 -0.46875 -1.6875q0 -1.03125 0.578125 -1.921875q0.59375 -0.90625 1.703125 -1.359375q1.125 -0.46875 2.5 -0.46875q1.515625 0 2.671875 0.484375q1.15625 0.484375 1.765625 1.4375q0.625 0.9375 0.671875 2.140625l-1.71875 0.125q-0.140625 -1.28125 -0.953125 -1.9375q-0.796875 -0.671875 -2.359375 -0.671875q-1.625 0 -2.375 0.609375q-0.75 0.59375 -0.75 1.4375q0 0.734375 0.53125 1.203125q0.515625 0.46875 2.703125 0.96875q2.203125 0.5 3.015625 0.875q1.1875 0.546875 1.75 1.390625q0.578125 0.828125 0.578125 1.921875q0 1.09375 -0.625 2.0625q-0.625 0.953125 -1.796875 1.484375q-1.15625 0.53125 -2.609375 0.53125q-1.84375 0 -3.09375 -0.53125q-1.25 -0.546875 -1.96875 -1.625q-0.703125 -1.078125 -0.734375 -2.453125zm16.506073 4.375l0 -12.0l-4.46875 0l0 -1.59375l10.765625 0l0 1.59375l-4.5 0l0 12.0l-1.796875 0zm7.880371 0l0 -13.59375l2.71875 0l3.21875 9.625q0.4375 1.34375 0.640625 2.015625q0.234375 -0.75 0.734375 -2.1875l3.25 -9.453125l2.421875 0l0 13.59375l-1.734375 0l0 -11.390625l-3.953125 11.390625l-1.625 0l-3.9375 -11.578125l0 11.578125l-1.734375 0zm14.9313965 -3.59375l1.671875 -0.21875q0.28125 1.421875 0.96875 2.046875q0.703125 0.625 1.6875 0.625q1.1875 0 2.0 -0.8125q0.8125 -0.828125 0.8125 -2.03125q0 -1.140625 -0.765625 -1.890625q-0.75 -0.75 -1.90625 -0.75q-0.46875 0 -1.171875 0.1875l0.1875 -1.46875q0.15625 0.015625 0.265625 0.015625q1.0625 0 1.90625 -0.546875q0.859375 -0.5625 0.859375 -1.71875q0 -0.921875 -0.625 -1.515625q-0.609375 -0.609375 -1.59375 -0.609375q-0.96875 0 -1.625 0.609375q-0.640625 0.609375 -0.828125 1.84375l-1.671875 -0.296875q0.296875 -1.6875 1.375 -2.609375q1.09375 -0.921875 2.71875 -0.921875q1.109375 0 2.046875 0.484375q0.9375 0.46875 1.421875 1.296875q0.5 0.828125 0.5 1.75q0 0.890625 -0.46875 1.609375q-0.46875 0.71875 -1.40625 1.15625q1.21875 0.265625 1.875 1.15625q0.671875 0.875 0.671875 2.1875q0 1.78125 -1.296875 3.015625q-1.296875 1.234375 -3.28125 1.234375q-1.796875 0 -2.984375 -1.0625q-1.171875 -1.0625 -1.34375 -2.765625zm19.199646 7.59375q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm11.2283325 -14.265625l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm13.875702 4.40625l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm4.375702 4.78125l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.984375 0l-3.390625 4.640625l3.65625 5.21875l-2.046875 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm18.640625 -10.265625l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm13.875732 4.40625l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm6.5788574 8.78125l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m295.0 327.0l150.01575 0l0 58.992126l-150.01575 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m295.0 327.0l150.01575 0l0 58.992126l-150.01575 0z" fill-rule="evenodd"></path><path fill="#000000" d="m306.7563 353.91998l0 -13.59375l9.171875 0l0 1.59375l-7.375 0l0 4.21875l6.375 0l0 1.609375l-6.375 0l0 6.171875l-1.796875 0zm17.536621 0l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm3.8913574 0l0 -13.59375l1.671875 0l0 13.59375l-1.671875 0zm4.144806 0l0 -13.59375l1.671875 0l0 13.59375l-1.671875 0zm4.097931 3.796875l-0.171875 -1.5625q0.546875 0.140625 0.953125 0.140625q0.546875 0 0.875 -0.1875q0.34375 -0.1875 0.5625 -0.515625q0.15625 -0.25 0.5 -1.25q0.046875 -0.140625 0.15625 -0.40625l-3.734375 -9.875l1.796875 0l2.046875 5.71875q0.40625 1.078125 0.71875 2.28125q0.28125 -1.15625 0.6875 -2.25l2.09375 -5.75l1.671875 0l-3.75 10.03125q-0.59375 1.625 -0.9375 2.234375q-0.4375 0.828125 -1.015625 1.203125q-0.578125 0.390625 -1.375 0.390625q-0.484375 0 -1.078125 -0.203125zm19.328125 -8.5625l1.796875 0.453125q-0.5625 2.21875 -2.03125 3.390625q-1.46875 1.15625 -3.59375 1.15625q-2.203125 0 -3.578125 -0.890625q-1.375 -0.90625 -2.09375 -2.59375q-0.71875 -1.703125 -0.71875 -3.65625q0 -2.125 0.796875 -3.703125q0.8125 -1.578125 2.3125 -2.390625q1.5 -0.828125 3.296875 -0.828125q2.046875 0 3.4375 1.046875q1.390625 1.03125 1.9375 2.90625l-1.765625 0.421875q-0.46875 -1.484375 -1.375 -2.15625q-0.90625 -0.6875 -2.265625 -0.6875q-1.5625 0 -2.625 0.75q-1.046875 0.75 -1.484375 2.03125q-0.421875 1.265625 -0.421875 2.609375q0 1.734375 0.5 3.03125q0.515625 1.28125 1.578125 1.921875q1.078125 0.640625 2.3125 0.640625q1.515625 0 2.5625 -0.859375q1.046875 -0.875 1.421875 -2.59375zm2.9260864 -0.15625q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.281982 4.921875l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm10.375702 0l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm17.125702 -3.171875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.547607 2.265625l1.640625 0.21875q-0.265625 1.6875 -1.375 2.65625q-1.109375 0.953125 -2.734375 0.953125q-2.015625 0 -3.25 -1.3125q-1.21875 -1.328125 -1.21875 -3.796875q0 -1.59375 0.515625 -2.78125q0.53125 -1.203125 1.609375 -1.796875q1.09375 -0.609375 2.359375 -0.609375q1.609375 0 2.625 0.8125q1.015625 0.8125 1.3125 2.3125l-1.625 0.25q-0.234375 -1.0 -0.828125 -1.5q-0.59375 -0.5 -1.421875 -0.5q-1.265625 0 -2.0625 0.90625q-0.78125 0.90625 -0.78125 2.859375q0 1.984375 0.765625 2.890625q0.765625 0.890625 1.984375 0.890625q0.984375 0 1.640625 -0.59375q0.65625 -0.609375 0.84375 -1.859375zm6.546875 2.109375l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm8.277039 -1.671875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.500732 5.875l0 -1.25q-0.9375 1.46875 -2.75 1.46875q-1.171875 0 -2.171875 -0.640625q-0.984375 -0.65625 -1.53125 -1.8125q-0.53125 -1.171875 -0.53125 -2.6875q0 -1.46875 0.484375 -2.671875q0.5 -1.203125 1.46875 -1.84375q0.984375 -0.640625 2.203125 -0.640625q0.890625 0 1.578125 0.375q0.703125 0.375 1.140625 0.984375l0 -4.875l1.65625 0l0 13.59375l-1.546875 0zm-5.28125 -4.921875q0 1.890625 0.796875 2.828125q0.8125 0.9375 1.890625 0.9375q1.09375 0 1.859375 -0.890625q0.765625 -0.890625 0.765625 -2.734375q0 -2.015625 -0.78125 -2.953125q-0.78125 -0.953125 -1.921875 -0.953125q-1.109375 0 -1.859375 0.90625q-0.75 0.90625 -0.75 2.859375z" fill-rule="nonzero"></path><path fill="#000000" d="m342.8172 379.91998q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm11.228302 -14.265625l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm13.875702 4.40625l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm4.3757324 4.78125l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.984375 0l-3.390625 4.640625l3.65625 5.21875l-2.046875 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm18.640625 -10.265625l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm13.875732 4.40625l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm6.578827 8.78125l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m282.0 416.00787l177.19684 0l0 41.984253l-177.19684 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m282.0 416.00787l177.19684 0l0 41.984253l-177.19684 0z" fill-rule="evenodd"></path><path fill="#000000" d="m297.11847 436.3029q0 -3.390625 1.8125 -5.296875q1.828125 -1.921875 4.703125 -1.921875q1.875 0 3.390625 0.90625q1.515625 0.890625 2.296875 2.5q0.796875 1.609375 0.796875 3.65625q0 2.0625 -0.84375 3.703125q-0.828125 1.625 -2.359375 2.46875q-1.53125 0.84375 -3.296875 0.84375q-1.921875 0 -3.4375 -0.921875q-1.5 -0.9375 -2.28125 -2.53125q-0.78125 -1.609375 -0.78125 -3.40625zm1.859375 0.03125q0 2.453125 1.3125 3.875q1.328125 1.40625 3.3125 1.40625q2.03125 0 3.34375 -1.421875q1.3125 -1.4375 1.3125 -4.0625q0 -1.65625 -0.5625 -2.890625q-0.546875 -1.234375 -1.640625 -1.921875q-1.078125 -0.6875 -2.421875 -0.6875q-1.90625 0 -3.28125 1.3125q-1.375 1.3125 -1.375 4.390625zm19.43332 6.59375l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm7.578827 -1.5l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm1.5270691 5.28125l0 -13.640625l1.53125 0l0 1.28125q0.53125 -0.75 1.203125 -1.125q0.6875 -0.375 1.640625 -0.375q1.265625 0 2.234375 0.65625q0.96875 0.640625 1.453125 1.828125q0.5 1.1875 0.5 2.59375q0 1.515625 -0.546875 2.734375q-0.546875 1.203125 -1.578125 1.84375q-1.03125 0.640625 -2.171875 0.640625q-0.84375 0 -1.515625 -0.34375q-0.65625 -0.359375 -1.078125 -0.890625l0 4.796875l-1.671875 0zm1.515625 -8.65625q0 1.90625 0.765625 2.8125q0.78125 0.90625 1.875 0.90625q1.109375 0 1.890625 -0.9375q0.796875 -0.9375 0.796875 -2.921875q0 -1.875 -0.78125 -2.8125q-0.765625 -0.9375 -1.84375 -0.9375q-1.0625 0 -1.890625 1.0q-0.8125 1.0 -0.8125 2.890625zm15.313202 4.875l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm7.5788574 -1.5l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm9.897858 5.5q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm2.353302 -6.9375l1.65625 -0.265625q0.140625 1.0 0.765625 1.53125q0.640625 0.515625 1.78125 0.515625q1.15625 0 1.703125 -0.46875q0.5625 -0.46875 0.5625 -1.09375q0 -0.5625 -0.484375 -0.890625q-0.34375 -0.21875 -1.703125 -0.5625q-1.84375 -0.46875 -2.5625 -0.796875q-0.703125 -0.34375 -1.078125 -0.9375q-0.359375 -0.609375 -0.359375 -1.328125q0 -0.65625 0.296875 -1.21875q0.3125 -0.5625 0.828125 -0.9375q0.390625 -0.28125 1.0625 -0.484375q0.671875 -0.203125 1.4375 -0.203125q1.171875 0 2.046875 0.34375q0.875 0.328125 1.28125 0.90625q0.421875 0.5625 0.578125 1.515625l-1.625 0.21875q-0.109375 -0.75 -0.65625 -1.171875q-0.53125 -0.4375 -1.5 -0.4375q-1.15625 0 -1.640625 0.390625q-0.484375 0.375 -0.484375 0.875q0 0.328125 0.203125 0.59375q0.203125 0.265625 0.640625 0.4375q0.25 0.09375 1.46875 0.4375q1.765625 0.46875 2.46875 0.765625q0.703125 0.296875 1.09375 0.875q0.40625 0.578125 0.40625 1.4375q0 0.828125 -0.484375 1.578125q-0.484375 0.734375 -1.40625 1.140625q-0.921875 0.390625 -2.078125 0.390625q-1.921875 0 -2.9375 -0.796875q-1.0 -0.796875 -1.28125 -2.359375zm10.015625 -8.75l0 -1.90625l1.671875 0l0 1.90625l-1.671875 0zm0 11.6875l0 -9.859375l1.671875 0l0 9.859375l-1.671875 0zm3.2542114 0l0 -1.359375l6.265625 -7.1875q-1.0625 0.046875 -1.875 0.046875l-4.015625 0l0 -1.359375l8.046875 0l0 1.109375l-5.34375 6.25l-1.015625 1.140625q1.109375 -0.078125 2.09375 -0.078125l4.5625 0l0 1.4375l-8.71875 0zm16.953125 -3.171875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm23.074646 -2.125l-8.96875 0l0 -1.5625l8.96875 0l0 1.5625zm0 4.125l-8.96875 0l0 -1.546875l8.96875 0l0 1.546875zm15.500122 -6.390625l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm13.875732 4.40625l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm6.578827 8.78125l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m370.00787 80.98425l0 30.015747" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m370.00787 80.98425l0 24.015747" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m368.35614 105.0l1.6517334 4.538101l1.6517334 -4.538101z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m370.00787 152.98425l0 30.015747" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m370.00787 152.98425l0 24.015747" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m368.35614 177.0l1.6517334 4.538101l1.6517334 -4.538101z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m370.00787 224.98425l0 30.015747" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m370.00787 224.98425l0 24.015747" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m368.35614 249.0l1.6517334 4.538101l1.6517334 -4.538101z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m370.00787 296.98425l0 30.015747" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m370.00787 296.98425l0 24.015747" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m368.35614 321.0l1.6517334 4.538086l1.6517334 -4.538086z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m370.00787 385.99213l0.5984192 30.015747" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m370.00787 385.99213l0.47885132 24.016937" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m368.8353 410.042l1.7418518 4.5042725l1.5609436 -4.5701294z" fill-rule="evenodd"></path></g></svg>
+
diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/tts.svg b/tensorflow/contrib/lite/models/testdata/g3doc/tts.svg
new file mode 100755
index 0000000000..9664b78f16
--- /dev/null
+++ b/tensorflow/contrib/lite/models/testdata/g3doc/tts.svg
@@ -0,0 +1,4 @@
+<?xml version="1.0" standalone="yes"?>
+
+<svg version="1.1" viewBox="0.0 0.0 720.0 540.0" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"><clipPath id="p.0"><path d="m0 0l720.0 0l0 540.0l-720.0 0l0 -540.0z" clip-rule="nonzero"></path></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l720.0 0l0 540.0l-720.0 0z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m264.0 14.7l166.01575 0l0 41.984253l-166.01575 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m264.0 14.7l166.01575 0l0 41.984253l-166.01575 0z" fill-rule="evenodd"></path><path fill="#000000" d="m276.43954 41.62l0 -13.59375l1.8125 0l0 13.59375l-1.8125 0zm4.6676636 0l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm10.375732 3.78125l0 -13.640625l1.53125 0l0 1.28125q0.53125 -0.75 1.203125 -1.125q0.6875 -0.375 1.640625 -0.375q1.265625 0 2.234375 0.65625q0.96875 0.640625 1.453125 1.828125q0.5 1.1875 0.5 2.59375q0 1.515625 -0.546875 2.734375q-0.546875 1.203125 -1.578125 1.84375q-1.03125 0.640625 -2.171875 0.640625q-0.84375 0 -1.515625 -0.34375q-0.65625 -0.359375 -1.078125 -0.890625l0 4.796875l-1.671875 0zm1.515625 -8.65625q0 1.90625 0.765625 2.8125q0.78125 0.90625 1.875 0.90625q1.109375 0 1.890625 -0.9375q0.796875 -0.9375 0.796875 -2.921875q0 -1.875 -0.78125 -2.8125q-0.765625 -0.9375 -1.84375 -0.9375q-1.0625 0 -1.890625 1.0q-0.8125 1.0 -0.8125 2.890625zm15.313202 4.875l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm7.5788574 -1.5l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm9.897858 5.5q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm2.353302 -6.9375l1.65625 -0.265625q0.140625 1.0 0.765625 1.53125q0.640625 0.515625 1.78125 0.515625q1.15625 0 1.703125 -0.46875q0.5625 -0.46875 0.5625 -1.09375q0 -0.5625 -0.484375 -0.890625q-0.34375 -0.21875 -1.703125 -0.5625q-1.84375 -0.46875 -2.5625 -0.796875q-0.703125 -0.34375 -1.078125 -0.9375q-0.359375 -0.609375 -0.359375 -1.328125q0 -0.65625 0.296875 -1.21875q0.3125 -0.5625 0.828125 -0.9375q0.390625 -0.28125 1.0625 -0.484375q0.671875 -0.203125 1.4375 -0.203125q1.171875 0 2.046875 0.34375q0.875 0.328125 1.28125 0.90625q0.421875 0.5625 0.578125 1.515625l-1.625 0.21875q-0.109375 -0.75 -0.65625 -1.171875q-0.53125 -0.4375 -1.5 -0.4375q-1.15625 0 -1.640625 0.390625q-0.484375 0.375 -0.484375 0.875q0 0.328125 0.203125 0.59375q0.203125 0.265625 0.640625 0.4375q0.25 0.09375 1.46875 0.4375q1.765625 0.46875 2.46875 0.765625q0.703125 0.296875 1.09375 0.875q0.40625 0.578125 0.40625 1.4375q0 0.828125 -0.484375 1.578125q-0.484375 0.734375 -1.40625 1.140625q-0.921875 0.390625 -2.078125 0.390625q-1.921875 0 -2.9375 -0.796875q-1.0 -0.796875 -1.28125 -2.359375zm10.015625 -8.75l0 -1.90625l1.671875 0l0 1.90625l-1.671875 0zm0 11.6875l0 -9.859375l1.671875 0l0 9.859375l-1.671875 0zm3.254181 0l0 -1.359375l6.265625 -7.1875q-1.0625 0.046875 -1.875 0.046875l-4.015625 0l0 -1.359375l8.046875 0l0 1.109375l-5.34375 6.25l-1.015625 1.140625q1.109375 -0.078125 2.09375 -0.078125l4.5625 0l0 1.4375l-8.71875 0zm16.953125 -3.171875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm23.074646 -2.125l-8.96875 0l0 -1.5625l8.96875 0l0 1.5625zm0 4.125l-8.96875 0l0 -1.546875l8.96875 0l0 1.546875zm6.8439026 0.28125l1.671875 -0.21875q0.28125 1.421875 0.96875 2.046875q0.703125 0.625 1.6875 0.625q1.1875 0 2.0 -0.8125q0.8125 -0.828125 0.8125 -2.03125q0 -1.140625 -0.765625 -1.890625q-0.75 -0.75 -1.90625 -0.75q-0.46875 0 -1.171875 0.1875l0.1875 -1.46875q0.15625 0.015625 0.265625 0.015625q1.0625 0 1.90625 -0.546875q0.859375 -0.5625 0.859375 -1.71875q0 -0.921875 -0.625 -1.515625q-0.609375 -0.609375 -1.59375 -0.609375q-0.96875 0 -1.625 0.609375q-0.640625 0.609375 -0.828125 1.84375l-1.671875 -0.296875q0.296875 -1.6875 1.375 -2.609375q1.09375 -0.921875 2.71875 -0.921875q1.109375 0 2.046875 0.484375q0.9375 0.46875 1.421875 1.296875q0.5 0.828125 0.5 1.75q0 0.890625 -0.46875 1.609375q-0.46875 0.71875 -1.40625 1.15625q1.21875 0.265625 1.875 1.15625q0.671875 0.875 0.671875 2.1875q0 1.78125 -1.296875 3.015625q-1.296875 1.234375 -3.28125 1.234375q-1.796875 0 -2.984375 -1.0625q-1.171875 -1.0625 -1.34375 -2.765625zm10.375702 0l1.671875 -0.21875q0.28125 1.421875 0.96875 2.046875q0.703125 0.625 1.6875 0.625q1.1875 0 2.0 -0.8125q0.8125 -0.828125 0.8125 -2.03125q0 -1.140625 -0.765625 -1.890625q-0.75 -0.75 -1.90625 -0.75q-0.46875 0 -1.171875 0.1875l0.1875 -1.46875q0.15625 0.015625 0.265625 0.015625q1.0625 0 1.90625 -0.546875q0.859375 -0.5625 0.859375 -1.71875q0 -0.921875 -0.625 -1.515625q-0.609375 -0.609375 -1.59375 -0.609375q-0.96875 0 -1.625 0.609375q-0.640625 0.609375 -0.828125 1.84375l-1.671875 -0.296875q0.296875 -1.6875 1.375 -2.609375q1.09375 -0.921875 2.71875 -0.921875q1.109375 0 2.046875 0.484375q0.9375 0.46875 1.421875 1.296875q0.5 0.828125 0.5 1.75q0 0.890625 -0.46875 1.609375q-0.46875 0.71875 -1.40625 1.15625q1.21875 0.265625 1.875 1.15625q0.671875 0.875 0.671875 2.1875q0 1.78125 -1.296875 3.015625q-1.296875 1.234375 -3.28125 1.234375q-1.796875 0 -2.984375 -1.0625q-1.171875 -1.0625 -1.34375 -2.765625zm15.719482 3.59375l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm6.578827 8.78125l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m264.0 238.01575l168.0 0l0 41.984253l-168.0 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m264.0 238.01575l168.0 0l0 41.984253l-168.0 0z" fill-rule="evenodd"></path><path fill="#000000" d="m279.91705 264.93576l0 -13.593765l1.796875 0l0 11.98439l6.703125 0l0 1.609375l-8.5 0zm9.844452 -4.375l1.6875 -0.140625q0.125 1.015625 0.5625 1.671875q0.4375 0.65625 1.359375 1.0625q0.9375 0.40625 2.09375 0.40625q1.03125 0 1.8125 -0.3125q0.796875 -0.3125 1.1875 -0.84375q0.390625 -0.53125 0.390625 -1.15625q0 -0.640625 -0.375 -1.109375q-0.375 -0.484375 -1.234375 -0.8125q-0.546875 -0.21875 -2.421875 -0.65625q-1.875 -0.453125 -2.625 -0.859375q-0.96875 -0.515625 -1.453125 -1.265625q-0.46875 -0.75001526 -0.46875 -1.6875153q0 -1.03125 0.578125 -1.921875q0.59375 -0.90625 1.703125 -1.359375q1.125 -0.46875 2.5 -0.46875q1.515625 0 2.671875 0.484375q1.15625 0.484375 1.765625 1.4375q0.625 0.9375 0.671875 2.140625l-1.71875 0.125q-0.140625 -1.28125 -0.953125 -1.9375q-0.796875 -0.671875 -2.359375 -0.671875q-1.625 0 -2.375 0.609375q-0.75 0.59375 -0.75 1.4375q0 0.734375 0.53125 1.203125q0.515625 0.46876526 2.703125 0.96876526q2.203125 0.5 3.015625 0.875q1.1875 0.546875 1.75 1.390625q0.578125 0.828125 0.578125 1.921875q0 1.09375 -0.625 2.0625q-0.625 0.953125 -1.796875 1.484375q-1.15625 0.53125 -2.609375 0.53125q-1.84375 0 -3.09375 -0.53125q-1.25 -0.546875 -1.96875 -1.625q-0.703125 -1.078125 -0.734375 -2.453125zm16.506073 4.375l0 -12.000015l-4.46875 0l0 -1.59375l10.765625 0l0 1.59375l-4.5 0l0 12.000015l-1.796875 0zm7.880371 0l0 -13.593765l2.71875 0l3.21875 9.625015q0.4375 1.34375 0.640625 2.015625q0.234375 -0.75 0.734375 -2.1875l3.25 -9.45314l2.421875 0l0 13.593765l-1.734375 0l0 -11.39064l-3.953125 11.39064l-1.625 0l-3.9375 -11.57814l0 11.57814l-1.734375 0zm21.212677 0l-1.671875 0l0 -10.64064q-0.59375 0.578125 -1.578125 1.15625q-0.984375 0.56251526 -1.765625 0.85939026l0 -1.6250153q1.40625 -0.65625 2.453125 -1.59375q1.046875 -0.9375 1.484375 -1.8125l1.078125 0l0 13.656265zm12.918396 4.0q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.2343903q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.6718903q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm8.853302 -4.0l-1.671875 0l0 -10.64064q-0.59375 0.578125 -1.578125 1.15625q-0.984375 0.56251526 -1.765625 0.85939026l0 -1.6250153q1.40625 -0.65625 2.453125 -1.59375q1.046875 -0.9375 1.484375 -1.8125l1.078125 0l0 13.656265zm12.860077 -1.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625153 0.78125 -2.0156403q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.6093903q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm4.1726074 -5.765625q-1.046875 -0.375 -1.546875 -1.078125q-0.5 -0.71876526 -0.5 -1.7031403q0 -1.484375 1.0625 -2.484375q1.078125 -1.015625 2.84375 -1.015625q1.78125 0 2.859375 1.03125q1.09375 1.03125 1.09375 2.515625q0 0.953125 -0.5 1.6562653q-0.484375 0.703125 -1.5 1.078125q1.25 0.40625 1.90625 1.3125q0.65625 0.90625 0.65625 2.171875q0 1.75 -1.234375 2.9375q-1.234375 1.1875 -3.25 1.1875q-2.015625 0 -3.25 -1.1875q-1.234375 -1.203125 -1.234375 -2.984375q0 -1.328125 0.671875 -2.21875q0.671875 -0.890625 1.921875 -1.21875zm-0.328125 -2.8281403q0 0.96875 0.609375 1.5781403q0.625 0.609375 1.625 0.609375q0.953125 0 1.5625 -0.609375q0.625 -0.60939026 0.625 -1.4843903q0 -0.921875 -0.640625 -1.546875q-0.625 -0.625 -1.578125 -0.625q-0.953125 0 -1.578125 0.609375q-0.625 0.609375 -0.625 1.46875zm-0.546875 6.2812653q0 0.71875 0.328125 1.390625q0.34375 0.65625 1.015625 1.03125q0.671875 0.359375 1.4375 0.359375q1.203125 0 1.984375 -0.765625q0.78125 -0.78125 0.78125 -1.96875q0 -1.203125 -0.8125 -1.984375q-0.796875 -0.796875 -2.0 -0.796875q-1.1875 0 -1.96875 0.78125q-0.765625 0.78125 -0.765625 1.953125zm8.031952 3.921875l3.59375 -5.125l-3.328125 -4.7343903l2.09375 0l1.515625 2.3125153q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.3281403l1.984375 0l-3.390625 4.6406403l3.65625 5.21875l-2.046875 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm18.640625 -10.26564l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.1406403q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.9531403 1.453125 -5.7343903q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.8593903q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm13.875732 4.40625l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.812515l1.359375 0l0 8.812515l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.1406403l-4.25 6.1406403l4.25 0zm6.578827 8.78125l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.3906403 -0.890625 -2.6718903q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.0156403 0.71875 4.2343903q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m264.0 296.0l168.0 0l0 41.984253l-168.0 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m264.0 296.0l168.0 0l0 41.984253l-168.0 0z" fill-rule="evenodd"></path><path fill="#000000" d="m285.10492 322.91998l0 -13.59375l1.796875 0l0 11.984375l6.703125 0l0 1.609375l-8.5 0zm9.844452 -4.375l1.6875 -0.140625q0.125 1.015625 0.5625 1.671875q0.4375 0.65625 1.359375 1.0625q0.9375 0.40625 2.09375 0.40625q1.03125 0 1.8125 -0.3125q0.796875 -0.3125 1.1875 -0.84375q0.390625 -0.53125 0.390625 -1.15625q0 -0.640625 -0.375 -1.109375q-0.375 -0.484375 -1.234375 -0.8125q-0.546875 -0.21875 -2.421875 -0.65625q-1.875 -0.453125 -2.625 -0.859375q-0.96875 -0.515625 -1.453125 -1.265625q-0.46875 -0.75 -0.46875 -1.6875q0 -1.03125 0.578125 -1.921875q0.59375 -0.90625 1.703125 -1.359375q1.125 -0.46875 2.5 -0.46875q1.515625 0 2.671875 0.484375q1.15625 0.484375 1.765625 1.4375q0.625 0.9375 0.671875 2.140625l-1.71875 0.125q-0.140625 -1.28125 -0.953125 -1.9375q-0.796875 -0.671875 -2.359375 -0.671875q-1.625 0 -2.375 0.609375q-0.75 0.59375 -0.75 1.4375q0 0.734375 0.53125 1.203125q0.515625 0.46875 2.703125 0.96875q2.203125 0.5 3.015625 0.875q1.1875 0.546875 1.75 1.390625q0.578125 0.828125 0.578125 1.921875q0 1.09375 -0.625 2.0625q-0.625 0.953125 -1.796875 1.484375q-1.15625 0.53125 -2.609375 0.53125q-1.84375 0 -3.09375 -0.53125q-1.25 -0.546875 -1.96875 -1.625q-0.703125 -1.078125 -0.734375 -2.453125zm16.506073 4.375l0 -12.0l-4.46875 0l0 -1.59375l10.765625 0l0 1.59375l-4.5 0l0 12.0l-1.796875 0zm7.880371 0l0 -13.59375l2.71875 0l3.21875 9.625q0.4375 1.34375 0.640625 2.015625q0.234375 -0.75 0.734375 -2.1875l3.25 -9.453125l2.421875 0l0 13.59375l-1.734375 0l0 -11.390625l-3.953125 11.390625l-1.625 0l-3.9375 -11.578125l0 11.578125l-1.734375 0zm23.697021 -1.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm10.434021 5.609375q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm11.2283325 -14.265625l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm13.875702 4.40625l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm4.375702 4.78125l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.984375 0l-3.390625 4.640625l3.65625 5.21875l-2.046875 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm18.640625 -10.265625l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm13.875732 4.40625l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm6.5788574 8.78125l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m264.0 358.1l168.0 0l0 41.984253l-168.0 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m264.0 358.1l168.0 0l0 41.984253l-168.0 0z" fill-rule="evenodd"></path><path fill="#000000" d="m285.10492 385.02l0 -13.59375l1.796875 0l0 11.984375l6.703125 0l0 1.609375l-8.5 0zm9.844452 -4.375l1.6875 -0.140625q0.125 1.015625 0.5625 1.671875q0.4375 0.65625 1.359375 1.0625q0.9375 0.40625 2.09375 0.40625q1.03125 0 1.8125 -0.3125q0.796875 -0.3125 1.1875 -0.84375q0.390625 -0.53125 0.390625 -1.15625q0 -0.640625 -0.375 -1.109375q-0.375 -0.484375 -1.234375 -0.8125q-0.546875 -0.21875 -2.421875 -0.65625q-1.875 -0.453125 -2.625 -0.859375q-0.96875 -0.515625 -1.453125 -1.265625q-0.46875 -0.75 -0.46875 -1.6875q0 -1.03125 0.578125 -1.921875q0.59375 -0.90625 1.703125 -1.359375q1.125 -0.46875 2.5 -0.46875q1.515625 0 2.671875 0.484375q1.15625 0.484375 1.765625 1.4375q0.625 0.9375 0.671875 2.140625l-1.71875 0.125q-0.140625 -1.28125 -0.953125 -1.9375q-0.796875 -0.671875 -2.359375 -0.671875q-1.625 0 -2.375 0.609375q-0.75 0.59375 -0.75 1.4375q0 0.734375 0.53125 1.203125q0.515625 0.46875 2.703125 0.96875q2.203125 0.5 3.015625 0.875q1.1875 0.546875 1.75 1.390625q0.578125 0.828125 0.578125 1.921875q0 1.09375 -0.625 2.0625q-0.625 0.953125 -1.796875 1.484375q-1.15625 0.53125 -2.609375 0.53125q-1.84375 0 -3.09375 -0.53125q-1.25 -0.546875 -1.96875 -1.625q-0.703125 -1.078125 -0.734375 -2.453125zm16.506073 4.375l0 -12.0l-4.46875 0l0 -1.59375l10.765625 0l0 1.59375l-4.5 0l0 12.0l-1.796875 0zm7.880371 0l0 -13.59375l2.71875 0l3.21875 9.625q0.4375 1.34375 0.640625 2.015625q0.234375 -0.75 0.734375 -2.1875l3.25 -9.453125l2.421875 0l0 13.59375l-1.734375 0l0 -11.390625l-3.953125 11.390625l-1.625 0l-3.9375 -11.578125l0 11.578125l-1.734375 0zm14.9313965 -3.59375l1.671875 -0.21875q0.28125 1.421875 0.96875 2.046875q0.703125 0.625 1.6875 0.625q1.1875 0 2.0 -0.8125q0.8125 -0.828125 0.8125 -2.03125q0 -1.140625 -0.765625 -1.890625q-0.75 -0.75 -1.90625 -0.75q-0.46875 0 -1.171875 0.1875l0.1875 -1.46875q0.15625 0.015625 0.265625 0.015625q1.0625 0 1.90625 -0.546875q0.859375 -0.5625 0.859375 -1.71875q0 -0.921875 -0.625 -1.515625q-0.609375 -0.609375 -1.59375 -0.609375q-0.96875 0 -1.625 0.609375q-0.640625 0.609375 -0.828125 1.84375l-1.671875 -0.296875q0.296875 -1.6875 1.375 -2.609375q1.09375 -0.921875 2.71875 -0.921875q1.109375 0 2.046875 0.484375q0.9375 0.46875 1.421875 1.296875q0.5 0.828125 0.5 1.75q0 0.890625 -0.46875 1.609375q-0.46875 0.71875 -1.40625 1.15625q1.21875 0.265625 1.875 1.15625q0.671875 0.875 0.671875 2.1875q0 1.78125 -1.296875 3.015625q-1.296875 1.234375 -3.28125 1.234375q-1.796875 0 -2.984375 -1.0625q-1.171875 -1.0625 -1.34375 -2.765625zm19.199646 7.59375q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm11.2283325 -14.265625l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm13.875702 4.40625l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm4.375702 4.78125l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.984375 0l-3.390625 4.640625l3.65625 5.21875l-2.046875 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm18.640625 -10.265625l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm13.875732 4.40625l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm6.5788574 8.78125l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m264.0 78.7l168.0 0l0 58.992126l-168.0 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m264.0 78.7l168.0 0l0 58.992126l-168.0 0z" fill-rule="evenodd"></path><path fill="#000000" d="m279.56058 105.619995l0 -13.59375l9.171875 0l0 1.59375l-7.375 0l0 4.21875l6.375 0l0 1.609375l-6.375 0l0 6.171875l-1.796875 0zm17.53659 0l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm3.8913574 0l0 -13.59375l1.671875 0l0 13.59375l-1.671875 0zm4.144806 0l0 -13.59375l1.671875 0l0 13.59375l-1.671875 0zm4.0979614 3.796875l-0.171875 -1.5625q0.546875 0.140625 0.953125 0.140625q0.546875 0 0.875 -0.1875q0.34375 -0.1875 0.5625 -0.515625q0.15625 -0.25 0.5 -1.25q0.046875 -0.140625 0.15625 -0.40625l-3.734375 -9.875l1.796875 0l2.046875 5.71875q0.40625 1.078125 0.71875 2.28125q0.28125 -1.15625 0.6875 -2.25l2.09375 -5.75l1.671875 0l-3.75 10.03125q-0.59375 1.625 -0.9375 2.234375q-0.4375 0.828125 -1.015625 1.203125q-0.578125 0.390625 -1.375 0.390625q-0.484375 0 -1.078125 -0.203125zm19.328125 -8.5625l1.796875 0.453125q-0.5625 2.21875 -2.03125 3.390625q-1.46875 1.15625 -3.59375 1.15625q-2.203125 0 -3.578125 -0.890625q-1.375 -0.90625 -2.09375 -2.59375q-0.71875 -1.703125 -0.71875 -3.65625q0 -2.125 0.796875 -3.703125q0.8125 -1.578125 2.3125 -2.390625q1.5 -0.828125 3.296875 -0.828125q2.046875 0 3.4375 1.046875q1.390625 1.03125 1.9375 2.90625l-1.765625 0.421875q-0.46875 -1.484375 -1.375 -2.15625q-0.90625 -0.6875 -2.265625 -0.6875q-1.5625 0 -2.625 0.75q-1.046875 0.75 -1.484375 2.03125q-0.421875 1.265625 -0.421875 2.609375q0 1.734375 0.5 3.03125q0.515625 1.28125 1.578125 1.921875q1.078125 0.640625 2.3125 0.640625q1.515625 0 2.5625 -0.859375q1.046875 -0.875 1.421875 -2.59375zm2.926056 -0.15625q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.281982 4.921875l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm10.375702 0l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm17.125732 -3.171875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.547577 2.265625l1.640625 0.21875q-0.265625 1.6875 -1.375 2.65625q-1.109375 0.953125 -2.734375 0.953125q-2.015625 0 -3.25 -1.3125q-1.21875 -1.328125 -1.21875 -3.796875q0 -1.59375 0.515625 -2.78125q0.53125 -1.203125 1.609375 -1.796875q1.09375 -0.609375 2.359375 -0.609375q1.609375 0 2.625 0.8125q1.015625 0.8125 1.3125 2.3125l-1.625 0.25q-0.234375 -1.0 -0.828125 -1.5q-0.59375 -0.5 -1.421875 -0.5q-1.265625 0 -2.0625 0.90625q-0.78125 0.90625 -0.78125 2.859375q0 1.984375 0.765625 2.890625q0.765625 0.890625 1.984375 0.890625q0.984375 0 1.640625 -0.59375q0.65625 -0.609375 0.84375 -1.859375zm6.546875 2.109375l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm8.277069 -1.671875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.500702 5.875l0 -1.25q-0.9375 1.46875 -2.75 1.46875q-1.171875 0 -2.171875 -0.640625q-0.984375 -0.65625 -1.53125 -1.8125q-0.53125 -1.171875 -0.53125 -2.6875q0 -1.46875 0.484375 -2.671875q0.5 -1.203125 1.46875 -1.84375q0.984375 -0.640625 2.203125 -0.640625q0.890625 0 1.578125 0.375q0.703125 0.375 1.140625 0.984375l0 -4.875l1.65625 0l0 13.59375l-1.546875 0zm-5.28125 -4.921875q0 1.890625 0.796875 2.828125q0.8125 0.9375 1.890625 0.9375q1.09375 0 1.859375 -0.890625q0.765625 -0.890625 0.765625 -2.734375q0 -2.015625 -0.78125 -2.953125q-0.78125 -0.953125 -1.921875 -0.953125q-1.109375 0 -1.859375 0.90625q-0.75 0.90625 -0.75 2.859375zm15.094482 4.921875l-1.671875 0l0 -10.640625q-0.59375 0.578125 -1.578125 1.15625q-0.984375 0.5625 -1.765625 0.859375l0 -1.625q1.40625 -0.65625 2.453125 -1.59375q1.046875 -0.9375 1.484375 -1.8125l1.078125 0l0 13.65625z" fill-rule="nonzero"></path><path fill="#000000" d="m310.4336 131.62q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm2.5720825 -7.59375l1.671875 -0.21875q0.28125 1.421875 0.96875 2.046875q0.703125 0.625 1.6875 0.625q1.1875 0 2.0 -0.8125q0.8125 -0.828125 0.8125 -2.03125q0 -1.140625 -0.765625 -1.890625q-0.75 -0.75 -1.90625 -0.75q-0.46875 0 -1.171875 0.1875l0.1875 -1.46875q0.15625 0.015625 0.265625 0.015625q1.0625 0 1.90625 -0.546875q0.859375 -0.5625 0.859375 -1.71875q0 -0.921875 -0.625 -1.515625q-0.609375 -0.609375 -1.59375 -0.609375q-0.96875 0 -1.625 0.609375q-0.640625 0.609375 -0.828125 1.84375l-1.671875 -0.296875q0.296875 -1.6875 1.375 -2.609375q1.09375 -0.921875 2.71875 -0.921875q1.109375 0 2.046875 0.484375q0.9375 0.46875 1.421875 1.296875q0.5 0.828125 0.5 1.75q0 0.890625 -0.46875 1.609375q-0.46875 0.71875 -1.40625 1.15625q1.21875 0.265625 1.875 1.15625q0.671875 0.875 0.671875 2.1875q0 1.78125 -1.296875 3.015625q-1.296875 1.234375 -3.28125 1.234375q-1.796875 0 -2.984375 -1.0625q-1.171875 -1.0625 -1.34375 -2.765625zm10.375702 0l1.671875 -0.21875q0.28125 1.421875 0.96875 2.046875q0.703125 0.625 1.6875 0.625q1.1875 0 2.0 -0.8125q0.8125 -0.828125 0.8125 -2.03125q0 -1.140625 -0.765625 -1.890625q-0.75 -0.75 -1.90625 -0.75q-0.46875 0 -1.171875 0.1875l0.1875 -1.46875q0.15625 0.015625 0.265625 0.015625q1.0625 0 1.90625 -0.546875q0.859375 -0.5625 0.859375 -1.71875q0 -0.921875 -0.625 -1.515625q-0.609375 -0.609375 -1.59375 -0.609375q-0.96875 0 -1.625 0.609375q-0.640625 0.609375 -0.828125 1.84375l-1.671875 -0.296875q0.296875 -1.6875 1.375 -2.609375q1.09375 -0.921875 2.71875 -0.921875q1.109375 0 2.046875 0.484375q0.9375 0.46875 1.421875 1.296875q0.5 0.828125 0.5 1.75q0 0.890625 -0.46875 1.609375q-0.46875 0.71875 -1.40625 1.15625q1.21875 0.265625 1.875 1.15625q0.671875 0.875 0.671875 2.1875q0 1.78125 -1.296875 3.015625q-1.296875 1.234375 -3.28125 1.234375q-1.796875 0 -2.984375 -1.0625q-1.171875 -1.0625 -1.34375 -2.765625zm15.719452 3.59375l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm4.3757324 4.78125l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.984375 0l-3.390625 4.640625l3.65625 5.21875l-2.046875 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm16.265625 0l-1.671875 0l0 -10.640625q-0.59375 0.578125 -1.578125 1.15625q-0.984375 0.5625 -1.765625 0.859375l0 -1.625q1.40625 -0.65625 2.453125 -1.59375q1.046875 -0.9375 1.484375 -1.8125l1.078125 0l0 13.65625zm12.860107 -1.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm4.172577 -5.765625q-1.046875 -0.375 -1.546875 -1.078125q-0.5 -0.71875 -0.5 -1.703125q0 -1.484375 1.0625 -2.484375q1.078125 -1.015625 2.84375 -1.015625q1.78125 0 2.859375 1.03125q1.09375 1.03125 1.09375 2.515625q0 0.953125 -0.5 1.65625q-0.484375 0.703125 -1.5 1.078125q1.25 0.40625 1.90625 1.3125q0.65625 0.90625 0.65625 2.171875q0 1.75 -1.234375 2.9375q-1.234375 1.1875 -3.25 1.1875q-2.015625 0 -3.25 -1.1875q-1.234375 -1.203125 -1.234375 -2.984375q0 -1.328125 0.671875 -2.21875q0.671875 -0.890625 1.921875 -1.21875zm-0.328125 -2.828125q0 0.96875 0.609375 1.578125q0.625 0.609375 1.625 0.609375q0.953125 0 1.5625 -0.609375q0.625 -0.609375 0.625 -1.484375q0 -0.921875 -0.640625 -1.546875q-0.625 -0.625 -1.578125 -0.625q-0.953125 0 -1.578125 0.609375q-0.625 0.609375 -0.625 1.46875zm-0.546875 6.28125q0 0.71875 0.328125 1.390625q0.34375 0.65625 1.015625 1.03125q0.671875 0.359375 1.4375 0.359375q1.203125 0 1.984375 -0.765625q0.78125 -0.78125 0.78125 -1.96875q0 -1.203125 -0.8125 -1.984375q-0.796875 -0.796875 -2.0 -0.796875q-1.1875 0 -1.96875 0.78125q-0.765625 0.78125 -0.765625 1.953125zm10.235077 7.921875l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m257.8 488.0l180.0 0l0 46.992126l-180.0 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m257.8 488.0l180.0 0l0 46.992126l-180.0 0z" fill-rule="evenodd"></path><path fill="#000000" d="m269.1322 508.29498q0 -3.390625 1.8125 -5.296875q1.828125 -1.921875 4.703125 -1.921875q1.875 0 3.390625 0.90625q1.515625 0.890625 2.296875 2.5q0.796875 1.609375 0.796875 3.65625q0 2.0625 -0.84375 3.703125q-0.828125 1.625 -2.359375 2.46875q-1.53125 0.84375 -3.296875 0.84375q-1.921875 0 -3.4375 -0.921875q-1.5 -0.9375 -2.28125 -2.53125q-0.78125 -1.609375 -0.78125 -3.40625zm1.859375 0.03125q0 2.453125 1.3125 3.875q1.328125 1.40625 3.3125 1.40625q2.03125 0 3.34375 -1.421875q1.3125 -1.4375 1.3125 -4.0625q0 -1.65625 -0.5625 -2.890625q-0.546875 -1.234375 -1.640625 -1.921875q-1.078125 -0.6875 -2.421875 -0.6875q-1.90625 0 -3.28125 1.3125q-1.375 1.3125 -1.375 4.390625zm19.433289 6.59375l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm7.5788574 -1.5l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm1.5270386 5.28125l0 -13.640625l1.53125 0l0 1.28125q0.53125 -0.75 1.203125 -1.125q0.6875 -0.375 1.640625 -0.375q1.265625 0 2.234375 0.65625q0.96875 0.640625 1.453125 1.828125q0.5 1.1875 0.5 2.59375q0 1.515625 -0.546875 2.734375q-0.546875 1.203125 -1.578125 1.84375q-1.03125 0.640625 -2.171875 0.640625q-0.84375 0 -1.515625 -0.34375q-0.65625 -0.359375 -1.078125 -0.890625l0 4.796875l-1.671875 0zm1.515625 -8.65625q0 1.90625 0.765625 2.8125q0.78125 0.90625 1.875 0.90625q1.109375 0 1.890625 -0.9375q0.796875 -0.9375 0.796875 -2.921875q0 -1.875 -0.78125 -2.8125q-0.765625 -0.9375 -1.84375 -0.9375q-1.0625 0 -1.890625 1.0q-0.8125 1.0 -0.8125 2.890625zm15.313232 4.875l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm7.578827 -1.5l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm9.897858 5.5q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm2.353302 -6.9375l1.65625 -0.265625q0.140625 1.0 0.765625 1.53125q0.640625 0.515625 1.78125 0.515625q1.15625 0 1.703125 -0.46875q0.5625 -0.46875 0.5625 -1.09375q0 -0.5625 -0.484375 -0.890625q-0.34375 -0.21875 -1.703125 -0.5625q-1.84375 -0.46875 -2.5625 -0.796875q-0.703125 -0.34375 -1.078125 -0.9375q-0.359375 -0.609375 -0.359375 -1.328125q0 -0.65625 0.296875 -1.21875q0.3125 -0.5625 0.828125 -0.9375q0.390625 -0.28125 1.0625 -0.484375q0.671875 -0.203125 1.4375 -0.203125q1.171875 0 2.046875 0.34375q0.875 0.328125 1.28125 0.90625q0.421875 0.5625 0.578125 1.515625l-1.625 0.21875q-0.109375 -0.75 -0.65625 -1.171875q-0.53125 -0.4375 -1.5 -0.4375q-1.15625 0 -1.640625 0.390625q-0.484375 0.375 -0.484375 0.875q0 0.328125 0.203125 0.59375q0.203125 0.265625 0.640625 0.4375q0.25 0.09375 1.46875 0.4375q1.765625 0.46875 2.46875 0.765625q0.703125 0.296875 1.09375 0.875q0.40625 0.578125 0.40625 1.4375q0 0.828125 -0.484375 1.578125q-0.484375 0.734375 -1.40625 1.140625q-0.921875 0.390625 -2.078125 0.390625q-1.921875 0 -2.9375 -0.796875q-1.0 -0.796875 -1.28125 -2.359375zm10.015625 -8.75l0 -1.90625l1.671875 0l0 1.90625l-1.671875 0zm0 11.6875l0 -9.859375l1.671875 0l0 9.859375l-1.671875 0zm3.2542114 0l0 -1.359375l6.265625 -7.1875q-1.0625 0.046875 -1.875 0.046875l-4.015625 0l0 -1.359375l8.046875 0l0 1.109375l-5.34375 6.25l-1.015625 1.140625q1.109375 -0.078125 2.09375 -0.078125l4.5625 0l0 1.4375l-8.71875 0zm16.953125 -3.171875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm23.074646 -2.125l-8.96875 0l0 -1.5625l8.96875 0l0 1.5625zm0 4.125l-8.96875 0l0 -1.546875l8.96875 0l0 1.546875zm13.125122 3.875l-1.671875 0l0 -10.640625q-0.59375 0.578125 -1.578125 1.15625q-0.984375 0.5625 -1.765625 0.859375l0 -1.625q1.40625 -0.65625 2.453125 -1.59375q1.046875 -0.9375 1.484375 -1.8125l1.078125 0l0 13.65625zm4.3444824 -3.140625l1.59375 -0.15625q0.203125 1.140625 0.78125 1.65625q0.578125 0.5 1.484375 0.5q0.765625 0 1.34375 -0.34375q0.578125 -0.359375 0.953125 -0.953125q0.375 -0.59375 0.625 -1.59375q0.25 -1.0 0.25 -2.03125q0 -0.109375 -0.015625 -0.34375q-0.5 0.796875 -1.375 1.296875q-0.859375 0.5 -1.875 0.5q-1.6875 0 -2.859375 -1.21875q-1.171875 -1.234375 -1.171875 -3.234375q0 -2.078125 1.21875 -3.328125q1.234375 -1.265625 3.0625 -1.265625q1.328125 0 2.421875 0.71875q1.109375 0.703125 1.671875 2.03125q0.578125 1.328125 0.578125 3.828125q0 2.609375 -0.578125 4.15625q-0.5625 1.546875 -1.6875 2.359375q-1.109375 0.796875 -2.609375 0.796875q-1.59375 0 -2.609375 -0.890625q-1.0 -0.890625 -1.203125 -2.484375zm6.828125 -6.0q0 -1.4375 -0.765625 -2.28125q-0.765625 -0.859375 -1.84375 -0.859375q-1.109375 0 -1.9375 0.921875q-0.828125 0.90625 -0.828125 2.34375q0 1.3125 0.78125 2.125q0.796875 0.796875 1.9375 0.796875q1.171875 0 1.90625 -0.796875q0.75 -0.8125 0.75 -2.25zm11.953827 -1.125l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm10.078857 8.40625l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m348.0 280.0l0 16.0" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m348.0 280.0l0 10.0" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m346.34827 290.0l1.6517334 4.538086l1.6517334 -4.538086z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m348.0 337.98425l0 20.125977" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m348.0 337.98425l0 14.125977" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m346.34827 352.11023l1.6517334 4.5381165l1.6517334 -4.5381165z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m347.00787 56.684254l1.0078735 22.015743" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m347.00787 56.68425l0.7334595 16.022026" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m346.09134 72.781815l1.857544 4.4578094l1.4424744 -4.6088867z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m348.0 400.08426l0.31497192 21.921265" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m348.0 400.08423l0.22875977 15.921875" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m346.5772 416.02985l1.7167358 4.5138855l1.5863647 -4.5613403z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m264.0 158.19606l168.0 0l0 58.992126l-168.0 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m264.0 158.19606l168.0 0l0 58.992126l-168.0 0z" fill-rule="evenodd"></path><path fill="#000000" d="m279.56058 185.11606l0 -13.59375l9.171875 0l0 1.59375l-7.375 0l0 4.21875l6.375 0l0 1.609375l-6.375 0l0 6.171875l-1.796875 0zm17.53659 0l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm3.8913574 0l0 -13.59375l1.671875 0l0 13.59375l-1.671875 0zm4.144806 0l0 -13.59375l1.671875 0l0 13.59375l-1.671875 0zm4.0979614 3.796875l-0.171875 -1.5625q0.546875 0.140625 0.953125 0.140625q0.546875 0 0.875 -0.1875q0.34375 -0.1875 0.5625 -0.515625q0.15625 -0.25 0.5 -1.25q0.046875 -0.140625 0.15625 -0.40625l-3.734375 -9.875l1.796875 0l2.046875 5.71875q0.40625 1.078125 0.71875 2.28125q0.28125 -1.15625 0.6875 -2.25l2.09375 -5.75l1.671875 0l-3.75 10.03125q-0.59375 1.625 -0.9375 2.234375q-0.4375 0.828125 -1.015625 1.203125q-0.578125 0.390625 -1.375 0.390625q-0.484375 0 -1.078125 -0.203125zm19.328125 -8.5625l1.796875 0.453125q-0.5625 2.21875 -2.03125 3.390625q-1.46875 1.15625 -3.59375 1.15625q-2.203125 0 -3.578125 -0.890625q-1.375 -0.90625 -2.09375 -2.59375q-0.71875 -1.703125 -0.71875 -3.65625q0 -2.125 0.796875 -3.703125q0.8125 -1.578125 2.3125 -2.390625q1.5 -0.828125 3.296875 -0.828125q2.046875 0 3.4375 1.046875q1.390625 1.03125 1.9375 2.90625l-1.765625 0.421875q-0.46875 -1.484375 -1.375 -2.15625q-0.90625 -0.6875 -2.265625 -0.6875q-1.5625 0 -2.625 0.75q-1.046875 0.75 -1.484375 2.03125q-0.421875 1.265625 -0.421875 2.609375q0 1.734375 0.5 3.03125q0.515625 1.28125 1.578125 1.921875q1.078125 0.640625 2.3125 0.640625q1.515625 0 2.5625 -0.859375q1.046875 -0.875 1.421875 -2.59375zm2.926056 -0.15625q0 -2.734375 1.53125 -4.0625q1.265625 -1.09375 3.09375 -1.09375q2.03125 0 3.3125 1.34375q1.296875 1.328125 1.296875 3.671875q0 1.90625 -0.578125 3.0q-0.5625 1.078125 -1.65625 1.6875q-1.078125 0.59375 -2.375 0.59375q-2.0625 0 -3.34375 -1.328125q-1.28125 -1.328125 -1.28125 -3.8125zm1.71875 0q0 1.890625 0.828125 2.828125q0.828125 0.9375 2.078125 0.9375q1.25 0 2.0625 -0.9375q0.828125 -0.953125 0.828125 -2.890625q0 -1.828125 -0.828125 -2.765625q-0.828125 -0.9375 -2.0625 -0.9375q-1.25 0 -2.078125 0.9375q-0.828125 0.9375 -0.828125 2.828125zm9.281982 4.921875l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm10.375702 0l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm17.125732 -3.171875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.547577 2.265625l1.640625 0.21875q-0.265625 1.6875 -1.375 2.65625q-1.109375 0.953125 -2.734375 0.953125q-2.015625 0 -3.25 -1.3125q-1.21875 -1.328125 -1.21875 -3.796875q0 -1.59375 0.515625 -2.78125q0.53125 -1.203125 1.609375 -1.796875q1.09375 -0.609375 2.359375 -0.609375q1.609375 0 2.625 0.8125q1.015625 0.8125 1.3125 2.3125l-1.625 0.25q-0.234375 -1.0 -0.828125 -1.5q-0.59375 -0.5 -1.421875 -0.5q-1.265625 0 -2.0625 0.90625q-0.78125 0.90625 -0.78125 2.859375q0 1.984375 0.765625 2.890625q0.765625 0.890625 1.984375 0.890625q0.984375 0 1.640625 -0.59375q0.65625 -0.609375 0.84375 -1.859375zm6.546875 2.109375l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm8.277069 -1.671875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.500702 5.875l0 -1.25q-0.9375 1.46875 -2.75 1.46875q-1.171875 0 -2.171875 -0.640625q-0.984375 -0.65625 -1.53125 -1.8125q-0.53125 -1.171875 -0.53125 -2.6875q0 -1.46875 0.484375 -2.671875q0.5 -1.203125 1.46875 -1.84375q0.984375 -0.640625 2.203125 -0.640625q0.890625 0 1.578125 0.375q0.703125 0.375 1.140625 0.984375l0 -4.875l1.65625 0l0 13.59375l-1.546875 0zm-5.28125 -4.921875q0 1.890625 0.796875 2.828125q0.8125 0.9375 1.890625 0.9375q1.09375 0 1.859375 -0.890625q0.765625 -0.890625 0.765625 -2.734375q0 -2.015625 -0.78125 -2.953125q-0.78125 -0.953125 -1.921875 -0.953125q-1.109375 0 -1.859375 0.90625q-0.75 0.90625 -0.75 2.859375zm17.578857 3.3125l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0z" fill-rule="nonzero"></path><path fill="#000000" d="m310.4336 211.11606q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm8.8533325 -4.0l-1.671875 0l0 -10.640625q-0.59375 0.578125 -1.578125 1.15625q-0.984375 0.5625 -1.765625 0.859375l0 -1.625q1.40625 -0.65625 2.453125 -1.59375q1.046875 -0.9375 1.484375 -1.8125l1.078125 0l0 13.65625zm12.860077 -1.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm4.172577 -5.765625q-1.046875 -0.375 -1.546875 -1.078125q-0.5 -0.71875 -0.5 -1.703125q0 -1.484375 1.0625 -2.484375q1.078125 -1.015625 2.84375 -1.015625q1.78125 0 2.859375 1.03125q1.09375 1.03125 1.09375 2.515625q0 0.953125 -0.5 1.65625q-0.484375 0.703125 -1.5 1.078125q1.25 0.40625 1.90625 1.3125q0.65625 0.90625 0.65625 2.171875q0 1.75 -1.234375 2.9375q-1.234375 1.1875 -3.25 1.1875q-2.015625 0 -3.25 -1.1875q-1.234375 -1.203125 -1.234375 -2.984375q0 -1.328125 0.671875 -2.21875q0.671875 -0.890625 1.921875 -1.21875zm-0.328125 -2.828125q0 0.96875 0.609375 1.578125q0.625 0.609375 1.625 0.609375q0.953125 0 1.5625 -0.609375q0.625 -0.609375 0.625 -1.484375q0 -0.921875 -0.640625 -1.546875q-0.625 -0.625 -1.578125 -0.625q-0.953125 0 -1.578125 0.609375q-0.625 0.609375 -0.625 1.46875zm-0.546875 6.28125q0 0.71875 0.328125 1.390625q0.34375 0.65625 1.015625 1.03125q0.671875 0.359375 1.4375 0.359375q1.203125 0 1.984375 -0.765625q0.78125 -0.78125 0.78125 -1.96875q0 -1.203125 -0.8125 -1.984375q-0.796875 -0.796875 -2.0 -0.796875q-1.1875 0 -1.96875 0.78125q-0.765625 0.78125 -0.765625 1.953125zm8.031982 3.921875l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.984375 0l-3.390625 4.640625l3.65625 5.21875l-2.046875 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm16.265625 0l-1.671875 0l0 -10.640625q-0.59375 0.578125 -1.578125 1.15625q-0.984375 0.5625 -1.765625 0.859375l0 -1.625q1.40625 -0.65625 2.453125 -1.59375q1.046875 -0.9375 1.484375 -1.8125l1.078125 0l0 13.65625zm12.860107 -1.609375l0 1.609375l-8.984375 0q-0.015625 -0.609375 0.1875 -1.15625q0.34375 -0.921875 1.09375 -1.8125q0.765625 -0.890625 2.1875 -2.0625q2.21875 -1.8125 3.0 -2.875q0.78125 -1.0625 0.78125 -2.015625q0 -0.984375 -0.71875 -1.671875q-0.703125 -0.6875 -1.84375 -0.6875q-1.203125 0 -1.9375 0.734375q-0.71875 0.71875 -0.71875 2.0l-1.71875 -0.171875q0.171875 -1.921875 1.328125 -2.921875q1.15625 -1.015625 3.09375 -1.015625q1.953125 0 3.09375 1.09375q1.140625 1.078125 1.140625 2.6875q0 0.8125 -0.34375 1.609375q-0.328125 0.78125 -1.109375 1.65625q-0.765625 0.859375 -2.5625 2.390625q-1.5 1.265625 -1.9375 1.71875q-0.421875 0.4375 -0.703125 0.890625l6.671875 0zm4.172577 -5.765625q-1.046875 -0.375 -1.546875 -1.078125q-0.5 -0.71875 -0.5 -1.703125q0 -1.484375 1.0625 -2.484375q1.078125 -1.015625 2.84375 -1.015625q1.78125 0 2.859375 1.03125q1.09375 1.03125 1.09375 2.515625q0 0.953125 -0.5 1.65625q-0.484375 0.703125 -1.5 1.078125q1.25 0.40625 1.90625 1.3125q0.65625 0.90625 0.65625 2.171875q0 1.75 -1.234375 2.9375q-1.234375 1.1875 -3.25 1.1875q-2.015625 0 -3.25 -1.1875q-1.234375 -1.203125 -1.234375 -2.984375q0 -1.328125 0.671875 -2.21875q0.671875 -0.890625 1.921875 -1.21875zm-0.328125 -2.828125q0 0.96875 0.609375 1.578125q0.625 0.609375 1.625 0.609375q0.953125 0 1.5625 -0.609375q0.625 -0.609375 0.625 -1.484375q0 -0.921875 -0.640625 -1.546875q-0.625 -0.625 -1.578125 -0.625q-0.953125 0 -1.578125 0.609375q-0.625 0.609375 -0.625 1.46875zm-0.546875 6.28125q0 0.71875 0.328125 1.390625q0.34375 0.65625 1.015625 1.03125q0.671875 0.359375 1.4375 0.359375q1.203125 0 1.984375 -0.765625q0.78125 -0.78125 0.78125 -1.96875q0 -1.203125 -0.8125 -1.984375q-0.796875 -0.796875 -2.0 -0.796875q-1.1875 0 -1.96875 0.78125q-0.765625 0.78125 -0.765625 1.953125zm10.235077 7.921875l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m348.0 137.69212l0 20.503937" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m348.0 137.69212l0 14.503937" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m346.34827 152.19606l1.6517334 4.538101l1.6517334 -4.538101z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m348.0 217.18819l0 20.818893" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m348.0 217.1882l0 14.818893" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m346.34827 232.0071l1.6517334 4.538101l1.6517334 -4.538101z" fill-rule="evenodd"></path><path fill="#000000" fill-opacity="0.0" d="m253.3 422.01575l190.01573 0l0 41.984253l-190.01573 0z" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="2.0" stroke-linejoin="round" stroke-linecap="butt" d="m253.3 422.01575l190.01573 0l0 41.984253l-190.01573 0z" fill-rule="evenodd"></path><path fill="#000000" d="m269.44388 448.93576l0 -13.59375l6.03125 0q1.8125 0 2.75 0.359375q0.953125 0.359375 1.515625 1.296875q0.5625 0.921875 0.5625 2.046875q0 1.453125 -0.9375 2.453125q-0.921875 0.984375 -2.890625 1.25q0.71875 0.34375 1.09375 0.671875q0.78125 0.734375 1.484375 1.8125l2.375 3.703125l-2.265625 0l-1.796875 -2.828125q-0.796875 -1.21875 -1.3125 -1.875q-0.5 -0.65625 -0.90625 -0.90625q-0.40625 -0.265625 -0.8125 -0.359375q-0.3125 -0.078125 -1.015625 -0.078125l-2.078125 0l0 6.046875l-1.796875 0zm1.796875 -7.59375l3.859375 0q1.234375 0 1.921875 -0.25q0.703125 -0.265625 1.0625 -0.828125q0.375 -0.5625 0.375 -1.21875q0 -0.96875 -0.703125 -1.578125q-0.703125 -0.625 -2.21875 -0.625l-4.296875 0l0 4.5zm18.176086 4.421875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm15.547577 2.265625l1.640625 0.21875q-0.265625 1.6875 -1.375 2.65625q-1.109375 0.953125 -2.734375 0.953125q-2.015625 0 -3.25 -1.3125q-1.21875 -1.328125 -1.21875 -3.796875q0 -1.59375 0.515625 -2.78125q0.53125 -1.203125 1.609375 -1.796875q1.09375 -0.609375 2.359375 -0.609375q1.609375 0 2.625 0.8125q1.015625 0.8125 1.3125 2.3125l-1.625 0.25q-0.234375 -1.0 -0.828125 -1.5q-0.59375 -0.5 -1.421875 -0.5q-1.265625 0 -2.0625 0.90625q-0.78125 0.90625 -0.78125 2.859375q0 1.984375 0.765625 2.890625q0.765625 0.890625 1.984375 0.890625q0.984375 0 1.640625 -0.59375q0.65625 -0.609375 0.84375 -1.859375zm9.34375 3.609375l0 -1.453125q-1.140625 1.671875 -3.125 1.671875q-0.859375 0 -1.625 -0.328125q-0.75 -0.34375 -1.125 -0.84375q-0.359375 -0.5 -0.515625 -1.234375q-0.09375 -0.5 -0.09375 -1.5625l0 -6.109375l1.671875 0l0 5.46875q0 1.3125 0.09375 1.765625q0.15625 0.65625 0.671875 1.03125q0.515625 0.375 1.265625 0.375q0.75 0 1.40625 -0.375q0.65625 -0.390625 0.921875 -1.046875q0.28125 -0.671875 0.28125 -1.9375l0 -5.28125l1.671875 0l0 9.859375l-1.5 0zm3.9069824 0l0 -9.859375l1.5 0l0 1.5q0.578125 -1.046875 1.0625 -1.375q0.484375 -0.34375 1.078125 -0.34375q0.84375 0 1.71875 0.546875l-0.578125 1.546875q-0.609375 -0.359375 -1.234375 -0.359375q-0.546875 0 -0.984375 0.328125q-0.421875 0.328125 -0.609375 0.90625q-0.28125 0.890625 -0.28125 1.953125l0 5.15625l-1.671875 0zm6.212677 0l0 -9.859375l1.5 0l0 1.5q0.578125 -1.046875 1.0625 -1.375q0.484375 -0.34375 1.078125 -0.34375q0.84375 0 1.71875 0.546875l-0.578125 1.546875q-0.609375 -0.359375 -1.234375 -0.359375q-0.546875 0 -0.984375 0.328125q-0.421875 0.328125 -0.609375 0.90625q-0.28125 0.890625 -0.28125 1.953125l0 5.15625l-1.671875 0zm12.978302 -3.171875l1.71875 0.21875q-0.40625 1.5 -1.515625 2.34375q-1.09375 0.828125 -2.8125 0.828125q-2.15625 0 -3.421875 -1.328125q-1.265625 -1.328125 -1.265625 -3.734375q0 -2.484375 1.265625 -3.859375q1.28125 -1.375 3.328125 -1.375q1.984375 0 3.234375 1.34375q1.25 1.34375 1.25 3.796875q0 0.140625 -0.015625 0.4375l-7.34375 0q0.09375 1.625 0.921875 2.484375q0.828125 0.859375 2.0625 0.859375q0.90625 0 1.546875 -0.46875q0.65625 -0.484375 1.046875 -1.546875zm-5.484375 -2.703125l5.5 0q-0.109375 -1.234375 -0.625 -1.859375q-0.796875 -0.96875 -2.078125 -0.96875q-1.140625 0 -1.9375 0.78125q-0.78125 0.765625 -0.859375 2.046875zm9.110077 5.875l0 -9.859375l1.5 0l0 1.40625q1.09375 -1.625 3.140625 -1.625q0.890625 0 1.640625 0.328125q0.75 0.3125 1.109375 0.84375q0.375 0.515625 0.53125 1.21875q0.09375 0.46875 0.09375 1.625l0 6.0625l-1.671875 0l0 -6.0q0 -1.015625 -0.203125 -1.515625q-0.1875 -0.515625 -0.6875 -0.8125q-0.5 -0.296875 -1.171875 -0.296875q-1.0625 0 -1.84375 0.671875q-0.765625 0.671875 -0.765625 2.578125l0 5.375l-1.671875 0zm14.031982 -1.5l0.234375 1.484375q-0.703125 0.140625 -1.265625 0.140625q-0.90625 0 -1.40625 -0.28125q-0.5 -0.296875 -0.703125 -0.75q-0.203125 -0.46875 -0.203125 -1.984375l0 -5.65625l-1.234375 0l0 -1.3125l1.234375 0l0 -2.4375l1.65625 -1.0l0 3.4375l1.6875 0l0 1.3125l-1.6875 0l0 5.75q0 0.71875 0.078125 0.921875q0.09375 0.203125 0.296875 0.328125q0.203125 0.125 0.578125 0.125q0.265625 0 0.734375 -0.078125zm9.897858 5.5q-1.375 -1.75 -2.328125 -4.078125q-0.953125 -2.34375 -0.953125 -4.84375q0 -2.21875 0.703125 -4.234375q0.84375 -2.34375 2.578125 -4.671875l1.203125 0q-1.125 1.921875 -1.484375 2.75q-0.5625 1.28125 -0.890625 2.671875q-0.40625 1.734375 -0.40625 3.484375q0 4.46875 2.78125 8.921875l-1.203125 0zm11.228302 -14.265625l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm13.875702 4.40625l0 -3.25l-5.90625 0l0 -1.53125l6.21875 -8.8125l1.359375 0l0 8.8125l1.84375 0l0 1.53125l-1.84375 0l0 3.25l-1.671875 0zm0 -4.78125l0 -6.140625l-4.25 6.140625l4.25 0zm4.3757324 4.78125l3.59375 -5.125l-3.328125 -4.734375l2.09375 0l1.515625 2.3125q0.421875 0.65625 0.671875 1.109375q0.421875 -0.609375 0.765625 -1.09375l1.65625 -2.328125l1.984375 0l-3.390625 4.640625l3.65625 5.21875l-2.046875 0l-2.03125 -3.0625l-0.53125 -0.828125l-2.59375 3.890625l-2.015625 0zm16.265625 0l-1.671875 0l0 -10.640625q-0.59375 0.578125 -1.578125 1.15625q-0.984375 0.5625 -1.765625 0.859375l0 -1.625q1.40625 -0.65625 2.453125 -1.59375q1.046875 -0.9375 1.484375 -1.8125l1.078125 0l0 13.65625zm4.344452 -3.140625l1.59375 -0.15625q0.203125 1.140625 0.78125 1.65625q0.578125 0.5 1.484375 0.5q0.765625 0 1.34375 -0.34375q0.578125 -0.359375 0.953125 -0.953125q0.375 -0.59375 0.625 -1.59375q0.25 -1.0 0.25 -2.03125q0 -0.109375 -0.015625 -0.34375q-0.5 0.796875 -1.375 1.296875q-0.859375 0.5 -1.875 0.5q-1.6875 0 -2.859375 -1.21875q-1.171875 -1.234375 -1.171875 -3.234375q0 -2.078125 1.21875 -3.328125q1.234375 -1.265625 3.0625 -1.265625q1.328125 0 2.421875 0.71875q1.109375 0.703125 1.671875 2.03125q0.578125 1.328125 0.578125 3.828125q0 2.609375 -0.578125 4.15625q-0.5625 1.546875 -1.6875 2.359375q-1.109375 0.796875 -2.609375 0.796875q-1.59375 0 -2.609375 -0.890625q-1.0 -0.890625 -1.203125 -2.484375zm6.828125 -6.0q0 -1.4375 -0.765625 -2.28125q-0.765625 -0.859375 -1.84375 -0.859375q-1.109375 0 -1.9375 0.921875q-0.828125 0.90625 -0.828125 2.34375q0 1.3125 0.78125 2.125q0.796875 0.796875 1.9375 0.796875q1.171875 0 1.90625 -0.796875q0.75 -0.8125 0.75 -2.25zm11.953857 -1.125l-1.65625 0.125q-0.21875 -0.984375 -0.640625 -1.421875q-0.671875 -0.71875 -1.65625 -0.71875q-0.8125 0 -1.40625 0.4375q-0.796875 0.578125 -1.25 1.6875q-0.453125 1.09375 -0.46875 3.140625q0.609375 -0.921875 1.46875 -1.359375q0.875 -0.453125 1.828125 -0.453125q1.671875 0 2.84375 1.234375q1.171875 1.234375 1.171875 3.171875q0 1.28125 -0.546875 2.390625q-0.546875 1.09375 -1.515625 1.6875q-0.96875 0.578125 -2.1875 0.578125q-2.09375 0 -3.40625 -1.53125q-1.3125 -1.546875 -1.3125 -5.0625q0 -3.953125 1.453125 -5.734375q1.265625 -1.5625 3.421875 -1.5625q1.609375 0 2.625 0.90625q1.03125 0.890625 1.234375 2.484375zm-6.8125 5.859375q0 0.859375 0.359375 1.65625q0.375 0.78125 1.03125 1.203125q0.65625 0.40625 1.375 0.40625q1.0625 0 1.8125 -0.84375q0.765625 -0.859375 0.765625 -2.328125q0 -1.40625 -0.75 -2.21875q-0.75 -0.8125 -1.890625 -0.8125q-1.125 0 -1.921875 0.8125q-0.78125 0.8125 -0.78125 2.125zm10.078827 8.40625l-1.1875 0q2.765625 -4.453125 2.765625 -8.921875q0 -1.734375 -0.390625 -3.453125q-0.328125 -1.390625 -0.890625 -2.671875q-0.359375 -0.84375 -1.484375 -2.78125l1.1875 0q1.75 2.328125 2.578125 4.671875q0.71875 2.015625 0.71875 4.234375q0 2.5 -0.96875 4.84375q-0.953125 2.328125 -2.328125 4.078125z" fill-rule="nonzero"></path><path fill="#000000" fill-opacity="0.0" d="m348.30786 464.0l-0.50393677 24.0" fill-rule="evenodd"></path><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m348.30786 464.0l-0.37799072 18.001312" fill-rule="evenodd"></path><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m346.2785 481.96664l1.5561218 4.5717773l1.7466431 -4.502411z" fill-rule="evenodd"></path></g></svg>
+
diff --git a/tensorflow/contrib/lite/models/testdata/smartreply_samples.tsv b/tensorflow/contrib/lite/models/testdata/smartreply_samples.tsv
new file mode 100644
index 0000000000..dfdc783106
--- /dev/null
+++ b/tensorflow/contrib/lite/models/testdata/smartreply_samples.tsv
@@ -0,0 +1,50 @@
+any chance ur free tonight Maybe not
+any updates? No update yet
+anything i can do to help? No, but thanks No, but thank you No, but thanks for asking
+be safe. I will be Will do my best Thanks, I will
+congratulations Thanks thanks Congratulations
+cool, let me know when you have time Cool Yes very cool Yeah, cool
+drive safe Thank you, I will Home now I will thanks
+hang in there, you'll be okay Doing my best Of course we will
+happy birthday! Hey, thanks
+happy new year! Wish you the same Thanks and same to you
+have a safe flight Thanks, love you too Safe travels
+hey What is up? How it going? Can I help you?
+hey, got a sec? What is up? How it going? Can I help you?
+how are you doing? Great and you? I am doing great
+how are you feeling Feeling okay A little better Much much better
+how was your weekend? It was real good
+how you doing Okay and you
+hugs. So sweet Thanks sweetie Take care of yourself
+i'm bored Sorry to hear that Join the club No you are not
+i'm planning on coming next week. let me know if that works. Works Perfect, thanks
+i'm sick Sorry to hear that
+i'm so happy for you Thanks me too
+i'm so hungry Haha me too
+i'm sorry No I am sorry Why sorry? No worries love
+i'm sorry, i'm going to have to cancel. No I am sorry Why sorry? No worries love
+is there anything i can do to help? No, but thanks No, but thanks for asking
+lunch? Yes coming
+okay. lemme know as soon as you find out. Any more questions? It is done
+omg amazing So amazing
+on my way Okay see you soon Cool, see you soon Oh wow, ok
+oops, mistexted. Oops Haha, oh well That was funny
+safe travels. Thanks, love you too Safe travels
+so sorry So sorry
+sorry, i can't. No worries at all Sorry what?
+sorry, i can't do saturday No worries at all
+thank you so much. You are so welcome You are so very welcome You are most welcome
+thanks for coming It was my pleasure
+thanks, this has been great. Glad to help So happy for you
+tomorrow would be ideal. Yes it would
+tried calling Try again?
+ugh, my flight is delayed. Ugh indeed
+what are you guys up to tonight? Nothing planned
+what day works best for you Any day
+what do you want for dinner Your call Whatever is fine
+what time will you be home? Not sure why
+where are you?!? At my house
+wish you were here. I wish the same Me too honey
+you're amazing You are too You are amazing I am
+you're marvelous You are too
+you're the best. I do my best You are the best Well, I try \ No newline at end of file
diff --git a/tensorflow/contrib/lite/nnapi/README.md b/tensorflow/contrib/lite/nnapi/README.md
new file mode 100644
index 0000000000..913467d176
--- /dev/null
+++ b/tensorflow/contrib/lite/nnapi/README.md
@@ -0,0 +1,15 @@
+# Android Neural Network API
+
+The Android Neural Networks API (NNAPI) is an Android C API designed for running
+computationally intensive operators for machine learning on mobile devices.
+Tensorflow Lite is designed to use the NNAPI to perform hardware-accelerated
+inference operators on supported devices.
+Based on the app’s requirements and the hardware capabilities on a device, the
+NNAPI can distribute the computation workload across available on-device
+processors, including dedicated neural network hardware, graphics processing
+units (GPUs), and digital signal processors (DSPs).
+For devices that lack a specialized vendor driver, the NNAPI runtime relies on
+optimized code to execute requests on the CPU. For more information about the
+NNAPI, please refer to the [NNAPI documentation](https://developer.android.com/ndk/guides/neuralnetworks/index.html)
+
+
diff --git a/tensorflow/contrib/lite/toco/README.md b/tensorflow/contrib/lite/toco/README.md
new file mode 100644
index 0000000000..281b2ea5e4
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/README.md
@@ -0,0 +1,26 @@
+# The TensorFlow Lite Optimizing Converter
+
+The TensorFlow Lite Optimizing Converter's most typical use is converting from the TensorFlow GraphDef to the TensorFlow Lite
+format, but it supports much more than that.
+
+## Usage documentation
+
+Usage information is given in these documents:
+
+* [Command-line examples](g3doc/cmdline_examples.md)
+* [Command-line reference](g3doc/cmdline_reference.md)
+* [Python API](g3doc/python_api.md)
+
+## Design documentation
+
+Coming soon!
+
+## Where the converter fits in the TensorFlow landscape
+
+In the typical case, an application developer is using TensorFlow to design and
+train models, then uses TensorFlow's freeze_graph.py to generate a frozen
+inference graph, then uses the converter to convert that into a TensorFlow Lite flatbuffer file,
+then ships that file to client devices where the TensorFlow Lite interpreter handles them
+on-device. This is represented in the following diagram:
+
+![drawing](https://storage.googleapis.com/download.tensorflow.org/example_images/tensorflow_landscape.svg)
diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
new file mode 100644
index 0000000000..b9f8c8d152
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
@@ -0,0 +1,509 @@
+# TensorFlow Lite Optimizing Converter command-line examples
+
+This page is a guide to using the TensorFlow Lite Optimizing Converter by
+looking at some example command lines. It is complemented by the following other
+documents:
+
+* [README](../README.md)
+* [Command-line reference](cmdline_reference.md)
+
+Table of contents:
+
+[TOC]
+
+## Convert a TensorFlow GraphDef to TensorFlow Lite for float inference
+
+In this example, we look at the most common task: we have an ordinary TensorFlow
+GraphDef and want to convert it to a TensorFlow Lite flatbuffer to perform
+floating-point inference.
+
+```
+curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \
+ | tar xzv -C /tmp
+bazel run --config=opt \
+ //tensorflow/contrib/lite/toco:toco -- \
+ --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
+ --output_file=/tmp/foo.lite \
+ --input_format=TENSORFLOW_GRAPHDEF \
+ --output_format=TFLITE \
+ --input_type=FLOAT \
+ --inference_type=FLOAT \
+ --input_shape=1,128,128,3 \
+ --input_array=input \
+ --output_array=MobilenetV1/Predictions/Reshape_1
+```
+
+To explain each of these flags:
+
+* `--input_format` and `--output_format` determine the formats of the input
+ and output files: here we are converting from `TENSORFLOW_GRAPHDEF` to
+ `TFLITE`.
+* `--input_file` specifies the path of the input file, to be converted. When
+ `--input_format=TENSORFLOW_GRAPHDEF`, this file should be a
+ *[frozen](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)*
+ *inference* graph. Being frozen means in particular that the input file is
+ self-contained, and does not reference any external "checkpoint" file. An
+ *inference* graph is a version of a graph meant to be used for inference,
+ typically not the same graph file as was used for training a given model.
+* `--output_file` specifies the destination to write the converted file to.
+* `--input_array` specifies the input activations, that is, the input "tensor"
+ in the input TensorFlow GraphDef file. The array designated by
+ `--input_array` is the one that the user will have to provide the contents
+ of as input to the runtime inference code.
+* `--output_array` specifies the output activations, that is, the output
+ "tensor" in the input TensorFlow GraphDef file. The runtime inference code
+ will store its results in the array designated by `--output_array`.
+* `--input_shape` specifies the shape of the input array. It is currently
+ required, but the plan is for a future version to no longer require it,
+ allowing to defer the specification of the input shape until runtime. The
+ format of `input_shape` is always a comma-separated list of dimensions,
+ always in TensorFlow convention.
+* `--input_type` specifies what should be the type of the input arrays in the
+ **output** file. `--input_type` does not describe a property of the input
+ file: the type of input arrays is already encoded in the input graph.
+ Rather, `--input_type` is how you specify what should be the type of the
+ inputs to be provided to the output converted graph. This only affects
+ arrays of real numbers: this flag allows to quantized/dequantize
+ real-numbers inputs, switching between floating-point and quantized forms.
+ This flag has no incidence on all other types of input arrays, such as plain
+ integers or strings.
+* `--inference_type` specifies what type of arithmetic the output file should
+ be relying on. It implies in particular the choice of type of the output
+ arrays in the output file. Like `--input_type`, `--inference_type` does not
+ describe a property of the input file.
+
+## Just optimize a TensorFlow GraphDef
+
+The converter accepts both TENSORFLOW_GRAPHDEF and TFLITE file formats as both
+`--input_format` and `--output_format`. This means that conversion from and to
+any supported format is possible, and in particular, same-format "conversions"
+are possible, and effectively ask the converter to optimize and simplify a
+graph. Example:
+
+```
+curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \
+ | tar xzv -C /tmp
+bazel run --config=opt \
+ //tensorflow/contrib/lite/toco:toco -- \
+ --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
+ --output_file=/tmp/foo.pb \
+ --input_format=TENSORFLOW_GRAPHDEF \
+ --output_format=TENSORFLOW_GRAPHDEF \
+ --input_shape=1,128,128,3 \
+ --input_array=input \
+ --output_array=MobilenetV1/Predictions/Reshape_1
+```
+
+Here we did not pass `--input_type` and `--inference_type` because they are
+considered not applicable to the TensorFlow GraphDef format (as far as we are
+concerned, TensorFlow GraphDefs are technically always float, and the only
+flavor of "quantized" GraphDef that the converter deals with is "FakeQuantized"
+graphs that are still technically float graphs).
+
+Below in the section about passing arbitrary input/output arrays we give another
+example, using the converter to extract just a sub-graph from a TensorFlow
+GraphDef.
+
+## Convert a TensorFlow Lite flatbuffer back into TensorFlow GraphDef format
+
+As we mentioned that the converter supports file format conversions in any
+direction, let us just give an example of that:
+
+```
+bazel run --config=opt \
+ //tensorflow/contrib/lite/toco:toco -- \
+ --input_file=/tmp/foo.lite \
+ --output_file=/tmp/foo.pb \
+ --input_format=TFLITE \
+ --output_format=TENSORFLOW_GRAPHDEF \
+ --input_shape=1,128,128,3 \
+ --input_array=input \
+ --output_array=MobilenetV1/Predictions/Reshape_1
+```
+
+## Convert a TensorFlow GraphDef to TensorFlow Lite for quantized inference
+
+Let us now look at a quantized model. As mentioned above, the only flavor of
+quantized TensorFlow GraphDefs that the converter is concerned with, is
+"FakeQuantized" models. These are technically float models, but with special
+`FakeQuant*` ops inserted at the boundaries of fused layers to record min-max
+range information allowing to generate a quantized inference workload that is
+able to reproduce exactly the specific quantization behavior that was used
+during training. Indeed, the whole point of quantized training is to allow for
+both training and inference to perform exactly the same arithmetic, so that the
+way that the training process about around quantization inaccuracy is
+effectively helping the quantized inference process to be more accurate.
+
+Given a quantized TensorFlow GraphDef, generating a quantized TensorFlow Lite
+flatbuffer is done like this:
+
+```
+bazel run --config=opt \
+ //tensorflow/contrib/lite/toco:toco -- \
+ --input_file=/tmp/some_quantized_graph.pb \
+ --output_file=/tmp/foo.lite \
+ --input_format=TENSORFLOW_GRAPHDEF \
+ --output_format=TFLITE \
+ --input_type=QUANTIZED_UINT8 \
+ --inference_type=QUANTIZED_UINT8 \
+ --input_shape=1,128,128,3 \
+ --input_array=input \
+ --output_array=MobilenetV1/Predictions/Reshape_1 \
+ --mean_value=128 \
+ --std_value=127
+```
+
+Here, besides changing `--input_file` to point to a (fake-)quantized GraphDef,
+the only other changes are:
+
+* To change `--input_type` and `--inference_type` to `QUANTIZED_UINT8`. This
+ effectively tells the converter to generate an output file that can take a
+ quantized uint8 array as input (`--input_type=QUANTIZED_UINT8`), and have
+ quantized uint8 internal and output arrays as well
+ (`--inference_type=QUANTIZED_UINT8`).
+* To pass `--mean_value` and `--std_value` flags to describe how the quantized
+ uint8 input array values are to be interpreted as the mathematical real
+ numbers that the graph is concerned with (keep in mind that even a
+ "fake-quantized" TensorFlow GraphDef is still technically a float graph).
+ The meaning of `--mean_value` and `--std_value` is explained in the
+ command-line reference; it suffices for now to say that they are a property
+ of each model.
+
+## Use dummy-quantization to try out quantized inference on a float graph
+
+Sometimes, one only has a plain float graph, and one is curious as to how much
+faster inference might run if one could perform quantized inference instead of
+float inference. Rather than requiring users to first invest in quantizing their
+graphs before they can evaluate a possible benefit, the converter allows to
+simply experiment with what we call "dummy quantization": provide some vaguely
+plausible values for the min-max ranges of values in all arrays that do not have
+min-max information, so that quantization can carry on, certainly producing
+inaccurate results (do not use that in production!) but with performance
+characteristics that should be identical to those of an actually quantized
+flavor of the model.
+
+In the present example, we have a model using Relu6 activation functions almost
+everywhere, so a reasonable guess is that most activation ranges should be
+contained in [0, 6] and roughly comparable to it.
+
+```
+curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \
+ | tar xzv -C /tmp
+bazel run --config=opt \
+ //tensorflow/contrib/lite/toco:toco -- \
+ --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
+ --output_file=/tmp/foo.cc \
+ --input_format=TENSORFLOW_GRAPHDEF \
+ --output_format=TFLITE \
+ --input_type=QUANTIZED_UINT8 \
+ --inference_type=QUANTIZED_UINT8 \
+ --input_shape=1,128,128,3 \
+ --input_array=input \
+ --output_array=MobilenetV1/Predictions/Reshape_1 \
+ --default_ranges_min=0 \
+ --default_ranges_max=6 \
+ --mean_value=127.5 \
+ --std_value=127.5
+```
+
+## Multiple output arrays
+
+Some models have multiple outputs. Even in a model with only one output, you may
+want for the inference code to return the contents of other arrays as well, or
+to perform inference on a subgraph with multiple outputs (see the section below
+on specifying arbitrary arrays as input/output arrays).
+
+Either way, using `--output_arrays` instead of `--output_array` allows to
+specify a comma-separated list of output arrays.
+
+```
+curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \
+ | tar xzv -C /tmp
+bazel run --config=opt \
+ //tensorflow/contrib/lite/toco:toco -- \
+ --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \
+ --output_file=/tmp/foo.lite \
+ --input_format=TENSORFLOW_GRAPHDEF \
+ --output_format=TFLITE \
+ --input_type=FLOAT \
+ --inference_type=FLOAT \
+ --input_shape=1,224,224,3 \
+ --input_array=input \
+ --output_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu
+```
+
+## Multiple input arrays
+
+Some models have multiple inputs; even in a model with a single input, you may
+want for the inference code to implement only a subgraph with multiple inputs
+(see the section below on specifying arbitrary arrays as input/output arrays).
+
+Either way, multiple input arrays are specified by using `--input_arrays`
+instead of `--input_array` to specify a comma-separated list of input arrays. In
+that case, one also needs to use `--input_shapes` instead of `--input_shape`.
+The syntax for `--input_shapes` is a bit trickier, since already the singular
+`--input_shape` was a comma-separated list of integers! Multiple input shapes
+are delimited by a colon (`:`) in `--input_shapes`.
+
+```
+curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \
+ | tar xzv -C /tmp
+bazel run --config=opt \
+ //tensorflow/contrib/lite/toco:toco -- \
+ --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \
+ --output_file=/tmp/foo.lite \
+ --input_format=TENSORFLOW_GRAPHDEF \
+ --output_format=TFLITE \
+ --input_type=FLOAT \
+ --inference_type=FLOAT \
+ --input_shapes=1,28,28,96:1,28,28,16:1,28,28,192:1,28,28,64 \
+ --input_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_3/MaxPool_0a_3x3/MaxPool,InceptionV1/InceptionV1/Mixed_3b/Branch_0/Conv2d_0a_1x1/Relu \
+ --output_array=InceptionV1/Logits/Predictions/Reshape_1
+```
+
+## Specifying arbitrary arrays in a graph as input or output arrays
+
+Any array in the input file can be specified as an input or output array. This
+allows to use the converter to extract a sub-graph out of the input graph file.
+The converter then automatically discards any part of the graph that is not
+needed for the subgraph identified by the specified input and output arrays.
+Another use case for specifying multiple output arrays is to get inference code
+to return the contents of some specified intermediate activations array, not
+just the output activations.
+
+In order to know which array you want to pass as `--input_arrays` /
+`--output_arrays`, it helps to have a visualization of the graph. See the
+section below on graph visualization. When using graph visualization for that
+purpose, make sure to use `--dump_graphviz=` to visualize exactly the graph as
+it is in the actual final form being exported to the output file.
+
+Note that the final representation of an on-device inference workload (say, in
+TensorFlow Lite flatbuffers format) tends to have coarser granularity than the
+very fine granularity of the TensorFlow GraphDef representation. For example,
+while a fully-connected layer is typically represented as at least four separate
+ops in TensorFlow GraphDef (Reshape, MatMul, BiasAdd, Relu...), it is typically
+represented as a single "fused" op (FullyConnected) in the converter's optimized
+representation and in the final on-device representation (e.g. in TensorFlow
+Lite flatbuffer format). As the level of granularity gets coarser, some
+intermediate arrays (say, the array between the MatMul and the BiasAdd in the
+TensorFlow GraphDef) are dropped. When specifying intermediate arrays as
+`--input_arrays` / `--output_arrays`, it is generally at least desirable (and
+often required) to specify arrays that are meant to survive in the final form of
+the graph, after fusing. These are typically the outputs of activation functions
+(since everything in each layer until the activation function tends to get
+fused).
+
+Here is an example of extracting just a sub-graph, namely just a single fused
+layer, out of a TensorFlow GraphDef, and exporting a TensorFlow GraphDef
+containing just that subgraph:
+
+```
+curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \
+ | tar xzv -C /tmp
+bazel run --config=opt \
+ //tensorflow/contrib/lite/toco:toco -- \
+ --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \
+ --output_file=/tmp/foo.pb \
+ --input_format=TENSORFLOW_GRAPHDEF \
+ --output_format=TENSORFLOW_GRAPHDEF \
+ --input_shapes=1,28,28,96:1,28,28,16:1,28,28,192:1,28,28,64 \
+ --input_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_3/MaxPool_0a_3x3/MaxPool,InceptionV1/InceptionV1/Mixed_3b/Branch_0/Conv2d_0a_1x1/Relu \
+ --output_array=InceptionV1/InceptionV1/Mixed_3b/concat_v2
+```
+
+## Logging
+
+### Standard logging
+
+The converter generates some informative log messages during processing. The
+easiest way to view them is to add `--logtostderr` to command lines. For the
+previous example, that gives:
+
+```
+curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \
+ | tar xzv -C /tmp
+bazel run --config=opt \
+ //tensorflow/contrib/lite/toco:toco -- \
+ --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
+ --output_file=/tmp/foo.lite \
+ --input_format=TENSORFLOW_GRAPHDEF \
+ --output_format=TFLITE \
+ --input_type=FLOAT \
+ --inference_type=FLOAT \
+ --input_shape=1,128,128,3 \
+ --input_array=input \
+ --output_array=MobilenetV1/Predictions/Reshape_1 \
+ --logtostderr
+```
+
+After some initialization messages, we get the following informative messages:
+
+```
+I1101 21:51:33.297475 5339 graph_transformations.cc:39] Before general graph transformations: 416 operators, 583 arrays (0 quantized)
+I1101 21:51:33.308972 5339 graph_transformations.cc:39] After general graph transformations pass 1: 31 operators, 89 arrays (0 quantized)
+I1101 21:51:33.309204 5339 graph_transformations.cc:39] Before dequantization graph transformations: 31 operators, 89 arrays (0 quantized)
+I1101 21:51:33.309368 5339 allocate_transient_arrays.cc:312] Total transient array allocated size: 1048576 bytes, theoretical optimal value: 786432 bytes.
+I1101 21:51:33.309484 5339 toco_tooling.cc:249] Estimated count of arithmetic ops: 0.099218 billion (note that a multiply-add is counted as 2 ops).
+```
+
+### Verbose logging
+
+For debugging purposes, the converter supports two levels of verbose logging,
+which can be set by passing a `--v=` flag:
+
+* At `--v=1`, the converter generates text dumps of the graph at various
+ points during processing, as well as log messages about every graph
+ transformation that did take place, typically answering questions of the
+ form "why was my graph transformed in this way"?
+* At `--v=2`, the converter additionally generates log messages about graph
+ transformations that were considered but not actually performed, typically
+ answering questions of the form "why was my graph NOT transformed when I
+ expected it would be?".
+
+### Graph "video" logging
+
+When `--dump_graphviz=` is used (see the section on Graph visualizations), one
+may additionally pass `--dump_graphviz_video`, which causes a graph
+visualization to be dumped after each individual graph transformations, often
+resulting in thousands of files. Typically, one would then bisect into these
+files to understand when a given change was introduced in the graph.
+
+## Graph visualizations
+
+The converter is able to export a graph to the GraphViz Dot format, for easy
+visualization. Combined with the converter's ability to transform the graph into
+a simpler, coarser-granularity representation, that makes it a very powerful
+visualization tool.
+
+There are two ways to get the converter to export a GraphViz Dot file,
+corresponding to two separate use cases. Understanding the difference between
+them is key to getting useful graph visualizations.
+
+### Using `--output_format=GRAPHVIZ_DOT`
+
+The first way to get a graphviz rendering is to pass
+`--output_format=GRAPHVIZ_DOT`, instead of the `--output_format` that you would
+otherwise use. This says: "I just want to get a plausible visualization of that
+graph". The upside is that it makes for very simple command lines, and makes the
+converter very lax about aspects of the graph or the command line that it would
+otherwise complain about. Example:
+
+```
+curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \
+ | tar xzv -C /tmp
+bazel run --config=opt \
+ //tensorflow/contrib/lite/toco:toco -- \
+ --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
+ --output_file=/tmp/foo.dot \
+ --input_format=TENSORFLOW_GRAPHDEF \
+ --output_format=GRAPHVIZ_DOT \
+ --input_shape=1,128,128,3 \
+ --input_array=input \
+ --output_array=MobilenetV1/Predictions/Reshape_1
+```
+
+The resulting `.dot` file can be rendered into a PDF as follows:
+
+```
+dot -Tpdf -O /tmp/foo.dot
+```
+
+And the resulting `.dot.pdf` can be viewed in any PDF viewer, but we suggest one
+with a good ability to pan and zoom across a very large page; Google Chrome does
+well in that respect.
+
+```
+google-chrome /tmp/foo.dot.pdf
+```
+
+Example PDF files are viewable online in the next section.
+
+### Using `--dump_graphviz=`
+
+The second way to get a graphviz rendering is to pass a `--dump_graphviz=` flag
+specifying a destination directory to dump GraphViz rendering to. Unlike the
+previous approach, this one allows you to keep your real command-line (with your
+real `--output_format` and other flags) unchanged, just appending a
+`--dump_graphviz=` flag to it. This says: "I want visualizations of the actual
+graph during this specific conversion process". Example:
+
+```
+curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \
+ | tar xzv -C /tmp
+bazel run --config=opt \
+ //tensorflow/contrib/lite/toco:toco -- \
+ --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
+ --output_file=/tmp/foo.lite \
+ --input_format=TENSORFLOW_GRAPHDEF \
+ --output_format=TFLITE \
+ --input_type=FLOAT \
+ --inference_type=FLOAT \
+ --input_shape=1,128,128,3 \
+ --input_array=input \
+ --output_array=MobilenetV1/Predictions/Reshape_1 \
+ --dump_graphviz=/tmp
+```
+
+This generates a few files in the destination directory, here `/tmp`. Most
+important are these two files:
+
+```
+/tmp/toco_AT_IMPORT.dot
+/tmp/toco_AFTER_TRANSFORMATIONS.dot
+```
+
+`toco_AT_IMPORT.dot` represents the graph as it was imported from
+`--input_file`, before any transformation was applied to it (besides some
+transformations that are applied immediately while importing). This tends to be
+a complex visualization with limited information, but is useful especially in
+situations where a conversion command fails (this file is generated even if the
+conversion subsequently fails).
+
+`toco_AFTER_TRANSFORMATIONS.dot` represents the graph after all transformations
+were applied to it, just before it was exported to the `--output_file`.
+Typically, this is a much smaller graph, and it conveys much more information
+about each node.
+
+Again, these can be rendered to PDFs:
+
+```
+dot -Tpdf -O /tmp/toco_*.dot
+```
+
+The resulting files can be seen here:
+
+* [toco_AT_IMPORT.dot.pdf](https://storage.googleapis.com/download.tensorflow.org/example_images/toco_AT_IMPORT.dot.pdf)
+* [toco_AFTER_TRANSFORMATIONS.dot.pdf](https://storage.googleapis.com/download.tensorflow.org/example_images/toco_AFTER_TRANSFORMATIONS.dot.pdf).
+
+### Legend for the graph visualizations
+
+* Operators are red square boxes with the following hues of red:
+ * Most operators are
+ <span style="background-color:#db4437;color:white;border:1px;border-style:solid;border-color:black;padding:1px">bright
+ red</span>.
+ * Some typically heavy operators (e.g. Conv) are rendered in a
+ <span style="background-color:#c53929;color:white;border:1px;border-style:solid;border-color:black;padding:1px">darker
+ red</span>.
+* Arrays are octogons with the following colors:
+ * Constant arrays are
+ <span style="background-color:#4285f4;color:white;border:1px;border-style:solid;border-color:black;padding:1px">blue</span>.
+ * Activation arrays are gray:
+ * Internal (intermediate) activation arrays are
+ <span style="background-color:#f5f5f5;border:1px;border-style:solid;border-color:black;border:1px;border-style:solid;border-color:black;padding:1px">light
+ gray</span>.
+ * Those activation arrays that are designated as `--input_arrays` or
+ `--output_arrays` are
+ <span style="background-color:#9e9e9e;border:1px;border-style:solid;border-color:black;padding:1px">dark
+ gray</span>.
+ * RNN state arrays are green. Because of the way that the converter
+ represents RNN back-edges explicitly, each RNN state is represented by a
+ pair of green arrays:
+ * The activation array that is the source of the RNN back-edge (i.e.
+ whose contents are copied into the RNN state array after having been
+ computed) is
+ <span style="background-color:#b7e1cd;border:1px;border-style:solid;border-color:black;padding:1px">light
+ green</span>.
+ * The actual RNN state array is
+ <span style="background-color:#0f9d58;color:white;border:1px;border-style:solid;border-color:black;padding:1px">dark
+ green</span>. It is the destination of the RNN back-edge updating
+ it.
diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
new file mode 100644
index 0000000000..cc6d416959
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
@@ -0,0 +1,238 @@
+# TensorFlow Lite Optimizing Converter command-line reference
+
+This page is complete reference of command-line flags. It is complemented by the
+following other documents:
+
+* [README](../README.md)
+* [Command-line examples](cmdline_examples.md)
+
+Table of contents:
+
+[TOC]
+
+## High-level overview
+
+A full list and detailed specification of all flags is given in the next
+section. For now we focus on a higher-level description of command lines:
+
+```
+toco \
+ --input_format=... \
+ --output_format=... \
+ --input_file=... \
+ --output_file=... \
+ [model flags...] \
+ [transformation flags...] \
+ [logging flags...]
+```
+
+In other words, the converter requires at least the following mandatory flags:
+`--input_format`, `--output_format`, `--input_file`, `--output_file`. Depending
+on the input and output formats, additional flags may be allowed or mandatory:
+
+* *Model flags* provide additional information about the model stored in the
+ input file.
+ * `--output_array` or `--output_arrays` specify which arrays in the input
+ file are to be considered the output activations.
+ * `--input_array` or `--input_arrays` specify which arrays in the input
+ file are to be considered the input activations.
+ * `--input_shape` or `--input_shapes` specify the shapes of the input
+ arrays.
+ * `--mean_value` or `--mean_values`, and `--std_value` or `--std_values`,
+ give the dequantization parameters of the input arrays, for the case
+ when the output file will accept quantized input arrays.
+* *Transformation flags* specify options of the transformations to be applied
+ to the graph, i.e. they specify requested properties that the output file
+ should have.
+ * `--input_type` specifies the type that the input arrays should have
+ after transformations, in the output file. This is where you choose
+ whether you want runtime inference code to accept float or quantized
+ inputs. This flag only applies to float or quantized inputs, and allows
+ to convert between the two. This flag has no effect on all other types
+ of inputs, such as ordinary integer arrays.
+ * `--inference_type` or `--inference_types` specify the type that generic
+ intermediate and output activation arrays should have after
+ transformations, in the output file. This is where you choose whether
+ you want runtime inference code to perform float or quantized inference
+ arithmetic.
+ * Some transformation flags allow to carry on with quantization when the
+ input graph is not properly quantized: `--default_ranges_min`,
+ `--default_ranges_max`, `--drop_fake_quant`,
+ `--reorder_across_fake_quant`.
+* *Logging flags* described below.
+
+## Command-line flags complete reference
+
+### Mandatory flags
+
+* `--input_format`. Type: string. Specifies the format of the input file.
+ Allowed values:
+ * `TENSORFLOW_GRAPHDEF` &mdash; The TensorFlow GraphDef format. Both
+ binary and text proto formats are allowed.
+ * `TFLITE` &mdash; The TensorFlow Lite flatbuffers format.
+* `--output_format`. Type: string. Specifies the format of the output file.
+ Allowed values:
+ * `TENSORFLOW_GRAPHDEF` &mdash; The TensorFlow GraphDef format. Always
+ produces a file in binary (not text) proto format.
+ * `TFLITE` &mdash; The TensorFlow Lite flatbuffers format.
+ * Whether a float or quantized TensorFlow Lite file will be produced
+ depends on the `--inference_type` flag.
+ * Whether the produced TensorFlow Lite file will accept a float or
+ quantized input depends on the `--input_type` flag.
+ * `GRAPHVIZ_DOT` &mdash; The GraphViz `.dot` format. This asks the
+ converter to generate a reasonable graphical representation of the graph
+ after simplification by a generic set of transformation.
+ * A typical `dot` command line to view the resulting graph might look
+ like: `dot -Tpdf -O file.dot`.
+ * Note that since passing this `--output_format` means losing the
+ information of which output format you actually care about, and
+ since the converter's transformations depend on the specific output
+ format, the resulting visualization may not fully reflect what you
+ would get on the actual output format that you are using. To avoid
+ that concern, and generally to get a visualization of exactly what
+ you get in your actual output format as opposed to just a merely
+ plausible visualization of a model, consider using `--dump_graphviz`
+ instead and keeping your true `--output_format`.
+* `--input_file`. Type: string. Specifies the path of the input file. This may
+ be either an absolute or a relative path.
+* `--output_file`. Type: string. Specifies the path of the output file.
+
+### Model flags
+
+* `--output_array`. Type: string. Specifies a single array as the output
+ activations. Incompatible with `--output_arrays`.
+* `--output_arrays`. Type: comma-separated list of strings. Specifies a list
+ of arrays as the output activations, for models with multiple outputs.
+ Incompatible with `--output_array`.
+* `--input_array`. Type: string. Specifies a single array as the input
+ activations. Incompatible with `--input_arrays`.
+* `--input_arrays`. Type: comma-separated list of strings. Specifies a list of
+ arrays as the input activations, for models with multiple inputs.
+ Incompatible with `--input_array`.
+
+When `--input_array` is used, the following flags are available to provide
+additional information about the single input array:
+
+* `--input_shape`. Type: comma-separated list of integers. Specifies the shape
+ of the input array, in TensorFlow convention: starting with the outer-most
+ dimension (the dimension corresponding to the largest offset stride in the
+ array layout), ending with the inner-most dimension (the dimension along
+ which array entries are typically laid out contiguously in memory).
+ * For example, a typical vision model might pass
+ `--input_shape=1,60,80,3`, meaning a batch size of 1 (no batching), an
+ input image height of 60, an input image width of 80, and an input image
+ depth of 3, for the typical case where the input image is a RGB bitmap
+ (3 channels, depth=3) stored by horizontal scanlines (so 'width' is the
+ next innermost dimension after 'depth').
+* `--mean_value` and `--std_value`. Type: floating-point. The decimal point
+ character is always the dot (`.`) regardless of the locale. These specify
+ the (de-)quantization parameters of the input array, to use when the output
+ file will take a quantized input array (that is, when passing
+ `--input_type=QUANTIZED_UINT8`).
+ * The meaning of mean_value and std_value is as follows: each quantized
+ value in the quantized input array will be interpreted as a mathematical
+ real number (i.e. as an input activation value) according to the
+ following formula:
+ * `real_value = (quantized_input_value - mean_value) / std_value`.
+ * When performing float inference (`--inference_type=FLOAT`) on a
+ quantized input, the quantized input would be immediately dequantized by
+ the inference code according to the above formula, before proceeding
+ with float inference.
+ * When performing quantized inference
+ (`--inference_type=QUANTIZED_UINT8`), no dequantization is ever to be
+ performed by the inference code; however, the quantization parameters of
+ all arrays, including those of the input arrays as specified by
+ mean_value and std_value, all participate in the determination of the
+ fixed-point multipliers used in the quantized inference code.
+
+When `--input_arrays` is used, the following flags are available to provide
+additional information about the multiple input arrays:
+
+* `--input_shapes`. Type: colon-separated list of comma-separated lists of
+ integers. Each comma-separated list of integer gives the shape of one of the
+ input arrays specified in `--input_arrays`, in the same order. See
+ `--input_shape` for details.
+ * Example: `--input_arrays=foo,bar --input_shapes=2,3:4,5,6` means that
+ there are two input arrays. The first one, "foo", has shape [2,3]. The
+ second one, "bar", has shape [4,5,6].
+* `--mean_values`, `--std_values`. Type: comma-separated lists of
+ floating-point numbers. Each number gives the corresponding value for one of
+ the input arrays specified in `--input_arrays`, in the same order. See
+ `--mean_value`, `--std_value` for details.
+
+### Transformation flags
+
+* `--input_type`. Type: string. Specifies what should be the type of the
+ entries in the input array(s) in the output file, after transformations, for
+ those input arrays that are originally either floating-point or quantized
+ real numbers in the input file. If there are multiple such input arrays,
+ then they all use this type. Input arrays of other types, such as arrays of
+ plain integers or strings, are not concerned with this flag. Allowed values:
+ * `FLOAT` &mdash; Keep floating-point input arrays as such. Dequantize any
+ quantized input array. entries ("float32").
+ * `QUANTIZED_UINT8` &mdash; Quantize floating-point input arrays, to have
+ 8-bit unsigned integer entries. The quantization params are specified by
+ `--mean_value`, `--std_value` flags as explained in the documentation of
+ these flags.
+* `--inference_type`. Type: string. Specifies what to do with floating-point
+ arrays found in the input file, besides input arrays. In other words, this
+ controls the possible quantization of floating-point weights, intermediate
+ activations, and output activations. Has no effect on arrays that aren't
+ floating-point in the input file. Allowed values:
+ * `FLOAT` &mdash; Keep floating-point arrays as floating-point in the
+ output file. This corresponds to what is commonly called "floating-point
+ inference".
+ * `QUANTIZED_UINT8` &mdash; Quantize floating-point arrays, changing their
+ storage data type from float to some integer type:
+ * All float activations are quantized as `uint8`.
+ * Almost all float weights are quantized as `uint8`.
+ * A few exceptions exist. In particular, the bias-vectors in
+ "Conv" and "FullyConnected" layers are quantized as `int32`
+ instead for technical reasons.
+* `--default_ranges_min`, `--default_ranges_max`. Type: floating-point. The
+ decimal point character is always the dot (`.`) regardless of the locale.
+ These flags enable what is called "dummy quantization". If defined, their
+ effect is to define fallback (min, max) range values for all arrays that do
+ not have a properly specified (min, max) range in the input file, thus
+ allowing to proceed with quantization of non-quantized or
+ incorrectly-quantized input files. This enables easy performance prototyping
+ ("how fast would my model run if I quantized it?") but should never be used
+ in production as the resulting quantized arithmetic is inaccurate.
+* `--drop_fake_quant`. Type: boolean. Default: false. Causes fake-quantization
+ nodes to be dropped from the graph. This may be used to recover a plain
+ float graph from a fake-quantized graph.
+* `--reorder_across_fake_quant`. Type: boolean. Default: false. Normally,
+ fake-quantization nodes must be strict boundaries for graph transformations,
+ in order to ensure that quantized inference has the exact same arithmetic
+ behavior as quantized training --- which is the whole point of quantized
+ training and of FakeQuant nodes in the first place. However, that entails
+ subtle requirements on where exactly FakeQuant nodes must be placed in the
+ graph. Some quantized graphs have FakeQuant nodes at unexpected locations,
+ that prevent graph transformations that are necessary in order to generate a
+ well-formed quantized representation of these graphs. Such graphs should be
+ fixed, but as a temporary work-around, setting this
+ reorder_across_fake_quant flag allows the converter to perform necessary
+ graph transformaitons on them, at the cost of no longer faithfully matching
+ inference and training arithmetic.
+
+### Logging flags
+
+The following are standard Google logging flags:
+
+* `--logtostderr` redirects Google logging to standard error, typically making
+ it visible in a terminal.
+* `--v` sets verbose logging levels (for debugging purposes). Defined levels:
+ * `--v=1`: log all graph transformations that did make a change on the
+ graph.
+ * `--v=2`: log all graph transformations that did *not* make a change on
+ the graph.
+
+The following flags allow to generate graph visualizations of the actual graph
+at various points during transformations:
+
+* `--dump_graphviz=/path` enables dumping of the graphs at various stages of
+ processing as GraphViz `.dot` files. Generally preferred over
+ `--output_format=GRAPHVIZ_DOT` as this allows you to keep your actually
+ relevant `--output_format`.
+* `--dump_graphviz_video` enables dumping of the graph after every single
+ graph transformation (for debugging purposes).
diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md
new file mode 100644
index 0000000000..440f9c367c
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md
@@ -0,0 +1,62 @@
+# TensorFlow Lite Optimizing Converter (TOCO) Python API reference
+
+## High-level overview
+
+While the TensorFlow Lite Optimizing Converter can be used from the command
+line, it is often convenient to use it as part of Python model build and
+training script. This is so that conversion can be part of your model
+development pipeline. This allows you to know early and often that you are
+designing a model that can be targeted to devices with mobile.
+
+## API
+
+In Python you can run `help(tf.contrib.lite)` to get documentation on functions.
+In particular, `tf.contrib.lite.toco_convert` presents a simple API and
+`tf.contrib.lite.toco_from_protos` allows more detailed control of TOCO using
+the protobuf interface to TOCO.
+
+## Example
+
+In particular, here we show creating a simple model and converting it to a
+TensorFlow Lite Model.
+
+```python
+import tensorflow as tf
+
+img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
+val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
+out = tf.identity(val, name="out")
+with tf.Session() as sess:
+ tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
+ open("test.tflite", "wb").write(tflite_modeL)
+```
+
+**NOTE** Currently, the TOCO command will cause a fatal error to the Python
+interpreter when TOCO conversion fails. This will be remedied as soon as
+possible.
+
+## Example 2: Export with variables
+
+If a model has variables, they need to be turned into constants. This process is
+known as freezing, and it can actually be accomplished with
+
+```python
+import tensorflow as tf
+
+img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
+var = tf.get_variable("weights", dtype=tf.float32, shape=(1,64,64,3))
+val = img + var
+
+def canonical_name(x):
+ return x.name.split(":")[0]
+
+out = tf.identity(val, name="out")
+with tf.Session() as sess:
+ sess.run(tf.global_variables_initializer())
+ out_tensors = [out]
+ frozen_graphdef = tf.graph_util.convert_variables_to_constants(
+ sess, sess.graph_def, map(canonical_name, out_tensors))
+ tflite_model = tf.contrib.lite.toco_convert(
+ frozen_graphdef, [img], out_tensors)
+ open("converted_model.tflite", "wb").write(tflite_model)
+```