aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-03-07 16:56:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-07 17:02:07 -0800
commit708def503604a3a9be255edf36623833937c3469 (patch)
tree439bd84911e7e3f796d9847f5728d2f0ab655ffb /tensorflow/tools/graph_transforms
parent1408f05c9a1f1180f67112d8adb9cf79b3b0ac44 (diff)
Remove unneeded rewrite, now that contrib.quantize is ready and better.
PiperOrigin-RevId: 188257466
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r--tensorflow/tools/graph_transforms/BUILD4
-rw-r--r--tensorflow/tools/graph_transforms/fake_quantize_training.cc51
-rw-r--r--tensorflow/tools/graph_transforms/fake_quantize_training_test.cc63
-rw-r--r--tensorflow/tools/graph_transforms/remove_ema.cc146
-rw-r--r--tensorflow/tools/graph_transforms/remove_ema_test.cc121
5 files changed, 0 insertions, 385 deletions
diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD
index ad3668fa02..fba39526b2 100644
--- a/tensorflow/tools/graph_transforms/BUILD
+++ b/tensorflow/tools/graph_transforms/BUILD
@@ -91,7 +91,6 @@ cc_library(
srcs = [
"add_default_attributes.cc",
"backports.cc",
- "fake_quantize_training.cc",
"flatten_atrous.cc",
"fold_batch_norms.cc",
"fold_constants_lib.cc",
@@ -105,7 +104,6 @@ cc_library(
"remove_attribute.cc",
"remove_control_dependencies.cc",
"remove_device.cc",
- "remove_ema.cc",
"remove_nodes.cc",
"rename_attribute.cc",
"rename_op.cc",
@@ -148,7 +146,6 @@ tf_cc_test(
srcs = [
"add_default_attributes_test.cc",
"backports_test.cc",
- "fake_quantize_training_test.cc",
"flatten_atrous_test.cc",
"fold_batch_norms_test.cc",
"fold_constants_test.cc",
@@ -161,7 +158,6 @@ tf_cc_test(
"quantize_weights_test.cc",
"remove_attribute_test.cc",
"remove_device_test.cc",
- "remove_ema_test.cc",
"remove_nodes_test.cc",
"rename_attribute_test.cc",
"rename_op_test.cc",
diff --git a/tensorflow/tools/graph_transforms/fake_quantize_training.cc b/tensorflow/tools/graph_transforms/fake_quantize_training.cc
deleted file mode 100644
index 61aecc6e16..0000000000
--- a/tensorflow/tools/graph_transforms/fake_quantize_training.cc
+++ /dev/null
@@ -1,51 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#define EIGEN_USE_THREADS
-
-#include "tensorflow/core/graph/quantize_training.h"
-#include "tensorflow/tools/graph_transforms/transform_utils.h"
-
-namespace tensorflow {
-namespace graph_transforms {
-
-// EXPERIMENTAL: This can change without warning.
-// Rewrites the GraphDef for quantized training.
-// Rewrites the forward pass to include the precision loss with quantization so
-// the model can learn to deal with such loss and achieve better accuracy when
-// it is quantized later for inference.
-// Quantization range information is collected in FakeQuantizeWithMinMaxVars
-// ops.
-//
-// TODO(suharshs): Provide instructions on converting the resulting graph for
-// inference.
-// TODO(suharshs): Implement this using the GTT rather than calling the old
-// prototype function.
-Status FakeQuantizeTraining(const GraphDef& input_graph_def,
- const TransformFuncContext& context,
- GraphDef* output_graph_def) {
- // TODO(suharshs): Make num_bits a parameter.
- const int32 num_bits = 8;
- // TODO(suharshs): Make quantization op a parameter?
- const string quant_op_type = "FakeQuantWithMinMaxVars";
-
- return DoQuantizeTrainingOnGraphDef(input_graph_def, num_bits, quant_op_type,
- output_graph_def);
-}
-
-REGISTER_GRAPH_TRANSFORM("fake_quantize_training", FakeQuantizeTraining);
-
-} // namespace graph_transforms
-} // namespace tensorflow
diff --git a/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc b/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc
deleted file mode 100644
index 5e4ab209e9..0000000000
--- a/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc
+++ /dev/null
@@ -1,63 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/cc/ops/const_op.h"
-#include "tensorflow/cc/ops/math_ops.h"
-#include "tensorflow/core/framework/tensor_testutil.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/tools/graph_transforms/transform_utils.h"
-
-namespace tensorflow {
-namespace graph_transforms {
-
-// Declare here, so we don't need a public header.
-Status FakeQuantizeTraining(const GraphDef& input_graph_def,
- const TransformFuncContext& context,
- GraphDef* output_graph_def);
-
-class FakeQuantizeTrainingTest : public ::testing::Test {};
-
-// For now, since the fake_quantize_training transform just calls the
-// quantize_training rewrite from tensorflow/core/graph/quantize_training.h,
-// we just test that the graph has been changed by the transform.
-// TODO(suharshs): Once we implement the fake_quantize_training transform
-// using the GTT, write proper tests of the transform here.
-TEST_F(FakeQuantizeTrainingTest, TransformOccurred) {
- auto root = tensorflow::Scope::DisabledShapeInferenceScope();
- using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
-
- Tensor a_data(DT_FLOAT, TensorShape());
- test::FillIota<float>(&a_data, 1.0f);
- Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
-
- Tensor b_data(DT_FLOAT, TensorShape());
- test::FillIota<float>(&b_data, 1.0f);
- Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
-
- Output matmul = MatMul(root.WithOpName("matmul"), a_const, b_const);
- GraphDef graph_def;
- TF_ASSERT_OK(root.ToGraphDef(&graph_def));
-
- GraphDef result;
- TransformFuncContext context;
- TF_ASSERT_OK(FakeQuantizeTraining(graph_def, context, &result));
-
- // Test that the transformation resulted in a graph with more nodes.
- EXPECT_GT(result.node_size(), graph_def.node_size());
-}
-
-} // namespace graph_transforms
-} // namespace tensorflow
diff --git a/tensorflow/tools/graph_transforms/remove_ema.cc b/tensorflow/tools/graph_transforms/remove_ema.cc
deleted file mode 100644
index 22e2626702..0000000000
--- a/tensorflow/tools/graph_transforms/remove_ema.cc
+++ /dev/null
@@ -1,146 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#define EIGEN_USE_THREADS
-
-#include "tensorflow/tools/graph_transforms/transform_utils.h"
-
-namespace tensorflow {
-namespace graph_transforms {
-
-// EXPERIMENTAL: This can change without warning.
-// Given a graph that has gone through the FakeQuantizeTraining transform and
-// has been frozen afterwards, RemoveEMA simplifies the FakeQuantize estimated
-// moving average subgraphs to make it compatible with the QuantizeNodes
-// transform.
-Status RemoveEMA(const GraphDef& input_graph_def,
- const TransformFuncContext& context,
- GraphDef* output_graph_def) {
- TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
- input_graph_def, // clang-format off
- {"FakeQuantWithMinMaxVars",
- {
- {"*"},
- {"Assign",
- {
- {"Const"},
- {"Merge",
- {
- {"Switch",
- {
- {"Min",
- {
- {"*"},
- {"Range",
- {
- {"*"},
- {"*"},
- {"*"},
- }
- }
- }
- },
- {"IsVariableInitialized"}
- }
- },
- {"Sub",
- {
- {"Const"},
- {"Mul",
- {
- {"Sub"},
- {"Sub",
- {
- {"Const"},
- {"Const"}
- }
- }
- }
- }
- }
- }
- }
- }
- }
- },
- {"Assign",
- {
- {"Const"},
- {"Merge",
- {
- {"Switch",
- {
- {"Max"},
- {"IsVariableInitialized"}
- }
- },
- {"Sub",
- {
- {"Const"},
- {"Mul",
- {
- {"Sub"},
- {"Sub",
- {
- {"Const"},
- {"Const"}
- }
- }
- }
- }
- }
- }
- }
- }
- }
- },
- }
- }, // clang-format on
- [](const NodeMatch& match, const std::set<string>& input_nodes,
- const std::set<string>& output_nodes,
- std::vector<NodeDef>* new_nodes) {
- const NodeDef& fake_quant_node = match.node;
- const NodeDef& input_node = match.inputs[0].node;
- const NodeDef& min_var_node = match.inputs[1].inputs[0].node;
- const NodeDef& max_var_node = match.inputs[2].inputs[0].node;
-
- // Make a new FakeQuantizeWithMinMaxVars operation that uses constants
- // for its min/max arguments rather than an entire EMA subgraph.
- NodeDef new_fake_quant_node;
- new_fake_quant_node.set_op(fake_quant_node.op());
- new_fake_quant_node.set_name(fake_quant_node.name());
- AddNodeInput(input_node.name(), &new_fake_quant_node);
- AddNodeInput(min_var_node.name(), &new_fake_quant_node);
- AddNodeInput(max_var_node.name(), &new_fake_quant_node);
- CopyNodeAttr(fake_quant_node, "narrow_range", "narrow_range",
- &new_fake_quant_node);
- CopyNodeAttr(fake_quant_node, "num_bits", "num_bits",
- &new_fake_quant_node);
-
- new_nodes->push_back(new_fake_quant_node);
- new_nodes->push_back(input_node);
- new_nodes->push_back(min_var_node);
- new_nodes->push_back(max_var_node);
-
- return Status::OK();
- },
- {}, output_graph_def));
- return Status::OK();
-}
-
-REGISTER_GRAPH_TRANSFORM("remove_ema", RemoveEMA);
-
-} // namespace graph_transforms
-} // namespace tensorflow
diff --git a/tensorflow/tools/graph_transforms/remove_ema_test.cc b/tensorflow/tools/graph_transforms/remove_ema_test.cc
deleted file mode 100644
index 27db90e272..0000000000
--- a/tensorflow/tools/graph_transforms/remove_ema_test.cc
+++ /dev/null
@@ -1,121 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/cc/ops/const_op.h"
-#include "tensorflow/cc/ops/math_ops.h"
-#include "tensorflow/core/framework/tensor_testutil.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/public/session.h"
-#include "tensorflow/tools/graph_transforms/transform_utils.h"
-
-namespace tensorflow {
-namespace graph_transforms {
-
-// Declare transformations here, so we don't need a public header.
-Status FakeQuantizeTraining(const GraphDef& input_graph_def,
- const TransformFuncContext& context,
- GraphDef* output_graph_def);
-
-Status RemoveEMA(const GraphDef& input_graph_def,
- const TransformFuncContext& context,
- GraphDef* output_graph_def);
-
-Status QuantizeNodes(const GraphDef& input_graph_def,
- const TransformFuncContext& context,
- GraphDef* output_graph_def);
-
-class RemoveEMATest : public ::testing::Test {};
-
-TEST_F(RemoveEMATest, FakeQuant_RemoveEMA_QuantizeTraining) {
- // Build a small graph.
- auto root = tensorflow::Scope::NewRootScope();
- using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
-
- Tensor a_data(DT_FLOAT, TensorShape({1, 1}));
- test::FillIota<float>(&a_data, 1.0f);
- Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
-
- Tensor b_data(DT_FLOAT, TensorShape({1, 1}));
- test::FillIota<float>(&b_data, 1.0f);
- Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
-
- Output matmul = MatMul(root.WithOpName("matmul"), a_const, b_const);
- GraphDef graph_def;
- TF_ASSERT_OK(root.ToGraphDef(&graph_def));
-
- // (1) FakeQuantize the graph.
- GraphDef fake_quantized_graph_def;
- TransformFuncContext context;
- TF_ASSERT_OK(
- FakeQuantizeTraining(graph_def, context, &fake_quantized_graph_def));
-
- // Test that the transformation resulted in a graph with more nodes.
- EXPECT_GT(fake_quantized_graph_def.node_size(), graph_def.node_size());
-
- // (2) Run the graph to initialize the newly added variables.
- std::unique_ptr<Session> session(NewSession(SessionOptions()));
- TF_ASSERT_OK(session->Create(fake_quantized_graph_def));
- std::vector<Tensor> outputs;
- TF_ASSERT_OK(session->Run({}, {"matmul"}, {}, &outputs));
-
- // (3) Freeze the graph. Create a "frozen graph" that matches what we would
- // expect if we actually froze the above graph.
- // TODO(suharshs): Use a c++ freeze graph alternative, when one is available.
- GraphDef frozen_graph_def;
- for (const NodeDef& node : fake_quantized_graph_def.node()) {
- if (node.op() == "Variable" || node.op() == "VariableV2") {
- NodeDef const_node;
- const_node.set_op("Const");
- const_node.set_name(node.name());
- SetNodeAttr("dtype", DT_FLOAT, &const_node);
- Tensor tensor(DT_FLOAT, {});
- tensor.flat<float>()(0) = 1.0f;
- SetNodeTensorAttr<float>("value", tensor, &const_node);
- *(frozen_graph_def.mutable_node()->Add()) = const_node;
- } else {
- *(frozen_graph_def.mutable_node()->Add()) = node;
- }
- }
-
- // Test that freezing the graph resulted in a graph with the same number of
- // nodes.
- EXPECT_EQ(frozen_graph_def.node_size(), fake_quantized_graph_def.node_size());
-
- // (4) RemoveEMA on the graph to make it compatible with QuantizeNodes.
- GraphDef removed_ema_graph_def;
- TF_ASSERT_OK(RemoveEMA(frozen_graph_def, context, &removed_ema_graph_def));
-
- // Test that the transformation resulted in a graph with less nodes.
- EXPECT_LT(removed_ema_graph_def.node_size(), frozen_graph_def.node_size());
-
- // (5) QuantizeNodes and inspect the final graph.
- // TODO(suharshs): Add a more thorough inspection of the structure of
- // the output graph.
- GraphDef quantized_graph_def;
- TF_ASSERT_OK(
- QuantizeNodes(removed_ema_graph_def, context, &quantized_graph_def));
-
- // Test that the transformation resulted in a graph with more nodes.
- EXPECT_GT(quantized_graph_def.node_size(), removed_ema_graph_def.node_size());
-
- // Make sure that the FakeQuantizeWithMinMaxVars op has been removed.
- for (const NodeDef& node : quantized_graph_def.node()) {
- EXPECT_NE(node.op(), "FakeQuantWithMinMaxVars");
- }
-}
-
-} // namespace graph_transforms
-} // namespace tensorflow