1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
|
licenses(["notice"]) # Apache 2.0
package(
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
)
load("//tensorflow:tensorflow.bzl", "tf_copts")
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
tf_kernel_library(
name = "xla_ops",
srcs = [
"aggregate_ops.cc",
"arg_op.cc",
"batch_matmul_op.cc",
"batch_norm_op.cc",
"batchtospace_op.cc",
"bcast_ops.cc",
"bias_ops.cc",
"binary_ops.cc",
"bucketize_op.cc",
"cast_op.cc",
"categorical_op.cc",
"cholesky_op.cc",
"clip_by_value_op.cc",
"concat_op.cc",
"const_op.cc",
"conv_ops.cc",
"cross_op.cc",
"cwise_ops.cc",
"cwise_ops.h",
"depthtospace_op.cc",
"diag_op.cc",
"dynamic_slice_ops.cc",
"dynamic_stitch_op.cc",
"elu_op.cc",
"extract_image_patches_op.cc",
"fake_quantize_ops.cc",
"fft_ops.cc",
"fill_op.cc",
"function_ops.cc",
"gather_op.cc",
"gather_op_helpers.h",
"identity_op.cc",
"image_ops.cc",
"image_resize_ops.cc",
"index_ops.cc",
"l2loss_op.cc",
"listdiff_op.cc",
"lrn_ops.cc",
"matmul_op.cc",
"matrix_band_part_op.cc",
"matrix_set_diag_op.cc",
"matrix_triangular_solve_op.cc",
"mirror_pad_op.cc",
"no_op.cc",
"one_hot_op.cc",
"pack_op.cc",
"pad_op.cc",
"pooling_ops.cc",
"qr_op.cc",
"quantize_and_dequantize_op.cc",
"random_ops.cc",
"reduce_window_op.cc",
"reduction_ops.cc",
"reduction_ops.h",
"reduction_ops_common.cc",
"relu_op.cc",
"reshape_op.cc",
"retval_op.cc",
"reverse_op.cc",
"reverse_sequence_op.cc",
"scan_ops.cc",
"scatter_nd_op.cc",
"segment_reduction_ops.cc",
"select_op.cc",
"sendrecv_ops.cc",
"sequence_ops.cc",
"shape_op.cc",
"shape_util.cc",
"slice_op.cc",
"softmax_op.cc",
"sort_ops.cc",
"spacetobatch_op.cc",
"spacetodepth_op.cc",
"sparse_to_dense_op.cc",
"split_op.cc",
"stack_ops.cc",
"stateless_random_ops.cc",
"strided_slice_op.cc",
"tensor_array_ops.cc",
"tile_ops.cc",
"topk_op.cc",
"training_ops.cc",
"transpose_op.cc",
"unary_ops.cc",
"unpack_op.cc",
"variable_ops.cc",
],
hdrs = [
"index_ops.h",
"shape_util.h",
],
deps = [
":if_op",
":while_op",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:batch_dot",
"//tensorflow/compiler/tf2xla/lib:cholesky",
"//tensorflow/compiler/tf2xla/lib:qr",
"//tensorflow/compiler/tf2xla/lib:random",
"//tensorflow/compiler/tf2xla/lib:scatter",
"//tensorflow/compiler/tf2xla/lib:triangular_solve",
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/tf2xla/lib:while_loop",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/compiler/xla/client/lib:prng",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:image_ops_op_lib",
"//tensorflow/core:lib",
"//tensorflow/core:linalg_ops_op_lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:spectral_ops_op_lib",
"//tensorflow/core:stateless_random_ops_op_lib",
"//tensorflow/core/kernels:bounds_check",
"//tensorflow/core/kernels:concat_lib",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:conv_ops",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:no_op",
"//tensorflow/core/kernels:ops_util",
"//tensorflow/core/kernels:pooling_ops",
"//tensorflow/core/kernels:random_op",
"//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:sendrecv_ops",
"//tensorflow/core/kernels:sparse_to_dense_op",
"//tensorflow/core/kernels:stack_ops",
"//tensorflow/core/kernels:training_ops",
"//tensorflow/core/kernels:transpose_op",
],
)
tf_kernel_library(
name = "while_op",
srcs = ["while_op.cc"],
hdrs = ["while_op.h"],
deps = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
tf_kernel_library(
name = "if_op",
srcs = ["if_op.cc"],
hdrs = ["if_op.h"],
deps = [
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
# Kernels that have a dummy (no-op) implementation.
tf_kernel_library(
name = "xla_dummy_ops",
srcs = [
"assert_op.cc",
"check_numerics_op.cc",
],
deps = [
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:array_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:logging_ops_op_lib",
],
alwayslink = 1,
)
# Kernels that only work on CPU, because they use XLA custom calls.
# Only link this when using the CPU backend for XLA.
tf_kernel_library(
name = "xla_cpu_only_ops",
srcs = ["index_ops_cpu.cc"],
deps = [
":index_ops_kernel_argmax_float_1d",
":index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:argmax_op",
"//tensorflow/core/kernels:bounds_check",
],
)
cc_library(
name = "index_ops_kernel_argmax_float_1d",
srcs = ["index_ops_kernel_argmax_float_1d.cc"],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
],
alwayslink = 1,
)
cc_library(
name = "index_ops_kernel_argmax_float_2d",
srcs = ["index_ops_kernel_argmax_float_2d.cc"],
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
],
alwayslink = 1,
)
|