aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-13 16:08:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-13 16:12:39 -0700
commit29b56bde1e28e558111b917fd44b973e2aea7fcf (patch)
tree5de81939bd5226b22deb304d4e3f6914fdb04c0c /tensorflow/compiler/tf2xla
parent4292b8107175b3c3223f65c75b3ca091bd0604ec (diff)
Automated rollback of commit ac60b46e2c5962fd8099a4406c1788d826ad3c0d
PiperOrigin-RevId: 212896336
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/test_util.cc8
-rw-r--r--tensorflow/compiler/tf2xla/test_util.h16
3 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index d549e7bb59..ba1e3b2b4f 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -611,6 +611,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc
index 3c6c9a91b6..f31bfb45a2 100644
--- a/tensorflow/compiler/tf2xla/test_util.cc
+++ b/tensorflow/compiler/tf2xla/test_util.cc
@@ -40,4 +40,12 @@ Status InstantiateFunctionForTest(const string& name,
return Status::OK();
}
+std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph) {
+ std::unordered_map<string, Node*> index;
+ for (Node* node : graph.nodes()) {
+ index[node->name()] = node;
+ }
+ return index;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h
index e6e4ae92ed..350a868568 100644
--- a/tensorflow/compiler/tf2xla/test_util.h
+++ b/tensorflow/compiler/tf2xla/test_util.h
@@ -24,8 +24,10 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
@@ -42,6 +44,20 @@ Status InstantiateFunctionForTest(const string& name,
const FunctionLibraryDefinition& library,
InstantiationResultForTest* result);
+// Builds a map from node name to Node* for `graph`.
+std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph);
+
} // namespace tensorflow
+// Variant of TF_EXPECT_GRAPH_EQ that also compares internal attributes for
+// equality.
+#define TF_EXPECT_GRAPH_EQ_INTERNAL(expected, actual) \
+ do { \
+ string diff; \
+ EqualGraphDefOptions eq_options; \
+ eq_options.ignore_internal_attrs = false; \
+ EXPECT_TRUE(EqualGraphDef(actual, expected, &diff, eq_options)) \
+ << diff << "\nActual: " << SummarizeGraphDef(actual); \
+ } while (false)
+
#endif // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_