diff options
Diffstat (limited to 'tensorflow/compiler')
42 files changed, 989 insertions, 645 deletions
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 9909e88e64..29dbe4a08b 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -4,7 +4,6 @@ package( default_visibility = ["//visibility:private"], ) -load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") # Optional runtime utilities for use by code generated by tfcompile. @@ -39,32 +38,24 @@ cc_library( deps = ["//tensorflow/core:test_main"], ) -xla_proto_library( - name = "tfcompile_proto", - srcs = ["tfcompile.proto"], - deps = [ - "//tensorflow/core:protos_all_cc", - ], -) - cc_library( name = "tfcompile_lib", srcs = [ "codegen.cc", "compile.cc", "flags.cc", - "tfcompile_util.cc", ], hdrs = [ "codegen.h", "compile.h", "flags.h", - "tfcompile_util.h", ], deps = [ ":runtime", # needed by codegen to print aligned_buffer_bytes - ":tfcompile_proto", + "//tensorflow/compiler/tf2xla", "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:tf2xla_proto", + "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", @@ -82,7 +73,6 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -99,18 +89,6 @@ cc_test( ], ) -cc_test( - name = "tfcompile_util_test", - srcs = ["tfcompile_util_test.cc"], - deps = [ - ":tfcompile_lib", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - cc_binary( name = "tfcompile", visibility = ["//visibility:public"], @@ -123,7 +101,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":tfcompile_lib", - ":tfcompile_proto", + "//tensorflow/compiler/tf2xla:tf2xla_proto", + "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/core:core_cpu", @@ -226,7 +205,11 @@ test_suite( tags = ["manual"], tests = [ ":benchmark_test", + ":codegen_test", + ":runtime_test", ":test_graph_tfadd_test", + ":test_graph_tfunknownop2_test", + ":test_graph_tfunknownop3_test", ":test_graph_tfunknownop_test", "//tensorflow/compiler/aot/tests:all_tests", ], diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index bbdb342a62..fc5c6ce58d 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -20,8 +20,8 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/aot/runtime.h" -#include "tensorflow/compiler/aot/tfcompile_util.h" #include "tensorflow/compiler/tf2xla/str_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -35,6 +35,12 @@ namespace tfcompile { namespace { +bool IsAlpha(char c) { + return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z'); +} + +bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); } + // Convert an XLA type into a C++ type. Status XLATypeToCpp(xla::PrimitiveType type, string* str) { switch (type) { @@ -156,7 +162,7 @@ string RewriteWithName(const string& name, string code, } // Generate methods for args (inputs). -Status GenArgMethods(const Config& config, const xla::ProgramShape& ps, +Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, const CompileResult& compile_result, string* methods) { *methods += R"( void** args() { return args_; } @@ -204,8 +210,8 @@ Status GenArgMethods(const Config& config, const xla::ProgramShape& ps, } // Generate methods for results (outputs). -Status GenResultMethods(const Config& config, const xla::ProgramShape& ps, - string* methods) { +Status GenResultMethods(const tf2xla::Config& config, + const xla::ProgramShape& ps, string* methods) { if (ps.result().element_type() != xla::TUPLE) { // Non-tuple (i.e. single-result) case. if (config.fetch_size() != 1) { @@ -285,11 +291,26 @@ Status GenResultMethods(const Config& config, const xla::ProgramShape& ps, return Status::OK(); } +Status ValidateFeedFetchCppNames(const tf2xla::Config& config) { + for (const tf2xla::Feed& feed : config.feed()) { + if (!feed.name().empty()) { + TF_RETURN_IF_ERROR(ValidateCppIdent(feed.name(), "feed name")); + } + } + for (const tf2xla::Fetch& fetch : config.fetch()) { + if (!fetch.name().empty()) { + TF_RETURN_IF_ERROR(ValidateCppIdent(fetch.name(), "fetch name")); + } + } + return Status::OK(); +} + } // namespace -Status GenerateHeader(const HeaderOpts& opts, const Config& config, +Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, const CompileResult& compile_result, string* header) { TF_RETURN_IF_ERROR(ValidateConfig(config)); + TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config)); const int64 result_index = compile_result.aot->result_buffer_index(); const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes(); if (result_index < 0 || result_index > temp_sizes.size()) { @@ -574,5 +595,29 @@ Status ParseCppClass(const string& cpp_class, string* class_name, return Status::OK(); } +Status ValidateCppIdent(StringPiece ident, StringPiece msg) { + if (ident.empty()) { + return errors::InvalidArgument("empty identifier: ", msg); + } + // Require that the identifier starts with a nondigit, and is composed of + // nondigits and digits, as specified in section [2.11 Identifiers] of the + // C++11 Standard. Note that nondigit is defined as [_a-zA-Z] and digit is + // defined as [0-9]. + // + // Technically the standard also allows for `universal-character-name`, with a + // table of allowed unicode ranges, as well as `other implementation-defined + // characters`. We disallow those here to give better error messages, at the + // expensive of being more restrictive than the standard. + if (ident[0] != '_' && !IsAlpha(ident[0])) { + return errors::InvalidArgument("illegal leading char: ", msg); + } + for (size_t pos = 1; pos < ident.size(); ++pos) { + if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) { + return errors::InvalidArgument("illegal char: ", msg); + } + } + return Status::OK(); +} + } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 7217c57739..740edd1e83 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -20,6 +20,8 @@ limitations under the License. #include <vector> #include "tensorflow/compiler/aot/compile.h" +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" namespace tensorflow { namespace tfcompile { @@ -37,7 +39,7 @@ struct HeaderOpts { // GenerateHeader uses the meta-information from compile_result to generate a // C++ header giving access to the function in the generated object file. The // header includes API usage documentation. -Status GenerateHeader(const HeaderOpts& opts, const Config& config, +Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, const CompileResult& compile_result, string* header); // ParseCppClass parses `cpp_class` into its `class_name` and `namespaces` @@ -47,6 +49,10 @@ Status GenerateHeader(const HeaderOpts& opts, const Config& config, Status ParseCppClass(const string& cpp_class, string* class_name, std::vector<string>* namespaces); +// ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is +// appended to error messages. +Status ValidateCppIdent(StringPiece ident, StringPiece msg); + } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index e3f76f3666..98cbd67e53 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" @@ -29,6 +30,41 @@ namespace tensorflow { namespace tfcompile { namespace { +void ExpectErrorContains(const Status& status, StringPiece str) { + EXPECT_NE(Status::OK(), status); + EXPECT_TRUE(StringPiece(status.error_message()).contains(str)) + << "expected error: " << status.error_message() << " to contain: " << str; +} + +TEST(ValidateCppIdent, Simple) { + TF_EXPECT_OK(ValidateCppIdent("a", "")); + TF_EXPECT_OK(ValidateCppIdent("abc", "")); + TF_EXPECT_OK(ValidateCppIdent("_abc", "")); + TF_EXPECT_OK(ValidateCppIdent("_abc123", "")); + // Make sure we didn't skip a valid letter or digit + string ident; + for (char c = 'a'; c <= 'z'; c++) { + ident.append(1, c); + } + for (char c = 'A'; c <= 'Z'; c++) { + ident.append(1, c); + } + for (char c = '0'; c <= '9'; c++) { + ident.append(1, c); + } + ident += "_"; + TF_EXPECT_OK(ValidateCppIdent(ident, "")); + + ExpectErrorContains(ValidateCppIdent("", ""), "empty identifier"); + ExpectErrorContains(ValidateCppIdent(" ", ""), "illegal leading char"); + ExpectErrorContains(ValidateCppIdent("0", ""), "illegal leading char"); + ExpectErrorContains(ValidateCppIdent(".", ""), "illegal leading char"); + ExpectErrorContains(ValidateCppIdent(":", ""), "illegal leading char"); + ExpectErrorContains(ValidateCppIdent("a.", ""), "illegal char"); + ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char"); + ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char"); +} + class ParseCppClassTest : public ::testing::Test { protected: void ExpectOK(const string& cpp_class, const string& want_class_name, @@ -91,13 +127,13 @@ TEST(GenerateHeader, Golden) { HeaderOpts opts; opts.class_name = "MyClass"; opts.namespaces = {"foo", "bar"}; - Config config; - Feed* feed = config.add_feed(); + tf2xla::Config config; + tf2xla::Feed* feed = config.add_feed(); feed->mutable_id()->set_node_name("feed0"); feed->set_name("myfeed"); feed = config.add_feed(); feed->mutable_id()->set_node_name("feed1"); - Fetch* fetch = config.add_fetch(); + tf2xla::Fetch* fetch = config.add_fetch(); fetch->mutable_id()->set_node_name("fetch0"); fetch->set_name("myfetch"); CompileResult compile_result; diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index a485d2e555..eac8da0ab1 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -15,326 +15,32 @@ limitations under the License. #include "tensorflow/compiler/aot/compile.h" -#include <map> #include <memory> #include <string> -#include <unordered_map> #include <utility> #include <vector> #include "tensorflow/compiler/aot/flags.h" -#include "tensorflow/compiler/aot/tfcompile_util.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/tf2xla/tf2xla.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/compile_only_client.h" -#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" -#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/graph_def_util.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/versions.pb.h" -#include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" -#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/public/session_options.h" namespace tensorflow { namespace tfcompile { -const char* const kArgOp = "_Arg"; -const char* const kRetvalOp = "_Retval"; -const char* const kFeedIdAttr = "_feed_id"; -const char* const kFetchIdAttr = "_fetch_id"; -const char* const kShapeAttr = "_shape"; -const char* const kDebugNameAttr = "_debug_name"; - namespace { -Status DumpGraph(const MainFlags& flags, const string& name, - const Graph& graph) { - if (flags.debug_dir.empty()) { - return Status::OK(); - } - GraphDef graph_def; - graph.ToGraphDef(&graph_def); - string file = io::JoinPath(flags.debug_dir, name + ".pbtxt"); - return WriteTextProto(Env::Default(), file, graph_def); -} - -typedef std::unordered_map<string, Node*> NodeMap; - -// Each feed id identifies the positional output of some node, which may consist -// of multiple edges. AddPlaceholdersForFeeds has already replaced each fed -// tensor with a placeholder. For each feed tensor, replaces all edges so they -// point from a new _Arg node instead. -Status AddArgNodes(Graph* graph, const NodeMap& node_map, - const protobuf::RepeatedPtrField<Feed>& feeds, - const std::unordered_map<string, string>& feed_remapping) { - for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) { - const Feed& feed = feeds[arg_index]; - // All feeds have been replaced by placeholders. - const int output_index = 0; - - const string key = TensorIdToString(feed.id()); - const auto remap_it = feed_remapping.find(key); - auto node_it = node_map.find(remap_it->second); - if (node_it == node_map.end()) { - // Strip off the aot_feed_#/ prefix. - StringPiece name(remap_it->second); - const auto index = name.find('/'); - if (index > 0) name.remove_prefix(index + 1); - return errors::InvalidArgument( - "Node is fed but not needed for fetching: ", name); - } - const Node* feed_node = node_it->second; - - // TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a - // "_shape" attr if we can determine it. That way the graph will be - // initialized with whatever shapes we can infer, while the user can still - // explicitly specify or override them. - Node* arg_node = nullptr; - TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp) - .Attr("T", BaseType(feed_node->output_type(output_index))) - .Attr("index", arg_index) - .Attr(kFeedIdAttr, TensorIdToString(feed.id())) - .Attr(kShapeAttr, TensorShape(feed.shape())) - .Attr(kDebugNameAttr, feed.name()) - .Finalize(graph, &arg_node)); - - // Collects out-edges from the feed node that have a matching edge index; - // these will be replaced with edges from the arg node instead. - // - // We must collect the edges first and process them in a second pass, since - // removing the edge from the graph invalidates feed_node->out_edges. - std::vector<const Edge*> feed_edges; - for (const Edge* edge : feed_node->out_edges()) { - if (edge->src_output() == output_index) { - feed_edges.push_back(edge); - } - } - for (const Edge* edge : feed_edges) { - graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input()); - graph->RemoveEdge(edge); - } - } - return Status::OK(); -} - -// Each fetch id identifies the positional output of some node. For each fetch -// node, adds a new _Retval node instead, and adds the node to `retval_nodes`. -Status AddRetvalNodes(Graph* graph, const NodeMap& node_map, - const protobuf::RepeatedPtrField<Fetch>& fetches, - std::unordered_set<const Node*>* retval_nodes) { - for (int ret_index = 0; ret_index < fetches.size(); ++ret_index) { - const TensorId& id = fetches[ret_index].id(); - auto it = node_map.find(id.node_name()); - if (it == node_map.end()) { - return errors::NotFound("Can't find fetch id: ", TensorIdToString(id)); - } - Node* fetch_node = it->second; - if (id.output_index() >= fetch_node->num_outputs()) { - return errors::InvalidArgument("Invalid fetch id: ", TensorIdToString(id), - ", output index should be < ", - fetch_node->num_outputs()); - } - // Connects fetch_node -> retval_node. - Node* retval_node = nullptr; - TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp) - .Input(fetch_node, id.output_index()) - .Attr("T", BaseType(fetch_node->output_type(id.output_index()))) - .Attr("index", ret_index) - .Attr(kFetchIdAttr, TensorIdToString(id)) - .Finalize(graph, &retval_node)); - retval_nodes->insert(retval_node); - } - return Status::OK(); -} - -// RewriteAndPruneGraph identifies input and output edges (named by the feed and -// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg -// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph -// execution to know the input and output args for the generated function. -Status RewriteAndPruneGraph( - Graph* graph, const Config& config, - const std::unordered_map<string, string>& feed_remapping, - const MainFlags& flags) { - NodeMap node_map; - for (Node* n : graph->nodes()) { - node_map[n->name()] = n; - } - TF_RETURN_IF_ERROR( - AddArgNodes(graph, node_map, config.feed(), feed_remapping)); - std::unordered_set<const Node*> retval_nodes; - TF_RETURN_IF_ERROR( - AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes)); - TF_RETURN_IF_ERROR(DumpGraph(flags, "tfcompile_post_rewrite", *graph)); - PruneForReverseReachability(graph, retval_nodes); - FixupSourceAndSinkEdges(graph); - TF_RETURN_IF_ERROR(DumpGraph(flags, "tfcompile_post_prune", *graph)); - // Sanity-check, to make sure the feeds and fetches still exist post-pruning. - std::set<string> missing_feeds, missing_fetches; - for (const Feed& feed : config.feed()) { - missing_feeds.insert(TensorIdToString(feed.id())); - } - for (const Fetch& fetch : config.fetch()) { - missing_fetches.insert(TensorIdToString(fetch.id())); - } - for (const Node* n : graph->op_nodes()) { - if (n->type_string() == kArgOp) { - string feed_id; - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id)); - if (missing_feeds.erase(feed_id) == 0) { - return errors::Aborted(kArgOp, - " node found with unknown feed id: ", feed_id); - } - } else if (n->type_string() == kRetvalOp) { - string fetch_id; - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id)); - if (missing_fetches.erase(fetch_id) == 0) { - return errors::Aborted(kRetvalOp, - " node found with unknown fetch id: ", fetch_id); - } - } - } - if (!missing_feeds.empty() || !missing_fetches.empty()) { - return errors::Aborted( - "Post graph-pruning", - ", missing feeds: ", str_util::Join(missing_feeds, ", "), - ", missing fetches: ", str_util::Join(missing_fetches, ", ")); - } - return Status::OK(); -} - -// CollectArgNodes collects _Arg nodes from the graph, and performs basic -// sanity-checking to ensure the index and type attributes of each node are -// initialized correctly. -Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) { - std::map<int, Node*> indexed_arg_nodes; - for (Node* n : graph.nodes()) { - if (n->type_string() == kArgOp) { - int index; - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); - auto insert_result = indexed_arg_nodes.insert({index, n}); - if (!insert_result.second) { - const Node* dup = insert_result.first->second; - return errors::InvalidArgument( - "Multiple ", kArgOp, " nodes with index ", index, ", ", - n->DebugString(), " and ", dup->DebugString()); - } - } - } - arg_nodes->clear(); - for (const auto& index_node : indexed_arg_nodes) { - if (index_node.first != arg_nodes->size()) { - return errors::InvalidArgument("Expected ", kArgOp, " node with index ", - arg_nodes->size(), ", but got index ", - index_node.first); - } - arg_nodes->push_back(index_node.second); - } - return Status::OK(); -} - -// Fills in xla_args from the corresponding _Arg nodes in the graph. -Status CreateXlaArgs(const Graph& graph, - std::vector<XlaCompiler::Argument>* xla_args) { - std::vector<Node*> arg_nodes; - TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes)); - for (const Node* node : arg_nodes) { - XlaCompiler::Argument arg; - arg.kind = XlaCompiler::Argument::kParameter; - TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type)); - TensorShape shape; - TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape)); - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape)); - TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name)); - xla_args->push_back(arg); - } - return Status::OK(); -} - -// Converts the TensorFlow graph into an XLA computation, by executing the -// graph symbolically, with each op building up the XLA HLO. -Status ConvertGraphToXla(xla::CompileOnlyClient* client, - std::unique_ptr<Graph> graph, - xla::Computation* computation, bool* has_context_arg) { - // Create a device and context to convert the graph into an XLA computation. - XlaOpRegistry::RegisterCompilationKernels(); - // Populate the context with args from the graph. - for (Node* node : graph->nodes()) { - node->set_assigned_device_name(DEVICE_CPU_XLA_JIT); - } - std::vector<XlaCompiler::Argument> xla_args; - TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args)); - - // Compile the graph into an XLA computation. - XlaCompiler::Options compiler_options; - compiler_options.client = client; - DeviceType device_type(DEVICE_CPU_XLA_JIT); - compiler_options.device_type = &device_type; - compiler_options.flib_def = &graph->flib_def(); - compiler_options.graph_def_version = graph->versions().producer(); - compiler_options.allow_cpu_custom_calls = true; - XlaCompiler compiler(compiler_options); - - XlaCompiler::CompilationResult result; - TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(), - "tfcompile", std::move(graph), - xla_args, &result)); - *has_context_arg = result.requires_runtime_context; - *computation = std::move(*result.computation); - - int num_const_results = 0; - for (int i = 0; i < result.outputs.size(); ++i) { - // Ending up with const results (i.e. output args) is an error, since it - // means that one or more fetches that the user specified will be dropped - // from the generated function. It's most likely a configuration error, - // since the user shouldn't be asking for output args that end up as consts. - // - // TODO(toddw): Provide a way for the user to access const output args, - // e.g. perhaps hard-coded into the header, or somehow copied into the - // output buffers. - if (result.outputs[i].is_constant) { - ++num_const_results; - LOG(ERROR) << "ConstRetVal index:" << i - << " value:" << result.outputs[i].constant_value.DebugString(); - } - } - if (num_const_results > 0) { - return errors::Unimplemented( - "Conversion from TensorFlow graph to XLA resulted in ", - num_const_results, - " constant results. The configuration of " - "the output args (i.e. fetch ids) is probably wrong."); - } - if (computation->IsNull()) { - return errors::Aborted( - "Conversion from TensorFlow graph to XLA resulted in an empty " - "computation."); - } - return Status::OK(); -} - // Compiles the XLA computation into executable code. Status CompileXla(xla::CompileOnlyClient* client, const xla::Computation& computation, @@ -376,41 +82,8 @@ Status CompileXla(xla::CompileOnlyClient* client, } // namespace -Status InitGraph(const GraphDef& graph_def, const Config& config, - const MainFlags& flags, std::unique_ptr<Graph>* graph) { - TF_RETURN_IF_ERROR(ValidateConfig(config)); - - FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library()); - std::unique_ptr<Graph> g(new Graph(flib_def)); - - // Replace references to fed tensors with references to newly added - // placeholders. - GraphDef first_copy_def = graph_def; - - // Maps from name:port of a feed to the name:port of the placeholder to use. - std::unordered_map<string, string> feed_remapping; - TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(), - &feed_remapping, &first_copy_def)); - - // Prune the GraphDef first so that unknown ops that we aren't compiling get - // filtered out. - GraphDef second_copy_def; - TF_RETURN_IF_ERROR( - PruneGraphDefInto(config, first_copy_def, &second_copy_def)); - - TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef( - &second_copy_def, *g->op_registry(), 0 /*node_offset*/)); - - TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(), - second_copy_def, g.get())); - TF_RETURN_IF_ERROR( - RewriteAndPruneGraph(g.get(), config, feed_remapping, flags)); - *graph = std::move(g); - return Status::OK(); -} - -Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags, - CompileResult* compile_result) { +Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, + const MainFlags& flags, CompileResult* compile_result) { // Converts the graph into an XLA computation, and compiles the // computation. // TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client? @@ -421,8 +94,9 @@ Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags, xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform) .ValueOrDie(); xla::Computation computation; - TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), &computation, - &compile_result->has_context_arg)); + TF_RETURN_IF_ERROR(ConvertGraphDefToXla(graph_def, config, client, + &computation, + &compile_result->has_context_arg)); if (!flags.debug_dir.empty()) { TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::SessionModule> module, computation.Snapshot()); diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h index e929272b2e..965c296081 100644 --- a/tensorflow/compiler/aot/compile.h +++ b/tensorflow/compiler/aot/compile.h @@ -18,46 +18,16 @@ limitations under the License. #include <memory> #include <string> -#include <vector> #include "tensorflow/compiler/aot/flags.h" -#include "tensorflow/compiler/aot/tfcompile.pb.h" -#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/graph/graph.h" namespace tensorflow { namespace tfcompile { -// Constants for op types and attribute names. -extern const char* const kArgOp; -extern const char* const kRetvalOp; -extern const char* const kFeedIdAttr; -extern const char* const kFetchIdAttr; -extern const char* const kShapeAttr; -extern const char* const kDebugNameAttr; - -// InitGraph creates a graph based on the graph_def, that may then be compiled -// by CompileGraph. -// -// The graph is rewritten with _Arg and _Retval nodes, representing the inputs -// and outputs of the function that will be compiled. Each feed id causes a new -// _Arg node to be created, where we first collect all existing edges pointing -// from the named node's output index, and then rewrite them to point from that -// _Arg node instead. Each fetch id causes a new _Retval node to be created, -// with a new edge pointing from the named node's output index to that _Retval -// node. All _Retval nodes also point to a special CompileExpressions node, -// used internally to finish the compilation. -// -// The rewritten graph is then pruned to only contain the portion necessary to -// compute the outputs. If dump_graphs is true, graph rewrites will be dumped -// for debugging. -Status InitGraph(const GraphDef& graph_def, const Config& config, - const MainFlags& flags, std::unique_ptr<Graph>* graph); - // CompileResult describes the output of CompileGraph, where the object file // data and meta-information is available in aot. struct CompileResult { @@ -69,20 +39,12 @@ struct CompileResult { int pointer_size = 0; // Size of a pointer in bytes. }; -// CompileGraph compiles the graph into an object file containing a function +// CompileGraph compiles the graph_def into an object file containing a function // that performs the graph operations. // -// The graph must have _Arg and _Retval nodes representing the function inputs -// and outputs. Every _Arg node must have a shape attribute (key=kShapeAttr, -// value=TensorShape) representing the static shape of that input, and every -// _Retval node must point to a CompileExpressions node. -// -// Typically InitGraph is called to perform this initialization, followed by -// full specification of the shape attributes. -// // The XLA compilation options are specified in the flags. -Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags, - CompileResult* result); +Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, + const MainFlags& flags, CompileResult* compile_result); } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt index f2d9c34b2d..a4ad334352 100644 --- a/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt +++ b/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt @@ -1,4 +1,4 @@ -# Text form of tensorflow.tfcompile.Config proto. +# Text form of tensorflow.tf2xla.Config proto. feed { id { node_name: "x_const" } shape { diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt index 5625c0ab03..d3f0e4990c 100644 --- a/tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt +++ b/tensorflow/compiler/aot/test_graph_tfunknownop.config.pbtxt @@ -1,4 +1,4 @@ -# Text form of tensorflow.tfcompile.Config proto. +# Text form of tensorflow.tf2xla.Config proto. feed { id { node_name: "x_const" } shape { diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt index 7370ed370d..e0b012adea 100644 --- a/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt +++ b/tensorflow/compiler/aot/test_graph_tfunknownop2.config.pbtxt @@ -1,4 +1,4 @@ -# Text form of tensorflow.tfcompile.Config proto. +# Text form of tensorflow.tf2xla.Config proto. feed { id { node_name: "x_const" } shape { diff --git a/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt index b2d7d54574..662ba1c321 100644 --- a/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt +++ b/tensorflow/compiler/aot/test_graph_tfunknownop3.config.pbtxt @@ -1,4 +1,4 @@ -# Text form of tensorflow.tfcompile.Config proto. +# Text form of tensorflow.tf2xla.Config proto. feed { id { node_name: "x_const" } shape { diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 05d338e4c5..4d65a044bc 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -13,9 +13,11 @@ test_suite( ":test_graph_tfadd_test", ":test_graph_tfadd_with_ckpt_saver_test", ":test_graph_tfadd_with_ckpt_test", + ":test_graph_tffunction_test", ":test_graph_tfgather_test", ":test_graph_tfmatmul_test", ":test_graph_tfmatmulandadd_test", + ":test_graph_tfsplits_test", ":tfcompile_test", ], ) @@ -91,6 +93,15 @@ tf_library( ) tf_library( + name = "test_graph_tffunction", + testonly = 1, + config = "test_graph_tffunction.config.pbtxt", + cpp_class = "FunctionComp", + graph = "test_graph_tffunction.pb", + tags = ["manual"], +) + +tf_library( name = "test_graph_tfgather", testonly = 1, config = "test_graph_tfgather.config.pbtxt", @@ -118,15 +129,6 @@ tf_library( ) tf_library( - name = "test_graph_tffunction", - testonly = 1, - config = "test_graph_tffunction.config.pbtxt", - cpp_class = "FunctionComp", - graph = "test_graph_tffunction.pb", - tags = ["manual"], -) - -tf_library( name = "test_graph_tfsplits", testonly = 1, config = "test_graph_tfsplits.config.pbtxt", diff --git a/tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt index 5625c0ab03..d3f0e4990c 100644 --- a/tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt +++ b/tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt @@ -1,4 +1,4 @@ -# Text form of tensorflow.tfcompile.Config proto. +# Text form of tensorflow.tf2xla.Config proto. feed { id { node_name: "x_const" } shape { diff --git a/tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt index 4d876a6e91..8adc9cdc14 100644 --- a/tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt +++ b/tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt @@ -1,4 +1,4 @@ -# Text form of tensorflow.tfcompile.Config proto. +# Text form of tensorflow.tf2xla.Config proto. feed { id { node_name: "x_hold" } shape { diff --git a/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt index eb9c1cacb7..cbfe458908 100644 --- a/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt +++ b/tensorflow/compiler/aot/tests/test_graph_tffunction.config.pbtxt @@ -1,4 +1,4 @@ -# Text form of tensorflow.tfcompile.Config proto. +# Text form of tensorflow.tf2xla.Config proto. feed { id { node_name: "x_const" } shape { diff --git a/tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt index 648ee31fdb..89ed678a9c 100644 --- a/tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt +++ b/tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt @@ -1,4 +1,4 @@ -# Text form of tensorflow.tfcompile.Config proto. +# Text form of tensorflow.tf2xla.Config proto. feed { id { node_name: "params" } shape { diff --git a/tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt index a3ce2029c1..2acd0289c2 100644 --- a/tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt +++ b/tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt @@ -1,4 +1,4 @@ -# Text form of tensorflow.tfcompile.Config proto. +# Text form of tensorflow.tf2xla.Config proto. feed { id { node_name: "x_hold" } shape { diff --git a/tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt index 4a4a237a4f..e5ca6115e9 100644 --- a/tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt +++ b/tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt @@ -1,4 +1,4 @@ -# Text form of tensorflow.tfcompile.Config proto. +# Text form of tensorflow.tf2xla.Config proto. feed { id { node_name: "x_hold" } shape { diff --git a/tensorflow/compiler/aot/tests/test_graph_tfsplits.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfsplits.config.pbtxt index 85fc7da442..5adc77336c 100644 --- a/tensorflow/compiler/aot/tests/test_graph_tfsplits.config.pbtxt +++ b/tensorflow/compiler/aot/tests/test_graph_tfsplits.config.pbtxt @@ -1,4 +1,4 @@ -# Text form of tensorflow.tfcompile.Config proto. +# Text form of tensorflow.tf2xla.Config proto. feed { id { node_name: "x" } shape { diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index f9896988dc..fc1342d84e 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -41,7 +41,7 @@ def tf_library(name, graph, config, graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it is expected to be in the human-readable proto text format, otherwise it is expected to be in the proto binary format. - config: File containing tensorflow.tfcompile.Config proto. If the file ends + config: File containing tensorflow.tf2xla.Config proto. If the file ends in '.pbtxt' it is expected to be in the human-readable proto text format, otherwise it is expected to be in the proto binary format. freeze_checkpoint: If provided, run freeze_graph with this checkpoint to diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index be2cfe4734..cc499c3284 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/compiler/aot/codegen.h" #include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/aot/flags.h" -#include "tensorflow/compiler/aot/tfcompile.pb.h" -#include "tensorflow/compiler/aot/tfcompile_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/core/framework/function.h" @@ -54,8 +54,7 @@ const char kUsageHeader[] = "--cpp_class=\"mynamespace::MyComputation\"\n" "\n"; -Status ReadProtoFile(const string& kind, const string& fname, - protobuf::Message* proto) { +Status ReadProtoFile(const string& fname, protobuf::Message* proto) { if (StringPiece(fname).ends_with(".pbtxt")) { return ReadTextProto(Env::Default(), fname, proto); } else { @@ -63,23 +62,17 @@ Status ReadProtoFile(const string& kind, const string& fname, } } -void ParseTensorId(const string& name, TensorId* id) { - const std::pair<StringPiece, int> name_index = ParseTensorName(name); - id->set_node_name(name_index.first.ToString()); - id->set_output_index(name_index.second); -} - Status Main(const MainFlags& flags) { // Process config. - Config config; + tf2xla::Config config; if (flags.config.empty()) { return errors::InvalidArgument("Must specify --config"); } - TF_RETURN_IF_ERROR(ReadProtoFile("config", flags.config, &config)); + TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config)); TF_RETURN_IF_ERROR(ValidateConfig(config)); if (flags.dump_fetch_nodes) { std::set<string> nodes; - for (const Fetch& fetch : config.fetch()) { + for (const tf2xla::Fetch& fetch : config.fetch()) { nodes.insert(fetch.id().node_name()); } std::cout << str_util::Join(nodes, ","); @@ -91,12 +84,9 @@ Status Main(const MainFlags& flags) { return errors::InvalidArgument("Must specify --graph"); } GraphDef graph_def; - TF_RETURN_IF_ERROR(ReadProtoFile("graph", flags.graph, &graph_def)); - std::unique_ptr<Graph> graph; - TF_RETURN_IF_ERROR(InitGraph(graph_def, config, flags, &graph)); - + TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def)); CompileResult compile_result; - TF_RETURN_IF_ERROR(CompileGraph(std::move(graph), flags, &compile_result)); + TF_RETURN_IF_ERROR(CompileGraph(graph_def, config, flags, &compile_result)); // Write output files. Env* env = Env::Default(); diff --git a/tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt b/tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt index c46e65f71a..3025fc27b1 100644 --- a/tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt +++ b/tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt @@ -1,4 +1,4 @@ -# Text form of tensorflow.tfcompile.Config proto. +# Text form of tensorflow.tf2xla.Config proto. feed{ id{node_name:"inputs/x_seq_0/read"} shape{dim{size:128}dim{size:1024}} } feed{ id{node_name:"inputs/x_seq_1/read"} shape{dim{size:128}dim{size:1024}} } feed{ id{node_name:"inputs/x_seq_2/read"} shape{dim{size:128}dim{size:1024}} } diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 13fc233054..22f2441a68 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -21,6 +21,40 @@ package( ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") +load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") + +xla_proto_library( + name = "tf2xla_proto", + srcs = ["tf2xla.proto"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:protos_all_cc", + ], +) + +cc_library( + name = "tf2xla", + srcs = ["tf2xla.cc"], + hdrs = ["tf2xla.h"], + visibility = ["//visibility:public"], + deps = [ + ":common", + ":dump_graph", + ":tf2xla_proto", + ":tf2xla_util", + ":xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) cc_library( name = "xla_compiler", @@ -96,6 +130,51 @@ cc_library( # Internal targets below this point. +cc_library( + name = "tf2xla_util", + srcs = ["tf2xla_util.cc"], + hdrs = ["tf2xla_util.h"], + deps = [ + ":tf2xla_proto", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "tf2xla_util_test", + srcs = ["tf2xla_util_test.cc"], + deps = [ + ":tf2xla_util", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_test( + name = "tf2xla_test", + srcs = ["tf2xla_test.cc"], + deps = [ + ":tf2xla", + ":tf2xla_proto", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_test( name = "xla_compiler_test", srcs = ["xla_compiler_test.cc"], diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc new file mode 100644 index 0000000000..b29c92190d --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -0,0 +1,370 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/tf2xla.h" + +#include <map> +#include <memory> +#include <string> +#include <unordered_map> +#include <utility> +#include <vector> + +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +const char* const kArgOp = "_Arg"; +const char* const kRetvalOp = "_Retval"; +const char* const kFeedIdAttr = "_feed_id"; +const char* const kFetchIdAttr = "_fetch_id"; +const char* const kShapeAttr = "_shape"; +const char* const kDebugNameAttr = "_debug_name"; + +namespace { + +typedef std::unordered_map<string, Node*> NodeMap; + +// Each feed id identifies the positional output of some node, which may consist +// of multiple edges. AddPlaceholdersForFeeds has already replaced each fed +// tensor with a placeholder. For each feed tensor, replaces all edges so they +// point from a new _Arg node instead. +Status AddArgNodes(Graph* graph, const NodeMap& node_map, + const protobuf::RepeatedPtrField<tf2xla::Feed>& feeds, + const std::unordered_map<string, string>& feed_remapping) { + for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) { + const tf2xla::Feed& feed = feeds[arg_index]; + // All feeds have been replaced by placeholders. + const int output_index = 0; + + const string key = TensorIdToString(feed.id()); + const auto remap_it = feed_remapping.find(key); + auto node_it = node_map.find(remap_it->second); + if (node_it == node_map.end()) { + // Strip off the aot_feed_#/ prefix. + StringPiece name(remap_it->second); + const auto index = name.find('/'); + if (index > 0) name.remove_prefix(index + 1); + return errors::InvalidArgument( + "Node is fed but not needed for fetching: ", name); + } + const Node* feed_node = node_it->second; + + // TODO(toddw): Invoke shape inference in AddPlaceholdersForFeeds and add a + // "_shape" attr if we can determine it. That way the graph will be + // initialized with whatever shapes we can infer, while the user can still + // explicitly specify or override them. + Node* arg_node = nullptr; + TF_RETURN_IF_ERROR( + NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp) + .Attr("T", BaseType(feed_node->output_type(output_index))) + .Attr("index", arg_index) + .Attr(kFeedIdAttr, TensorIdToString(feed.id())) + .Attr(kShapeAttr, TensorShape(feed.shape())) + .Attr(kDebugNameAttr, feed.name()) + .Finalize(graph, &arg_node)); + + // Collects out-edges from the feed node that have a matching edge index; + // these will be replaced with edges from the arg node instead. + // + // We must collect the edges first and process them in a second pass, since + // removing the edge from the graph invalidates feed_node->out_edges. + std::vector<const Edge*> feed_edges; + for (const Edge* edge : feed_node->out_edges()) { + if (edge->src_output() == output_index) { + feed_edges.push_back(edge); + } + } + for (const Edge* edge : feed_edges) { + graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input()); + graph->RemoveEdge(edge); + } + } + return Status::OK(); +} + +// Each fetch id identifies the positional output of some node. For each fetch +// node, adds a new _Retval node instead, and adds the node to `retval_nodes`. +Status AddRetvalNodes(Graph* graph, const NodeMap& node_map, + const protobuf::RepeatedPtrField<tf2xla::Fetch>& fetches, + std::unordered_set<const Node*>* retval_nodes) { + for (int ret_index = 0; ret_index < fetches.size(); ++ret_index) { + const tf2xla::TensorId& id = fetches[ret_index].id(); + auto it = node_map.find(id.node_name()); + if (it == node_map.end()) { + return errors::NotFound("Can't find fetch id: ", TensorIdToString(id)); + } + Node* fetch_node = it->second; + if (id.output_index() >= fetch_node->num_outputs()) { + return errors::InvalidArgument("Invalid fetch id: ", TensorIdToString(id), + ", output index should be < ", + fetch_node->num_outputs()); + } + // Connects fetch_node -> retval_node. + Node* retval_node = nullptr; + TF_RETURN_IF_ERROR( + NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp) + .Input(fetch_node, id.output_index()) + .Attr("T", BaseType(fetch_node->output_type(id.output_index()))) + .Attr("index", ret_index) + .Attr(kFetchIdAttr, TensorIdToString(id)) + .Finalize(graph, &retval_node)); + retval_nodes->insert(retval_node); + } + return Status::OK(); +} + +// RewriteAndPruneGraph identifies input and output edges (named by the feed and +// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg +// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph +// execution to know the input and output args for the generated function. +Status RewriteAndPruneGraph( + Graph* graph, const tf2xla::Config& config, + const std::unordered_map<string, string>& feed_remapping) { + NodeMap node_map; + for (Node* n : graph->nodes()) { + node_map[n->name()] = n; + } + TF_RETURN_IF_ERROR( + AddArgNodes(graph, node_map, config.feed(), feed_remapping)); + std::unordered_set<const Node*> retval_nodes; + TF_RETURN_IF_ERROR( + AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes)); + VLOG(2) << "Post rewrite: " + << dump_graph::DumpGraphToFile("tf2xla_post_rewrite", *graph); + PruneForReverseReachability(graph, retval_nodes); + FixupSourceAndSinkEdges(graph); + VLOG(2) << "Post prune: " + << dump_graph::DumpGraphToFile("tfcompile_post_prune", *graph); + // Sanity-check, to make sure the feeds and fetches still exist post-pruning. + std::set<string> missing_feeds, missing_fetches; + for (const tf2xla::Feed& feed : config.feed()) { + missing_feeds.insert(TensorIdToString(feed.id())); + } + for (const tf2xla::Fetch& fetch : config.fetch()) { + missing_fetches.insert(TensorIdToString(fetch.id())); + } + for (const Node* n : graph->op_nodes()) { + if (n->type_string() == kArgOp) { + string feed_id; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFeedIdAttr, &feed_id)); + if (missing_feeds.erase(feed_id) == 0) { + return errors::Aborted(kArgOp, + " node found with unknown feed id: ", feed_id); + } + } else if (n->type_string() == kRetvalOp) { + string fetch_id; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), kFetchIdAttr, &fetch_id)); + if (missing_fetches.erase(fetch_id) == 0) { + return errors::Aborted(kRetvalOp, + " node found with unknown fetch id: ", fetch_id); + } + } + } + if (!missing_feeds.empty() || !missing_fetches.empty()) { + return errors::Aborted( + "Post graph-pruning", + ", missing feeds: ", str_util::Join(missing_feeds, ", "), + ", missing fetches: ", str_util::Join(missing_fetches, ", ")); + } + return Status::OK(); +} + +// CollectArgNodes collects _Arg nodes from the graph, and performs basic +// sanity-checking to ensure the index and type attributes of each node are +// initialized correctly. +Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) { + std::map<int, Node*> indexed_arg_nodes; + for (Node* n : graph.nodes()) { + if (n->type_string() == kArgOp) { + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + auto insert_result = indexed_arg_nodes.insert({index, n}); + if (!insert_result.second) { + const Node* dup = insert_result.first->second; + return errors::InvalidArgument( + "Multiple ", kArgOp, " nodes with index ", index, ", ", + n->DebugString(), " and ", dup->DebugString()); + } + } + } + arg_nodes->clear(); + for (const auto& index_node : indexed_arg_nodes) { + if (index_node.first != arg_nodes->size()) { + return errors::InvalidArgument("Expected ", kArgOp, " node with index ", + arg_nodes->size(), ", but got index ", + index_node.first); + } + arg_nodes->push_back(index_node.second); + } + return Status::OK(); +} + +// Fills in xla_args from the corresponding _Arg nodes in the graph. +Status CreateXlaArgs(const Graph& graph, + std::vector<XlaCompiler::Argument>* xla_args) { + std::vector<Node*> arg_nodes; + TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes)); + for (const Node* node : arg_nodes) { + XlaCompiler::Argument arg; + arg.kind = XlaCompiler::Argument::kParameter; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type)); + TensorShape shape; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape)); + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name)); + xla_args->push_back(arg); + } + return Status::OK(); +} + +// Converts the TensorFlow graph into an XLA computation, by executing the +// graph symbolically, with each op building up the XLA HLO. +Status ConvertGraphToXla(std::unique_ptr<Graph> graph, xla::Client* client, + xla::Computation* computation, + bool* requires_runtime_context) { + // Create a device and context to convert the graph into an XLA computation. + XlaOpRegistry::RegisterCompilationKernels(); + // Populate the context with args from the graph. + for (Node* node : graph->nodes()) { + node->set_assigned_device_name(DEVICE_CPU_XLA_JIT); + } + std::vector<XlaCompiler::Argument> xla_args; + TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args)); + + // Compile the graph into an XLA computation. + XlaCompiler::Options compiler_options; + compiler_options.client = client; + DeviceType device_type(DEVICE_CPU_XLA_JIT); + compiler_options.device_type = &device_type; + compiler_options.flib_def = &graph->flib_def(); + compiler_options.graph_def_version = graph->versions().producer(); + compiler_options.allow_cpu_custom_calls = true; + XlaCompiler compiler(compiler_options); + + XlaCompiler::CompilationResult result; + TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(), + "tfcompile", std::move(graph), + xla_args, &result)); + *requires_runtime_context = result.requires_runtime_context; + *computation = std::move(*result.computation); + + int num_const_results = 0; + for (int i = 0; i < result.outputs.size(); ++i) { + // Ending up with const results (i.e. output args) is an error, since it + // means that one or more fetches that the user specified will be dropped + // from the generated function. It's most likely a configuration error, + // since the user shouldn't be asking for output args that end up as consts. + // + // TODO(toddw): Provide a way for the user to access const output args, + // e.g. perhaps hard-coded into the header, or somehow copied into the + // output buffers. + if (result.outputs[i].is_constant) { + ++num_const_results; + LOG(ERROR) << "ConstRetVal index:" << i + << " value:" << result.outputs[i].constant_value.DebugString(); + } + } + if (num_const_results > 0) { + return errors::Unimplemented( + "Conversion from TensorFlow graph to XLA resulted in ", + num_const_results, + " constant results. The configuration of " + "the output args (i.e. fetch ids) is probably wrong."); + } + if (computation->IsNull()) { + return errors::Aborted( + "Conversion from TensorFlow graph to XLA resulted in an empty " + "computation."); + } + return Status::OK(); +} + +// InitGraph creates a graph based on the graph_def, that may then be converted +// to an xla::Computation via ConvertGraphToXla. +// +// The graph is rewritten with _Arg and _Retval nodes, representing the inputs +// and outputs of the function that will be compiled. Each feed id causes a new +// _Arg node to be created, where we first collect all existing edges pointing +// from the named node's output index, and then rewrite them to point from that +// _Arg node instead. Each fetch id causes a new _Retval node to be created, +// with a new edge pointing from the named node's output index to that _Retval +// node. +Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, + std::unique_ptr<Graph>* graph) { + TF_RETURN_IF_ERROR(ValidateConfig(config)); + + FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library()); + std::unique_ptr<Graph> g(new Graph(flib_def)); + + // Replace references to fed tensors with references to newly added + // placeholders. + GraphDef first_copy_def = graph_def; + + // Maps from name:port of a feed to the name:port of the placeholder to use. + std::unordered_map<string, string> feed_remapping; + TF_RETURN_IF_ERROR(AddPlaceholdersForFeeds(config, g->op_registry(), + &feed_remapping, &first_copy_def)); + + // Prune the GraphDef first so that unknown ops that we aren't compiling get + // filtered out. + GraphDef second_copy_def; + TF_RETURN_IF_ERROR( + PruneGraphDefInto(config, first_copy_def, &second_copy_def)); + + TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef( + &second_copy_def, *g->op_registry(), /*node_offset=*/0)); + + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(), + second_copy_def, g.get())); + TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping)); + *graph = std::move(g); + return Status::OK(); +} + +} // namespace + +Status ConvertGraphDefToXla(const GraphDef& graph_def, + const tf2xla::Config& config, xla::Client* client, + xla::Computation* computation, + bool* requires_runtime_context) { + std::unique_ptr<Graph> graph; + TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph)); + TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation, + requires_runtime_context)); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla.h b/tensorflow/compiler/tf2xla/tf2xla.h new file mode 100644 index 0000000000..ab99beebf7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla.h @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_ + +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/core/framework/graph.pb.h" + +namespace tensorflow { + +// Converts a tensorflow::GraphDef into an xla::Computation. The given `config` +// specifies the portion of the graph to convert, via feeds and fetches. Each +// feed is a positional input argument for the generated computation, while each +// fetch is a positional output argument. +// +// The computation is built in the context of the given `client`, which may +// subsequently be used to compile or execute the computation. +// +// If `requires_runtime_context` is filled with true, this indicates the last +// argument of the computation is XlaLocalRuntimeContext*. +Status ConvertGraphDefToXla(const GraphDef& graph_def, + const tf2xla::Config& config, xla::Client* client, + xla::Computation* computation, + bool* requires_runtime_context); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_ diff --git a/tensorflow/compiler/aot/tfcompile.proto b/tensorflow/compiler/tf2xla/tf2xla.proto index cd83840d89..18c9089f5f 100644 --- a/tensorflow/compiler/aot/tfcompile.proto +++ b/tensorflow/compiler/tf2xla/tf2xla.proto @@ -1,10 +1,10 @@ syntax = "proto3"; -package tensorflow.tfcompile; +package tensorflow.tf2xla; option cc_enable_arenas = true; -option java_outer_classname = "CompileProtos"; +option java_outer_classname = "Tf2XlaProtos"; option java_multiple_files = true; -option java_package = "org.tensorflow.tfcompile"; +option java_package = "org.tensorflow.tf2xla"; import "tensorflow/core/framework/tensor_shape.proto"; import "tensorflow/core/framework/types.proto"; @@ -19,32 +19,32 @@ message TensorId { }; // Feed represents a single feed tensor in the graph, which corresponds to an -// input argument for the generated function. +// input argument for the generated computation. message Feed { TensorId id = 1; TensorShapeProto shape = 2; string name = 3; // Optional name for generated code. // Optional data type. This is not normally required, as the graph itself - // contains this information. However, if the node being fed is an op that - // is not linked into the tfcompile binary, then the type cannot be inferred - // from the node; in this case, the type should be set here. + // contains this information. However, if the node being fed is an op that is + // not linked into the binary, then the type cannot be inferred from the node; + // in this case, the type should be set here. DataType type = 4; }; // Fetch represents a single fetch tensor in the graph, which corresponds to an -// output argument for the generated function. +// output argument for the generated computation. message Fetch { TensorId id = 1; string name = 2; // Optional name for generated code. }; -// Config represents configuration information for tfcompile. +// Config represents configuration information for tf2xla conversion. message Config { - // Each feed is a positional input argument for the generated function. The - // order of each entry matches the order of each input argument. + // Each feed is a positional input argument for the generated computation. + // The order of each entry matches the order of each input argument. repeated Feed feed = 1; - // Each fetch is a positional output argument for the generated function. The - // order of each entry matches the order of each output argument. + // Each fetch is a positional output argument for the generated computation. + // The order of each entry matches the order of each output argument. repeated Fetch fetch = 2; }; diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc new file mode 100644 index 0000000000..57b53cc660 --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -0,0 +1,99 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/tf2xla.h" + +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +AttrValue TypeAttrValue(DataType type) { + AttrValue attr_value; + SetAttrValue(type, &attr_value); + return attr_value; +} + +GraphDef SumGraph() { + GraphDef graph_def; + NodeDef* x = graph_def.add_node(); + x->set_name("x"); + x->set_op("Placeholder"); + (*x->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32); + NodeDef* y = graph_def.add_node(); + y->set_name("y"); + y->set_op("Placeholder"); + (*y->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32); + NodeDef* sum = graph_def.add_node(); + sum->set_name("sum"); + sum->set_op("Add"); + sum->add_input("x"); + sum->add_input("y"); + (*sum->mutable_attr())["T"] = TypeAttrValue(DT_INT32); + return graph_def; +} + +tf2xla::Config SumConfig() { + tf2xla::Config config; + config.add_feed()->mutable_id()->set_node_name("x"); + config.add_feed()->mutable_id()->set_node_name("y"); + config.add_fetch()->mutable_id()->set_node_name("sum"); + return config; +} + +TEST(ConvertGraphDefToXla, Sum) { + GraphDef graph_def = SumGraph(); + tf2xla::Config config = SumConfig(); + + xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); + xla::Computation computation; + bool requires_runtime_context; + TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation, + &requires_runtime_context)); + ASSERT_FALSE(requires_runtime_context); + + // Set up arguments. + auto x_literal = xla::Literal::CreateR0<int32>(10); + auto y_literal = xla::Literal::CreateR0<int32>(32); + auto x_global_or = client->TransferToServer(*x_literal); + auto y_global_or = client->TransferToServer(*y_literal); + TF_EXPECT_OK(x_global_or.status()); + TF_EXPECT_OK(y_global_or.status()); + std::unique_ptr<xla::GlobalData> x_global = + std::move(x_global_or.ValueOrDie()); + std::unique_ptr<xla::GlobalData> y_global = + std::move(y_global_or.ValueOrDie()); + + // Execute and check result. + auto result_or = + client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()}); + TF_EXPECT_OK(result_or.status()); + std::unique_ptr<xla::Literal> result = std::move(result_or.ValueOrDie()); + EXPECT_EQ("42", result->ToString()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/tfcompile_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 629187d621..14e0910cab 100644 --- a/tensorflow/compiler/aot/tfcompile_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/aot/tfcompile_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include <queue> #include <set> #include <unordered_map> -#include "tensorflow/compiler/aot/tfcompile.pb.h" +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -29,21 +29,13 @@ limitations under the License. #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { -namespace tfcompile { namespace { -bool IsAlpha(char c) { - return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z'); -} - -bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); } - -Status ValidateTensorId(const TensorId& id) { +Status ValidateTensorId(const tf2xla::TensorId& id) { if (id.node_name().empty()) { return errors::InvalidArgument("TensorId node_name must be non-empty"); } @@ -53,10 +45,9 @@ Status ValidateTensorId(const TensorId& id) { return Status::OK(); } -Status ValidateFeedFetchName(const string& kind, const string& name, - std::set<string>* names) { +Status CheckNameDuplicates(const string& kind, const string& name, + std::set<string>* names) { if (!name.empty()) { - TF_RETURN_IF_ERROR(ValidateCppIdent(name, kind + " name")); if (!names->insert(name).second) { return errors::InvalidArgument("duplicate ", kind, " name: ", name); } @@ -80,42 +71,18 @@ Status CheckFeedFetchNameConflicts(const string& kind, } // namespace -Status ValidateCppIdent(StringPiece ident, StringPiece msg) { - if (ident.empty()) { - return errors::InvalidArgument("empty identifier: ", msg); - } - // Require that the identifier starts with a nondigit, and is composed of - // nondigits and digits, as specified in section [2.11 Identifiers] of the - // C++11 Standard. Note that nondigit is defined as [_a-zA-Z] and digit is - // defined as [0-9]. - // - // Technically the standard also allows for `universal-character-name`, with a - // table of allowed unicode ranges, as well as `other implementation-defined - // characters`. We disallow those here to give better error messages, at the - // expensive of being more restrictive than the standard. - if (ident[0] != '_' && !IsAlpha(ident[0])) { - return errors::InvalidArgument("illegal leading char: ", msg); - } - for (size_t pos = 1; pos < ident.size(); ++pos) { - if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) { - return errors::InvalidArgument("illegal char: ", msg); - } - } - return Status::OK(); -} - -Status ValidateConfig(const Config& config) { +Status ValidateConfig(const tf2xla::Config& config) { std::set<string> names; - for (const Feed& feed : config.feed()) { + for (const tf2xla::Feed& feed : config.feed()) { TF_RETURN_IF_ERROR(ValidateTensorId(feed.id())); TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape())); - TF_RETURN_IF_ERROR(ValidateFeedFetchName("feed", feed.name(), &names)); + TF_RETURN_IF_ERROR(CheckNameDuplicates("feed", feed.name(), &names)); } TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names)); names.clear(); - for (const Fetch& fetch : config.fetch()) { + for (const tf2xla::Fetch& fetch : config.fetch()) { TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id())); - TF_RETURN_IF_ERROR(ValidateFeedFetchName("fetch", fetch.name(), &names)); + TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names)); } TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names)); if (config.feed().empty() || config.fetch().empty()) { @@ -125,10 +92,10 @@ Status ValidateConfig(const Config& config) { } Status AddPlaceholdersForFeeds( - const Config& config, const OpRegistryInterface* op_registry, + const tf2xla::Config& config, const OpRegistryInterface* op_registry, std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) { struct PlaceholderInfo { - const Feed* feed = nullptr; // point to Feed in <config>. + const tf2xla::Feed* feed = nullptr; // point to Feed in <config>. string placeholder_name; DataType data_type = DT_INVALID; }; @@ -137,9 +104,9 @@ Status AddPlaceholdersForFeeds( // when creating placeholders (genrules want deterministic output). std::map<string, PlaceholderInfo> placeholder_info; for (int i = 0; i < config.feed_size(); ++i) { - const Feed* feed = &config.feed(i); + const tf2xla::Feed* feed = &config.feed(i); const string name_port = TensorIdToString(feed->id()); - auto& info = placeholder_info[name_port]; + PlaceholderInfo& info = placeholder_info[name_port]; info.feed = feed; info.placeholder_name = strings::StrCat( "aot_feed_", feed->id().output_index(), "/", feed->id().node_name()); @@ -153,7 +120,7 @@ Status AddPlaceholdersForFeeds( } for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) { PlaceholderInfo& info = it->second; - const TensorId& feed_id = info.feed->id(); + const tf2xla::TensorId& feed_id = info.feed->id(); // Find the existing node and determine data type. auto node_it = name_to_node.find(feed_id.node_name()); @@ -214,16 +181,16 @@ Status AddPlaceholdersForFeeds( return Status::OK(); } -Status PruneGraphDefInto(const Config& config, const GraphDef& in, +Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, GraphDef* out) { *out = in; out->clear_node(); // Tensors needed for feeding. std::set<std::pair<string, int>> feed_tensors; - for (const auto& feed_config : config.feed()) { - feed_tensors.insert(std::make_pair(feed_config.id().node_name(), - feed_config.id().output_index())); + for (const tf2xla::Feed& feed : config.feed()) { + feed_tensors.insert( + std::make_pair(feed.id().node_name(), feed.id().output_index())); } // Maps node name to reachability. @@ -279,9 +246,8 @@ Status PruneGraphDefInto(const Config& config, const GraphDef& in, return Status::OK(); } -string TensorIdToString(const TensorId& id) { +string TensorIdToString(const tf2xla::TensorId& id) { return strings::StrCat(id.node_name(), ":", id.output_index()); } -} // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/tfcompile_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index 365f7b0e7b..a29d0c16f9 100644 --- a/tensorflow/compiler/aot/tfcompile_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -13,26 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_ -#define TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ #include <unordered_map> -#include "tensorflow/compiler/aot/tfcompile.pb.h" +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace tensorflow { -namespace tfcompile { - -// ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is -// appended to error messages. -Status ValidateCppIdent(StringPiece ident, StringPiece msg); // ValidateConfig returns OK iff config is valid. -Status ValidateConfig(const Config& config); +Status ValidateConfig(const tf2xla::Config& config); // Modifies <graph_def> to include placeholders for each fed tensor, and // update references to the fed tensors to refer to the placeholders. @@ -40,18 +34,17 @@ Status ValidateConfig(const Config& config); // (except where their input edges are modified by the replacement of other // feeds). Status AddPlaceholdersForFeeds( - const Config& config, const OpRegistryInterface* op_registry, + const tf2xla::Config& config, const OpRegistryInterface* op_registry, std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def); // Returns in <out> a copy of <in>, pruned to only include fetches from // <config>. -Status PruneGraphDefInto(const Config& config, const GraphDef& in, +Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, GraphDef* out); // Returns node:port for the given <id>. -string TensorIdToString(const TensorId& id); +string TensorIdToString(const tf2xla::TensorId& id); -} // namespace tfcompile } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_ +#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/aot/tfcompile_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index 5a92851ceb..b98c89f284 100644 --- a/tensorflow/compiler/aot/tfcompile_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/aot/tfcompile_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace tensorflow { -namespace tfcompile { namespace { void ExpectErrorContains(const Status& status, StringPiece str) { @@ -32,45 +31,16 @@ void ExpectErrorContains(const Status& status, StringPiece str) { << "expected error: " << status.error_message() << " to contain: " << str; } -TEST(ValidateCppIdent, Simple) { - TF_EXPECT_OK(ValidateCppIdent("a", "")); - TF_EXPECT_OK(ValidateCppIdent("abc", "")); - TF_EXPECT_OK(ValidateCppIdent("_abc", "")); - TF_EXPECT_OK(ValidateCppIdent("_abc123", "")); - // Make sure we didn't skip a valid letter or digit - string ident; - for (char c = 'a'; c <= 'z'; c++) { - ident.append(1, c); - } - for (char c = 'A'; c <= 'Z'; c++) { - ident.append(1, c); - } - for (char c = '0'; c <= '9'; c++) { - ident.append(1, c); - } - ident += "_"; - TF_EXPECT_OK(ValidateCppIdent(ident, "")); - - ExpectErrorContains(ValidateCppIdent("", ""), "empty identifier"); - ExpectErrorContains(ValidateCppIdent(" ", ""), "illegal leading char"); - ExpectErrorContains(ValidateCppIdent("0", ""), "illegal leading char"); - ExpectErrorContains(ValidateCppIdent(".", ""), "illegal leading char"); - ExpectErrorContains(ValidateCppIdent(":", ""), "illegal leading char"); - ExpectErrorContains(ValidateCppIdent("a.", ""), "illegal char"); - ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char"); - ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char"); -} - TEST(ValidateConfig, Good) { - Config config; - Feed* feed = config.add_feed(); + tf2xla::Config config; + tf2xla::Feed* feed = config.add_feed(); feed->mutable_id()->set_node_name("foo"); feed->mutable_id()->set_output_index(123); feed->set_name("foo_debug"); feed = config.add_feed(); feed->mutable_id()->set_node_name("bar"); feed->mutable_id()->set_output_index(0); - Fetch* fetch = config.add_fetch(); + tf2xla::Fetch* fetch = config.add_fetch(); fetch->mutable_id()->set_node_name("baz"); fetch->mutable_id()->set_output_index(456); fetch->set_name("baz_debug"); @@ -81,62 +51,62 @@ TEST(ValidateConfig, Good) { } TEST(ValidateConfig, BadEmpty) { - Config config; + tf2xla::Config config; ExpectErrorContains(ValidateConfig(config), "feeds and fetches must be specified"); } TEST(ValidateConfig, BadNoFeed) { - Config config; - Fetch* fetch = config.add_fetch(); + tf2xla::Config config; + tf2xla::Fetch* fetch = config.add_fetch(); fetch->mutable_id()->set_node_name("foo"); ExpectErrorContains(ValidateConfig(config), "feeds and fetches must be specified"); } TEST(ValidateConfig, BadNoFetch) { - Config config; - Feed* feed = config.add_feed(); + tf2xla::Config config; + tf2xla::Feed* feed = config.add_feed(); feed->mutable_id()->set_node_name("foo"); ExpectErrorContains(ValidateConfig(config), "feeds and fetches must be specified"); } TEST(ValidateConfig, BadFeedNodeName) { - Config config; + tf2xla::Config config; config.add_feed(); ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty"); } TEST(ValidateConfig, BadFeedOutputIndex) { - Config config; - Feed* feed = config.add_feed(); + tf2xla::Config config; + tf2xla::Feed* feed = config.add_feed(); feed->mutable_id()->set_node_name("foo"); feed->mutable_id()->set_output_index(-1); ExpectErrorContains(ValidateConfig(config), "output_index must be positive"); } TEST(ValidateConfig, BadFetchNodeName) { - Config config; - Feed* feed = config.add_feed(); + tf2xla::Config config; + tf2xla::Feed* feed = config.add_feed(); feed->mutable_id()->set_node_name("foo"); config.add_fetch(); ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty"); } TEST(ValidateConfig, BadFetchOutputIndex) { - Config config; - Feed* feed = config.add_feed(); + tf2xla::Config config; + tf2xla::Feed* feed = config.add_feed(); feed->mutable_id()->set_node_name("foo"); - Fetch* fetch = config.add_fetch(); + tf2xla::Fetch* fetch = config.add_fetch(); fetch->mutable_id()->set_node_name("bar"); fetch->mutable_id()->set_output_index(-1); ExpectErrorContains(ValidateConfig(config), "output_index must be positive"); } TEST(ValidateConfig, DuplicateFeedName) { - Config config; - Feed* feed = config.add_feed(); + tf2xla::Config config; + tf2xla::Feed* feed = config.add_feed(); feed->mutable_id()->set_node_name("foo"); feed->set_name("dup"); feed = config.add_feed(); @@ -146,10 +116,10 @@ TEST(ValidateConfig, DuplicateFeedName) { } TEST(ValidateConfig, DuplicateFetchName) { - Config config; - Feed* feed = config.add_feed(); + tf2xla::Config config; + tf2xla::Feed* feed = config.add_feed(); feed->mutable_id()->set_node_name("foo"); - Fetch* fetch = config.add_fetch(); + tf2xla::Fetch* fetch = config.add_fetch(); fetch->mutable_id()->set_node_name("bar"); fetch->set_name("dup"); fetch = config.add_fetch(); @@ -159,8 +129,8 @@ TEST(ValidateConfig, DuplicateFetchName) { } TEST(ValidateConfig, ConflictingFeedName) { - Config config; - Feed* feed = config.add_feed(); + tf2xla::Config config; + tf2xla::Feed* feed = config.add_feed(); feed->mutable_id()->set_node_name("foo"); feed->set_name("conflict"); feed = config.add_feed(); @@ -170,10 +140,10 @@ TEST(ValidateConfig, ConflictingFeedName) { } TEST(ValidateConfig, ConflictingFetchName) { - Config config; - Feed* feed = config.add_feed(); + tf2xla::Config config; + tf2xla::Feed* feed = config.add_feed(); feed->mutable_id()->set_node_name("foo"); - Fetch* fetch = config.add_fetch(); + tf2xla::Fetch* fetch = config.add_fetch(); fetch->mutable_id()->set_node_name("bar"); fetch->set_name("conflict"); fetch = config.add_fetch(); @@ -182,8 +152,8 @@ TEST(ValidateConfig, ConflictingFetchName) { ExpectErrorContains(ValidateConfig(config), "conflicting fetch name"); } -static Config FetchesConfig(std::vector<string> fetches) { - Config config; +static tf2xla::Config FetchesConfig(std::vector<string> fetches) { + tf2xla::Config config; for (const auto& fetch_node_name : fetches) { auto* fetch = config.add_fetch(); fetch->set_name(strings::StrCat("fetch_", fetch_node_name)); @@ -242,5 +212,4 @@ TEST(PruneGraphDefInto, Basic) { } } // namespace -} // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 30afaed732..e41a391ac5 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -1703,7 +1703,6 @@ StatusOr<Computation> ComputationBuilder::Build() { } void ComputationBuilder::AddOpMetadata(OpRequest* request) const { - tensorflow::mutex_lock lock(mutex_); *request->mutable_metadata() = metadata_; } diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index cf1f3b074e..96db56bc53 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -37,7 +37,6 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stacktrace.h" #include "tensorflow/core/platform/types.h" @@ -57,10 +56,10 @@ class ComputationBuilder { ~ComputationBuilder(); // Returns the client the builder was initialized with. - Client* client() { return client_; } + Client* client() const { return client_; } // Returns the computation name. - const string& name() { return name_; } + const string& name() const { return name_; } // Sets OpMetadata that will be added to all instructions until cleared. // @@ -69,13 +68,11 @@ class ComputationBuilder { // instructions generated via this Computation Builder will have the same // OpMetadata attached until a call to ClearOpMetdata. void SetOpMetadata(const OpMetadata& metadata) { - tensorflow::mutex_lock lock(mutex_); metadata_ = metadata; } // Clears the HloMetdata state. void ClearOpMetadata() { - tensorflow::mutex_lock lock(mutex_); metadata_.Clear(); } @@ -826,15 +823,12 @@ class ComputationBuilder { Client* client_; // Mode bit that indicates whether to die when a first error is encountered. - bool die_immediately_on_error_{false}; - - // Mutex to guard against concurrent access to metadata_. - mutable tensorflow::mutex mutex_; + bool die_immediately_on_error_ = false; // The metadata to attach to each op. This is structured as a "modal"-like // operation, in order to simplify client code (and not sprinkle this metadata // throughout the TensorFlow op kernel implementations). - OpMetadata metadata_ GUARDED_BY(mutex_); + OpMetadata metadata_; TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder); }; diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 95f8165795..1a18b28cbb 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -180,15 +180,18 @@ cc_library( cc_library( name = "ir_emitter", - srcs = ["ir_emitter.cc"], + srcs = [ + "elemental_ir_emitter.cc", + "ir_emitter.cc", + ], hdrs = [ + "elemental_ir_emitter.h", "ir_emitter.h", ], deps = [ ":cpu_options", ":cpu_runtime", ":dot_op_emitter", - ":elemental_ir_emitter", ":ir_emission_utils", ":simple_orc_jit", "//tensorflow/compiler/xla:shape_util", @@ -526,22 +529,6 @@ cc_library( ) cc_library( - name = "elemental_ir_emitter", - srcs = ["elemental_ir_emitter.cc"], - hdrs = ["elemental_ir_emitter.h"], - deps = [ - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:elemental_ir_emitter", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "@llvm//:core", - ], -) - -cc_library( name = "ir_emission_utils", srcs = ["ir_emission_utils.cc"], hdrs = ["ir_emission_utils.h"], diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index 511f89144a..902309b338 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -50,14 +50,6 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } - // Producer or consumer cannot be Map. Maps are technically elementwise but - // of a slightly different form (call instead of a computation). These are not - // yet supported in the CPU backend. - if (producer->opcode() == HloOpcode::kMap || - consumer->opcode() == HloOpcode::kMap) { - return false; - } - // Cost condition: not fuse (simple, expensive producers) and (consumers who // reuse operand elements). if (producer->opcode() != HloOpcode::kFusion && diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 0fc62281a0..b56466d5e4 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -209,6 +209,31 @@ class OpcodeFusionTest : public InstructionFusionTest { std::multiset<HloOpcode>(fused_opcodes.begin(), fused_opcodes.end()), expected_opcodes); } + + HloComputation* CreateAdderToOne(HloModule* module) { + HloComputation::Builder builder(TestName()); + HloInstruction* arg0 = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "arg0")); + HloInstruction* one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, arg0, one)); + return module->AddEmbeddedComputation(builder.Build()); + } + + HloComputation* CreateMax(HloModule* module) { + HloComputation::Builder builder(TestName()); + HloInstruction* arg0 = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "arg0")); + HloInstruction* arg1 = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "arg1")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, arg0, arg1)); + return module->AddEmbeddedComputation(builder.Build()); + } }; TEST_F(OpcodeFusionTest, Exponential_Bitcast_Negate) { @@ -402,6 +427,49 @@ TEST_F(OpcodeFusionTest, Exponential_Transpose_Negate) { HloOpcode::kParameter}); } +TEST_F(OpcodeFusionTest, UnaryMapOfExp) { + auto module = CreateNewModule(); + + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + + HloInstruction* exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0)); + builder.AddInstruction(HloInstruction::CreateMap( + shape, {exp}, CreateAdderToOne(module.get()), /*static_operands=*/{})); + + module->AddEntryComputation(builder.Build()); + + RunFusionAndCheckOpcodesWereFused( + module.get(), {HloOpcode::kParameter, HloOpcode::kExp, HloOpcode::kMap}); +} + +TEST_F(OpcodeFusionTest, BinaryMapOfExps) { + auto module = CreateNewModule(); + + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {3, 4}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param")); + + HloInstruction* exp0 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0)); + HloInstruction* exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param1)); + + builder.AddInstruction(HloInstruction::CreateMap( + shape, {exp0, exp1}, CreateMax(module.get()), /*static_operands=*/{})); + + module->AddEntryComputation(builder.Build()); + + RunFusionAndCheckOpcodesWereFused( + module.get(), {HloOpcode::kParameter, HloOpcode::kParameter, + HloOpcode::kExp, HloOpcode::kExp, HloOpcode::kMap}); +} } // namespace } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index fe447adf89..73e039250b 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -64,5 +64,25 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp( } } +llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator) const { + if (hlo->opcode() == HloOpcode::kMap) { + return [this, hlo, &operand_to_generator]( + const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> { + std::vector<llvm::Value*> operands; + for (int i = 0; i < hlo->operand_count(); i++) { + TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(i))( + ElementwiseSourceIndex(index, *hlo, 0))); + operands.push_back(operand_value); + } + return ir_emitter_->EmitScalarCall(hlo->shape().element_type(), + hlo->to_apply(), operands, + llvm_ir::IrName(hlo)); + }; + } + return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator); +} } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index 6f9d6a24b4..7e9f27befb 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -19,6 +19,7 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/statusor.h" @@ -29,12 +30,19 @@ namespace cpu { class CpuElementalIrEmitter : public ElementalIrEmitter { public: CpuElementalIrEmitter(const HloModuleConfig& module_config, - llvm::IRBuilder<>* ir_builder, llvm::Module* module) - : ElementalIrEmitter(module_config, module, ir_builder) {} + IrEmitter* ir_emitter, llvm::Module* module) + : ElementalIrEmitter(module_config, module, ir_emitter->ir_builder()), + ir_emitter_(ir_emitter) {} + + llvm_ir::ElementGenerator MakeElementGenerator( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator) const override; protected: StatusOr<llvm::Value*> EmitFloatUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const override; + + IrEmitter* ir_emitter_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index 94d4ce4a94..91b09f2472 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -136,6 +136,10 @@ DotInLlvmIrProfitable ProfitableToImplementDotInLlvmIr( const int64 kReductionDimensionThresholdBytes = 8 * 1024; const bool single_threaded_eigen = !dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen(); + + // This is the point at which it is better to call into Eigen and shard the + // dot across multiple worker threads. This is a rough estimate by running + // a matmult benchmark on my local machine, and it can be tuned further. const int64 kMaxSingleThreadedFlops = 16 * 1024; const int64 M = result_shape.dimensions(0); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index bc51ad2b36..8cd8740ee8 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -2354,8 +2354,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { for (HloInstruction* operand : fusion->operands()) { parameter_arrays.push_back(GetIrArrayForOp(operand)); } - CpuElementalIrEmitter elemental_emitter(hlo_module_config_, &ir_builder_, - module_); + CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); @@ -2737,14 +2736,10 @@ llvm::Value* IrEmitter::GetProfileCounterFor(const HloInstruction* hlo) { } prof_counter_idx = it->second; - uintptr_t hlo_address = reinterpret_cast<uintptr_t>(hlo); - counter_name = tensorflow::strings::StrCat( - "prof_counter_0x", - tensorflow::strings::Hex( - hlo_address, tensorflow::strings::PadSpec(sizeof(hlo_address)))); + counter_name = IrName("prof_counter", hlo->name()); } else { prof_counter_idx = hlo_to_profile_idx_->size(); - counter_name = "prof_counter_computation"; + counter_name = "prof_counter.computation"; } return ir_builder_.CreateGEP(GetProfileCountersArgument(), ir_builder_.getInt64(prof_counter_idx), @@ -3180,12 +3175,27 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { return GetIrArrayForOp(operand).EmitReadArrayElement(index, &ir_builder_); }; } - CpuElementalIrEmitter elemental_emitter(hlo_module_config_, &ir_builder_, - module_); + CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); return EmitTargetElementLoop( hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator)); } +StatusOr<llvm::Value*> IrEmitter::EmitScalarCall( + PrimitiveType return_type, HloComputation* computation, + const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name) { + llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation); + std::vector<llvm::Value*> argument_addrs; + for (auto argument : arguments) { + llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry( + argument->getType(), "arg_addr", &ir_builder_); + ir_builder_.CreateStore(argument, argument_addr); + argument_addrs.push_back(argument_addr); + } + return EmitElementFunctionCall(llvm_function, + ShapeUtil::MakeShape(return_type, {}), + argument_addrs, name); +} + unsigned TargetMachineFeatures::largest_register_size_in_bytes( llvm::Function* function) { auto itr = largest_register_size_in_bytes_.find(function); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index bcd33c3810..fa33a1eb7b 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -133,6 +133,13 @@ class IrEmitter : public DfsHloVisitorWithDefault { bool is_top_level_computation, std::vector<const HloInstruction*>* instruction_order); + llvm::IRBuilder<>* ir_builder() { return &ir_builder_; } + + // Emits a call to `computation` with scalar arguments `arguments`. + StatusOr<llvm::Value*> EmitScalarCall( + PrimitiveType return_type, HloComputation* computation, + const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name); + protected: // // The following methods implement the DfsHloVisitor interface. diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 5d650b872f..b24fe417ff 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -76,10 +76,11 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, // Since CUDA 9.0, all GPU versions are included in a single file const char* unified_libdevice_filename = "libdevice.10.bc"; std::vector<string> unified_libdevice_files; - tensorflow::Env::Default()->GetMatchingPaths( + const tensorflow::Status status = + tensorflow::Env::Default()->GetMatchingPaths( tensorflow::io::JoinPath(libdevice_dir_path, unified_libdevice_filename), &unified_libdevice_files); - if( unified_libdevice_files.size() == 1 ) { + if (status.ok() && unified_libdevice_files.size() == 1) { return unified_libdevice_filename; } // There are only four libdevice files: compute_{20,30,35,50}. Each GPU diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index efb5fca188..e0c23a3a08 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -77,7 +77,7 @@ class HloOrdering { // Precondition: 'a' and 'b' are in the same computation. // // Derived classes should implement this method for determining order of - // instructions in the same comptuation. ExecutesBefore() analyzes the + // instructions in the same computation. ExecutesBefore() analyzes the // callgraph and uses this method to determine ordering of instructions in // different computations. virtual bool ExecutesBeforeInSameComputation( diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 1865004911..a0f9be3dd8 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -50,7 +50,7 @@ class WhileTest : public ClientLibraryTestBase {}; // while (result < 5) { // result = result + 1; // } -TEST_F(WhileTest, WhileWithScalarResult) { +TEST_F(WhileTest, WhileWithScalarS32Result) { auto result_shape = ShapeUtil::MakeShape(S32, {}); // Create a computation for the condition: repeat for 5 iterations. @@ -81,6 +81,43 @@ TEST_F(WhileTest, WhileWithScalarResult) { ComputeAndCompareR0<int32>(&builder, 5, {}); } +// Tests a while node when the result type T is S64. +// +// int32 result = 0; +// while (result < 5) { +// result = result + 1; +// } +TEST_F(WhileTest, WhileWithScalarS64Result) { + auto result_shape = ShapeUtil::MakeShape(S64, {}); + + // Create a computation for the condition: repeat for 5 iterations. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + builder.Gt(builder.ConstantR0<int64>(5), prev); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body: add 1 to the result variable. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto input = builder.ConstantR0<int64>(1); + auto result = builder.Add(input, prev); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, TestName()); + auto init = builder.ConstantR0<int64>(0); + auto result = builder.While(condition, body, init); + auto shape = builder.GetShape(result).ConsumeValueOrDie(); + + ComputeAndCompareR0<int64>(&builder, 5, {}); +} + TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { auto result_shape = ShapeUtil::MakeShape(S32, {}); auto orig_shape = ShapeUtil::MakeShape(S32, {2}); |