aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-11 00:50:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-11 00:54:33 -0700
commit45965cfd8b54fb113275ffdaced5366e28aa3553 (patch)
tree253c390dceb910360cb3b62d5039bcbcdf0f5c5d /tensorflow/compiler/tf2xla
parent5375f8c48b3087512f7593cf699346cc0b30a27b (diff)
Graph optimization pass that creates XlaLaunch ops for the computations that have been explicitly marked to be compiled via xla.compile()
PiperOrigin-RevId: 212407112
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/cc/BUILD4
-rw-r--r--tensorflow/compiler/tf2xla/test_util.cc8
-rw-r--r--tensorflow/compiler/tf2xla/test_util.h16
4 files changed, 28 insertions, 1 deletions
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index ab289a2b6c..74b131e07e 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -594,6 +594,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD
index ea8d1b3d14..8ac5eb5df9 100644
--- a/tensorflow/compiler/tf2xla/cc/BUILD
+++ b/tensorflow/compiler/tf2xla/cc/BUILD
@@ -31,7 +31,9 @@ cc_library(
tf_gen_op_wrapper_cc(
name = "xla_jit_op_gen",
out_ops_file = "ops/xla_jit_op",
- deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
+ deps = [
+ "//tensorflow/compiler/jit/ops:xla_ops",
+ ],
)
cc_library(
diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc
index 3c6c9a91b6..f31bfb45a2 100644
--- a/tensorflow/compiler/tf2xla/test_util.cc
+++ b/tensorflow/compiler/tf2xla/test_util.cc
@@ -40,4 +40,12 @@ Status InstantiateFunctionForTest(const string& name,
return Status::OK();
}
+std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph) {
+ std::unordered_map<string, Node*> index;
+ for (Node* node : graph.nodes()) {
+ index[node->name()] = node;
+ }
+ return index;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h
index e6e4ae92ed..350a868568 100644
--- a/tensorflow/compiler/tf2xla/test_util.h
+++ b/tensorflow/compiler/tf2xla/test_util.h
@@ -24,8 +24,10 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
@@ -42,6 +44,20 @@ Status InstantiateFunctionForTest(const string& name,
const FunctionLibraryDefinition& library,
InstantiationResultForTest* result);
+// Builds a map from node name to Node* for `graph`.
+std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph);
+
} // namespace tensorflow
+// Variant of TF_EXPECT_GRAPH_EQ that also compares internal attributes for
+// equality.
+#define TF_EXPECT_GRAPH_EQ_INTERNAL(expected, actual) \
+ do { \
+ string diff; \
+ EqualGraphDefOptions eq_options; \
+ eq_options.ignore_internal_attrs = false; \
+ EXPECT_TRUE(EqualGraphDef(actual, expected, &diff, eq_options)) \
+ << diff << "\nActual: " << SummarizeGraphDef(actual); \
+ } while (false)
+
#endif // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_