diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-09-16 12:01:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-16 12:06:03 -0700 |
commit | 92c31bb620b0f8dd6590380dc6a5674f591ce1cb (patch) | |
tree | b8895b75144f03975584b43bb0ea956d1ce837be /tensorflow/compiler/jit | |
parent | aa2094fc9dc6e67d6e440231828de05a6da3cf78 (diff) |
Introduce gmock matchers for TensorFlow nodes
I need these to write readable unit tests for TF graph transformations. All of
my use cases will live inside tensorflow/compiler so putting it in
tensorflow/compiler/jit for now; but we can move these out if other users are
interested.
In the future we may want to auto-generate type safe versions of these from the
op registrations like we generate C++ wrappers today.
PiperOrigin-RevId: 213186810
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r-- | tensorflow/compiler/jit/BUILD | 29 | ||||
-rw-r--r-- | tensorflow/compiler/jit/node_matchers.cc | 458 | ||||
-rw-r--r-- | tensorflow/compiler/jit/node_matchers.h | 197 | ||||
-rw-r--r-- | tensorflow/compiler/jit/node_matchers_test.cc | 179 |
4 files changed, 863 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index f4e1bc5e83..1001c57f3d 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -599,6 +599,35 @@ tf_cuda_cc_test( ], ) +cc_library( + name = "node_matchers", + testonly = True, + srcs = ["node_matchers.cc"], + hdrs = ["node_matchers.h"], + deps = [ + "//tensorflow/cc:ops", + "//tensorflow/compiler/xla:test", + "//tensorflow/core:graph", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "node_matchers_test", + srcs = ["node_matchers_test.cc"], + deps = [ + ":node_matchers", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/core:ops", + "//tensorflow/core:test_main", + ], +) + # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. cc_header_only_library( name = "xla_jit_headers_lib", diff --git a/tensorflow/compiler/jit/node_matchers.cc b/tensorflow/compiler/jit/node_matchers.cc new file mode 100644 index 0000000000..d8ace628e6 --- /dev/null +++ b/tensorflow/compiler/jit/node_matchers.cc @@ -0,0 +1,458 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/node_matchers.h" + +#include <utility> +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "tensorflow/core/framework/tensor.pb.h" + +namespace tensorflow { +namespace testing { +namespace matchers { +namespace { + +using impl::NodeMatcherProperties; + +string IndentAllButFirstLine(absl::string_view text) { + std::vector<std::string> lines = absl::StrSplit(text, '\n'); + for (int i = 1; i < lines.size(); i++) { + lines[i].insert(0, " "); + } + return absl::StrJoin(lines, "\n"); +} + +template <typename T> +bool CompareTensor(const Tensor& actual, const Tensor& expected, + ::testing::MatchResultListener* listener) { + if (actual.NumElements() != expected.NumElements()) { + if (listener->IsInterested()) { + *listener << "\nwas looking for tensor with " << expected.NumElements() + << " elements, found tensor with " << actual.NumElements() + << " elements"; + return false; + } + } + + for (int64 i = 0, e = actual.NumElements(); i < e; i++) { + if (actual.flat<T>()(i) != expected.flat<T>()(i)) { + *listener << "\nmismatch in constant tensor at index " << i + << " expected = " << expected.flat<T>()(i) + << " actual = " << actual.flat<T>()(i); + return false; + } + } + + return true; +} + +bool MatchAndExplainTensor(const Tensor& tensor, const Tensor& expected_tensor, + ::testing::MatchResultListener* listener) { + if (tensor.dtype() != expected_tensor.dtype()) { + if (listener->IsInterested()) { + *listener << "\nexpected tensor of type " + << DataType_Name(expected_tensor.dtype()) + << " but found one of type " << DataType_Name(tensor.dtype()); + return false; + } + } + + switch (tensor.dtype()) { + case DT_FLOAT: + return CompareTensor<float>(tensor, expected_tensor, listener); + case DT_DOUBLE: + return CompareTensor<double>(tensor, expected_tensor, listener); + case DT_INT8: + return CompareTensor<int8>(tensor, expected_tensor, listener); + case DT_INT16: + return CompareTensor<int16>(tensor, expected_tensor, listener); + case DT_INT32: + return CompareTensor<int32>(tensor, expected_tensor, listener); + case DT_INT64: + return CompareTensor<int64>(tensor, expected_tensor, listener); + case DT_UINT8: + return CompareTensor<uint8>(tensor, expected_tensor, listener); + case DT_UINT16: + return CompareTensor<uint16>(tensor, expected_tensor, listener); + case DT_UINT32: + return CompareTensor<uint32>(tensor, expected_tensor, listener); + case DT_UINT64: + return CompareTensor<uint64>(tensor, expected_tensor, listener); + default: + LOG(FATAL) << "Unsupported dtype " // Crash ok: testonly. + << DataType_Name(tensor.dtype()); + } +} + +using Input = std::pair<const Node*, int>; + +struct NodeMatcher : public ::testing::MatcherInterface<const Node*> { + bool MatchAndExplain( + const Node* node, + ::testing::MatchResultListener* listener) const override { + if (op && node->type_string() != *op) { + if (listener->IsInterested()) { + *listener << "\nexpected op " << *op << " but found " + << node->type_string(); + } + return false; + } + + if (assigned_device && node->assigned_device_name() != *assigned_device) { + if (listener->IsInterested()) { + *listener << "\nexpected assigned_device " << *assigned_device + << " but found \"" << node->assigned_device_name() << "\""; + } + return false; + } + + if (name && node->name() != *name) { + if (listener->IsInterested()) { + *listener << "\nexpected name " << *name << " but found " + << node->name(); + } + return false; + } + + if (constant_value) { + const TensorProto* proto = nullptr; + if (!GetNodeAttr(node->def(), "value", &proto).ok()) { + if (listener->IsInterested()) { + *listener << "\ncould not find \"value\" attribute in node"; + } + return false; + } + + Tensor tensor(proto->dtype()); + if (!tensor.FromProto(*proto)) { + if (listener->IsInterested()) { + *listener << "\ncould not convert TensorProto in \"value\" attribute " + "to Tensor"; + } + return false; + } + + if (!MatchAndExplainTensor(/*tensor=*/tensor, + /*expected_tensor=*/*constant_value, + listener)) { + return false; + } + } + + if (input_matchers) { + if (input_matchers->size() != node->num_inputs()) { + if (listener->IsInterested()) { + *listener << "\nexpected " << input_matchers->size() + << " inputs but node has " << node->num_inputs(); + } + return false; + } + + for (int input_idx = 0, e = input_matchers->size(); input_idx < e; + input_idx++) { + if (!MatchAndExplainInput(node, input_idx, listener)) { + return false; + } + } + } + + std::vector<const Node*> control_deps; + for (const Edge* e : node->in_edges()) { + if (e->IsControlEdge()) { + control_deps.push_back(e->src()); + } + } + + ::testing::StringMatchResultListener inner_listener; + if (control_dep_set && + !control_dep_set->MatchAndExplain(control_deps, &inner_listener)) { + if (listener->IsInterested()) { + string explanation = inner_listener.str(); + if (!explanation.empty()) { + explanation = absl::StrCat(", ", explanation, ","); + } + *listener << "ctrl_deps" << explanation << " does not match expected: "; + control_dep_set->DescribeTo(listener->stream()); + } + return false; + } + return true; + } + + void DescribeTo(::std::ostream* os) const override { + std::vector<string> predicates; + + if (name) { + predicates.push_back(absl::StrCat("name: ", *name)); + } + + if (op) { + predicates.push_back(absl::StrCat("op: ", *op)); + } + + if (assigned_device) { + predicates.push_back(absl::StrCat("assigned device: ", *assigned_device)); + } + + bool printed_something = !predicates.empty(); + + *os << absl::StrJoin(predicates, ", "); + + if (constant_value) { + printed_something = true; + *os << "constant value: " << constant_value->DebugString(); + } + + if (input_matchers) { + if (!input_matchers->empty()) { + printed_something = true; + *os << " with " << (input_matchers->size() == 1 ? "only " : "") + << "input" << (input_matchers->size() == 1 ? "" : "s") << " "; + } + + if (input_matchers->size() == 1) { + ::std::stringstream ss; + input_matchers->front().DescribeTo(&ss); + printed_something = true; + *os << "matching " << ss.str(); + } else { + int edge_idx = 0; + for (const ::testing::Matcher<Input>& matcher : (*input_matchers)) { + *os << "\n [" << edge_idx << "] matching ("; + ::std::stringstream ss; + matcher.DescribeTo(&ss); + printed_something = true; + *os << IndentAllButFirstLine(ss.str()); + *os << ")"; + edge_idx++; + } + } + } + + if (control_dep_set) { + printed_something = true; + *os << " and control deps "; + control_dep_set->DescribeTo(os); + } + + if (!printed_something) { + *os << "is any node"; + } + } + + bool MatchAndExplainInput(const Node* node, int input_idx, + ::testing::MatchResultListener* listener) const { + const Edge* edge; + if (!node->input_edge(input_idx, &edge).ok()) { + if (listener->IsInterested()) { + *listener << "\ncould not find incoming edge for input " << input_idx; + } + return false; + } + + ::testing::StringMatchResultListener inner_listener; + Input input = {edge->src(), edge->src_output()}; + if ((*input_matchers)[input_idx].MatchAndExplain(input, &inner_listener)) { + return true; + } + + if (listener->IsInterested()) { + *listener << "\ninput " << input_idx << " does not match expected:\n"; + (*input_matchers)[input_idx].DescribeTo(listener->stream()); + string explanation = inner_listener.str(); + if (!explanation.empty()) { + *listener << ", " << explanation; + } + } + return false; + } + + absl::optional<string> op; + absl::optional<string> name; + absl::optional<string> assigned_device; + absl::optional<Tensor> constant_value; + absl::optional<std::vector<::testing::Matcher<Input>>> input_matchers; + absl::optional<::testing::Matcher<absl::Span<const Node* const>>> + control_dep_set; +}; + +// Matches a dst and dst_output on an input edge. Today we only use this with +// dst_output=0 but we will eventually need to support multi-output operations. +class InputMatcher : public ::testing::MatcherInterface<Input> { + public: + InputMatcher(::testing::Matcher<const Node*> src_matcher, int src_output) + : src_matcher_(std::move(src_matcher)), src_output_(src_output) {} + + bool MatchAndExplain( + Input input, ::testing::MatchResultListener* listener) const override { + ::testing::StringMatchResultListener inner_listener; + if (!src_matcher_.MatchAndExplain(input.first, &inner_listener)) { + if (listener->IsInterested()) { + *listener << "\nsource does not match expected "; + src_matcher_.DescribeTo(listener->stream()); + string explanation = inner_listener.str(); + if (!explanation.empty()) { + *listener << "\n\t" << explanation; + } + } + return false; + } + if (input.second != src_output_) { + if (listener->IsInterested()) { + *listener << "\nexpected output slot to be " << src_output_ + << " but found " << input.second; + } + return false; + } + + return true; + } + + void DescribeTo(::std::ostream* os) const override { + if (src_output_) { + *os << "output slot: " << src_output_ << ", source: ("; + } + + src_matcher_.DescribeTo(os); + + if (src_output_) { + *os << ")"; + } + } + + private: + ::testing::Matcher<const Node*> src_matcher_; + int src_output_; +}; + +std::vector<::testing::Matcher<Input>> NodeMatchersToInputMatchers( + absl::Span<const ::testing::Matcher<const Node*>> node_matchers) { + std::vector<::testing::Matcher<Input>> result; + absl::c_transform(node_matchers, std::back_inserter(result), + [](::testing::Matcher<const Node*> n) { + return ::testing::MakeMatcher(new InputMatcher(n, 0)); + }); + return result; +} +} // namespace + +::testing::Matcher<const Node*> impl::NodeWith( + absl::Span<const NodeMatcherProperties> props) { + NodeMatcher* matcher = new NodeMatcher(); + for (const NodeMatcherProperties& prop : props) { + if (prop.name()) { + DCHECK(!matcher->name); + matcher->name = prop.name(); + } + + if (prop.op()) { + DCHECK(!matcher->op); + matcher->op = prop.op(); + } + + if (prop.constant_value()) { + DCHECK(!matcher->constant_value); + matcher->constant_value = prop.constant_value(); + } + + if (prop.assigned_device()) { + DCHECK(!matcher->assigned_device); + matcher->assigned_device = prop.assigned_device(); + } + + if (prop.input_nodes()) { + DCHECK(!matcher->input_matchers); + matcher->input_matchers = + NodeMatchersToInputMatchers(*prop.input_nodes()); + } + + if (prop.control_deps()) { + DCHECK(!matcher->control_dep_set); + matcher->control_dep_set = + ::testing::UnorderedElementsAreArray(*prop.control_deps()); + } + } + + return ::testing::MakeMatcher(matcher); +} + +impl::NodeMatcherProperties Name(string name) { + impl::NodeMatcherProperties props; + props.set_name(std::move(name)); + return props; +} + +// Matches a node with op `op`. +impl::NodeMatcherProperties Op(string op) { + impl::NodeMatcherProperties props; + props.set_op(std::move(op)); + return props; +} + +// Matches a node with assigned device `assigned_device`. +impl::NodeMatcherProperties AssignedDevice(string assigned_device) { + impl::NodeMatcherProperties props; + props.set_assigned_device(std::move(assigned_device)); + return props; +} + +impl::NodeMatcherProperties impl::Inputs( + absl::Span<const ::testing::Matcher<const Node*>> inputs) { + std::vector<::testing::Matcher<const Node*>> inputs_vector; + absl::c_copy(inputs, std::back_inserter(inputs_vector)); + + impl::NodeMatcherProperties props; + props.set_input_nodes(std::move(inputs_vector)); + return props; +} + +impl::NodeMatcherProperties impl::CtrlDeps( + absl::Span<const ::testing::Matcher<const Node*>> control_deps) { + std::vector<::testing::Matcher<const Node*>> control_deps_vector; + absl::c_copy(control_deps, std::back_inserter(control_deps_vector)); + + impl::NodeMatcherProperties props; + props.set_control_deps(std::move(control_deps_vector)); + return props; +} + +NodeMatcherProperties ConstantValue( + const ::tensorflow::Input::Initializer& val) { + TF_CHECK_OK(val.status); + NodeMatcherProperties props; + props.set_constant_value(val.tensor); + return props; +} + +::testing::Matcher<const Node*> Const( + const ::tensorflow::Input::Initializer& val) { + return NodeWith(ConstantValue(val)); +} +} // namespace matchers + +Node* FindNodeByName(Graph* g, absl::string_view name) { + for (Node* n : g->nodes()) { + if (n->name() == name) { + return n; + } + } + + return nullptr; +} +} // namespace testing +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/node_matchers.h b/tensorflow/compiler/jit/node_matchers.h new file mode 100644 index 0000000000..0437a7e95c --- /dev/null +++ b/tensorflow/compiler/jit/node_matchers.h @@ -0,0 +1,197 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Provides a set of matchers for tensorflow nodes. +// +// Example usage: +// +// tensorflow::Node* node = ...; +// EXPECT_THAT(node, NodeWith(Name("name"), Op("op"), +// Inputs(NodeWith(Name("input"))))) +// +// Matchable node properties (the expressions that go inside NodeWith(...)) +// are: +// +// - Name(string): matches the node name exactly. We will probably need to +// have this take a string matcher soon in the future. +// +// - Op(string): matches the op exactly. +// +// - AssignedDevice(string): matches the assigned device exactly. +// +// - Inputs(<ordered list>): matches the list of non-control inputs to the node +// exactly (i.e. does not match a suffix or a prefix). +// +// - CtrlDeps(<unordered list>): matches the list of control dependences on the +// node exactly but in any order. +// +// - ConstantValue(tensorflow::Input::Initializer init): matches a Const node +// with the constant value `init`. Implies Op("Const"). +// +// Node properties may not be repeated in a single NodeWith(...) matcher. +// E.g. NodeWith(Op("Foo"), Op("Bar")) will CHECK-fail. Since ConstantValue +// implies Op("Const"), a single NodeWith matcher can't have both +// ConstantValue(...) and Op(...). + +#ifndef TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ +#define TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ + +#include <array> +#include <string> +#include <vector> + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +namespace testing { +namespace matchers { + +namespace impl { + +// ----------------------------------------------------------------------------- +// Implementation details. + +// Properties that we match on for a particular Node. If a particular property +// is nullopt then any value for it is allowed. +class NodeMatcherProperties { + public: + using NodeSeqMatcher = std::vector<::testing::Matcher<const Node*>>; + + const absl::optional<string>& name() const { return name_; } + const absl::optional<string>& op() const { return op_; } + const absl::optional<string>& assigned_device() const { + return assigned_device_; + } + const absl::optional<Tensor>& constant_value() const { + return constant_value_; + } + const absl::optional<NodeSeqMatcher>& input_nodes() const { + return input_nodes_; + } + const absl::optional<NodeSeqMatcher>& control_deps() const { + return control_deps_; + } + + void set_name(string name) { + DCHECK(IsEmpty()); + name_ = std::move(name); + } + + void set_op(string op) { + DCHECK(IsEmpty()); + op_ = std::move(op); + } + + void set_assigned_device(string assigned_device) { + DCHECK(IsEmpty()); + assigned_device_ = std::move(assigned_device); + } + + void set_constant_value(Tensor constant_value) { + DCHECK(IsEmpty()); + constant_value_ = std::move(constant_value); + op_ = "Const"; + } + + void set_input_nodes(NodeSeqMatcher input_nodes) { + DCHECK(IsEmpty()); + input_nodes_ = std::move(input_nodes); + } + + void set_control_deps(NodeSeqMatcher control_deps) { + DCHECK(IsEmpty()); + control_deps_ = std::move(control_deps); + } + + bool IsEmpty() const { + return !name().has_value() && !op().has_value() && + !input_nodes().has_value() && !control_deps().has_value(); + } + + private: + absl::optional<string> name_; + absl::optional<string> op_; + absl::optional<string> assigned_device_; + absl::optional<Tensor> constant_value_; + absl::optional<NodeSeqMatcher> input_nodes_; + absl::optional<NodeSeqMatcher> control_deps_; +}; + +::testing::Matcher<const Node*> NodeWith( + absl::Span<const NodeMatcherProperties> props); + +impl::NodeMatcherProperties Inputs( + absl::Span<const ::testing::Matcher<const Node*>> inputs); + +impl::NodeMatcherProperties CtrlDeps( + absl::Span<const ::testing::Matcher<const Node*>> control_deps); +} // namespace impl + +// ----------------------------------------------------------------------------- +// Public interface. + +// Matches a node with name `name`. +impl::NodeMatcherProperties Name(string name); + +// Matches a node with op `op`. +impl::NodeMatcherProperties Op(string op); + +// Matches a node with assigned device `assigned_device`. +impl::NodeMatcherProperties AssignedDevice(string assigned_device); + +// Matches a node with inputs `inputs`. +// +// `inputs` are ordered; `inputs`[i] must match input i. +template <typename... Ts> +impl::NodeMatcherProperties Inputs(Ts... inputs) { + return impl::Inputs({inputs...}); +} + +// Matches a node with control dependences `control_deps`. +// +// `control_deps` are unordered and will match the control deps of a node in any +// order. +template <typename... Ts> +impl::NodeMatcherProperties CtrlDeps(Ts... control_deps) { + return impl::CtrlDeps({control_deps...}); +} + +// Matches a constant node with value `val`. +impl::NodeMatcherProperties ConstantValue( + const ::tensorflow::Input::Initializer& val); + +// The main gmock matcher. See file comment for example usage. +template <typename... Ts> +::testing::Matcher<const Node*> NodeWith(Ts... args) { + std::array<impl::NodeMatcherProperties, sizeof...(Ts)> array = {args...}; + return impl::NodeWith(array); +} + +::testing::Matcher<const Node*> Const( + const ::tensorflow::Input::Initializer& val); +} // namespace matchers + +// If `g` has a node named `name` returns it, otherwise returns null. +Node* FindNodeByName(Graph* g, absl::string_view name); +} // namespace testing +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ diff --git a/tensorflow/compiler/jit/node_matchers_test.cc b/tensorflow/compiler/jit/node_matchers_test.cc new file mode 100644 index 0000000000..93a8994307 --- /dev/null +++ b/tensorflow/compiler/jit/node_matchers_test.cc @@ -0,0 +1,179 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/node_matchers.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/math_ops.h" + +namespace tensorflow { +namespace testing { +namespace { + +using ::testing::_; + +using testing::matchers::AssignedDevice; +using testing::matchers::ConstantValue; +using testing::matchers::CtrlDeps; +using testing::matchers::Inputs; +using testing::matchers::Name; +using testing::matchers::NodeWith; +using testing::matchers::Op; + +template <typename M, typename T> +string Explain(const T& t, const M& m) { + ::testing::StringMatchResultListener listener; + EXPECT_THAT(t, ::testing::Not(m)); // For the error message. + EXPECT_FALSE(m.MatchAndExplain(t, &listener)); + return listener.str(); +} + +TEST(NodeMatchers, CheckAgainstConstant) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output placeholder = + ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT); + + EXPECT_THAT(placeholder.node(), NodeWith(Op("Placeholder"))); + EXPECT_THAT(placeholder.node(), NodeWith(Name("placeholder"))); + EXPECT_THAT(placeholder.node(), + NodeWith(Op("Placeholder"), Name("placeholder"))); + EXPECT_THAT(placeholder.node(), + NodeWith(Name("placeholder"), Op("Placeholder"))); + EXPECT_THAT(placeholder.node(), NodeWith(Inputs())); + EXPECT_THAT(placeholder.node(), + NodeWith(Op("Placeholder"), Name("placeholder"), Inputs())); + + EXPECT_EQ(Explain(placeholder.node(), NodeWith(Op("Add"))), + "\nexpected op Add but found Placeholder"); + EXPECT_EQ(Explain(placeholder.node(), NodeWith(Name("add"))), + "\nexpected name add but found placeholder"); + EXPECT_EQ(Explain(placeholder.node(), NodeWith(Inputs(NodeWith()))), + "\nexpected 1 inputs but node has 0"); +} + +TEST(NodeMatchers, CheckAgainstBinary) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output placeholder_a = + ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT); + Output placeholder_b = + ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT); + Output add = ops::Add(root.WithOpName("add"), placeholder_a, placeholder_b); + + EXPECT_THAT(add.node(), NodeWith(Op("Add"), Name("add"), + Inputs(NodeWith(Name("placeholder_a")), + NodeWith(Name("placeholder_b"))))); + + EXPECT_EQ(Explain(add.node(), NodeWith(Inputs())), + "\nexpected 0 inputs but node has 2"); + EXPECT_EQ( + Explain(add.node(), NodeWith(Inputs(NodeWith(Name("blah")), _))), + "\ninput 0 does not match expected:\nname: blah, \nsource does not match " + "expected name: blah\n\t\nexpected name blah but found placeholder_a"); + EXPECT_EQ( + Explain(add.node(), NodeWith(Inputs(_, NodeWith(Name("blah"))))), + "\ninput 1 does not match expected:\nname: blah, \nsource does not match " + "expected name: blah\n\t\nexpected name blah but found placeholder_b"); +} + +TEST(NodeMatchers, CheckControlDependence) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output placeholder_a = + ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT); + Output placeholder_b = + ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT); + Output placeholder_c = + ops::Placeholder(root.WithOpName("placeholder_c"), DT_FLOAT); + Output placeholder_d = + ops::Placeholder(root.WithOpName("placeholder_d"), DT_FLOAT); + + root.graph()->AddControlEdge(placeholder_a.node(), placeholder_c.node()); + root.graph()->AddControlEdge(placeholder_b.node(), placeholder_c.node()); + + EXPECT_THAT(placeholder_c.node(), + NodeWith(Name("placeholder_c"), + CtrlDeps(NodeWith(Name("placeholder_a")), + NodeWith(Name("placeholder_b"))))); + EXPECT_THAT(placeholder_d.node(), + NodeWith(Name("placeholder_d"), CtrlDeps())); + + EXPECT_EQ( + Explain(placeholder_c.node(), NodeWith(CtrlDeps())), + "ctrl_deps, which has 2 elements, does not match expected: is empty"); + EXPECT_EQ(Explain(placeholder_d.node(), NodeWith(CtrlDeps(NodeWith()))), + "ctrl_deps does not match expected: has 1 element and that element " + "is any node"); +} + +TEST(NodeMatchers, ConstVaulue) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output placeholder = + ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT); + Output const_0d = ops::Const(root.WithOpName("const_0d"), 42); + + Output const_2d = ops::Const(root.WithOpName("const_2d"), {{1, 2}, {4, 3}}); + + EXPECT_THAT(const_0d.node(), NodeWith(ConstantValue(42))); + EXPECT_THAT(const_0d.node(), NodeWith(ConstantValue(42), Name("const_0d"))); + + EXPECT_THAT(const_2d.node(), NodeWith(ConstantValue({{1, 2}, {4, 3}}))); + + EXPECT_EQ(Explain(placeholder.node(), NodeWith(ConstantValue(42))), + "\nexpected op Const but found Placeholder"); + EXPECT_EQ( + Explain(const_0d.node(), NodeWith(ConstantValue(43))), + "\nmismatch in constant tensor at index 0 expected = 43 actual = 42"); + EXPECT_EQ( + Explain(const_0d.node(), NodeWith(ConstantValue({{1, 2}, {4, 3}}))), + "\nwas looking for tensor with 4 elements, found tensor with 1 elements"); + EXPECT_EQ( + Explain(const_2d.node(), NodeWith(ConstantValue(42))), + "\nwas looking for tensor with 1 elements, found tensor with 4 elements"); +} + +TEST(NodeMatchers, AssignedDevice) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output placeholder_a = + ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT); + Output placeholder_b = + ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT); + + Output assigned_add = + ops::Add(root.WithOpName("assigned_add"), placeholder_a, placeholder_b); + assigned_add.node()->set_assigned_device_name( + "/job:localhost/replica:0/task:0/device:CPU:0"); + + Output unassigned_add = + ops::Add(root.WithOpName("unassigned_add"), placeholder_a, placeholder_b); + + EXPECT_THAT( + assigned_add.node(), + NodeWith(AssignedDevice("/job:localhost/replica:0/task:0/device:CPU:0"))); + EXPECT_THAT(unassigned_add.node(), NodeWith(AssignedDevice(""))); + + EXPECT_EQ(Explain(unassigned_add.node(), + NodeWith(AssignedDevice( + "/job:localhost/replica:0/task:0/device:CPU:0"))), + "\nexpected assigned_device " + "/job:localhost/replica:0/task:0/device:CPU:0 but found \"\""); +} + +} // namespace +} // namespace testing +} // namespace tensorflow |