# -*- Python -*- """Build macro that compiles a TensorFlow graph into a cc_library. To use from your BUILD file, add the following line to load the macro: load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") Then call the macro like this: tf_library( name = "test_graph_tfmatmul", config = "test_graph_tfmatmul.config.pbtxt", cpp_class = "MatMulComp", graph = ":test_graph_tfmatmul.pb", ) """ load( "//tensorflow:tensorflow.bzl", "if_android", "tf_cc_test", "tf_copts", ) def tf_library( name, graph, config, freeze_checkpoint = None, freeze_saver = None, cpp_class = None, gen_test = True, gen_benchmark = True, visibility = None, testonly = None, tfcompile_flags = None, tfcompile_tool = "//tensorflow/compiler/aot:tfcompile", include_standard_runtime_deps = True, enable_xla_hlo_profiling = False, deps = None, tags = None): """Runs tfcompile to compile a TensorFlow graph into executable code. Given an invocation of tf_library(name="foo", ...), generates the following build targets: foo: A cc_library containing the generated header and computation. foo_test: A cc_test with simple tests and benchmarks. Only created if gen_test=True. foo_benchmark: A cc_binary that runs a minimal-dependency benchmark, useful for mobile devices or other platforms that can't compile the full test libraries. Only created if gen_benchmark=True. The output header is called .h. Args: name: The name of the build rule. graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it is expected to be in the human-readable proto text format, otherwise it is expected to be in the proto binary format. config: File containing tensorflow.tf2xla.Config proto. If the file ends in '.pbtxt' it is expected to be in the human-readable proto text format, otherwise it is expected to be in the proto binary format. freeze_checkpoint: If provided, run freeze_graph with this checkpoint to convert variables into constants. freeze_saver: If provided, run freeze_graph with this saver, in SaverDef binary form, to convert variables into constants. cpp_class: The name of the generated C++ class, wrapping the generated function. The syntax of this flag is [[::],...]. This mirrors the C++ syntax for referring to a class, where multiple namespaces may precede the class name, separated by double-colons. The class will be generated in the given namespace(s), or if no namespaces are given, within the global namespace. gen_test: If True, also generate a cc_test rule that builds a simple test and benchmark. gen_benchmark: If True, also generate a binary with a simple benchmark. Unlike the output of gen_test, this benchmark can be run on android. visibility: Bazel build visibility. testonly: Bazel testonly attribute. tfcompile_flags: Extra flags to pass to tfcompile to control compilation. tfcompile_tool: The tfcompile binary. A non-default can be passed to use a tfcompile built with extra dependencies. include_standard_runtime_deps: If True, the standard list of kernel/runtime deps is added to deps. If False, deps must contain the full set of deps needed by the generated library. enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program, and emit metadata that lets us pretty-print the gathered profile counters. deps: a list of deps to include on the build rules for the generated library, added to the standard deps if standard_runtime_deps is True. tags: tags to apply to subsidiary build rules. """ if not cpp_class: fail("cpp_class must be specified") tfcompile_graph = graph if freeze_checkpoint or freeze_saver: if not freeze_checkpoint: fail("freeze_checkpoint must be specified when freeze_saver is " + "specified") freeze_name = "freeze_" + name freeze_file = freeze_name + ".pb" # First run tfcompile to generate the list of out_nodes. # # Here and below, we set CUDA_VISIBLE_DEVICES='' to prevent the code we # launch from using any GPUs which might be present. This is important # because builds may run concurrently with tests, and tests need to be # able to assume that they have control of the full GPU. out_nodes_file = "out_nodes_" + freeze_name native.genrule( name = ("gen_" + out_nodes_file), srcs = [config], outs = [out_nodes_file], cmd = ("CUDA_VISIBLE_DEVICES='' " + "$(location " + tfcompile_tool + ")" + " --config=$(location " + config + ")" + " --dump_fetch_nodes > $@"), tools = [tfcompile_tool], # Run tfcompile on the build host, rather than forge, since it's # typically way faster on the local machine. local = 1, tags = tags, ) # Now run freeze_graph to convert variables into constants. freeze_args = ( " --input_graph=$(location " + graph + ")" + " --checkpoint_version=1" + " --input_binary=" + str(not graph.endswith(".pbtxt")) + " --input_checkpoint=$(location " + freeze_checkpoint + ")" + " --output_graph=$(location " + freeze_file + ")" + " --output_node_names=$$(<$(location " + out_nodes_file + "))" ) freeze_saver_srcs = [] if freeze_saver: freeze_args += " --input_saver=$(location " + freeze_saver + ")" freeze_saver_srcs += [freeze_saver] native.genrule( name = freeze_name, srcs = [ graph, freeze_checkpoint, out_nodes_file, ] + freeze_saver_srcs, outs = [freeze_file], cmd = ( "CUDA_VISIBLE_DEVICES='' " + "$(location " + "//tensorflow/python/tools:freeze_graph)" + freeze_args ), tools = ["//tensorflow/python/tools:freeze_graph"], tags = tags, ) tfcompile_graph = freeze_file # Rule that runs tfcompile to produce the header and object file. header_file = name + ".h" metadata_object_file = name + "_tfcompile_metadata.o" function_object_file = name + "_tfcompile_function.o" ep = ("__" + native.package_name() + "__" + name).replace("/", "_") if type(tfcompile_flags) == type(""): flags = tfcompile_flags else: flags = " ".join([ "'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or []) ]) if enable_xla_hlo_profiling: profiling_flag = "--xla_hlo_profile" else: profiling_flag = "" native.genrule( name = ("gen_" + name), srcs = [ tfcompile_graph, config, ], outs = [ header_file, metadata_object_file, function_object_file, ], cmd = ( "CUDA_VISIBLE_DEVICES='' " + "$(location " + tfcompile_tool + ")" + " --graph=$(location " + tfcompile_graph + ")" + " --config=$(location " + config + ")" + " --entry_point=" + ep + " --cpp_class=" + cpp_class + " --target_triple=" + target_llvm_triple() + " --out_header=$(@D)/" + header_file + " --out_metadata_object=$(@D)/" + metadata_object_file + " --out_function_object=$(@D)/" + function_object_file + " " + flags + " " + profiling_flag ), tools = [tfcompile_tool], visibility = visibility, testonly = testonly, # Run tfcompile on the build host since it's typically faster on the # local machine. # # Note that setting the local=1 attribute on a *test target* causes the # test infrastructure to skip that test. However this is a genrule, not # a test target, and runs with --genrule_strategy=forced_forge, meaning # the local=1 attribute is ignored, and the genrule is still run. # # https://www.bazel.io/versions/master/docs/be/general.html#genrule local = 1, tags = tags, ) # Rule that runs tfcompile to produce the SessionModule proto, useful for # debugging. TODO(b/64813587): Once the SessionModule proto is # deterministic, move this into the main rule above. session_module_pb = name + "_session_module.pb" native.genrule( name = (name + "_session_module"), srcs = [ tfcompile_graph, config, ], outs = [ session_module_pb, ], cmd = ( "CUDA_VISIBLE_DEVICES='' " + "$(location " + tfcompile_tool + ")" + " --graph=$(location " + tfcompile_graph + ")" + " --config=$(location " + config + ")" + " --entry_point=" + ep + " --cpp_class=" + cpp_class + " --target_triple=" + target_llvm_triple() + " --out_session_module=$(@D)/" + session_module_pb + " " + flags ), tools = [tfcompile_tool], visibility = visibility, testonly = testonly, local = 1, tags = tags, ) # The cc_library rule packaging up the header and object file, and needed # kernel implementations. need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1) native.cc_library( name = name, srcs = [function_object_file, metadata_object_file], hdrs = [header_file], visibility = visibility, testonly = testonly, deps = [ # These deps are required by all tf_library targets even if # include_standard_runtime_deps is False. Without them, the # generated code will fail to compile. "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", "//tensorflow/core:framework_lite", ] + (need_xla_data_proto and [ # If we're generating the program shape, we must depend on the # proto. "//tensorflow/compiler/xla:xla_data_proto", ] or []) + (enable_xla_hlo_profiling and [ "//tensorflow/compiler/xla/service:hlo_profile_printer_data", ] or []) + (include_standard_runtime_deps and [ # TODO(cwhipkey): only depend on kernel code that the model actually # needed. "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", "//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort", "//tensorflow/compiler/xla/service/cpu:runtime_matmul", "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//third_party/eigen3", ] or []) + (deps or []), tags = tags, ) # Variables used for gen_test and gen_benchmark. cpp_class_split = cpp_class.rsplit("::", maxsplit = 2) if len(cpp_class_split) == 1: no_ns_name = cpp_class_split[0] else: no_ns_name = cpp_class_split[1] sed_replace = ( "-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " + "-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " + "-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" " ) if gen_test: test_name = name + "_test" test_file = test_name + ".cc" # Rule to rewrite test.cc to produce the test_file. native.genrule( name = ("gen_" + test_name), testonly = 1, srcs = [ "//tensorflow/compiler/aot:test.cc", header_file, ], outs = [test_file], cmd = ( "sed " + sed_replace + " $(location //tensorflow/compiler/aot:test.cc) " + "> $(OUTS)" ), tags = tags, ) # The cc_test rule for the generated code. To ensure that this works # reliably across build configurations, we must use tf_cc_test instead # of native.cc_test. This is related to how we build # //tensorflow/core:lib -- see the note in # tensorflow/core/BUILD for more details. tf_cc_test( name = test_name, srcs = [test_file], deps = [ ":" + name, "//tensorflow/compiler/aot:tf_library_test_main", "//tensorflow/compiler/xla:executable_run_options", "//third_party/eigen3", "//tensorflow/core:lib", "//tensorflow/core:test", ], tags = tags, ) if gen_benchmark: benchmark_name = name + "_benchmark" benchmark_file = benchmark_name + ".cc" benchmark_main = ("//tensorflow/compiler/aot:" + "benchmark_main.template") # Rule to rewrite benchmark.cc to produce the benchmark_file. native.genrule( name = ("gen_" + benchmark_name), srcs = [ benchmark_main, header_file, ], testonly = testonly, outs = [benchmark_file], cmd = ("sed " + sed_replace + " $(location " + benchmark_main + ") " + "> $(OUTS)"), tags = tags, ) # The cc_benchmark rule for the generated code. This does not need the # tf_cc_binary since we (by deliberate design) do not depend on # //tensorflow/core:lib. # # Note: to get smaller size on android for comparison, compile with: # --copt=-fvisibility=hidden # --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN # --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN native.cc_binary( name = benchmark_name, srcs = [benchmark_file], testonly = testonly, copts = tf_copts(), linkopts = if_android(["-pie", "-s"]), deps = [ ":" + name, "//tensorflow/compiler/aot:benchmark", "//tensorflow/compiler/xla:executable_run_options", "//third_party/eigen3", ] + if_android([ "//tensorflow/compiler/aot:benchmark_extra_android", ]), tags = tags, ) def target_llvm_triple(): """Returns the target LLVM triple to be used for compiling the target.""" # TODO(toddw): Add target_triple for other targets. For details see: # http://llvm.org/docs/doxygen/html/Triple_8h_source.html return select({ "//tensorflow:android_armeabi": "armv5-none-android", "//tensorflow:android_arm": "armv7-none-android", "//tensorflow:android_arm64": "aarch64-none-android", "//tensorflow:android_x86": "i686-none-android", "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", "//tensorflow:darwin": "x86_64-none-darwin", "//conditions:default": "x86_64-pc-linux", })