diff options
Diffstat (limited to 'tensorflow/compiler/aot/tfcompile_main.cc')
-rw-r--r-- | tensorflow/compiler/aot/tfcompile_main.cc | 142 |
1 files changed, 142 insertions, 0 deletions
diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc new file mode 100644 index 0000000000..85ef9560bb --- /dev/null +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -0,0 +1,142 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "tensorflow/compiler/aot/codegen.h" +#include "tensorflow/compiler/aot/compile.h" +#include "tensorflow/compiler/aot/flags.h" +#include "tensorflow/compiler/aot/tfcompile.pb.h" +#include "tensorflow/compiler/aot/tfcompile_util.h" +#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace tfcompile { + +const char kUsageHeader[] = + "tfcompile performs ahead-of-time compilation of a TensorFlow graph,\n" + "resulting in an object file compiled for your target architecture, and a\n" + "header file that gives access to the functionality in the object file.\n" + "A typical invocation looks like this:\n" + "\n" + " $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt\n" + "\n"; + +Status ReadProtoFile(const string& kind, const string& fname, + protobuf::Message* proto) { + if (StringPiece(fname).ends_with(".pbtxt")) { + return ReadTextProto(Env::Default(), fname, proto); + } else { + return ReadBinaryProto(Env::Default(), fname, proto); + } +} + +void ParseTensorId(const string& name, TensorId* id) { + const std::pair<StringPiece, int> name_index = ParseTensorName(name); + id->set_node_name(name_index.first.ToString()); + id->set_output_index(name_index.second); +} + +Status Main(const MainFlags& flags) { + // Process config. + Config config; + TF_RETURN_IF_ERROR(ReadProtoFile("config", flags.config, &config)); + TF_RETURN_IF_ERROR(ValidateConfig(config)); + if (flags.dump_fetch_nodes) { + std::set<string> nodes; + for (const Fetch& fetch : config.fetch()) { + nodes.insert(fetch.id().node_name()); + } + std::cout << str_util::Join(nodes, ","); + return Status::OK(); + } + + // Read and initialize the graph. + GraphDef graph_def; + TF_RETURN_IF_ERROR(ReadProtoFile("graph", flags.graph, &graph_def)); + std::unique_ptr<Graph> graph; + FunctionLibraryDefinition flib(OpRegistry::Global(), graph_def.library()); + TF_RETURN_IF_ERROR(InitGraph(graph_def, config, flags, &flib, &graph)); + + CompileResult compile_result; + TF_RETURN_IF_ERROR( + CompileGraph(std::move(graph), flags, &flib, &compile_result)); + + // Write output files. + Env* env = Env::Default(); + const std::vector<char>& obj = compile_result.aot->object_file_data(); + TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_object, + StringPiece(obj.data(), obj.size()))); + HeaderOpts header_opts; + TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &header_opts.class_name, + &header_opts.namespaces)); + string header; + TF_RETURN_IF_ERROR( + GenerateHeader(header_opts, config, compile_result, &header)); + TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header)); + return Status::OK(); +} + +} // end namespace tfcompile +} // end namespace tensorflow + +int main(int argc, char** argv) { + tensorflow::tfcompile::MainFlags flags; + flags.target_triple = "x86_64-pc-linux"; + flags.out_object = "out.o"; + flags.out_header = "out.h"; + + std::vector<tensorflow::Flag> flag_list; + AppendMainFlags(&flag_list, &flags); + xla::legacy_flags::AppendCompilerFunctorFlags(&flag_list); + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list); + + tensorflow::string usage = tensorflow::tfcompile::kUsageHeader; + usage += tensorflow::Flags::Usage(argv[0], flag_list); + bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + QCHECK(parsed_flags_ok) << "\n" << usage; + + tensorflow::port::InitMain(usage.c_str(), &argc, &argv); + QCHECK(argc == 1 && !flags.config.empty() && + (flags.dump_fetch_nodes || + (!flags.graph.empty() && !flags.entry_point.empty()))) + << "\n" + << usage; + + TF_QCHECK_OK(tensorflow::tfcompile::Main(flags)); + return 0; +} |