From f41959ccb2d9d4c722fe8fc3351401d53bcf4900 Mon Sep 17 00:00:00 2001 From: Manjunath Kudlur Date: Fri, 6 Nov 2015 16:27:58 -0800 Subject: TensorFlow: Initial commit of TensorFlow library. TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108 --- tensorflow/core/graph/algorithm_test.cc | 103 ++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 tensorflow/core/graph/algorithm_test.cc (limited to 'tensorflow/core/graph/algorithm_test.cc') diff --git a/tensorflow/core/graph/algorithm_test.cc b/tensorflow/core/graph/algorithm_test.cc new file mode 100644 index 0000000000..48f2e1ebd7 --- /dev/null +++ b/tensorflow/core/graph/algorithm_test.cc @@ -0,0 +1,103 @@ +#include "tensorflow/core/graph/algorithm.h" + +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/kernels/ops_util.h" +#include +#include "tensorflow/core/public/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +// TODO(josh11b): Test setting the "device" field of a NodeDef. +// TODO(josh11b): Test that feeding won't prune targets. + +namespace tensorflow { +namespace { + +REGISTER_OP("TestParams").Output("o: float"); +REGISTER_OP("TestInput").Output("a: float").Output("b: float"); +REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float"); + +// Compares that the order of nodes in 'inputs' respects the +// pair orders described in 'ordered_pairs'. +bool ExpectBefore(const std::vector>& ordered_pairs, + const std::vector& inputs, string* error) { + for (const std::pair& pair : ordered_pairs) { + const string& before_node = pair.first; + const string& after_node = pair.second; + bool seen_before = false; + bool seen_both = false; + for (const Node* node : inputs) { + if (!seen_before && after_node == node->name()) { + *error = strings::StrCat("Saw ", after_node, " before ", before_node); + return false; + } + + if (before_node == node->name()) { + seen_before = true; + } else if (after_node == node->name()) { + seen_both = seen_before; + break; + } + } + if (!seen_both) { + *error = strings::StrCat("didn't see either ", before_node, " or ", + after_node); + return false; + } + } + + return true; +} + +TEST(AlgorithmTest, ReversePostOrder) { + RequireDefaultOps(); + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* w1 = SourceOp("TestParams", b.opts().WithName("W1")); + Node* w2 = SourceOp("TestParams", b.opts().WithName("W2")); + Node* input = + SourceOp("TestInput", b.opts().WithName("input").WithControlInput(w1)); + Node* t1 = BinaryOp("TestMul", w1, {input, 1}, b.opts().WithName("t1")); + BinaryOp("TestMul", w1, {input, 1}, + b.opts().WithName("t2").WithControlInput(t1)); + BinaryOp("TestMul", w2, {input, 1}, b.opts().WithName("t3")); + + Graph g(OpRegistry::Global()); + ASSERT_OK(b.ToGraph(&g)); + std::vector order; + + // Test reverse post order: + GetReversePostOrder(g, &order); + + // Check that the order respects the dependencies correctly. + std::vector> reverse_orders = { + {"W1", "input"}, {"W1", "t1"}, {"W1", "t2"}, {"W1", "t3"}, + {"input", "t1"}, {"input", "t3"}, {"t1", "t2"}, {"W2", "t3"}}; + string error; + EXPECT_TRUE(ExpectBefore(reverse_orders, order, &error)) << error; + + // A false ordering should fail the check. + reverse_orders = {{"input", "W1"}}; + EXPECT_FALSE(ExpectBefore(reverse_orders, order, &error)); + + // Test post order: + GetPostOrder(g, &order); + + // Check that the order respects the dependencies correctly. + std::vector> orders = { + {"input", "W1"}, {"t1", "W1"}, {"t2", "W1"}, {"t3", "W1"}, + {"t1", "input"}, {"t3", "input"}, {"t2", "t1"}, {"t3", "W2"}}; + EXPECT_TRUE(ExpectBefore(orders, order, &error)) << error; + + // A false ordering should fail the check. + orders = {{"W1", "t3"}}; + EXPECT_FALSE(ExpectBefore(orders, order, &error)); +} + +} // namespace +} // namespace tensorflow -- cgit v1.2.3