aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/lite/toco/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc106
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD13
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc167
-rw-r--r--tensorflow/contrib/lite/toco/tflite/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc38
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.h7
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export_test.cc54
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc9
-rw-r--r--tensorflow/core/BUILD2
11 files changed, 92 insertions, 307 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 02d0890a7a..a75553db84 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -213,7 +213,6 @@ cc_library(
"graph_transformations/quantization_util.cc",
"graph_transformations/quantization_util.h",
"graph_transformations/quantize.cc",
- "graph_transformations/quantize_weights.cc",
"graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc",
"graph_transformations/remove_final_dequantize_op.cc",
"graph_transformations/remove_tensorflow_assert.cc",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 99f4a7d8f6..34945ecc45 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -142,7 +142,6 @@ DECLARE_GRAPH_TRANSFORMATION(PropagateFakeQuantNumBits);
DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes)
DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax)
DECLARE_GRAPH_TRANSFORMATION(Quantize)
-DECLARE_GRAPH_TRANSFORMATION(QuantizeWeights)
DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp)
DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert)
DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc
deleted file mode 100644
index 7a8515f6d1..0000000000
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc
+++ /dev/null
@@ -1,106 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <iterator>
-#include <string>
-#include <vector>
-
-#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
-#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h"
-#include "tensorflow/contrib/lite/toco/model.h"
-#include "tensorflow/contrib/lite/toco/tooling_util.h"
-
-namespace toco {
-
-namespace {
-
-// The minimum number of elements a weights array must have to be quantized
-// by this transformation.
-// TODO(suharshs): Make this minimum size configurable.
-const int kWeightsMinSize = 1024;
-
-// Gets the quantization params from the float array.
-void GetQuantizationParamsFromArray(const Array& array,
- QuantizationParams* params) {
- const std::vector<float>& float_vals =
- array.GetBuffer<ArrayDataType::kFloat>().data;
- auto minmax = std::minmax_element(float_vals.begin(), float_vals.end());
- *params = tflite::ChooseQuantizationParams<uint8>(
- *minmax.first, *minmax.second, array.narrow_range);
-}
-
-} // namespace
-
-bool QuantizeWeights::Run(Model* model, std::size_t op_index) {
- const auto op_it = model->operators.begin() + op_index;
- Operator* op = op_it->get();
-
- // Get the weights tensor, if the current operator has one.
- int weights_index;
- if (op->type == OperatorType::kConv ||
- op->type == OperatorType::kDepthwiseConv ||
- op->type == OperatorType::kFullyConnected) {
- weights_index = 1;
- } else if (op->type == OperatorType::kLstmCell) {
- weights_index = LstmCellOperator::WEIGHTS_INPUT;
- } else {
- return false;
- }
-
- // Return early if the array isn't a constant param, this can happen in early
- // transformation passes until transpose operations following the weight array
- // are resolved.
- const string weights = op->inputs[weights_index];
- if (!IsConstantParameterArray(*model, weights)) {
- return false;
- }
-
- // Return early if the weight tensor is not type float.
- Array& weights_array = model->GetArray(weights);
- if (weights_array.data_type != ArrayDataType::kFloat) {
- return false;
- }
-
- // Return early if the tensor is too small. Small tensors don't take up too
- // much space and can result in bad quantization results.
- if (weights_array.GetBuffer<ArrayDataType::kFloat>().data.size() <
- kWeightsMinSize) {
- return false;
- }
-
- // Quantize the weight tensor to type kUint8.
- QuantizationParams params;
- GetQuantizationParamsFromArray(weights_array, &params);
- QuantizeArray(this, model, weights, ArrayDataType::kUint8, params);
-
- // Insert a Dequantize operation after the quantized weights tensor.
- auto* dequantize_op = new DequantizeOperator;
- model->operators.emplace(op_it, dequantize_op);
-
- // Create a new intermediate tensor to connect the Dequantize op to the
- // original op.
- const string dequantized_output =
- AvailableArrayName(*model, weights + "_dequantized");
- Array& dequantized_output_array = model->GetOrCreateArray(dequantized_output);
- dequantized_output_array.data_type = ArrayDataType::kFloat;
-
- // Connect up the new Dequantize op with the weights and original op.
- op->inputs[weights_index] = dequantized_output;
- dequantize_op->inputs = {weights};
- dequantize_op->outputs = {dequantized_output};
-
- return true;
-}
-
-} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
index e163fc9ae1..acf1e3ede5 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
@@ -20,19 +20,6 @@ tf_cc_test(
)
tf_cc_test(
- name = "quantize_weights_test",
- srcs = ["quantize_weights_test.cc"],
- tags = ["no_oss"],
- deps = [
- "//tensorflow/contrib/lite/toco:graph_transformations",
- "//tensorflow/contrib/lite/toco:model",
- "//tensorflow/contrib/lite/toco:tooling_util",
- "@com_google_absl//absl/memory",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-tf_cc_test(
name = "resolve_constant_concatenation_test",
srcs = ["resolve_constant_concatenation_test.cc"],
tags = ["no_oss"],
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc
deleted file mode 100644
index c05eb0929f..0000000000
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc
+++ /dev/null
@@ -1,167 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <math.h>
-#include <string>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/memory/memory.h"
-#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
-#include "tensorflow/contrib/lite/toco/model.h"
-#include "tensorflow/contrib/lite/toco/tooling_util.h"
-
-namespace toco {
-
-class QuantizeWeightsTest : public ::testing::Test {
- protected:
- QuantizeWeightsTest() {}
-
- // The name of the weights input array.
- const string kWeightsName = "weights";
- // The zero_point of the values in the input array.
- const int kZeroPoint = 128;
-
- // Prepare a hypothetical TOCO model of a quantizable fully connected float
- // layer.
- void PrepareModel(Model* model, int elements_per_dim) {
- std::vector<string> fc_input_names = {"inputs", kWeightsName};
-
- const int kDim = 4;
- const int buf_size = std::pow(elements_per_dim, static_cast<double>(kDim));
- auto in_buf = absl::make_unique<float[]>(buf_size);
- // Initialize the array with values from -128.0 to 127.0, since these values
- // should be exactly representable by quantization.
- for (int i = 0; i < buf_size; i++) {
- in_buf[i] = static_cast<float>(i % 256 - kZeroPoint);
- }
-
- for (const string& fc_input_name : fc_input_names) {
- Array& in_array = model->GetOrCreateArray(fc_input_name);
- in_array.data_type = ArrayDataType::kFloat;
-
- // Initialize shape for the input array.
- Shape* in_array_shape = in_array.mutable_shape();
- std::vector<int>* in_array_shape_dim = in_array_shape->mutable_dims();
- in_array_shape_dim->resize(kDim, elements_per_dim);
- auto& in_array_buffer =
- in_array.GetMutableBuffer<ArrayDataType::kFloat>();
- in_array_buffer.data.resize(buf_size);
- float* buf_ptr =
- in_array.GetMutableBuffer<ArrayDataType::kFloat>().data.data();
- std::copy(in_buf.get(), in_buf.get() + buf_size, buf_ptr);
- }
-
- auto* fc_op = new FullyConnectedOperator;
- fc_op->inputs = fc_input_names;
- fc_op->outputs = {"fc_op_outputs"};
- Array& out_array = model->GetOrCreateArray(fc_op->outputs[0]);
- out_array.data_type = ArrayDataType::kFloat;
- Shape* out_array_shape = out_array.mutable_shape();
- std::vector<int>* out_array_shape_dim = out_array_shape->mutable_dims();
- out_array_shape_dim->resize(kDim, elements_per_dim);
- model->operators.push_back(std::unique_ptr<Operator>(fc_op));
- }
-};
-
-TEST_F(QuantizeWeightsTest, QuantizedFullyConnected) {
- // Test that weight arrays that are large enough are quantized.
- Model model;
- // 6 elements per dim gives us 1296 elements, which is sufficient to be
- // quantized.
- PrepareModel(&model, 6);
-
- // Check the state of the graph before the transformation.
- const auto& float_array_map = model.GetArrayMap();
- EXPECT_EQ(float_array_map.size(), 3);
- // Before the transformation, all arrays should be type float.
- for (const auto& element : float_array_map) {
- EXPECT_EQ(element.second->data_type, ArrayDataType::kFloat);
- }
- const std::vector<float> float_weight_vals =
- model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kFloat>().data;
-
- // Invoke the transformation.
- GraphTransformationsSet graph_transformation_set;
- graph_transformation_set.Add(new toco::QuantizeWeights);
- (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
-
- // Check the state of the graph after the transformation.
- const auto& quantized_array_map = model.GetArrayMap();
- EXPECT_EQ(quantized_array_map.size(), 4);
- // After the transformation, three arrays should be type float and one array
- // should be uint8.
- int num_float = 0;
- int num_uint8 = 0;
- for (const auto& element : quantized_array_map) {
- if (element.second->data_type == ArrayDataType::kFloat) {
- num_float++;
- } else if (element.second->data_type == ArrayDataType::kUint8) {
- num_uint8++;
- } else {
- FAIL() << "Unexpected array type.";
- }
- }
- EXPECT_EQ(num_float, 3);
- EXPECT_EQ(num_uint8, 1);
- // Ensure that the values were quantized correctly.
- const std::vector<uint8>& quantized_weight_vals =
- model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kUint8>().data;
- for (int i = 0; i < quantized_weight_vals.size(); i++) {
- EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i] + kZeroPoint);
- }
-
- // Ensure that a Dequantize operator has been inserted before the
- // FullyConnectedLayer.
- EXPECT_EQ(model.operators[0]->type, OperatorType::kDequantize);
-}
-
-TEST_F(QuantizeWeightsTest, NotQuantizedFullyConnected) {
- // Test that weight arrays that are too small are left untouched.
- Model model;
- // 5 elements per dim gives us 625 elements, which is NOT sufficient to be
- // quantized.
- PrepareModel(&model, 5);
-
- // Check the state of the graph before the transformation.
- const auto& float_array_map = model.GetArrayMap();
- EXPECT_EQ(float_array_map.size(), 3);
- // Before the transformation, all arrays should be type float.
- for (auto it = float_array_map.begin(); it != float_array_map.end(); it++) {
- EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat);
- }
- std::vector<float> float_weight_vals =
- model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kFloat>().data;
-
- // Invoke the transformation.
- GraphTransformationsSet graph_transformation_set;
- graph_transformation_set.Add(new toco::QuantizeWeights);
- (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
-
- // Check the state of the graph after the transformation.
- const auto& post_array_map = model.GetArrayMap();
- EXPECT_EQ(post_array_map.size(), 3);
- for (auto it = post_array_map.begin(); it != post_array_map.end(); it++) {
- EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat);
- }
- // Ensure that the values remain unchanged.
- std::vector<float> const& quantized_weight_vals =
- model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kFloat>().data;
- for (int i = 0; i < quantized_weight_vals.size(); i++) {
- EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i]);
- }
-}
-
-} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD
index 709c53606b..71cdb7703e 100644
--- a/tensorflow/contrib/lite/toco/tflite/BUILD
+++ b/tensorflow/contrib/lite/toco/tflite/BUILD
@@ -91,6 +91,7 @@ cc_library(
"//tensorflow/contrib/lite/schema:schema_fbs",
"//tensorflow/contrib/lite/toco:model",
"//tensorflow/contrib/lite/toco:tooling_util",
+ "//tensorflow/contrib/lite/tools/optimize:quantize_weights",
"@com_google_absl//absl/strings",
"@flatbuffers",
],
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index 5ad307af14..a27d00eb77 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -16,10 +16,12 @@ limitations under the License.
#include "flatbuffers/flexbuffers.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/toco/tflite/operator.h"
#include "tensorflow/contrib/lite/toco/tflite/types.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/contrib/lite/tools/optimize/quantize_weights.h"
#include "tensorflow/contrib/lite/version.h"
namespace toco {
@@ -61,6 +63,13 @@ details::OperatorKey GetOperatorKey(
return details::OperatorKey(op.type, custom_code, version);
}
+void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder,
+ string* file_contents) {
+ const uint8_t* buffer = builder.GetBufferPointer();
+ int size = builder.GetSize();
+ *file_contents = string(reinterpret_cast<const char*>(buffer), size);
+}
+
} // Anonymous namespace.
namespace details {
@@ -311,14 +320,16 @@ Offset<Vector<Offset<Buffer>>> ExportBuffers(
return builder->CreateVector(buffer_vector);
}
-void Export(const Model& model, bool allow_custom_ops,
+void Export(const Model& model, bool allow_custom_ops, bool quantize_weights,
string* output_file_contents) {
const auto ops_by_type = BuildOperatorByTypeMap();
- Export(model, allow_custom_ops, output_file_contents, ops_by_type);
+ Export(model, allow_custom_ops, quantize_weights, output_file_contents,
+ ops_by_type);
}
void Export(
- const Model& model, bool allow_custom_ops, string* output_file_contents,
+ const Model& model, bool allow_custom_ops, bool quantize_weights,
+ string* output_file_contents,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
@@ -390,9 +401,24 @@ void Export(
CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
builder.CreateVector(subgraphs), description, buffers);
::tflite::FinishModelBuffer(builder, new_model_location);
- const uint8_t* buffer = builder.GetBufferPointer();
- int size = builder.GetSize();
- *output_file_contents = string(reinterpret_cast<const char*>(buffer), size);
+
+ if (quantize_weights) {
+ // Call the quantize_weights tool.
+ LOG(INFO) << "Quantizing TFLite model after conversion to flatbuffer. "
+ "dump_graphviz will only output the model before this "
+ "transformation. To visualize the output graph use "
+ "lite/tools/optimize.py.";
+ flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);
+ const uint8_t* buffer = builder.GetBufferPointer();
+ const ::tflite::Model* input_model = ::tflite::GetModel(buffer);
+ if (::tflite::optimize::QuantizeWeights(&q_builder, input_model) !=
+ kTfLiteOk) {
+ LOG(QFATAL) << "Quantize weights transformation failed.";
+ }
+ WriteModelToString(q_builder, output_file_contents);
+ } else {
+ WriteModelToString(builder, output_file_contents);
+ }
}
} // namespace tflite
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h
index 58ea5c725c..915d5dd3d6 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.h
+++ b/tensorflow/contrib/lite/toco/tflite/export.h
@@ -25,18 +25,19 @@ namespace tflite {
// Transform the given tf.mini model into a TF Lite flatbuffer and deposit the
// result in the given string.
-void Export(const Model& model, bool allow_custom_ops,
+void Export(const Model& model, bool allow_custom_ops, bool quantize_weights,
string* output_file_contents);
// This if backward-compatibility.
// TODO(ycling): Remove the deprecated entry functions.
inline void Export(const Model& model, string* output_file_contents) {
- Export(model, true, output_file_contents);
+ Export(model, true, false, output_file_contents);
}
// Export API with custom TFLite operator mapping.
void Export(
- const Model& model, bool allow_custom_ops, string* output_file_contents,
+ const Model& model, bool allow_custom_ops, bool quantize_weights,
+ string* output_file_contents,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
namespace details {
diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
index a95937ba0f..4994ea30de 100644
--- a/tensorflow/contrib/lite/toco/tflite/export_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -52,6 +52,42 @@ class ExportTest : public ::testing::Test {
input_model_.operators.emplace_back(new SubOperator);
}
+ void BuildQuantizableTestModel() {
+ input_model_.GetOrCreateArray("inputs");
+ Array& weight_array = input_model_.GetOrCreateArray("weights");
+
+ // Make the buffer large enough for QuantizeWeights transformation to take
+ // effect.
+ int buf_size = 1296;
+ auto weight_buf = absl::make_unique<float[]>(buf_size);
+ for (int i = 0; i < buf_size; i++) {
+ // Fill the array with some garbage values.
+ weight_buf[i] = static_cast<float>(i % 128);
+ }
+
+ weight_array.data_type = ArrayDataType::kFloat;
+
+ // Initialize shape for the input array.
+ Shape* weight_array_shape = weight_array.mutable_shape();
+ std::vector<int>* weight_array_shape_dim =
+ weight_array_shape->mutable_dims();
+ weight_array_shape_dim->resize(4, 6);
+ auto& weight_array_buffer =
+ weight_array.GetMutableBuffer<ArrayDataType::kFloat>();
+ weight_array_buffer.data.resize(buf_size);
+ float* buf_ptr =
+ weight_array.GetMutableBuffer<ArrayDataType::kFloat>().data.data();
+ std::copy(weight_buf.get(), weight_buf.get() + buf_size, buf_ptr);
+
+ {
+ auto* op = new ConvOperator;
+ op->padding.type = PaddingType::kSame;
+ op->inputs = {"inputs", "weights"};
+ input_model_.operators.emplace_back(op);
+ }
+ input_model_.operators.emplace_back(new AddOperator);
+ }
+
Model input_model_;
};
@@ -81,7 +117,7 @@ TEST_F(ExportTest, Export) {
BuildTestModel();
string result;
- Export(input_model_, true, &result);
+ Export(input_model_, true, false, &result);
auto* model = ::tflite::GetModel(result.data());
@@ -108,6 +144,20 @@ TEST_F(ExportTest, Export) {
EXPECT_THAT(indices, ElementsAre(1, 0, 3, 2));
}
+TEST_F(ExportTest, QuantizeWeights) {
+ // Sanity check for quantize_weights parameter.
+ BuildQuantizableTestModel();
+ string unquantized_result;
+ Export(input_model_, true, /*quantize_weights*/ false, &unquantized_result);
+
+ BuildQuantizableTestModel();
+ string quantized_result;
+ Export(input_model_, true, /*quantize_weights*/ true, &quantized_result);
+
+ // The quantized models should be smaller.
+ EXPECT_LT(quantized_result.size(), unquantized_result.size());
+}
+
// This test is based on a hypothetical scenario that dilation is supported
// only in Conv version 2. So Toco populates version=1 when dialation
// parameters are all 1, and version=2 otehrwise.
@@ -239,7 +289,7 @@ TEST_F(VersionedOpExportTest, Export) {
string result;
const auto ops_by_type = BuildFakeOperatorByTypeMap();
- Export(input_model_, true, &result, ops_by_type);
+ Export(input_model_, true, false, &result, ops_by_type);
auto* model = ::tflite::GetModel(result.data());
auto operator_codes = model->operator_codes();
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 34130a02b0..243d0dabdb 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -281,12 +281,6 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
RunGraphTransformations(model, "general graph transformations",
transformations);
- if (toco_flags.quantize_weights()) {
- // Run the quantize weights transformation after batchnorms have been
- // folded into the weights.
- RunGraphTransformations(model, "quantize weights transformation",
- {new QuantizeWeights});
- }
if (quantize_output) {
if (toco_flags.propagate_fake_quant_num_bits()) {
RunGraphTransformations(model,
@@ -404,7 +398,8 @@ void Export(const TocoFlags& toco_flags, const Model& model,
ExportTensorFlowGraphDef(model, output_file_contents);
break;
case TFLITE:
- toco::tflite::Export(model, allow_custom_ops, output_file_contents);
+ toco::tflite::Export(model, allow_custom_ops,
+ toco_flags.quantize_weights(), output_file_contents);
break;
case GRAPHVIZ_DOT:
DumpGraphviz(model, output_file_contents);
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 9b7a5018b2..84b11024fd 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2231,7 +2231,7 @@ cc_library(
"platform/macros.h",
"platform/platform.h",
"platform/types.h",
- ],
+ ] + if_windows(["platform/windows/integral_types.h"]),
copts = tf_copts(),
linkopts = ["-ldl"],
deps = [