aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc
diff options
context:
space:
mode:
authorGravatar Pete Warden <petewarden@google.com>2016-12-21 20:50:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-21 21:05:54 -0800
commit0f0e29e7ba06c50fe4a1a7718e63731b96563a8d (patch)
treebfbb7dcc4dba793bfef84d0b1681f20ad4eeff2f /tensorflow/tools/graph_transforms/fold_batch_norms_test.cc
parentbe60473c88175dbc9359c9d1bbb384518757ee81 (diff)
Create Graph Transform Tool for rewriting model files.
Change: 142729497
Diffstat (limited to 'tensorflow/tools/graph_transforms/fold_batch_norms_test.cc')
-rw-r--r--tensorflow/tools/graph_transforms/fold_batch_norms_test.cc93
1 files changed, 93 insertions, 0 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc
new file mode 100644
index 0000000000..b9983fdd0b
--- /dev/null
+++ b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc
@@ -0,0 +1,93 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/image_ops.h"
+#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/sendrecv_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace graph_transforms {
+
+// Declare here, so we don't need a public header.
+Status FoldBatchNorms(const GraphDef& input_graph_def,
+ const TransformFuncContext& context,
+ GraphDef* output_graph_def);
+
+class FoldBatchNormsTest : public ::testing::Test {
+ protected:
+ void TestFoldBatchNorms() {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
+ test::FillValues<float>(
+ &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
+ -5.0f, -3.0f, -6.0f});
+ Output input_op =
+ Const(root.WithOpName("input_op"), Input::Initializer(input_data));
+
+ Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
+ test::FillValues<float>(&weights_data,
+ {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
+ Output weights_op =
+ Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
+
+ Output conv_op = Conv2D(root.WithOpName("conv_op"), input_op, weights_op,
+ {1, 1, 1, 1}, "VALID");
+
+ Tensor mul_values_data(DT_FLOAT, TensorShape({2}));
+ test::FillValues<float>(&mul_values_data, {2.0f, 3.0f});
+ Output mul_values_op = Const(root.WithOpName("mul_values"),
+ Input::Initializer(mul_values_data));
+
+ Output mul_op = Mul(root.WithOpName("output"), conv_op, mul_values_op);
+
+ GraphDef original_graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
+
+ std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(original_session->Create(original_graph_def));
+ std::vector<Tensor> original_outputs;
+ TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
+
+ GraphDef fused_graph_def;
+ TF_ASSERT_OK(
+ FoldBatchNorms(original_graph_def, {{}, {"output"}}, &fused_graph_def));
+
+ std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(fused_session->Create(fused_graph_def));
+ std::vector<Tensor> fused_outputs;
+ TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
+
+ test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
+
+ for (const NodeDef& node : fused_graph_def.node()) {
+ EXPECT_NE("Mul", node.op());
+ }
+ }
+};
+
+TEST_F(FoldBatchNormsTest, TestFoldBatchNorms) { TestFoldBatchNorms(); }
+
+} // namespace graph_transforms
+} // namespace tensorflow