aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-09-16 12:01:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-16 12:06:03 -0700
commit92c31bb620b0f8dd6590380dc6a5674f591ce1cb (patch)
treeb8895b75144f03975584b43bb0ea956d1ce837be /tensorflow/compiler/jit
parentaa2094fc9dc6e67d6e440231828de05a6da3cf78 (diff)
Introduce gmock matchers for TensorFlow nodes
I need these to write readable unit tests for TF graph transformations. All of my use cases will live inside tensorflow/compiler so putting it in tensorflow/compiler/jit for now; but we can move these out if other users are interested. In the future we may want to auto-generate type safe versions of these from the op registrations like we generate C++ wrappers today. PiperOrigin-RevId: 213186810
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r--tensorflow/compiler/jit/BUILD29
-rw-r--r--tensorflow/compiler/jit/node_matchers.cc458
-rw-r--r--tensorflow/compiler/jit/node_matchers.h197
-rw-r--r--tensorflow/compiler/jit/node_matchers_test.cc179
4 files changed, 863 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index f4e1bc5e83..1001c57f3d 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -599,6 +599,35 @@ tf_cuda_cc_test(
],
)
+cc_library(
+ name = "node_matchers",
+ testonly = True,
+ srcs = ["node_matchers.cc"],
+ hdrs = ["node_matchers.h"],
+ deps = [
+ "//tensorflow/cc:ops",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+tf_cc_test(
+ name = "node_matchers_test",
+ srcs = ["node_matchers_test.cc"],
+ deps = [
+ ":node_matchers",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:test_main",
+ ],
+)
+
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",
diff --git a/tensorflow/compiler/jit/node_matchers.cc b/tensorflow/compiler/jit/node_matchers.cc
new file mode 100644
index 0000000000..d8ace628e6
--- /dev/null
+++ b/tensorflow/compiler/jit/node_matchers.cc
@@ -0,0 +1,458 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/node_matchers.h"
+
+#include <utility>
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+
+namespace tensorflow {
+namespace testing {
+namespace matchers {
+namespace {
+
+using impl::NodeMatcherProperties;
+
+string IndentAllButFirstLine(absl::string_view text) {
+ std::vector<std::string> lines = absl::StrSplit(text, '\n');
+ for (int i = 1; i < lines.size(); i++) {
+ lines[i].insert(0, " ");
+ }
+ return absl::StrJoin(lines, "\n");
+}
+
+template <typename T>
+bool CompareTensor(const Tensor& actual, const Tensor& expected,
+ ::testing::MatchResultListener* listener) {
+ if (actual.NumElements() != expected.NumElements()) {
+ if (listener->IsInterested()) {
+ *listener << "\nwas looking for tensor with " << expected.NumElements()
+ << " elements, found tensor with " << actual.NumElements()
+ << " elements";
+ return false;
+ }
+ }
+
+ for (int64 i = 0, e = actual.NumElements(); i < e; i++) {
+ if (actual.flat<T>()(i) != expected.flat<T>()(i)) {
+ *listener << "\nmismatch in constant tensor at index " << i
+ << " expected = " << expected.flat<T>()(i)
+ << " actual = " << actual.flat<T>()(i);
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool MatchAndExplainTensor(const Tensor& tensor, const Tensor& expected_tensor,
+ ::testing::MatchResultListener* listener) {
+ if (tensor.dtype() != expected_tensor.dtype()) {
+ if (listener->IsInterested()) {
+ *listener << "\nexpected tensor of type "
+ << DataType_Name(expected_tensor.dtype())
+ << " but found one of type " << DataType_Name(tensor.dtype());
+ return false;
+ }
+ }
+
+ switch (tensor.dtype()) {
+ case DT_FLOAT:
+ return CompareTensor<float>(tensor, expected_tensor, listener);
+ case DT_DOUBLE:
+ return CompareTensor<double>(tensor, expected_tensor, listener);
+ case DT_INT8:
+ return CompareTensor<int8>(tensor, expected_tensor, listener);
+ case DT_INT16:
+ return CompareTensor<int16>(tensor, expected_tensor, listener);
+ case DT_INT32:
+ return CompareTensor<int32>(tensor, expected_tensor, listener);
+ case DT_INT64:
+ return CompareTensor<int64>(tensor, expected_tensor, listener);
+ case DT_UINT8:
+ return CompareTensor<uint8>(tensor, expected_tensor, listener);
+ case DT_UINT16:
+ return CompareTensor<uint16>(tensor, expected_tensor, listener);
+ case DT_UINT32:
+ return CompareTensor<uint32>(tensor, expected_tensor, listener);
+ case DT_UINT64:
+ return CompareTensor<uint64>(tensor, expected_tensor, listener);
+ default:
+ LOG(FATAL) << "Unsupported dtype " // Crash ok: testonly.
+ << DataType_Name(tensor.dtype());
+ }
+}
+
+using Input = std::pair<const Node*, int>;
+
+struct NodeMatcher : public ::testing::MatcherInterface<const Node*> {
+ bool MatchAndExplain(
+ const Node* node,
+ ::testing::MatchResultListener* listener) const override {
+ if (op && node->type_string() != *op) {
+ if (listener->IsInterested()) {
+ *listener << "\nexpected op " << *op << " but found "
+ << node->type_string();
+ }
+ return false;
+ }
+
+ if (assigned_device && node->assigned_device_name() != *assigned_device) {
+ if (listener->IsInterested()) {
+ *listener << "\nexpected assigned_device " << *assigned_device
+ << " but found \"" << node->assigned_device_name() << "\"";
+ }
+ return false;
+ }
+
+ if (name && node->name() != *name) {
+ if (listener->IsInterested()) {
+ *listener << "\nexpected name " << *name << " but found "
+ << node->name();
+ }
+ return false;
+ }
+
+ if (constant_value) {
+ const TensorProto* proto = nullptr;
+ if (!GetNodeAttr(node->def(), "value", &proto).ok()) {
+ if (listener->IsInterested()) {
+ *listener << "\ncould not find \"value\" attribute in node";
+ }
+ return false;
+ }
+
+ Tensor tensor(proto->dtype());
+ if (!tensor.FromProto(*proto)) {
+ if (listener->IsInterested()) {
+ *listener << "\ncould not convert TensorProto in \"value\" attribute "
+ "to Tensor";
+ }
+ return false;
+ }
+
+ if (!MatchAndExplainTensor(/*tensor=*/tensor,
+ /*expected_tensor=*/*constant_value,
+ listener)) {
+ return false;
+ }
+ }
+
+ if (input_matchers) {
+ if (input_matchers->size() != node->num_inputs()) {
+ if (listener->IsInterested()) {
+ *listener << "\nexpected " << input_matchers->size()
+ << " inputs but node has " << node->num_inputs();
+ }
+ return false;
+ }
+
+ for (int input_idx = 0, e = input_matchers->size(); input_idx < e;
+ input_idx++) {
+ if (!MatchAndExplainInput(node, input_idx, listener)) {
+ return false;
+ }
+ }
+ }
+
+ std::vector<const Node*> control_deps;
+ for (const Edge* e : node->in_edges()) {
+ if (e->IsControlEdge()) {
+ control_deps.push_back(e->src());
+ }
+ }
+
+ ::testing::StringMatchResultListener inner_listener;
+ if (control_dep_set &&
+ !control_dep_set->MatchAndExplain(control_deps, &inner_listener)) {
+ if (listener->IsInterested()) {
+ string explanation = inner_listener.str();
+ if (!explanation.empty()) {
+ explanation = absl::StrCat(", ", explanation, ",");
+ }
+ *listener << "ctrl_deps" << explanation << " does not match expected: ";
+ control_dep_set->DescribeTo(listener->stream());
+ }
+ return false;
+ }
+ return true;
+ }
+
+ void DescribeTo(::std::ostream* os) const override {
+ std::vector<string> predicates;
+
+ if (name) {
+ predicates.push_back(absl::StrCat("name: ", *name));
+ }
+
+ if (op) {
+ predicates.push_back(absl::StrCat("op: ", *op));
+ }
+
+ if (assigned_device) {
+ predicates.push_back(absl::StrCat("assigned device: ", *assigned_device));
+ }
+
+ bool printed_something = !predicates.empty();
+
+ *os << absl::StrJoin(predicates, ", ");
+
+ if (constant_value) {
+ printed_something = true;
+ *os << "constant value: " << constant_value->DebugString();
+ }
+
+ if (input_matchers) {
+ if (!input_matchers->empty()) {
+ printed_something = true;
+ *os << " with " << (input_matchers->size() == 1 ? "only " : "")
+ << "input" << (input_matchers->size() == 1 ? "" : "s") << " ";
+ }
+
+ if (input_matchers->size() == 1) {
+ ::std::stringstream ss;
+ input_matchers->front().DescribeTo(&ss);
+ printed_something = true;
+ *os << "matching " << ss.str();
+ } else {
+ int edge_idx = 0;
+ for (const ::testing::Matcher<Input>& matcher : (*input_matchers)) {
+ *os << "\n [" << edge_idx << "] matching (";
+ ::std::stringstream ss;
+ matcher.DescribeTo(&ss);
+ printed_something = true;
+ *os << IndentAllButFirstLine(ss.str());
+ *os << ")";
+ edge_idx++;
+ }
+ }
+ }
+
+ if (control_dep_set) {
+ printed_something = true;
+ *os << " and control deps ";
+ control_dep_set->DescribeTo(os);
+ }
+
+ if (!printed_something) {
+ *os << "is any node";
+ }
+ }
+
+ bool MatchAndExplainInput(const Node* node, int input_idx,
+ ::testing::MatchResultListener* listener) const {
+ const Edge* edge;
+ if (!node->input_edge(input_idx, &edge).ok()) {
+ if (listener->IsInterested()) {
+ *listener << "\ncould not find incoming edge for input " << input_idx;
+ }
+ return false;
+ }
+
+ ::testing::StringMatchResultListener inner_listener;
+ Input input = {edge->src(), edge->src_output()};
+ if ((*input_matchers)[input_idx].MatchAndExplain(input, &inner_listener)) {
+ return true;
+ }
+
+ if (listener->IsInterested()) {
+ *listener << "\ninput " << input_idx << " does not match expected:\n";
+ (*input_matchers)[input_idx].DescribeTo(listener->stream());
+ string explanation = inner_listener.str();
+ if (!explanation.empty()) {
+ *listener << ", " << explanation;
+ }
+ }
+ return false;
+ }
+
+ absl::optional<string> op;
+ absl::optional<string> name;
+ absl::optional<string> assigned_device;
+ absl::optional<Tensor> constant_value;
+ absl::optional<std::vector<::testing::Matcher<Input>>> input_matchers;
+ absl::optional<::testing::Matcher<absl::Span<const Node* const>>>
+ control_dep_set;
+};
+
+// Matches a dst and dst_output on an input edge. Today we only use this with
+// dst_output=0 but we will eventually need to support multi-output operations.
+class InputMatcher : public ::testing::MatcherInterface<Input> {
+ public:
+ InputMatcher(::testing::Matcher<const Node*> src_matcher, int src_output)
+ : src_matcher_(std::move(src_matcher)), src_output_(src_output) {}
+
+ bool MatchAndExplain(
+ Input input, ::testing::MatchResultListener* listener) const override {
+ ::testing::StringMatchResultListener inner_listener;
+ if (!src_matcher_.MatchAndExplain(input.first, &inner_listener)) {
+ if (listener->IsInterested()) {
+ *listener << "\nsource does not match expected ";
+ src_matcher_.DescribeTo(listener->stream());
+ string explanation = inner_listener.str();
+ if (!explanation.empty()) {
+ *listener << "\n\t" << explanation;
+ }
+ }
+ return false;
+ }
+ if (input.second != src_output_) {
+ if (listener->IsInterested()) {
+ *listener << "\nexpected output slot to be " << src_output_
+ << " but found " << input.second;
+ }
+ return false;
+ }
+
+ return true;
+ }
+
+ void DescribeTo(::std::ostream* os) const override {
+ if (src_output_) {
+ *os << "output slot: " << src_output_ << ", source: (";
+ }
+
+ src_matcher_.DescribeTo(os);
+
+ if (src_output_) {
+ *os << ")";
+ }
+ }
+
+ private:
+ ::testing::Matcher<const Node*> src_matcher_;
+ int src_output_;
+};
+
+std::vector<::testing::Matcher<Input>> NodeMatchersToInputMatchers(
+ absl::Span<const ::testing::Matcher<const Node*>> node_matchers) {
+ std::vector<::testing::Matcher<Input>> result;
+ absl::c_transform(node_matchers, std::back_inserter(result),
+ [](::testing::Matcher<const Node*> n) {
+ return ::testing::MakeMatcher(new InputMatcher(n, 0));
+ });
+ return result;
+}
+} // namespace
+
+::testing::Matcher<const Node*> impl::NodeWith(
+ absl::Span<const NodeMatcherProperties> props) {
+ NodeMatcher* matcher = new NodeMatcher();
+ for (const NodeMatcherProperties& prop : props) {
+ if (prop.name()) {
+ DCHECK(!matcher->name);
+ matcher->name = prop.name();
+ }
+
+ if (prop.op()) {
+ DCHECK(!matcher->op);
+ matcher->op = prop.op();
+ }
+
+ if (prop.constant_value()) {
+ DCHECK(!matcher->constant_value);
+ matcher->constant_value = prop.constant_value();
+ }
+
+ if (prop.assigned_device()) {
+ DCHECK(!matcher->assigned_device);
+ matcher->assigned_device = prop.assigned_device();
+ }
+
+ if (prop.input_nodes()) {
+ DCHECK(!matcher->input_matchers);
+ matcher->input_matchers =
+ NodeMatchersToInputMatchers(*prop.input_nodes());
+ }
+
+ if (prop.control_deps()) {
+ DCHECK(!matcher->control_dep_set);
+ matcher->control_dep_set =
+ ::testing::UnorderedElementsAreArray(*prop.control_deps());
+ }
+ }
+
+ return ::testing::MakeMatcher(matcher);
+}
+
+impl::NodeMatcherProperties Name(string name) {
+ impl::NodeMatcherProperties props;
+ props.set_name(std::move(name));
+ return props;
+}
+
+// Matches a node with op `op`.
+impl::NodeMatcherProperties Op(string op) {
+ impl::NodeMatcherProperties props;
+ props.set_op(std::move(op));
+ return props;
+}
+
+// Matches a node with assigned device `assigned_device`.
+impl::NodeMatcherProperties AssignedDevice(string assigned_device) {
+ impl::NodeMatcherProperties props;
+ props.set_assigned_device(std::move(assigned_device));
+ return props;
+}
+
+impl::NodeMatcherProperties impl::Inputs(
+ absl::Span<const ::testing::Matcher<const Node*>> inputs) {
+ std::vector<::testing::Matcher<const Node*>> inputs_vector;
+ absl::c_copy(inputs, std::back_inserter(inputs_vector));
+
+ impl::NodeMatcherProperties props;
+ props.set_input_nodes(std::move(inputs_vector));
+ return props;
+}
+
+impl::NodeMatcherProperties impl::CtrlDeps(
+ absl::Span<const ::testing::Matcher<const Node*>> control_deps) {
+ std::vector<::testing::Matcher<const Node*>> control_deps_vector;
+ absl::c_copy(control_deps, std::back_inserter(control_deps_vector));
+
+ impl::NodeMatcherProperties props;
+ props.set_control_deps(std::move(control_deps_vector));
+ return props;
+}
+
+NodeMatcherProperties ConstantValue(
+ const ::tensorflow::Input::Initializer& val) {
+ TF_CHECK_OK(val.status);
+ NodeMatcherProperties props;
+ props.set_constant_value(val.tensor);
+ return props;
+}
+
+::testing::Matcher<const Node*> Const(
+ const ::tensorflow::Input::Initializer& val) {
+ return NodeWith(ConstantValue(val));
+}
+} // namespace matchers
+
+Node* FindNodeByName(Graph* g, absl::string_view name) {
+ for (Node* n : g->nodes()) {
+ if (n->name() == name) {
+ return n;
+ }
+ }
+
+ return nullptr;
+}
+} // namespace testing
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/node_matchers.h b/tensorflow/compiler/jit/node_matchers.h
new file mode 100644
index 0000000000..0437a7e95c
--- /dev/null
+++ b/tensorflow/compiler/jit/node_matchers.h
@@ -0,0 +1,197 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Provides a set of matchers for tensorflow nodes.
+//
+// Example usage:
+//
+// tensorflow::Node* node = ...;
+// EXPECT_THAT(node, NodeWith(Name("name"), Op("op"),
+// Inputs(NodeWith(Name("input")))))
+//
+// Matchable node properties (the expressions that go inside NodeWith(...))
+// are:
+//
+// - Name(string): matches the node name exactly. We will probably need to
+// have this take a string matcher soon in the future.
+//
+// - Op(string): matches the op exactly.
+//
+// - AssignedDevice(string): matches the assigned device exactly.
+//
+// - Inputs(<ordered list>): matches the list of non-control inputs to the node
+// exactly (i.e. does not match a suffix or a prefix).
+//
+// - CtrlDeps(<unordered list>): matches the list of control dependences on the
+// node exactly but in any order.
+//
+// - ConstantValue(tensorflow::Input::Initializer init): matches a Const node
+// with the constant value `init`. Implies Op("Const").
+//
+// Node properties may not be repeated in a single NodeWith(...) matcher.
+// E.g. NodeWith(Op("Foo"), Op("Bar")) will CHECK-fail. Since ConstantValue
+// implies Op("Const"), a single NodeWith matcher can't have both
+// ConstantValue(...) and Op(...).
+
+#ifndef TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_
+#define TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_
+
+#include <array>
+#include <string>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+namespace testing {
+namespace matchers {
+
+namespace impl {
+
+// -----------------------------------------------------------------------------
+// Implementation details.
+
+// Properties that we match on for a particular Node. If a particular property
+// is nullopt then any value for it is allowed.
+class NodeMatcherProperties {
+ public:
+ using NodeSeqMatcher = std::vector<::testing::Matcher<const Node*>>;
+
+ const absl::optional<string>& name() const { return name_; }
+ const absl::optional<string>& op() const { return op_; }
+ const absl::optional<string>& assigned_device() const {
+ return assigned_device_;
+ }
+ const absl::optional<Tensor>& constant_value() const {
+ return constant_value_;
+ }
+ const absl::optional<NodeSeqMatcher>& input_nodes() const {
+ return input_nodes_;
+ }
+ const absl::optional<NodeSeqMatcher>& control_deps() const {
+ return control_deps_;
+ }
+
+ void set_name(string name) {
+ DCHECK(IsEmpty());
+ name_ = std::move(name);
+ }
+
+ void set_op(string op) {
+ DCHECK(IsEmpty());
+ op_ = std::move(op);
+ }
+
+ void set_assigned_device(string assigned_device) {
+ DCHECK(IsEmpty());
+ assigned_device_ = std::move(assigned_device);
+ }
+
+ void set_constant_value(Tensor constant_value) {
+ DCHECK(IsEmpty());
+ constant_value_ = std::move(constant_value);
+ op_ = "Const";
+ }
+
+ void set_input_nodes(NodeSeqMatcher input_nodes) {
+ DCHECK(IsEmpty());
+ input_nodes_ = std::move(input_nodes);
+ }
+
+ void set_control_deps(NodeSeqMatcher control_deps) {
+ DCHECK(IsEmpty());
+ control_deps_ = std::move(control_deps);
+ }
+
+ bool IsEmpty() const {
+ return !name().has_value() && !op().has_value() &&
+ !input_nodes().has_value() && !control_deps().has_value();
+ }
+
+ private:
+ absl::optional<string> name_;
+ absl::optional<string> op_;
+ absl::optional<string> assigned_device_;
+ absl::optional<Tensor> constant_value_;
+ absl::optional<NodeSeqMatcher> input_nodes_;
+ absl::optional<NodeSeqMatcher> control_deps_;
+};
+
+::testing::Matcher<const Node*> NodeWith(
+ absl::Span<const NodeMatcherProperties> props);
+
+impl::NodeMatcherProperties Inputs(
+ absl::Span<const ::testing::Matcher<const Node*>> inputs);
+
+impl::NodeMatcherProperties CtrlDeps(
+ absl::Span<const ::testing::Matcher<const Node*>> control_deps);
+} // namespace impl
+
+// -----------------------------------------------------------------------------
+// Public interface.
+
+// Matches a node with name `name`.
+impl::NodeMatcherProperties Name(string name);
+
+// Matches a node with op `op`.
+impl::NodeMatcherProperties Op(string op);
+
+// Matches a node with assigned device `assigned_device`.
+impl::NodeMatcherProperties AssignedDevice(string assigned_device);
+
+// Matches a node with inputs `inputs`.
+//
+// `inputs` are ordered; `inputs`[i] must match input i.
+template <typename... Ts>
+impl::NodeMatcherProperties Inputs(Ts... inputs) {
+ return impl::Inputs({inputs...});
+}
+
+// Matches a node with control dependences `control_deps`.
+//
+// `control_deps` are unordered and will match the control deps of a node in any
+// order.
+template <typename... Ts>
+impl::NodeMatcherProperties CtrlDeps(Ts... control_deps) {
+ return impl::CtrlDeps({control_deps...});
+}
+
+// Matches a constant node with value `val`.
+impl::NodeMatcherProperties ConstantValue(
+ const ::tensorflow::Input::Initializer& val);
+
+// The main gmock matcher. See file comment for example usage.
+template <typename... Ts>
+::testing::Matcher<const Node*> NodeWith(Ts... args) {
+ std::array<impl::NodeMatcherProperties, sizeof...(Ts)> array = {args...};
+ return impl::NodeWith(array);
+}
+
+::testing::Matcher<const Node*> Const(
+ const ::tensorflow::Input::Initializer& val);
+} // namespace matchers
+
+// If `g` has a node named `name` returns it, otherwise returns null.
+Node* FindNodeByName(Graph* g, absl::string_view name);
+} // namespace testing
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_
diff --git a/tensorflow/compiler/jit/node_matchers_test.cc b/tensorflow/compiler/jit/node_matchers_test.cc
new file mode 100644
index 0000000000..93a8994307
--- /dev/null
+++ b/tensorflow/compiler/jit/node_matchers_test.cc
@@ -0,0 +1,179 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/node_matchers.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/math_ops.h"
+
+namespace tensorflow {
+namespace testing {
+namespace {
+
+using ::testing::_;
+
+using testing::matchers::AssignedDevice;
+using testing::matchers::ConstantValue;
+using testing::matchers::CtrlDeps;
+using testing::matchers::Inputs;
+using testing::matchers::Name;
+using testing::matchers::NodeWith;
+using testing::matchers::Op;
+
+template <typename M, typename T>
+string Explain(const T& t, const M& m) {
+ ::testing::StringMatchResultListener listener;
+ EXPECT_THAT(t, ::testing::Not(m)); // For the error message.
+ EXPECT_FALSE(m.MatchAndExplain(t, &listener));
+ return listener.str();
+}
+
+TEST(NodeMatchers, CheckAgainstConstant) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output placeholder =
+ ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
+
+ EXPECT_THAT(placeholder.node(), NodeWith(Op("Placeholder")));
+ EXPECT_THAT(placeholder.node(), NodeWith(Name("placeholder")));
+ EXPECT_THAT(placeholder.node(),
+ NodeWith(Op("Placeholder"), Name("placeholder")));
+ EXPECT_THAT(placeholder.node(),
+ NodeWith(Name("placeholder"), Op("Placeholder")));
+ EXPECT_THAT(placeholder.node(), NodeWith(Inputs()));
+ EXPECT_THAT(placeholder.node(),
+ NodeWith(Op("Placeholder"), Name("placeholder"), Inputs()));
+
+ EXPECT_EQ(Explain(placeholder.node(), NodeWith(Op("Add"))),
+ "\nexpected op Add but found Placeholder");
+ EXPECT_EQ(Explain(placeholder.node(), NodeWith(Name("add"))),
+ "\nexpected name add but found placeholder");
+ EXPECT_EQ(Explain(placeholder.node(), NodeWith(Inputs(NodeWith()))),
+ "\nexpected 1 inputs but node has 0");
+}
+
+TEST(NodeMatchers, CheckAgainstBinary) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output placeholder_a =
+ ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT);
+ Output placeholder_b =
+ ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT);
+ Output add = ops::Add(root.WithOpName("add"), placeholder_a, placeholder_b);
+
+ EXPECT_THAT(add.node(), NodeWith(Op("Add"), Name("add"),
+ Inputs(NodeWith(Name("placeholder_a")),
+ NodeWith(Name("placeholder_b")))));
+
+ EXPECT_EQ(Explain(add.node(), NodeWith(Inputs())),
+ "\nexpected 0 inputs but node has 2");
+ EXPECT_EQ(
+ Explain(add.node(), NodeWith(Inputs(NodeWith(Name("blah")), _))),
+ "\ninput 0 does not match expected:\nname: blah, \nsource does not match "
+ "expected name: blah\n\t\nexpected name blah but found placeholder_a");
+ EXPECT_EQ(
+ Explain(add.node(), NodeWith(Inputs(_, NodeWith(Name("blah"))))),
+ "\ninput 1 does not match expected:\nname: blah, \nsource does not match "
+ "expected name: blah\n\t\nexpected name blah but found placeholder_b");
+}
+
+TEST(NodeMatchers, CheckControlDependence) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output placeholder_a =
+ ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT);
+ Output placeholder_b =
+ ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT);
+ Output placeholder_c =
+ ops::Placeholder(root.WithOpName("placeholder_c"), DT_FLOAT);
+ Output placeholder_d =
+ ops::Placeholder(root.WithOpName("placeholder_d"), DT_FLOAT);
+
+ root.graph()->AddControlEdge(placeholder_a.node(), placeholder_c.node());
+ root.graph()->AddControlEdge(placeholder_b.node(), placeholder_c.node());
+
+ EXPECT_THAT(placeholder_c.node(),
+ NodeWith(Name("placeholder_c"),
+ CtrlDeps(NodeWith(Name("placeholder_a")),
+ NodeWith(Name("placeholder_b")))));
+ EXPECT_THAT(placeholder_d.node(),
+ NodeWith(Name("placeholder_d"), CtrlDeps()));
+
+ EXPECT_EQ(
+ Explain(placeholder_c.node(), NodeWith(CtrlDeps())),
+ "ctrl_deps, which has 2 elements, does not match expected: is empty");
+ EXPECT_EQ(Explain(placeholder_d.node(), NodeWith(CtrlDeps(NodeWith()))),
+ "ctrl_deps does not match expected: has 1 element and that element "
+ "is any node");
+}
+
+TEST(NodeMatchers, ConstVaulue) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output placeholder =
+ ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
+ Output const_0d = ops::Const(root.WithOpName("const_0d"), 42);
+
+ Output const_2d = ops::Const(root.WithOpName("const_2d"), {{1, 2}, {4, 3}});
+
+ EXPECT_THAT(const_0d.node(), NodeWith(ConstantValue(42)));
+ EXPECT_THAT(const_0d.node(), NodeWith(ConstantValue(42), Name("const_0d")));
+
+ EXPECT_THAT(const_2d.node(), NodeWith(ConstantValue({{1, 2}, {4, 3}})));
+
+ EXPECT_EQ(Explain(placeholder.node(), NodeWith(ConstantValue(42))),
+ "\nexpected op Const but found Placeholder");
+ EXPECT_EQ(
+ Explain(const_0d.node(), NodeWith(ConstantValue(43))),
+ "\nmismatch in constant tensor at index 0 expected = 43 actual = 42");
+ EXPECT_EQ(
+ Explain(const_0d.node(), NodeWith(ConstantValue({{1, 2}, {4, 3}}))),
+ "\nwas looking for tensor with 4 elements, found tensor with 1 elements");
+ EXPECT_EQ(
+ Explain(const_2d.node(), NodeWith(ConstantValue(42))),
+ "\nwas looking for tensor with 1 elements, found tensor with 4 elements");
+}
+
+TEST(NodeMatchers, AssignedDevice) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output placeholder_a =
+ ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT);
+ Output placeholder_b =
+ ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT);
+
+ Output assigned_add =
+ ops::Add(root.WithOpName("assigned_add"), placeholder_a, placeholder_b);
+ assigned_add.node()->set_assigned_device_name(
+ "/job:localhost/replica:0/task:0/device:CPU:0");
+
+ Output unassigned_add =
+ ops::Add(root.WithOpName("unassigned_add"), placeholder_a, placeholder_b);
+
+ EXPECT_THAT(
+ assigned_add.node(),
+ NodeWith(AssignedDevice("/job:localhost/replica:0/task:0/device:CPU:0")));
+ EXPECT_THAT(unassigned_add.node(), NodeWith(AssignedDevice("")));
+
+ EXPECT_EQ(Explain(unassigned_add.node(),
+ NodeWith(AssignedDevice(
+ "/job:localhost/replica:0/task:0/device:CPU:0"))),
+ "\nexpected assigned_device "
+ "/job:localhost/replica:0/task:0/device:CPU:0 but found \"\"");
+}
+
+} // namespace
+} // namespace testing
+} // namespace tensorflow