aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/node_matchers.h
blob: 0437a7e95c1eb3bdcdbe24a440dd90a5943c0894 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Provides a set of matchers for tensorflow nodes.
//
// Example usage:
//
//  tensorflow::Node* node = ...;
//  EXPECT_THAT(node, NodeWith(Name("name"), Op("op"),
//                             Inputs(NodeWith(Name("input")))))
//
// Matchable node properties (the expressions that go inside NodeWith(...))
// are:
//
//  - Name(string): matches the node name exactly.  We will probably need to
//    have this take a string matcher soon in the future.
//
//  - Op(string): matches the op exactly.
//
//  - AssignedDevice(string): matches the assigned device exactly.
//
//  - Inputs(<ordered list>): matches the list of non-control inputs to the node
//    exactly (i.e. does not match a suffix or a prefix).
//
//  - CtrlDeps(<unordered list>): matches the list of control dependences on the
//    node exactly but in any order.
//
//  - ConstantValue(tensorflow::Input::Initializer init): matches a Const node
//    with the constant value `init`.  Implies Op("Const").
//
// Node properties may not be repeated in a single NodeWith(...)  matcher.
// E.g. NodeWith(Op("Foo"), Op("Bar")) will CHECK-fail.  Since ConstantValue
// implies Op("Const"), a single NodeWith matcher can't have both
// ConstantValue(...) and Op(...).

#ifndef TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_
#define TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_

#include <array>
#include <string>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/core/graph/graph.h"

namespace tensorflow {
namespace testing {
namespace matchers {

namespace impl {

// -----------------------------------------------------------------------------
// Implementation details.

// Properties that we match on for a particular Node.  If a particular property
// is nullopt then any value for it is allowed.
class NodeMatcherProperties {
 public:
  using NodeSeqMatcher = std::vector<::testing::Matcher<const Node*>>;

  const absl::optional<string>& name() const { return name_; }
  const absl::optional<string>& op() const { return op_; }
  const absl::optional<string>& assigned_device() const {
    return assigned_device_;
  }
  const absl::optional<Tensor>& constant_value() const {
    return constant_value_;
  }
  const absl::optional<NodeSeqMatcher>& input_nodes() const {
    return input_nodes_;
  }
  const absl::optional<NodeSeqMatcher>& control_deps() const {
    return control_deps_;
  }

  void set_name(string name) {
    DCHECK(IsEmpty());
    name_ = std::move(name);
  }

  void set_op(string op) {
    DCHECK(IsEmpty());
    op_ = std::move(op);
  }

  void set_assigned_device(string assigned_device) {
    DCHECK(IsEmpty());
    assigned_device_ = std::move(assigned_device);
  }

  void set_constant_value(Tensor constant_value) {
    DCHECK(IsEmpty());
    constant_value_ = std::move(constant_value);
    op_ = "Const";
  }

  void set_input_nodes(NodeSeqMatcher input_nodes) {
    DCHECK(IsEmpty());
    input_nodes_ = std::move(input_nodes);
  }

  void set_control_deps(NodeSeqMatcher control_deps) {
    DCHECK(IsEmpty());
    control_deps_ = std::move(control_deps);
  }

  bool IsEmpty() const {
    return !name().has_value() && !op().has_value() &&
           !input_nodes().has_value() && !control_deps().has_value();
  }

 private:
  absl::optional<string> name_;
  absl::optional<string> op_;
  absl::optional<string> assigned_device_;
  absl::optional<Tensor> constant_value_;
  absl::optional<NodeSeqMatcher> input_nodes_;
  absl::optional<NodeSeqMatcher> control_deps_;
};

::testing::Matcher<const Node*> NodeWith(
    absl::Span<const NodeMatcherProperties> props);

impl::NodeMatcherProperties Inputs(
    absl::Span<const ::testing::Matcher<const Node*>> inputs);

impl::NodeMatcherProperties CtrlDeps(
    absl::Span<const ::testing::Matcher<const Node*>> control_deps);
}  // namespace impl

// -----------------------------------------------------------------------------
// Public interface.

// Matches a node with name `name`.
impl::NodeMatcherProperties Name(string name);

// Matches a node with op `op`.
impl::NodeMatcherProperties Op(string op);

// Matches a node with assigned device `assigned_device`.
impl::NodeMatcherProperties AssignedDevice(string assigned_device);

// Matches a node with inputs `inputs`.
//
// `inputs` are ordered; `inputs`[i] must match input i.
template <typename... Ts>
impl::NodeMatcherProperties Inputs(Ts... inputs) {
  return impl::Inputs({inputs...});
}

// Matches a node with control dependences `control_deps`.
//
// `control_deps` are unordered and will match the control deps of a node in any
// order.
template <typename... Ts>
impl::NodeMatcherProperties CtrlDeps(Ts... control_deps) {
  return impl::CtrlDeps({control_deps...});
}

// Matches a constant node with value `val`.
impl::NodeMatcherProperties ConstantValue(
    const ::tensorflow::Input::Initializer& val);

// The main gmock matcher.  See file comment for example usage.
template <typename... Ts>
::testing::Matcher<const Node*> NodeWith(Ts... args) {
  std::array<impl::NodeMatcherProperties, sizeof...(Ts)> array = {args...};
  return impl::NodeWith(array);
}

::testing::Matcher<const Node*> Const(
    const ::tensorflow::Input::Initializer& val);
}  // namespace matchers

// If `g` has a node named `name` returns it, otherwise returns null.
Node* FindNodeByName(Graph* g, absl::string_view name);
}  // namespace testing
}  // namespace tensorflow

#endif  // TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_