diff options
author | 2018-09-11 00:50:04 -0700 | |
---|---|---|
committer | 2018-09-11 00:54:33 -0700 | |
commit | 45965cfd8b54fb113275ffdaced5366e28aa3553 (patch) | |
tree | 253c390dceb910360cb3b62d5039bcbcdf0f5c5d /tensorflow/compiler/tf2xla | |
parent | 5375f8c48b3087512f7593cf699346cc0b30a27b (diff) |
Graph optimization pass that creates XlaLaunch ops for the computations that have been explicitly marked to be compiled via xla.compile()
PiperOrigin-RevId: 212407112
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r-- | tensorflow/compiler/tf2xla/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/cc/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/test_util.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/test_util.h | 16 |
4 files changed, 28 insertions, 1 deletions
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index ab289a2b6c..74b131e07e 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -594,6 +594,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index ea8d1b3d14..8ac5eb5df9 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -31,7 +31,9 @@ cc_library( tf_gen_op_wrapper_cc( name = "xla_jit_op_gen", out_ops_file = "ops/xla_jit_op", - deps = ["//tensorflow/compiler/jit/ops:xla_ops"], + deps = [ + "//tensorflow/compiler/jit/ops:xla_ops", + ], ) cc_library( diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc index 3c6c9a91b6..f31bfb45a2 100644 --- a/tensorflow/compiler/tf2xla/test_util.cc +++ b/tensorflow/compiler/tf2xla/test_util.cc @@ -40,4 +40,12 @@ Status InstantiateFunctionForTest(const string& name, return Status::OK(); } +std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph) { + std::unordered_map<string, Node*> index; + for (Node* node : graph.nodes()) { + index[node->name()] = node; + } + return index; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h index e6e4ae92ed..350a868568 100644 --- a/tensorflow/compiler/tf2xla/test_util.h +++ b/tensorflow/compiler/tf2xla/test_util.h @@ -24,8 +24,10 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { @@ -42,6 +44,20 @@ Status InstantiateFunctionForTest(const string& name, const FunctionLibraryDefinition& library, InstantiationResultForTest* result); +// Builds a map from node name to Node* for `graph`. +std::unordered_map<string, Node*> BuildNodeIndex(const Graph& graph); + } // namespace tensorflow +// Variant of TF_EXPECT_GRAPH_EQ that also compares internal attributes for +// equality. +#define TF_EXPECT_GRAPH_EQ_INTERNAL(expected, actual) \ + do { \ + string diff; \ + EqualGraphDefOptions eq_options; \ + eq_options.ignore_internal_attrs = false; \ + EXPECT_TRUE(EqualGraphDef(actual, expected, &diff, eq_options)) \ + << diff << "\nActual: " << SummarizeGraphDef(actual); \ + } while (false) + #endif // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ |