diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/BUILD')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/BUILD | 177 |
1 files changed, 177 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD new file mode 100644 index 0000000000..d913f898e9 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -0,0 +1,177 @@ +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [ + "//tensorflow/compiler/tf2xla:internal", + ], + features = ["no_layering_check"], +) + +load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load("//tensorflow/compiler/xla:xla.bzl", "export_dynamic_linkopts") + +tf_kernel_library( + name = "xla_ops", + srcs = [ + "aggregate_ops.cc", + "batch_matmul_op.cc", + "bcast_ops.cc", + "bias_ops.cc", + "binary_ops.cc", + "cast_op.cc", + "concat_op.cc", + "conv_ops.cc", + "cwise_ops.cc", + "declaration_op.cc", + "depthwise_conv_ops.cc", + "diag_op.cc", + "dynamic_stitch_op.cc", + "fill_op.cc", + "function_ops.cc", + "identity_op.cc", + "l2loss_op.cc", + "lrn_ops.cc", + "matmul_op.cc", + "no_op.cc", + "pack_op.cc", + "pad_op.cc", + "pooling_ops.cc", + "random_ops.cc", + "reduction_ops.cc", + "reduction_ops_common.cc", + "relu_op.cc", + "reshape_op.cc", + "retval_op.cc", + "select_op.cc", + "sequence_ops.cc", + "shape_op.cc", + "slice_op.cc", + "softmax_op.cc", + "split_op.cc", + "strided_slice_op.cc", + "tile_ops.cc", + "transpose_op.cc", + "unary_ops.cc", + "unpack_op.cc", + ], + hdrs = [ + "cwise_ops.h", + "reduction_ops.h", + ], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:tensorflow_opensource", + "//tensorflow/core/kernels:concat_lib", + "//tensorflow/core/kernels:conv_2d", + "//tensorflow/core/kernels:conv_ops", + "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:depthwise_conv_op", + "//tensorflow/core/kernels:matmul_op", + "//tensorflow/core/kernels:no_op", + "//tensorflow/core/kernels:ops_util", + "//tensorflow/core/kernels:pooling_ops", + "//tensorflow/core/kernels:sendrecv_ops", + "//tensorflow/core/kernels:transpose_op", + ], +) + +# Kernels that only work on CPU, because they use XLA custom calls. +# Only link this when using the CPU backend for XLA. +# +# TODO(cwhipkey): move into xla_ops when ops can be registered for +# CPU compilation only (b/31363654). +tf_kernel_library( + name = "xla_cpu_only_ops", + srcs = [ + "gather_op.cc", + "index_ops.cc", + ], + deps = [ + ":gather_op_kernel_float_int32", + ":gather_op_kernel_float_int64", + ":index_ops_kernel_argmax_float_1d", + ":index_ops_kernel_argmax_float_2d", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:tensorflow_opensource", + ], +) + +tf_kernel_library( + name = "gather_op_kernel_float_int32", + srcs = ["gather_op_kernel_float_int32.cc"], + # Makes the custom-call function visible to LLVM during JIT. + linkopts = export_dynamic_linkopts, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/tf2xla:xla_local_runtime_context", + "//tensorflow/core:framework_lite", + "//tensorflow/core/kernels:gather_functor", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "gather_op_kernel_float_int64", + srcs = ["gather_op_kernel_float_int64.cc"], + # Makes the custom-call function visible to LLVM during JIT. + linkopts = export_dynamic_linkopts, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/tf2xla:xla_local_runtime_context", + "//tensorflow/core:framework_lite", + "//tensorflow/core/kernels:gather_functor", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "index_ops_kernel_argmax_float_1d", + srcs = ["index_ops_kernel_argmax_float_1d.cc"], + # Makes the custom-call function visible to LLVM during JIT. + linkopts = export_dynamic_linkopts, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "index_ops_kernel_argmax_float_2d", + srcs = ["index_ops_kernel_argmax_float_2d.cc"], + # Makes the custom-call function visible to LLVM during JIT. + linkopts = export_dynamic_linkopts, + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) |