aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/aot/BUILD35
-rw-r--r--tensorflow/compiler/aot/codegen.cc55
-rw-r--r--tensorflow/compiler/aot/codegen.h8
-rw-r--r--tensorflow/compiler/aot/codegen_test.cc42
-rw-r--r--tensorflow/compiler/aot/compile.cc340
-rw-r--r--tensorflow/compiler/aot/compile.h46
-rw-r--r--tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/tests/BUILD20
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfsplits.config.pbtxt2
-rw-r--r--tensorflow/compiler/aot/tfcompile.bzl2
-rw-r--r--tensorflow/compiler/aot/tfcompile_main.cc26
-rw-r--r--tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt2
-rw-r--r--tensorflow/compiler/tf2xla/BUILD79
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla.cc370
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla.h43
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla.proto (renamed from tensorflow/compiler/aot/tfcompile.proto)26
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_test.cc99
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.cc (renamed from tensorflow/compiler/aot/tfcompile_util.cc)74
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.h (renamed from tensorflow/compiler/aot/tfcompile_util.h)23
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util_test.cc (renamed from tensorflow/compiler/aot/tfcompile_util_test.cc)89
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc1
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h14
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD23
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc68
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc20
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h12
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc30
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h7
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.h2
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc39
42 files changed, 989 insertions, 645 deletions
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index 9909e88e64..29dbe4a08b 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -4,7 +4,6 @@ package(
default_visibility = ["//visibility:private"],
)
-load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
# Optional runtime utilities for use by code generated by tfcompile.
@@ -39,32 +38,24 @@ cc_library(
deps = ["//tensorflow/core:test_main"],
)
-xla_proto_library(
- name = "tfcompile_proto",
- srcs = ["tfcompile.proto"],
- deps = [
- "//tensorflow/core:protos_all_cc",
- ],
-)
-
cc_library(
name = "tfcompile_lib",
srcs = [
"codegen.cc",
"compile.cc",
"flags.cc",
- "tfcompile_util.cc",
],
hdrs = [
"codegen.h",
"compile.h",
"flags.h",
- "tfcompile_util.h",
],
deps = [
":runtime", # needed by codegen to print aligned_buffer_bytes
- ":tfcompile_proto",
+ "//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:tf2xla_proto",
+ "//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@@ -82,7 +73,6 @@ cc_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core:stream_executor_no_cuda",
],
)
@@ -99,18 +89,6 @@ cc_test(
],
)
-cc_test(
- name = "tfcompile_util_test",
- srcs = ["tfcompile_util_test.cc"],
- deps = [
- ":tfcompile_lib",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ],
-)
-
cc_binary(
name = "tfcompile",
visibility = ["//visibility:public"],
@@ -123,7 +101,8 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":tfcompile_lib",
- ":tfcompile_proto",
+ "//tensorflow/compiler/tf2xla:tf2xla_proto",
+ "//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/core:core_cpu",
@@ -226,7 +205,11 @@ test_suite(
tags = ["manual"],
tests = [
":benchmark_test",
+ ":codegen_test",
+ ":runtime_test",
":test_graph_tfadd_test",
+ ":test_graph_tfunknownop2_test",
+ ":test_graph_tfunknownop3_test",
":test_graph_tfunknownop_test",
"//tensorflow/compiler/aot/tests:all_tests",
],
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index bbdb342a62..fc5c6ce58d 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/aot/runtime.h"
-#include "tensorflow/compiler/aot/tfcompile_util.h"
#include "tensorflow/compiler/tf2xla/str_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -35,6 +35,12 @@ namespace tfcompile {
namespace {
+bool IsAlpha(char c) {
+ return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
+}
+
+bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
+
// Convert an XLA type into a C++ type.
Status XLATypeToCpp(xla::PrimitiveType type, string* str) {
switch (type) {
@@ -156,7 +162,7 @@ string RewriteWithName(const string& name, string code,
}
// Generate methods for args (inputs).
-Status GenArgMethods(const Config& config, const xla::ProgramShape& ps,
+Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
const CompileResult& compile_result, string* methods) {
*methods += R"(
void** args() { return args_; }
@@ -204,8 +210,8 @@ Status GenArgMethods(const Config& config, const xla::ProgramShape& ps,
}
// Generate methods for results (outputs).
-Status GenResultMethods(const Config& config, const xla::ProgramShape& ps,
- string* methods) {
+Status GenResultMethods(const tf2xla::Config& config,
+ const xla::ProgramShape& ps, string* methods) {
if (ps.result().element_type() != xla::TUPLE) {
// Non-tuple (i.e. single-result) case.
if (config.fetch_size() != 1) {
@@ -285,11 +291,26 @@ Status GenResultMethods(const Config& config, const xla::ProgramShape& ps,
return Status::OK();
}
+Status ValidateFeedFetchCppNames(const tf2xla::Config& config) {
+ for (const tf2xla::Feed& feed : config.feed()) {
+ if (!feed.name().empty()) {
+ TF_RETURN_IF_ERROR(ValidateCppIdent(feed.name(), "feed name"));
+ }
+ }
+ for (const tf2xla::Fetch& fetch : config.fetch()) {
+ if (!fetch.name().empty()) {
+ TF_RETURN_IF_ERROR(ValidateCppIdent(fetch.name(), "fetch name"));
+ }
+ }
+ return Status::OK();
+}
+
} // namespace
-Status GenerateHeader(const HeaderOpts& opts, const Config& config,
+Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config,
const CompileResult& compile_result, string* header) {
TF_RETURN_IF_ERROR(ValidateConfig(config));
+ TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config));
const int64 result_index = compile_result.aot->result_buffer_index();
const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes();
if (result_index < 0 || result_index > temp_sizes.size()) {
@@ -574,5 +595,29 @@ Status ParseCppClass(const string& cpp_class, string* class_name,
return Status::OK();
}
+Status ValidateCppIdent(StringPiece ident, StringPiece msg) {
+ if (ident.empty()) {
+ return errors::InvalidArgument("empty identifier: ", msg);
+ }
+ // Require that the identifier starts with a nondigit, and is composed of
+ // nondigits and digits, as specified in section [2.11 Identifiers] of the
+ // C++11 Standard. Note that nondigit is defined as [_a-zA-Z] and digit is
+ // defined as [0-9].
+ //
+ // Technically the standard also allows for `universal-character-name`, with a
+ // table of allowed unicode ranges, as well as `other implementation-defined
+ // characters`. We disallow those here to give better error messages, at the
+ // expensive of being more restrictive than the standard.
+ if (ident[0] != '_' && !IsAlpha(ident[0])) {
+ return errors::InvalidArgument("illegal leading char: ", msg);
+ }
+ for (size_t pos = 1; pos < ident.size(); ++pos) {
+ if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) {
+ return errors::InvalidArgument("illegal char: ", msg);
+ }
+ }
+ return Status::OK();
+}
+
} // namespace tfcompile
} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h
index 7217c57739..740edd1e83 100644
--- a/tensorflow/compiler/aot/codegen.h
+++ b/tensorflow/compiler/aot/codegen.h
@@ -20,6 +20,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/aot/compile.h"
+#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
namespace tensorflow {
namespace tfcompile {
@@ -37,7 +39,7 @@ struct HeaderOpts {
// GenerateHeader uses the meta-information from compile_result to generate a
// C++ header giving access to the function in the generated object file. The
// header includes API usage documentation.
-Status GenerateHeader(const HeaderOpts& opts, const Config& config,
+Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config,
const CompileResult& compile_result, string* header);
// ParseCppClass parses `cpp_class` into its `class_name` and `namespaces`
@@ -47,6 +49,10 @@ Status GenerateHeader(const HeaderOpts& opts, const Config& config,
Status ParseCppClass(const string& cpp_class, string* class_name,
std::vector<string>* namespaces);
+// ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is
+// appended to error messages.
+Status ValidateCppIdent(StringPiece ident, StringPiece msg);
+
} // namespace tfcompile
} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc
index e3f76f3666..98cbd67e53 100644
--- a/tensorflow/compiler/aot/codegen_test.cc
+++ b/tensorflow/compiler/aot/codegen_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
@@ -29,6 +30,41 @@ namespace tensorflow {
namespace tfcompile {
namespace {
+void ExpectErrorContains(const Status& status, StringPiece str) {
+ EXPECT_NE(Status::OK(), status);
+ EXPECT_TRUE(StringPiece(status.error_message()).contains(str))
+ << "expected error: " << status.error_message() << " to contain: " << str;
+}
+
+TEST(ValidateCppIdent, Simple) {
+ TF_EXPECT_OK(ValidateCppIdent("a", ""));
+ TF_EXPECT_OK(ValidateCppIdent("abc", ""));
+ TF_EXPECT_OK(ValidateCppIdent("_abc", ""));
+ TF_EXPECT_OK(ValidateCppIdent("_abc123", ""));
+ // Make sure we didn't skip a valid letter or digit
+ string ident;
+ for (char c = 'a'; c <= 'z'; c++) {
+ ident.append(1, c);
+ }
+ for (char c = 'A'; c <= 'Z'; c++) {
+ ident.append(1, c);
+ }
+ for (char c = '0'; c <= '9'; c++) {
+ ident.append(1, c);
+ }
+ ident += "_";
+ TF_EXPECT_OK(ValidateCppIdent(ident, ""));
+
+ ExpectErrorContains(ValidateCppIdent("", ""), "empty identifier");
+ ExpectErrorContains(ValidateCppIdent(" ", ""), "illegal leading char");
+ ExpectErrorContains(ValidateCppIdent("0", ""), "illegal leading char");
+ ExpectErrorContains(ValidateCppIdent(".", ""), "illegal leading char");
+ ExpectErrorContains(ValidateCppIdent(":", ""), "illegal leading char");
+ ExpectErrorContains(ValidateCppIdent("a.", ""), "illegal char");
+ ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
+ ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
+}
+
class ParseCppClassTest : public ::testing::Test {
protected:
void ExpectOK(const string& cpp_class, const string& want_class_name,
@@ -91,13 +127,13 @@ TEST(GenerateHeader, Golden) {
HeaderOpts opts;
opts.class_name = "MyClass";
opts.namespaces = {"foo", "bar"};
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("feed0");
feed->set_name("myfeed");
feed = config.add_feed();
feed->mutable_id()->set_node_name("feed1");
- Fetch* fetch = config.add_fetch();
+ tf2xla::Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("fetch0");
fetch->set_name("myfetch");
CompileResult compile_result;
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc
index a485d2e555..eac8da0ab1 100644
--- a/tensorflow/compiler/aot/compile.cc
+++ b/tensorflow/compiler/aot/compile.cc
@@ -15,326 +15,32 @@ limitations under the License.
#include "tensorflow/compiler/aot/compile.h"
-#include <map>
#include <memory>
#include <string>
-#include <unordered_map>
#include <utility>
#include <vector>
#include "tensorflow/compiler/aot/flags.h"
-#include "tensorflow/compiler/aot/tfcompile_util.h"
-#include "tensorflow/compiler/tf2xla/shape_util.h"
-#include "tensorflow/compiler/tf2xla/xla_compiler.h"
-#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/tf2xla/tf2xla.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/compile_only_client.h"
-#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
-#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/framework/graph_def_util.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/framework/versions.pb.h"
-#include "tensorflow/core/graph/algorithm.h"
-#include "tensorflow/core/graph/graph.h"
-#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace tfcompile {
-const char* const kArgOp = "_Arg";
-const char* const kRetvalOp = "_Retval";
-const char* const kFeedIdAttr = "_feed_id";
-const char* const kFetchIdAttr = "_fetch_id";
-const char* const kShapeAttr = "_shape";
-const char* const kDebugNameAttr = "_debug_name";
-
namespace {
-Status DumpGraph(const MainFlags& flags, const string& name,
- const Graph& graph) {
- if (flags.debug_dir.empty()) {
- return Status::OK();
- }
- GraphDef graph_def;
- graph.ToGraphDef(&graph_def);
- string file = io::JoinPath(flags.debug_dir, name + ".pbtxt");
- return WriteTextProto(Env::Default(), file, graph_def);
-}
-
-typedef std::unordered_map<string, Node*> NodeMap;
-
-// Each feed id identifies the positional output of some node, which may consist
-// of multiple edges. AddPlaceholdersForFeeds has already replaced each fed
-// tensor with a placeholder. For each feed tensor, replaces all edges so they
-// point from a new _Arg node instead.
-Status AddArgNodes(Graph* graph, const NodeMap& node_map,
- const protobuf::RepeatedPtrField<Feed>& feeds,
- const std::unordered_map<string, string>& feed_remapping) {
- for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
- const Feed& feed = feeds[arg_index];
- // All feeds have been replaced by placeholders.
- const int output_index = 0;
-
- const string key = TensorIdToString(feed.id());
- const auto remap_it = feed_remapping.find(key);
- auto node_it = node_map.find(remap_it->second);
- if (node_it == node_map.end()) {
- // Strip off the aot_feed_#/ prefix.
- StringPiece name(remap_it->second);
- const auto index = name.find('/');
- if (index > 0) name.remove_prefix(index + 1);
- return errors::InvalidArgument(
- "Node is fed but not needed for fetching: ", name);
- }
- const Node* feed_node = node_it->second;
-
- // TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a
- // "_shape" attr if we can determine it. That way the graph will be
- // initialized with whatever shapes we can infer, while the user can still
- // explicitly specify or override them.
- Node* arg_node = nullptr;
- TF_RETURN_IF_ERROR(
- NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
- .Attr("T", BaseType(feed_node->output_type(output_index)))
- .Attr("index", arg_index)
- .Attr(kFeedIdAttr, TensorIdToString(feed.id()))
- .Attr(kShapeAttr, TensorShape(feed.shape()))
- .Attr(kDebugNameAttr, feed.name())
- .Finalize(graph, &arg_node));
-
- // Collects out-edges from the feed node that have a matching edge index;
- // these will be replaced with edges from the arg node instead.
- //
- // We must collect the edges first and process them in a second pass, since
- // removing the edge from the graph invalidates feed_node->out_edges.
- std::vector<const Edge*> feed_edges;
- for (const Edge* edge : feed_node->out_edges()) {
- if (edge->src_output() == output_index) {
- feed_edges.push_back(edge);
- }
- }
- for (const Edge* edge : feed_edges) {
- graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
- graph->RemoveEdge(edge);
- }
- }
- return Status::OK();
-}
-
-// Each fetch id identifies the positional output of some node. For each fetch
-// node, adds a new _Retval node instead, and adds the node to `retval_nodes`.
-Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
- const protobuf::RepeatedPtrField<Fetch>& fetches,
- std::unordered_set<const Node*>* retval_nodes) {
- for (int ret_index = 0; ret_index < fetches.size(); ++ret_index) {
- const TensorId& id = fetches[ret_index].id();
- auto it = node_map.find(id.node_name());
- if (it == node_map.end()) {
- return errors::NotFound("Can't find fetch id: ", TensorIdToString(id));
- }
- Node* fetch_node = it->second;
- if (id.output_index() >= fetch_node->num_outputs()) {
- return errors::InvalidArgument("Invalid fetch id: ", TensorIdToString(id),
- ", output index should be < ",
- fetch_node->num_outputs());
- }
- // Connects fetch_node -> retval_node.
- Node* retval_node = nullptr;
- TF_RETURN_IF_ERROR(
- NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp)
- .Input(fetch_node, id.output_index())
- .Attr("T", BaseType(fetch_node->output_type(id.output_index())))
- .Attr("index", ret_index)
- .Attr(kFetchIdAttr, TensorIdToString(id))
- .Finalize(graph, &retval_node));
- retval_nodes->insert(retval_node);
- }
- return Status::OK();
-}
-
-// RewriteAndPruneGraph identifies input and output edges (named by the feed and
-// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
-// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph
-// execution to know the input and output args for the generated function.
-Status RewriteAndPruneGraph(
- Graph* graph, const Config& config,
- const std::unordered_map<string, string>& feed_remapping,
- const MainFlags& flags) {
- NodeMap node_map;
- for (Node* n : graph->nodes()) {
- node_map[n->name()] = n;
- }
- TF_RETURN_IF_ERROR(
- AddArgNodes(graph, node_map, config.feed(), feed_remapping));
- std::unordered_set<const Node*> retval_nodes;
- TF_RETURN_IF_ERROR(
- AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
- TF_RETURN_IF_ERROR(DumpGraph(flags, "tfcompile_post_rewrite", *graph));
- PruneForReverseReachability(graph, retval_nodes);
- FixupSourceAndSinkEdges(graph);
- TF_RETURN_IF_ERROR(DumpGraph(flags, "tfcompile_post_prune", *graph));
- // Sanity-check, to make sure the feeds and fetches still exist post-pruning.
- std::set<string> missing_feeds, missing_fetches;
- for (const Feed& feed : config.feed()) {
- missing_feeds.insert(TensorIdToString(feed.id()));
- }
- for (const Fetch& fetch : config.fetch()) {
- missing_fetches.insert(TensorIdToString(fetch.id()));
- }
- for (const Node* n : graph->op_nodes()) {
- if (n->type_string() == kArgOp) {
- string feed_id;
- TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id));
- if (missing_feeds.erase(feed_id) == 0) {
- return errors::Aborted(kArgOp,
- " node found with unknown feed id: ", feed_id);
- }
- } else if (n->type_string() == kRetvalOp) {
- string fetch_id;
- TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id));
- if (missing_fetches.erase(fetch_id) == 0) {
- return errors::Aborted(kRetvalOp,
- " node found with unknown fetch id: ", fetch_id);
- }
- }
- }
- if (!missing_feeds.empty() || !missing_fetches.empty()) {
- return errors::Aborted(
- "Post graph-pruning",
- ", missing feeds: ", str_util::Join(missing_feeds, ", "),
- ", missing fetches: ", str_util::Join(missing_fetches, ", "));
- }
- return Status::OK();
-}
-
-// CollectArgNodes collects _Arg nodes from the graph, and performs basic
-// sanity-checking to ensure the index and type attributes of each node are
-// initialized correctly.
-Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) {
- std::map<int, Node*> indexed_arg_nodes;
- for (Node* n : graph.nodes()) {
- if (n->type_string() == kArgOp) {
- int index;
- TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
- auto insert_result = indexed_arg_nodes.insert({index, n});
- if (!insert_result.second) {
- const Node* dup = insert_result.first->second;
- return errors::InvalidArgument(
- "Multiple ", kArgOp, " nodes with index ", index, ", ",
- n->DebugString(), " and ", dup->DebugString());
- }
- }
- }
- arg_nodes->clear();
- for (const auto& index_node : indexed_arg_nodes) {
- if (index_node.first != arg_nodes->size()) {
- return errors::InvalidArgument("Expected ", kArgOp, " node with index ",
- arg_nodes->size(), ", but got index ",
- index_node.first);
- }
- arg_nodes->push_back(index_node.second);
- }
- return Status::OK();
-}
-
-// Fills in xla_args from the corresponding _Arg nodes in the graph.
-Status CreateXlaArgs(const Graph& graph,
- std::vector<XlaCompiler::Argument>* xla_args) {
- std::vector<Node*> arg_nodes;
- TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes));
- for (const Node* node : arg_nodes) {
- XlaCompiler::Argument arg;
- arg.kind = XlaCompiler::Argument::kParameter;
- TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
- TensorShape shape;
- TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape));
- TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape));
- TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
- xla_args->push_back(arg);
- }
- return Status::OK();
-}
-
-// Converts the TensorFlow graph into an XLA computation, by executing the
-// graph symbolically, with each op building up the XLA HLO.
-Status ConvertGraphToXla(xla::CompileOnlyClient* client,
- std::unique_ptr<Graph> graph,
- xla::Computation* computation, bool* has_context_arg) {
- // Create a device and context to convert the graph into an XLA computation.
- XlaOpRegistry::RegisterCompilationKernels();
- // Populate the context with args from the graph.
- for (Node* node : graph->nodes()) {
- node->set_assigned_device_name(DEVICE_CPU_XLA_JIT);
- }
- std::vector<XlaCompiler::Argument> xla_args;
- TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args));
-
- // Compile the graph into an XLA computation.
- XlaCompiler::Options compiler_options;
- compiler_options.client = client;
- DeviceType device_type(DEVICE_CPU_XLA_JIT);
- compiler_options.device_type = &device_type;
- compiler_options.flib_def = &graph->flib_def();
- compiler_options.graph_def_version = graph->versions().producer();
- compiler_options.allow_cpu_custom_calls = true;
- XlaCompiler compiler(compiler_options);
-
- XlaCompiler::CompilationResult result;
- TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(),
- "tfcompile", std::move(graph),
- xla_args, &result));
- *has_context_arg = result.requires_runtime_context;
- *computation = std::move(*result.computation);
-
- int num_const_results = 0;
- for (int i = 0; i < result.outputs.size(); ++i) {
- // Ending up with const results (i.e. output args) is an error, since it
- // means that one or more fetches that the user specified will be dropped
- // from the generated function. It's most likely a configuration error,
- // since the user shouldn't be asking for output args that end up as consts.
- //
- // TODO(toddw): Provide a way for the user to access const output args,
- // e.g. perhaps hard-coded into the header, or somehow copied into the
- // output buffers.
- if (result.outputs[i].is_constant) {
- ++num_const_results;
- LOG(ERROR) << "ConstRetVal index:" << i
- << " value:" << result.outputs[i].constant_value.DebugString();
- }
- }
- if (num_const_results > 0) {
- return errors::Unimplemented(
- "Conversion from TensorFlow graph to XLA resulted in ",
- num_const_results,
- " constant results. The configuration of "
- "the output args (i.e. fetch ids) is probably wrong.");
- }
- if (computation->IsNull()) {
- return errors::Aborted(
- "Conversion from TensorFlow graph to XLA resulted in an empty "
- "computation.");
- }
- return Status::OK();
-}
-
// Compiles the XLA computation into executable code.
Status CompileXla(xla::CompileOnlyClient* client,
const xla::Computation& computation,
@@ -376,41 +82,8 @@ Status CompileXla(xla::CompileOnlyClient* client,
} // namespace
-Status InitGraph(const GraphDef& graph_def, const Config& config,
- const MainFlags& flags, std::unique_ptr<Graph>* graph) {
- TF_RETURN_IF_ERROR(ValidateConfig(config));
-
- FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
- std::unique_ptr<Graph> g(new Graph(flib_def));
-
- // Replace references to fed tensors with references to newly added
- // placeholders.
- GraphDef first_copy_def = graph_def;
-
- // Maps from name:port of a feed to the name:port of the placeholder to use.
- std::unordered_map<string, string> feed_remapping;
- TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(),
- &feed_remapping, &first_copy_def));
-
- // Prune the GraphDef first so that unknown ops that we aren't compiling get
- // filtered out.
- GraphDef second_copy_def;
- TF_RETURN_IF_ERROR(
- PruneGraphDefInto(config, first_copy_def, &second_copy_def));
-
- TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
- &second_copy_def, *g->op_registry(), 0 /*node_offset*/));
-
- TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
- second_copy_def, g.get()));
- TF_RETURN_IF_ERROR(
- RewriteAndPruneGraph(g.get(), config, feed_remapping, flags));
- *graph = std::move(g);
- return Status::OK();
-}
-
-Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
- CompileResult* compile_result) {
+Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
+ const MainFlags& flags, CompileResult* compile_result) {
// Converts the graph into an XLA computation, and compiles the
// computation.
// TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client?
@@ -421,8 +94,9 @@ Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform)
.ValueOrDie();
xla::Computation computation;
- TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), &computation,
- &compile_result->has_context_arg));
+ TF_RETURN_IF_ERROR(ConvertGraphDefToXla(graph_def, config, client,
+ &computation,
+ &compile_result->has_context_arg));
if (!flags.debug_dir.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::SessionModule> module,
computation.Snapshot());
diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h
index e929272b2e..965c296081 100644
--- a/tensorflow/compiler/aot/compile.h
+++ b/tensorflow/compiler/aot/compile.h
@@ -18,46 +18,16 @@ limitations under the License.
#include <memory>
#include <string>
-#include <vector>
#include "tensorflow/compiler/aot/flags.h"
-#include "tensorflow/compiler/aot/tfcompile.pb.h"
-#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/graph/graph.h"
namespace tensorflow {
namespace tfcompile {
-// Constants for op types and attribute names.
-extern const char* const kArgOp;
-extern const char* const kRetvalOp;
-extern const char* const kFeedIdAttr;
-extern const char* const kFetchIdAttr;
-extern const char* const kShapeAttr;
-extern const char* const kDebugNameAttr;
-
-// InitGraph creates a graph based on the graph_def, that may then be compiled
-// by CompileGraph.
-//
-// The graph is rewritten with _Arg and _Retval nodes, representing the inputs
-// and outputs of the function that will be compiled. Each feed id causes a new
-// _Arg node to be created, where we first collect all existing edges pointing
-// from the named node's output index, and then rewrite them to point from that
-// _Arg node instead. Each fetch id causes a new _Retval node to be created,
-// with a new edge pointing from the named node's output index to that _Retval
-// node. All _Retval nodes also point to a special CompileExpressions node,
-// used internally to finish the compilation.
-//
-// The rewritten graph is then pruned to only contain the portion necessary to
-// compute the outputs. If dump_graphs is true, graph rewrites will be dumped
-// for debugging.
-Status InitGraph(const GraphDef& graph_def, const Config& config,
- const MainFlags& flags, std::unique_ptr<Graph>* graph);
-
// CompileResult describes the output of CompileGraph, where the object file
// data and meta-information is available in aot.
struct CompileResult {
@@ -69,20 +39,12 @@ struct CompileResult {
int pointer_size = 0; // Size of a pointer in bytes.
};
-// CompileGraph compiles the graph into an object file containing a function
+// CompileGraph compiles the graph_def into an object file containing a function
// that performs the graph operations.
//
-// The graph must have _Arg and _Retval nodes representing the function inputs
-// and outputs. Every _Arg node must have a shape attribute (key=kShapeAttr,
-// value=TensorShape) representing the static shape of that input, and every
-// _Retval node must point to a CompileExpressions node.
-//
-// Typically InitGraph is called to perform this initialization, followed by
-// full specification of the shape attributes.
-//
// The XLA compilation options are specified in the flags.
-Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
- CompileResult* result);
+Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
+ const MainFlags& flags, CompileResult* compile_result);
} // namespace tfcompile
} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt
index f2d9c34b2d..a4ad334352 100644
--- a/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt
+++ b/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_const" }
shape {
diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt
index 5625c0ab03..d3f0e4990c 100644
--- a/tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt
+++ b/tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_const" }
shape {
diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt
index 7370ed370d..e0b012adea 100644
--- a/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt
+++ b/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_const" }
shape {
diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt
index b2d7d54574..662ba1c321 100644
--- a/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt
+++ b/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_const" }
shape {
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index 05d338e4c5..4d65a044bc 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -13,9 +13,11 @@ test_suite(
":test_graph_tfadd_test",
":test_graph_tfadd_with_ckpt_saver_test",
":test_graph_tfadd_with_ckpt_test",
+ ":test_graph_tffunction_test",
":test_graph_tfgather_test",
":test_graph_tfmatmul_test",
":test_graph_tfmatmulandadd_test",
+ ":test_graph_tfsplits_test",
":tfcompile_test",
],
)
@@ -91,6 +93,15 @@ tf_library(
)
tf_library(
+ name = "test_graph_tffunction",
+ testonly = 1,
+ config = "test_graph_tffunction.config.pbtxt",
+ cpp_class = "FunctionComp",
+ graph = "test_graph_tffunction.pb",
+ tags = ["manual"],
+)
+
+tf_library(
name = "test_graph_tfgather",
testonly = 1,
config = "test_graph_tfgather.config.pbtxt",
@@ -118,15 +129,6 @@ tf_library(
)
tf_library(
- name = "test_graph_tffunction",
- testonly = 1,
- config = "test_graph_tffunction.config.pbtxt",
- cpp_class = "FunctionComp",
- graph = "test_graph_tffunction.pb",
- tags = ["manual"],
-)
-
-tf_library(
name = "test_graph_tfsplits",
testonly = 1,
config = "test_graph_tfsplits.config.pbtxt",
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt
index 5625c0ab03..d3f0e4990c 100644
--- a/tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt
+++ b/tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_const" }
shape {
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt
index 4d876a6e91..8adc9cdc14 100644
--- a/tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt
+++ b/tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_hold" }
shape {
diff --git a/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt
index eb9c1cacb7..cbfe458908 100644
--- a/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt
+++ b/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_const" }
shape {
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt
index 648ee31fdb..89ed678a9c 100644
--- a/tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt
+++ b/tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "params" }
shape {
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt
index a3ce2029c1..2acd0289c2 100644
--- a/tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt
+++ b/tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_hold" }
shape {
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt
index 4a4a237a4f..e5ca6115e9 100644
--- a/tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt
+++ b/tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x_hold" }
shape {
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfsplits.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfsplits.config.pbtxt
index 85fc7da442..5adc77336c 100644
--- a/tensorflow/compiler/aot/tests/test_graph_tfsplits.config.pbtxt
+++ b/tensorflow/compiler/aot/tests/test_graph_tfsplits.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed {
id { node_name: "x" }
shape {
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index f9896988dc..fc1342d84e 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -41,7 +41,7 @@ def tf_library(name, graph, config,
graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it
is expected to be in the human-readable proto text format, otherwise it is
expected to be in the proto binary format.
- config: File containing tensorflow.tfcompile.Config proto. If the file ends
+ config: File containing tensorflow.tf2xla.Config proto. If the file ends
in '.pbtxt' it is expected to be in the human-readable proto text format,
otherwise it is expected to be in the proto binary format.
freeze_checkpoint: If provided, run freeze_graph with this checkpoint to
diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc
index be2cfe4734..cc499c3284 100644
--- a/tensorflow/compiler/aot/tfcompile_main.cc
+++ b/tensorflow/compiler/aot/tfcompile_main.cc
@@ -21,8 +21,8 @@ limitations under the License.
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/compile.h"
#include "tensorflow/compiler/aot/flags.h"
-#include "tensorflow/compiler/aot/tfcompile.pb.h"
-#include "tensorflow/compiler/aot/tfcompile_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/core/framework/function.h"
@@ -54,8 +54,7 @@ const char kUsageHeader[] =
"--cpp_class=\"mynamespace::MyComputation\"\n"
"\n";
-Status ReadProtoFile(const string& kind, const string& fname,
- protobuf::Message* proto) {
+Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
if (StringPiece(fname).ends_with(".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
@@ -63,23 +62,17 @@ Status ReadProtoFile(const string& kind, const string& fname,
}
}
-void ParseTensorId(const string& name, TensorId* id) {
- const std::pair<StringPiece, int> name_index = ParseTensorName(name);
- id->set_node_name(name_index.first.ToString());
- id->set_output_index(name_index.second);
-}
-
Status Main(const MainFlags& flags) {
// Process config.
- Config config;
+ tf2xla::Config config;
if (flags.config.empty()) {
return errors::InvalidArgument("Must specify --config");
}
- TF_RETURN_IF_ERROR(ReadProtoFile("config", flags.config, &config));
+ TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
TF_RETURN_IF_ERROR(ValidateConfig(config));
if (flags.dump_fetch_nodes) {
std::set<string> nodes;
- for (const Fetch& fetch : config.fetch()) {
+ for (const tf2xla::Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name());
}
std::cout << str_util::Join(nodes, ",");
@@ -91,12 +84,9 @@ Status Main(const MainFlags& flags) {
return errors::InvalidArgument("Must specify --graph");
}
GraphDef graph_def;
- TF_RETURN_IF_ERROR(ReadProtoFile("graph", flags.graph, &graph_def));
- std::unique_ptr<Graph> graph;
- TF_RETURN_IF_ERROR(InitGraph(graph_def, config, flags, &graph));
-
+ TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
CompileResult compile_result;
- TF_RETURN_IF_ERROR(CompileGraph(std::move(graph), flags, &compile_result));
+ TF_RETURN_IF_ERROR(CompileGraph(graph_def, config, flags, &compile_result));
// Write output files.
Env* env = Env::Default();
diff --git a/tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt b/tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt
index c46e65f71a..3025fc27b1 100644
--- a/tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt
+++ b/tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt
@@ -1,4 +1,4 @@
-# Text form of tensorflow.tfcompile.Config proto.
+# Text form of tensorflow.tf2xla.Config proto.
feed{ id{node_name:"inputs/x_seq_0/read"} shape{dim{size:128}dim{size:1024}} }
feed{ id{node_name:"inputs/x_seq_1/read"} shape{dim{size:128}dim{size:1024}} }
feed{ id{node_name:"inputs/x_seq_2/read"} shape{dim{size:128}dim{size:1024}} }
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 13fc233054..22f2441a68 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -21,6 +21,40 @@ package(
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
+load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
+
+xla_proto_library(
+ name = "tf2xla_proto",
+ srcs = ["tf2xla.proto"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
+ name = "tf2xla",
+ srcs = ["tf2xla.cc"],
+ hdrs = ["tf2xla.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":common",
+ ":dump_graph",
+ ":tf2xla_proto",
+ ":tf2xla_util",
+ ":xla_compiler",
+ "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
+ "//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/compiler/xla/client",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
cc_library(
name = "xla_compiler",
@@ -96,6 +130,51 @@ cc_library(
# Internal targets below this point.
+cc_library(
+ name = "tf2xla_util",
+ srcs = ["tf2xla_util.cc"],
+ hdrs = ["tf2xla_util.h"],
+ deps = [
+ ":tf2xla_proto",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_test(
+ name = "tf2xla_util_test",
+ srcs = ["tf2xla_util_test.cc"],
+ deps = [
+ ":tf2xla_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_test(
+ name = "tf2xla_test",
+ srcs = ["tf2xla_test.cc"],
+ deps = [
+ ":tf2xla",
+ ":tf2xla_proto",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
cc_test(
name = "xla_compiler_test",
srcs = ["xla_compiler_test.cc"],
diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc
new file mode 100644
index 0000000000..b29c92190d
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/tf2xla.cc
@@ -0,0 +1,370 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/tf2xla.h"
+
+#include <map>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+const char* const kArgOp = "_Arg";
+const char* const kRetvalOp = "_Retval";
+const char* const kFeedIdAttr = "_feed_id";
+const char* const kFetchIdAttr = "_fetch_id";
+const char* const kShapeAttr = "_shape";
+const char* const kDebugNameAttr = "_debug_name";
+
+namespace {
+
+typedef std::unordered_map<string, Node*> NodeMap;
+
+// Each feed id identifies the positional output of some node, which may consist
+// of multiple edges. AddPlaceholdersForFeeds has already replaced each fed
+// tensor with a placeholder. For each feed tensor, replaces all edges so they
+// point from a new _Arg node instead.
+Status AddArgNodes(Graph* graph, const NodeMap& node_map,
+ const protobuf::RepeatedPtrField<tf2xla::Feed>& feeds,
+ const std::unordered_map<string, string>& feed_remapping) {
+ for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
+ const tf2xla::Feed& feed = feeds[arg_index];
+ // All feeds have been replaced by placeholders.
+ const int output_index = 0;
+
+ const string key = TensorIdToString(feed.id());
+ const auto remap_it = feed_remapping.find(key);
+ auto node_it = node_map.find(remap_it->second);
+ if (node_it == node_map.end()) {
+ // Strip off the aot_feed_#/ prefix.
+ StringPiece name(remap_it->second);
+ const auto index = name.find('/');
+ if (index > 0) name.remove_prefix(index + 1);
+ return errors::InvalidArgument(
+ "Node is fed but not needed for fetching: ", name);
+ }
+ const Node* feed_node = node_it->second;
+
+ // TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a
+ // "_shape" attr if we can determine it. That way the graph will be
+ // initialized with whatever shapes we can infer, while the user can still
+ // explicitly specify or override them.
+ Node* arg_node = nullptr;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
+ .Attr("T", BaseType(feed_node->output_type(output_index)))
+ .Attr("index", arg_index)
+ .Attr(kFeedIdAttr, TensorIdToString(feed.id()))
+ .Attr(kShapeAttr, TensorShape(feed.shape()))
+ .Attr(kDebugNameAttr, feed.name())
+ .Finalize(graph, &arg_node));
+
+ // Collects out-edges from the feed node that have a matching edge index;
+ // these will be replaced with edges from the arg node instead.
+ //
+ // We must collect the edges first and process them in a second pass, since
+ // removing the edge from the graph invalidates feed_node->out_edges.
+ std::vector<const Edge*> feed_edges;
+ for (const Edge* edge : feed_node->out_edges()) {
+ if (edge->src_output() == output_index) {
+ feed_edges.push_back(edge);
+ }
+ }
+ for (const Edge* edge : feed_edges) {
+ graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
+ graph->RemoveEdge(edge);
+ }
+ }
+ return Status::OK();
+}
+
+// Each fetch id identifies the positional output of some node. For each fetch
+// node, adds a new _Retval node instead, and adds the node to `retval_nodes`.
+Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
+ const protobuf::RepeatedPtrField<tf2xla::Fetch>& fetches,
+ std::unordered_set<const Node*>* retval_nodes) {
+ for (int ret_index = 0; ret_index < fetches.size(); ++ret_index) {
+ const tf2xla::TensorId& id = fetches[ret_index].id();
+ auto it = node_map.find(id.node_name());
+ if (it == node_map.end()) {
+ return errors::NotFound("Can't find fetch id: ", TensorIdToString(id));
+ }
+ Node* fetch_node = it->second;
+ if (id.output_index() >= fetch_node->num_outputs()) {
+ return errors::InvalidArgument("Invalid fetch id: ", TensorIdToString(id),
+ ", output index should be < ",
+ fetch_node->num_outputs());
+ }
+ // Connects fetch_node -> retval_node.
+ Node* retval_node = nullptr;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp)
+ .Input(fetch_node, id.output_index())
+ .Attr("T", BaseType(fetch_node->output_type(id.output_index())))
+ .Attr("index", ret_index)
+ .Attr(kFetchIdAttr, TensorIdToString(id))
+ .Finalize(graph, &retval_node));
+ retval_nodes->insert(retval_node);
+ }
+ return Status::OK();
+}
+
+// RewriteAndPruneGraph identifies input and output edges (named by the feed and
+// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
+// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph
+// execution to know the input and output args for the generated function.
+Status RewriteAndPruneGraph(
+ Graph* graph, const tf2xla::Config& config,
+ const std::unordered_map<string, string>& feed_remapping) {
+ NodeMap node_map;
+ for (Node* n : graph->nodes()) {
+ node_map[n->name()] = n;
+ }
+ TF_RETURN_IF_ERROR(
+ AddArgNodes(graph, node_map, config.feed(), feed_remapping));
+ std::unordered_set<const Node*> retval_nodes;
+ TF_RETURN_IF_ERROR(
+ AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
+ VLOG(2) << "Post rewrite: "
+ << dump_graph::DumpGraphToFile("tf2xla_post_rewrite", *graph);
+ PruneForReverseReachability(graph, retval_nodes);
+ FixupSourceAndSinkEdges(graph);
+ VLOG(2) << "Post prune: "
+ << dump_graph::DumpGraphToFile("tfcompile_post_prune", *graph);
+ // Sanity-check, to make sure the feeds and fetches still exist post-pruning.
+ std::set<string> missing_feeds, missing_fetches;
+ for (const tf2xla::Feed& feed : config.feed()) {
+ missing_feeds.insert(TensorIdToString(feed.id()));
+ }
+ for (const tf2xla::Fetch& fetch : config.fetch()) {
+ missing_fetches.insert(TensorIdToString(fetch.id()));
+ }
+ for (const Node* n : graph->op_nodes()) {
+ if (n->type_string() == kArgOp) {
+ string feed_id;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id));
+ if (missing_feeds.erase(feed_id) == 0) {
+ return errors::Aborted(kArgOp,
+ " node found with unknown feed id: ", feed_id);
+ }
+ } else if (n->type_string() == kRetvalOp) {
+ string fetch_id;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id));
+ if (missing_fetches.erase(fetch_id) == 0) {
+ return errors::Aborted(kRetvalOp,
+ " node found with unknown fetch id: ", fetch_id);
+ }
+ }
+ }
+ if (!missing_feeds.empty() || !missing_fetches.empty()) {
+ return errors::Aborted(
+ "Post graph-pruning",
+ ", missing feeds: ", str_util::Join(missing_feeds, ", "),
+ ", missing fetches: ", str_util::Join(missing_fetches, ", "));
+ }
+ return Status::OK();
+}
+
+// CollectArgNodes collects _Arg nodes from the graph, and performs basic
+// sanity-checking to ensure the index and type attributes of each node are
+// initialized correctly.
+Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) {
+ std::map<int, Node*> indexed_arg_nodes;
+ for (Node* n : graph.nodes()) {
+ if (n->type_string() == kArgOp) {
+ int index;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
+ auto insert_result = indexed_arg_nodes.insert({index, n});
+ if (!insert_result.second) {
+ const Node* dup = insert_result.first->second;
+ return errors::InvalidArgument(
+ "Multiple ", kArgOp, " nodes with index ", index, ", ",
+ n->DebugString(), " and ", dup->DebugString());
+ }
+ }
+ }
+ arg_nodes->clear();
+ for (const auto& index_node : indexed_arg_nodes) {
+ if (index_node.first != arg_nodes->size()) {
+ return errors::InvalidArgument("Expected ", kArgOp, " node with index ",
+ arg_nodes->size(), ", but got index ",
+ index_node.first);
+ }
+ arg_nodes->push_back(index_node.second);
+ }
+ return Status::OK();
+}
+
+// Fills in xla_args from the corresponding _Arg nodes in the graph.
+Status CreateXlaArgs(const Graph& graph,
+ std::vector<XlaCompiler::Argument>* xla_args) {
+ std::vector<Node*> arg_nodes;
+ TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes));
+ for (const Node* node : arg_nodes) {
+ XlaCompiler::Argument arg;
+ arg.kind = XlaCompiler::Argument::kParameter;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
+ TensorShape shape;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape));
+ TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape));
+ TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
+ xla_args->push_back(arg);
+ }
+ return Status::OK();
+}
+
+// Converts the TensorFlow graph into an XLA computation, by executing the
+// graph symbolically, with each op building up the XLA HLO.
+Status ConvertGraphToXla(std::unique_ptr<Graph> graph, xla::Client* client,
+ xla::Computation* computation,
+ bool* requires_runtime_context) {
+ // Create a device and context to convert the graph into an XLA computation.
+ XlaOpRegistry::RegisterCompilationKernels();
+ // Populate the context with args from the graph.
+ for (Node* node : graph->nodes()) {
+ node->set_assigned_device_name(DEVICE_CPU_XLA_JIT);
+ }
+ std::vector<XlaCompiler::Argument> xla_args;
+ TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args));
+
+ // Compile the graph into an XLA computation.
+ XlaCompiler::Options compiler_options;
+ compiler_options.client = client;
+ DeviceType device_type(DEVICE_CPU_XLA_JIT);
+ compiler_options.device_type = &device_type;
+ compiler_options.flib_def = &graph->flib_def();
+ compiler_options.graph_def_version = graph->versions().producer();
+ compiler_options.allow_cpu_custom_calls = true;
+ XlaCompiler compiler(compiler_options);
+
+ XlaCompiler::CompilationResult result;
+ TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(),
+ "tfcompile", std::move(graph),
+ xla_args, &result));
+ *requires_runtime_context = result.requires_runtime_context;
+ *computation = std::move(*result.computation);
+
+ int num_const_results = 0;
+ for (int i = 0; i < result.outputs.size(); ++i) {
+ // Ending up with const results (i.e. output args) is an error, since it
+ // means that one or more fetches that the user specified will be dropped
+ // from the generated function. It's most likely a configuration error,
+ // since the user shouldn't be asking for output args that end up as consts.
+ //
+ // TODO(toddw): Provide a way for the user to access const output args,
+ // e.g. perhaps hard-coded into the header, or somehow copied into the
+ // output buffers.
+ if (result.outputs[i].is_constant) {
+ ++num_const_results;
+ LOG(ERROR) << "ConstRetVal index:" << i
+ << " value:" << result.outputs[i].constant_value.DebugString();
+ }
+ }
+ if (num_const_results > 0) {
+ return errors::Unimplemented(
+ "Conversion from TensorFlow graph to XLA resulted in ",
+ num_const_results,
+ " constant results. The configuration of "
+ "the output args (i.e. fetch ids) is probably wrong.");
+ }
+ if (computation->IsNull()) {
+ return errors::Aborted(
+ "Conversion from TensorFlow graph to XLA resulted in an empty "
+ "computation.");
+ }
+ return Status::OK();
+}
+
+// InitGraph creates a graph based on the graph_def, that may then be converted
+// to an xla::Computation via ConvertGraphToXla.
+//
+// The graph is rewritten with _Arg and _Retval nodes, representing the inputs
+// and outputs of the function that will be compiled. Each feed id causes a new
+// _Arg node to be created, where we first collect all existing edges pointing
+// from the named node's output index, and then rewrite them to point from that
+// _Arg node instead. Each fetch id causes a new _Retval node to be created,
+// with a new edge pointing from the named node's output index to that _Retval
+// node.
+Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
+ std::unique_ptr<Graph>* graph) {
+ TF_RETURN_IF_ERROR(ValidateConfig(config));
+
+ FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library());
+ std::unique_ptr<Graph> g(new Graph(flib_def));
+
+ // Replace references to fed tensors with references to newly added
+ // placeholders.
+ GraphDef first_copy_def = graph_def;
+
+ // Maps from name:port of a feed to the name:port of the placeholder to use.
+ std::unordered_map<string, string> feed_remapping;
+ TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(),
+ &feed_remapping, &first_copy_def));
+
+ // Prune the GraphDef first so that unknown ops that we aren't compiling get
+ // filtered out.
+ GraphDef second_copy_def;
+ TF_RETURN_IF_ERROR(
+ PruneGraphDefInto(config, first_copy_def, &second_copy_def));
+
+ TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(
+ &second_copy_def, *g->op_registry(), /*node_offset=*/0));
+
+ TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
+ second_copy_def, g.get()));
+ TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping));
+ *graph = std::move(g);
+ return Status::OK();
+}
+
+} // namespace
+
+Status ConvertGraphDefToXla(const GraphDef& graph_def,
+ const tf2xla::Config& config, xla::Client* client,
+ xla::Computation* computation,
+ bool* requires_runtime_context) {
+ std::unique_ptr<Graph> graph;
+ TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph));
+ TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation,
+ requires_runtime_context));
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/tf2xla.h b/tensorflow/compiler/tf2xla/tf2xla.h
new file mode 100644
index 0000000000..ab99beebf7
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/tf2xla.h
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_
+#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_
+
+#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
+#include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/core/framework/graph.pb.h"
+
+namespace tensorflow {
+
+// Converts a tensorflow::GraphDef into an xla::Computation. The given `config`
+// specifies the portion of the graph to convert, via feeds and fetches. Each
+// feed is a positional input argument for the generated computation, while each
+// fetch is a positional output argument.
+//
+// The computation is built in the context of the given `client`, which may
+// subsequently be used to compile or execute the computation.
+//
+// If `requires_runtime_context` is filled with true, this indicates the last
+// argument of the computation is XlaLocalRuntimeContext*.
+Status ConvertGraphDefToXla(const GraphDef& graph_def,
+ const tf2xla::Config& config, xla::Client* client,
+ xla::Computation* computation,
+ bool* requires_runtime_context);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_
diff --git a/tensorflow/compiler/aot/tfcompile.proto b/tensorflow/compiler/tf2xla/tf2xla.proto
index cd83840d89..18c9089f5f 100644
--- a/tensorflow/compiler/aot/tfcompile.proto
+++ b/tensorflow/compiler/tf2xla/tf2xla.proto
@@ -1,10 +1,10 @@
syntax = "proto3";
-package tensorflow.tfcompile;
+package tensorflow.tf2xla;
option cc_enable_arenas = true;
-option java_outer_classname = "CompileProtos";
+option java_outer_classname = "Tf2XlaProtos";
option java_multiple_files = true;
-option java_package = "org.tensorflow.tfcompile";
+option java_package = "org.tensorflow.tf2xla";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
@@ -19,32 +19,32 @@ message TensorId {
};
// Feed represents a single feed tensor in the graph, which corresponds to an
-// input argument for the generated function.
+// input argument for the generated computation.
message Feed {
TensorId id = 1;
TensorShapeProto shape = 2;
string name = 3; // Optional name for generated code.
// Optional data type. This is not normally required, as the graph itself
- // contains this information. However, if the node being fed is an op that
- // is not linked into the tfcompile binary, then the type cannot be inferred
- // from the node; in this case, the type should be set here.
+ // contains this information. However, if the node being fed is an op that is
+ // not linked into the binary, then the type cannot be inferred from the node;
+ // in this case, the type should be set here.
DataType type = 4;
};
// Fetch represents a single fetch tensor in the graph, which corresponds to an
-// output argument for the generated function.
+// output argument for the generated computation.
message Fetch {
TensorId id = 1;
string name = 2; // Optional name for generated code.
};
-// Config represents configuration information for tfcompile.
+// Config represents configuration information for tf2xla conversion.
message Config {
- // Each feed is a positional input argument for the generated function. The
- // order of each entry matches the order of each input argument.
+ // Each feed is a positional input argument for the generated computation.
+ // The order of each entry matches the order of each input argument.
repeated Feed feed = 1;
- // Each fetch is a positional output argument for the generated function. The
- // order of each entry matches the order of each output argument.
+ // Each fetch is a positional output argument for the generated computation.
+ // The order of each entry matches the order of each output argument.
repeated Fetch fetch = 2;
};
diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc
new file mode 100644
index 0000000000..57b53cc660
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc
@@ -0,0 +1,99 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/tf2xla.h"
+
+#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+AttrValue TypeAttrValue(DataType type) {
+ AttrValue attr_value;
+ SetAttrValue(type, &attr_value);
+ return attr_value;
+}
+
+GraphDef SumGraph() {
+ GraphDef graph_def;
+ NodeDef* x = graph_def.add_node();
+ x->set_name("x");
+ x->set_op("Placeholder");
+ (*x->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32);
+ NodeDef* y = graph_def.add_node();
+ y->set_name("y");
+ y->set_op("Placeholder");
+ (*y->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32);
+ NodeDef* sum = graph_def.add_node();
+ sum->set_name("sum");
+ sum->set_op("Add");
+ sum->add_input("x");
+ sum->add_input("y");
+ (*sum->mutable_attr())["T"] = TypeAttrValue(DT_INT32);
+ return graph_def;
+}
+
+tf2xla::Config SumConfig() {
+ tf2xla::Config config;
+ config.add_feed()->mutable_id()->set_node_name("x");
+ config.add_feed()->mutable_id()->set_node_name("y");
+ config.add_fetch()->mutable_id()->set_node_name("sum");
+ return config;
+}
+
+TEST(ConvertGraphDefToXla, Sum) {
+ GraphDef graph_def = SumGraph();
+ tf2xla::Config config = SumConfig();
+
+ xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie();
+ xla::Computation computation;
+ bool requires_runtime_context;
+ TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation,
+ &requires_runtime_context));
+ ASSERT_FALSE(requires_runtime_context);
+
+ // Set up arguments.
+ auto x_literal = xla::Literal::CreateR0<int32>(10);
+ auto y_literal = xla::Literal::CreateR0<int32>(32);
+ auto x_global_or = client->TransferToServer(*x_literal);
+ auto y_global_or = client->TransferToServer(*y_literal);
+ TF_EXPECT_OK(x_global_or.status());
+ TF_EXPECT_OK(y_global_or.status());
+ std::unique_ptr<xla::GlobalData> x_global =
+ std::move(x_global_or.ValueOrDie());
+ std::unique_ptr<xla::GlobalData> y_global =
+ std::move(y_global_or.ValueOrDie());
+
+ // Execute and check result.
+ auto result_or =
+ client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()});
+ TF_EXPECT_OK(result_or.status());
+ std::unique_ptr<xla::Literal> result = std::move(result_or.ValueOrDie());
+ EXPECT_EQ("42", result->ToString());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/tfcompile_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc
index 629187d621..14e0910cab 100644
--- a/tensorflow/compiler/aot/tfcompile_util.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc
@@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/aot/tfcompile_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include <queue>
#include <set>
#include <unordered_map>
-#include "tensorflow/compiler/aot/tfcompile.pb.h"
+#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
@@ -29,21 +29,13 @@ limitations under the License.
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
-namespace tfcompile {
namespace {
-bool IsAlpha(char c) {
- return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
-}
-
-bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
-
-Status ValidateTensorId(const TensorId& id) {
+Status ValidateTensorId(const tf2xla::TensorId& id) {
if (id.node_name().empty()) {
return errors::InvalidArgument("TensorId node_name must be non-empty");
}
@@ -53,10 +45,9 @@ Status ValidateTensorId(const TensorId& id) {
return Status::OK();
}
-Status ValidateFeedFetchName(const string& kind, const string& name,
- std::set<string>* names) {
+Status CheckNameDuplicates(const string& kind, const string& name,
+ std::set<string>* names) {
if (!name.empty()) {
- TF_RETURN_IF_ERROR(ValidateCppIdent(name, kind + " name"));
if (!names->insert(name).second) {
return errors::InvalidArgument("duplicate ", kind, " name: ", name);
}
@@ -80,42 +71,18 @@ Status CheckFeedFetchNameConflicts(const string& kind,
} // namespace
-Status ValidateCppIdent(StringPiece ident, StringPiece msg) {
- if (ident.empty()) {
- return errors::InvalidArgument("empty identifier: ", msg);
- }
- // Require that the identifier starts with a nondigit, and is composed of
- // nondigits and digits, as specified in section [2.11 Identifiers] of the
- // C++11 Standard. Note that nondigit is defined as [_a-zA-Z] and digit is
- // defined as [0-9].
- //
- // Technically the standard also allows for `universal-character-name`, with a
- // table of allowed unicode ranges, as well as `other implementation-defined
- // characters`. We disallow those here to give better error messages, at the
- // expensive of being more restrictive than the standard.
- if (ident[0] != '_' && !IsAlpha(ident[0])) {
- return errors::InvalidArgument("illegal leading char: ", msg);
- }
- for (size_t pos = 1; pos < ident.size(); ++pos) {
- if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) {
- return errors::InvalidArgument("illegal char: ", msg);
- }
- }
- return Status::OK();
-}
-
-Status ValidateConfig(const Config& config) {
+Status ValidateConfig(const tf2xla::Config& config) {
std::set<string> names;
- for (const Feed& feed : config.feed()) {
+ for (const tf2xla::Feed& feed : config.feed()) {
TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
- TF_RETURN_IF_ERROR(ValidateFeedFetchName("feed", feed.name(), &names));
+ TF_RETURN_IF_ERROR(CheckNameDuplicates("feed", feed.name(), &names));
}
TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
names.clear();
- for (const Fetch& fetch : config.fetch()) {
+ for (const tf2xla::Fetch& fetch : config.fetch()) {
TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
- TF_RETURN_IF_ERROR(ValidateFeedFetchName("fetch", fetch.name(), &names));
+ TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names));
}
TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
if (config.feed().empty() || config.fetch().empty()) {
@@ -125,10 +92,10 @@ Status ValidateConfig(const Config& config) {
}
Status AddPlaceholdersForFeeds(
- const Config& config, const OpRegistryInterface* op_registry,
+ const tf2xla::Config& config, const OpRegistryInterface* op_registry,
std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
struct PlaceholderInfo {
- const Feed* feed = nullptr; // point to Feed in <config>.
+ const tf2xla::Feed* feed = nullptr; // point to Feed in <config>.
string placeholder_name;
DataType data_type = DT_INVALID;
};
@@ -137,9 +104,9 @@ Status AddPlaceholdersForFeeds(
// when creating placeholders (genrules want deterministic output).
std::map<string, PlaceholderInfo> placeholder_info;
for (int i = 0; i < config.feed_size(); ++i) {
- const Feed* feed = &config.feed(i);
+ const tf2xla::Feed* feed = &config.feed(i);
const string name_port = TensorIdToString(feed->id());
- auto& info = placeholder_info[name_port];
+ PlaceholderInfo& info = placeholder_info[name_port];
info.feed = feed;
info.placeholder_name = strings::StrCat(
"aot_feed_", feed->id().output_index(), "/", feed->id().node_name());
@@ -153,7 +120,7 @@ Status AddPlaceholdersForFeeds(
}
for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
PlaceholderInfo& info = it->second;
- const TensorId& feed_id = info.feed->id();
+ const tf2xla::TensorId& feed_id = info.feed->id();
// Find the existing node and determine data type.
auto node_it = name_to_node.find(feed_id.node_name());
@@ -214,16 +181,16 @@ Status AddPlaceholdersForFeeds(
return Status::OK();
}
-Status PruneGraphDefInto(const Config& config, const GraphDef& in,
+Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
GraphDef* out) {
*out = in;
out->clear_node();
// Tensors needed for feeding.
std::set<std::pair<string, int>> feed_tensors;
- for (const auto& feed_config : config.feed()) {
- feed_tensors.insert(std::make_pair(feed_config.id().node_name(),
- feed_config.id().output_index()));
+ for (const tf2xla::Feed& feed : config.feed()) {
+ feed_tensors.insert(
+ std::make_pair(feed.id().node_name(), feed.id().output_index()));
}
// Maps node name to reachability.
@@ -279,9 +246,8 @@ Status PruneGraphDefInto(const Config& config, const GraphDef& in,
return Status::OK();
}
-string TensorIdToString(const TensorId& id) {
+string TensorIdToString(const tf2xla::TensorId& id) {
return strings::StrCat(id.node_name(), ":", id.output_index());
}
-} // namespace tfcompile
} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/tfcompile_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h
index 365f7b0e7b..a29d0c16f9 100644
--- a/tensorflow/compiler/aot/tfcompile_util.h
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.h
@@ -13,26 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
-#define TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
+#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
+#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
#include <unordered_map>
-#include "tensorflow/compiler/aot/tfcompile.pb.h"
+#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
namespace tensorflow {
-namespace tfcompile {
-
-// ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is
-// appended to error messages.
-Status ValidateCppIdent(StringPiece ident, StringPiece msg);
// ValidateConfig returns OK iff config is valid.
-Status ValidateConfig(const Config& config);
+Status ValidateConfig(const tf2xla::Config& config);
// Modifies <graph_def> to include placeholders for each fed tensor, and
// update references to the fed tensors to refer to the placeholders.
@@ -40,18 +34,17 @@ Status ValidateConfig(const Config& config);
// (except where their input edges are modified by the replacement of other
// feeds).
Status AddPlaceholdersForFeeds(
- const Config& config, const OpRegistryInterface* op_registry,
+ const tf2xla::Config& config, const OpRegistryInterface* op_registry,
std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def);
// Returns in <out> a copy of <in>, pruned to only include fetches from
// <config>.
-Status PruneGraphDefInto(const Config& config, const GraphDef& in,
+Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
GraphDef* out);
// Returns node:port for the given <id>.
-string TensorIdToString(const TensorId& id);
+string TensorIdToString(const tf2xla::TensorId& id);
-} // namespace tfcompile
} // namespace tensorflow
-#endif // TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
+#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
diff --git a/tensorflow/compiler/aot/tfcompile_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc
index 5a92851ceb..b98c89f284 100644
--- a/tensorflow/compiler/aot/tfcompile_util_test.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/aot/tfcompile_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
@@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
-namespace tfcompile {
namespace {
void ExpectErrorContains(const Status& status, StringPiece str) {
@@ -32,45 +31,16 @@ void ExpectErrorContains(const Status& status, StringPiece str) {
<< "expected error: " << status.error_message() << " to contain: " << str;
}
-TEST(ValidateCppIdent, Simple) {
- TF_EXPECT_OK(ValidateCppIdent("a", ""));
- TF_EXPECT_OK(ValidateCppIdent("abc", ""));
- TF_EXPECT_OK(ValidateCppIdent("_abc", ""));
- TF_EXPECT_OK(ValidateCppIdent("_abc123", ""));
- // Make sure we didn't skip a valid letter or digit
- string ident;
- for (char c = 'a'; c <= 'z'; c++) {
- ident.append(1, c);
- }
- for (char c = 'A'; c <= 'Z'; c++) {
- ident.append(1, c);
- }
- for (char c = '0'; c <= '9'; c++) {
- ident.append(1, c);
- }
- ident += "_";
- TF_EXPECT_OK(ValidateCppIdent(ident, ""));
-
- ExpectErrorContains(ValidateCppIdent("", ""), "empty identifier");
- ExpectErrorContains(ValidateCppIdent(" ", ""), "illegal leading char");
- ExpectErrorContains(ValidateCppIdent("0", ""), "illegal leading char");
- ExpectErrorContains(ValidateCppIdent(".", ""), "illegal leading char");
- ExpectErrorContains(ValidateCppIdent(":", ""), "illegal leading char");
- ExpectErrorContains(ValidateCppIdent("a.", ""), "illegal char");
- ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
- ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
-}
-
TEST(ValidateConfig, Good) {
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
feed->mutable_id()->set_output_index(123);
feed->set_name("foo_debug");
feed = config.add_feed();
feed->mutable_id()->set_node_name("bar");
feed->mutable_id()->set_output_index(0);
- Fetch* fetch = config.add_fetch();
+ tf2xla::Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("baz");
fetch->mutable_id()->set_output_index(456);
fetch->set_name("baz_debug");
@@ -81,62 +51,62 @@ TEST(ValidateConfig, Good) {
}
TEST(ValidateConfig, BadEmpty) {
- Config config;
+ tf2xla::Config config;
ExpectErrorContains(ValidateConfig(config),
"feeds and fetches must be specified");
}
TEST(ValidateConfig, BadNoFeed) {
- Config config;
- Fetch* fetch = config.add_fetch();
+ tf2xla::Config config;
+ tf2xla::Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("foo");
ExpectErrorContains(ValidateConfig(config),
"feeds and fetches must be specified");
}
TEST(ValidateConfig, BadNoFetch) {
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
ExpectErrorContains(ValidateConfig(config),
"feeds and fetches must be specified");
}
TEST(ValidateConfig, BadFeedNodeName) {
- Config config;
+ tf2xla::Config config;
config.add_feed();
ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
}
TEST(ValidateConfig, BadFeedOutputIndex) {
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
feed->mutable_id()->set_output_index(-1);
ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
}
TEST(ValidateConfig, BadFetchNodeName) {
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
config.add_fetch();
ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
}
TEST(ValidateConfig, BadFetchOutputIndex) {
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
- Fetch* fetch = config.add_fetch();
+ tf2xla::Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("bar");
fetch->mutable_id()->set_output_index(-1);
ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
}
TEST(ValidateConfig, DuplicateFeedName) {
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
feed->set_name("dup");
feed = config.add_feed();
@@ -146,10 +116,10 @@ TEST(ValidateConfig, DuplicateFeedName) {
}
TEST(ValidateConfig, DuplicateFetchName) {
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
- Fetch* fetch = config.add_fetch();
+ tf2xla::Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("bar");
fetch->set_name("dup");
fetch = config.add_fetch();
@@ -159,8 +129,8 @@ TEST(ValidateConfig, DuplicateFetchName) {
}
TEST(ValidateConfig, ConflictingFeedName) {
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
feed->set_name("conflict");
feed = config.add_feed();
@@ -170,10 +140,10 @@ TEST(ValidateConfig, ConflictingFeedName) {
}
TEST(ValidateConfig, ConflictingFetchName) {
- Config config;
- Feed* feed = config.add_feed();
+ tf2xla::Config config;
+ tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("foo");
- Fetch* fetch = config.add_fetch();
+ tf2xla::Fetch* fetch = config.add_fetch();
fetch->mutable_id()->set_node_name("bar");
fetch->set_name("conflict");
fetch = config.add_fetch();
@@ -182,8 +152,8 @@ TEST(ValidateConfig, ConflictingFetchName) {
ExpectErrorContains(ValidateConfig(config), "conflicting fetch name");
}
-static Config FetchesConfig(std::vector<string> fetches) {
- Config config;
+static tf2xla::Config FetchesConfig(std::vector<string> fetches) {
+ tf2xla::Config config;
for (const auto& fetch_node_name : fetches) {
auto* fetch = config.add_fetch();
fetch->set_name(strings::StrCat("fetch_", fetch_node_name));
@@ -242,5 +212,4 @@ TEST(PruneGraphDefInto, Basic) {
}
} // namespace
-} // namespace tfcompile
} // namespace tensorflow
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index 30afaed732..e41a391ac5 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -1703,7 +1703,6 @@ StatusOr<Computation> ComputationBuilder::Build() {
}
void ComputationBuilder::AddOpMetadata(OpRequest* request) const {
- tensorflow::mutex_lock lock(mutex_);
*request->mutable_metadata() = metadata_;
}
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index cf1f3b074e..96db56bc53 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -37,7 +37,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stacktrace.h"
#include "tensorflow/core/platform/types.h"
@@ -57,10 +56,10 @@ class ComputationBuilder {
~ComputationBuilder();
// Returns the client the builder was initialized with.
- Client* client() { return client_; }
+ Client* client() const { return client_; }
// Returns the computation name.
- const string& name() { return name_; }
+ const string& name() const { return name_; }
// Sets OpMetadata that will be added to all instructions until cleared.
//
@@ -69,13 +68,11 @@ class ComputationBuilder {
// instructions generated via this Computation Builder will have the same
// OpMetadata attached until a call to ClearOpMetdata.
void SetOpMetadata(const OpMetadata& metadata) {
- tensorflow::mutex_lock lock(mutex_);
metadata_ = metadata;
}
// Clears the HloMetdata state.
void ClearOpMetadata() {
- tensorflow::mutex_lock lock(mutex_);
metadata_.Clear();
}
@@ -826,15 +823,12 @@ class ComputationBuilder {
Client* client_;
// Mode bit that indicates whether to die when a first error is encountered.
- bool die_immediately_on_error_{false};
-
- // Mutex to guard against concurrent access to metadata_.
- mutable tensorflow::mutex mutex_;
+ bool die_immediately_on_error_ = false;
// The metadata to attach to each op. This is structured as a "modal"-like
// operation, in order to simplify client code (and not sprinkle this metadata
// throughout the TensorFlow op kernel implementations).
- OpMetadata metadata_ GUARDED_BY(mutex_);
+ OpMetadata metadata_;
TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder);
};
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 95f8165795..1a18b28cbb 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -180,15 +180,18 @@ cc_library(
cc_library(
name = "ir_emitter",
- srcs = ["ir_emitter.cc"],
+ srcs = [
+ "elemental_ir_emitter.cc",
+ "ir_emitter.cc",
+ ],
hdrs = [
+ "elemental_ir_emitter.h",
"ir_emitter.h",
],
deps = [
":cpu_options",
":cpu_runtime",
":dot_op_emitter",
- ":elemental_ir_emitter",
":ir_emission_utils",
":simple_orc_jit",
"//tensorflow/compiler/xla:shape_util",
@@ -526,22 +529,6 @@ cc_library(
)
cc_library(
- name = "elemental_ir_emitter",
- srcs = ["elemental_ir_emitter.cc"],
- hdrs = ["elemental_ir_emitter.h"],
- deps = [
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/service:elemental_ir_emitter",
- "//tensorflow/compiler/xla/service:hlo",
- "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
- "@llvm//:core",
- ],
-)
-
-cc_library(
name = "ir_emission_utils",
srcs = ["ir_emission_utils.cc"],
hdrs = ["ir_emission_utils.h"],
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
index 511f89144a..902309b338 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
@@ -50,14 +50,6 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
return false;
}
- // Producer or consumer cannot be Map. Maps are technically elementwise but
- // of a slightly different form (call instead of a computation). These are not
- // yet supported in the CPU backend.
- if (producer->opcode() == HloOpcode::kMap ||
- consumer->opcode() == HloOpcode::kMap) {
- return false;
- }
-
// Cost condition: not fuse (simple, expensive producers) and (consumers who
// reuse operand elements).
if (producer->opcode() != HloOpcode::kFusion &&
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 0fc62281a0..b56466d5e4 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -209,6 +209,31 @@ class OpcodeFusionTest : public InstructionFusionTest {
std::multiset<HloOpcode>(fused_opcodes.begin(), fused_opcodes.end()),
expected_opcodes);
}
+
+ HloComputation* CreateAdderToOne(HloModule* module) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* arg0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {}), "arg0"));
+ HloInstruction* one = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, arg0, one));
+ return module->AddEmbeddedComputation(builder.Build());
+ }
+
+ HloComputation* CreateMax(HloModule* module) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* arg0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {}), "arg0"));
+ HloInstruction* arg1 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {}), "arg1"));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, arg0, arg1));
+ return module->AddEmbeddedComputation(builder.Build());
+ }
};
TEST_F(OpcodeFusionTest, Exponential_Bitcast_Negate) {
@@ -402,6 +427,49 @@ TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) {
HloOpcode::kParameter});
}
+TEST_F(OpcodeFusionTest, UnaryMapOfExp) {
+ auto module = CreateNewModule();
+
+ HloComputation::Builder builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {3, 4});
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param"));
+
+ HloInstruction* exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
+ builder.AddInstruction(HloInstruction::CreateMap(
+ shape, {exp}, CreateAdderToOne(module.get()), /*static_operands=*/{}));
+
+ module->AddEntryComputation(builder.Build());
+
+ RunFusionAndCheckOpcodesWereFused(
+ module.get(), {HloOpcode::kParameter, HloOpcode::kExp, HloOpcode::kMap});
+}
+
+TEST_F(OpcodeFusionTest, BinaryMapOfExps) {
+ auto module = CreateNewModule();
+
+ HloComputation::Builder builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {3, 4});
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param"));
+ HloInstruction* param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, shape, "param"));
+
+ HloInstruction* exp0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
+ HloInstruction* exp1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kExp, param1));
+
+ builder.AddInstruction(HloInstruction::CreateMap(
+ shape, {exp0, exp1}, CreateMax(module.get()), /*static_operands=*/{}));
+
+ module->AddEntryComputation(builder.Build());
+
+ RunFusionAndCheckOpcodesWereFused(
+ module.get(), {HloOpcode::kParameter, HloOpcode::kParameter,
+ HloOpcode::kExp, HloOpcode::kExp, HloOpcode::kMap});
+}
} // namespace
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index fe447adf89..73e039250b 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -64,5 +64,25 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp(
}
}
+llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
+ const HloInstruction* hlo,
+ const HloToElementGeneratorMap& operand_to_generator) const {
+ if (hlo->opcode() == HloOpcode::kMap) {
+ return [this, hlo, &operand_to_generator](
+ const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
+ std::vector<llvm::Value*> operands;
+ for (int i = 0; i < hlo->operand_count(); i++) {
+ TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
+ operand_to_generator.at(hlo->operand(i))(
+ ElementwiseSourceIndex(index, *hlo, 0)));
+ operands.push_back(operand_value);
+ }
+ return ir_emitter_->EmitScalarCall(hlo->shape().element_type(),
+ hlo->to_apply(), operands,
+ llvm_ir::IrName(hlo));
+ };
+ }
+ return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator);
+}
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
index 6f9d6a24b4..7e9f27befb 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -29,12 +30,19 @@ namespace cpu {
class CpuElementalIrEmitter : public ElementalIrEmitter {
public:
CpuElementalIrEmitter(const HloModuleConfig& module_config,
- llvm::IRBuilder<>* ir_builder, llvm::Module* module)
- : ElementalIrEmitter(module_config, module, ir_builder) {}
+ IrEmitter* ir_emitter, llvm::Module* module)
+ : ElementalIrEmitter(module_config, module, ir_emitter->ir_builder()),
+ ir_emitter_(ir_emitter) {}
+
+ llvm_ir::ElementGenerator MakeElementGenerator(
+ const HloInstruction* hlo,
+ const HloToElementGeneratorMap& operand_to_generator) const override;
protected:
StatusOr<llvm::Value*> EmitFloatUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const override;
+
+ IrEmitter* ir_emitter_;
};
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
index 94d4ce4a94..91b09f2472 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
@@ -136,6 +136,10 @@ DotInLlvmIrProfitable ProfitableToImplementDotInLlvmIr(
const int64 kReductionDimensionThresholdBytes = 8 * 1024;
const bool single_threaded_eigen =
!dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen();
+
+ // This is the point at which it is better to call into Eigen and shard the
+ // dot across multiple worker threads. This is a rough estimate by running
+ // a matmult benchmark on my local machine, and it can be tuned further.
const int64 kMaxSingleThreadedFlops = 16 * 1024;
const int64 M = result_shape.dimensions(0);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index bc51ad2b36..8cd8740ee8 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -2354,8 +2354,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
for (HloInstruction* operand : fusion->operands()) {
parameter_arrays.push_back(GetIrArrayForOp(operand));
}
- CpuElementalIrEmitter elemental_emitter(hlo_module_config_, &ir_builder_,
- module_);
+ CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
@@ -2737,14 +2736,10 @@ llvm::Value* IrEmitter::GetProfileCounterFor(const HloInstruction* hlo) {
}
prof_counter_idx = it->second;
- uintptr_t hlo_address = reinterpret_cast<uintptr_t>(hlo);
- counter_name = tensorflow::strings::StrCat(
- "prof_counter_0x",
- tensorflow::strings::Hex(
- hlo_address, tensorflow::strings::PadSpec(sizeof(hlo_address))));
+ counter_name = IrName("prof_counter", hlo->name());
} else {
prof_counter_idx = hlo_to_profile_idx_->size();
- counter_name = "prof_counter_computation";
+ counter_name = "prof_counter.computation";
}
return ir_builder_.CreateGEP(GetProfileCountersArgument(),
ir_builder_.getInt64(prof_counter_idx),
@@ -3180,12 +3175,27 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
return GetIrArrayForOp(operand).EmitReadArrayElement(index, &ir_builder_);
};
}
- CpuElementalIrEmitter elemental_emitter(hlo_module_config_, &ir_builder_,
- module_);
+ CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
return EmitTargetElementLoop(
hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
}
+StatusOr<llvm::Value*> IrEmitter::EmitScalarCall(
+ PrimitiveType return_type, HloComputation* computation,
+ const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name) {
+ llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation);
+ std::vector<llvm::Value*> argument_addrs;
+ for (auto argument : arguments) {
+ llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+ argument->getType(), "arg_addr", &ir_builder_);
+ ir_builder_.CreateStore(argument, argument_addr);
+ argument_addrs.push_back(argument_addr);
+ }
+ return EmitElementFunctionCall(llvm_function,
+ ShapeUtil::MakeShape(return_type, {}),
+ argument_addrs, name);
+}
+
unsigned TargetMachineFeatures::largest_register_size_in_bytes(
llvm::Function* function) {
auto itr = largest_register_size_in_bytes_.find(function);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index bcd33c3810..fa33a1eb7b 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -133,6 +133,13 @@ class IrEmitter : public DfsHloVisitorWithDefault {
bool is_top_level_computation,
std::vector<const HloInstruction*>* instruction_order);
+ llvm::IRBuilder<>* ir_builder() { return &ir_builder_; }
+
+ // Emits a call to `computation` with scalar arguments `arguments`.
+ StatusOr<llvm::Value*> EmitScalarCall(
+ PrimitiveType return_type, HloComputation* computation,
+ const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name);
+
protected:
//
// The following methods implement the DfsHloVisitor interface.
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
index 5d650b872f..b24fe417ff 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
@@ -76,10 +76,11 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path,
// Since CUDA 9.0, all GPU versions are included in a single file
const char* unified_libdevice_filename = "libdevice.10.bc";
std::vector<string> unified_libdevice_files;
- tensorflow::Env::Default()->GetMatchingPaths(
+ const tensorflow::Status status =
+ tensorflow::Env::Default()->GetMatchingPaths(
tensorflow::io::JoinPath(libdevice_dir_path, unified_libdevice_filename),
&unified_libdevice_files);
- if( unified_libdevice_files.size() == 1 ) {
+ if (status.ok() && unified_libdevice_files.size() == 1) {
return unified_libdevice_filename;
}
// There are only four libdevice files: compute_{20,30,35,50}. Each GPU
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h
index efb5fca188..e0c23a3a08 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.h
+++ b/tensorflow/compiler/xla/service/hlo_ordering.h
@@ -77,7 +77,7 @@ class HloOrdering {
// Precondition: 'a' and 'b' are in the same computation.
//
// Derived classes should implement this method for determining order of
- // instructions in the same comptuation. ExecutesBefore() analyzes the
+ // instructions in the same computation. ExecutesBefore() analyzes the
// callgraph and uses this method to determine ordering of instructions in
// different computations.
virtual bool ExecutesBeforeInSameComputation(
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 1865004911..a0f9be3dd8 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -50,7 +50,7 @@ class WhileTest : public ClientLibraryTestBase {};
// while (result < 5) {
// result = result + 1;
// }
-TEST_F(WhileTest, WhileWithScalarResult) {
+TEST_F(WhileTest, WhileWithScalarS32Result) {
auto result_shape = ShapeUtil::MakeShape(S32, {});
// Create a computation for the condition: repeat for 5 iterations.
@@ -81,6 +81,43 @@ TEST_F(WhileTest, WhileWithScalarResult) {
ComputeAndCompareR0<int32>(&builder, 5, {});
}
+// Tests a while node when the result type T is S64.
+//
+// int32 result = 0;
+// while (result < 5) {
+// result = result + 1;
+// }
+TEST_F(WhileTest, WhileWithScalarS64Result) {
+ auto result_shape = ShapeUtil::MakeShape(S64, {});
+
+ // Create a computation for the condition: repeat for 5 iterations.
+ Computation condition;
+ {
+ ComputationBuilder builder(client_, "condition");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ builder.Gt(builder.ConstantR0<int64>(5), prev);
+ condition = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a computation for the body: add 1 to the result variable.
+ Computation body;
+ {
+ ComputationBuilder builder(client_, "body");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto input = builder.ConstantR0<int64>(1);
+ auto result = builder.Add(input, prev);
+ body = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a While node with computations for the condition and the body.
+ ComputationBuilder builder(client_, TestName());
+ auto init = builder.ConstantR0<int64>(0);
+ auto result = builder.While(condition, body, init);
+ auto shape = builder.GetShape(result).ConsumeValueOrDie();
+
+ ComputeAndCompareR0<int64>(&builder, 5, {});
+}
+
TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
auto result_shape = ShapeUtil::MakeShape(S32, {});
auto orig_shape = ShapeUtil::MakeShape(S32, {2});