diff options
author | Peter Hawkins <phawkins@google.com> | 2017-12-23 08:59:26 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-23 09:03:02 -0800 |
commit | c975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0 (patch) | |
tree | 55db4d5d76c5acaa565feef0fe510644b162760e | |
parent | c0c2775ce3de682f7913d1aeaf50bbc4d1521934 (diff) |
[TF:XLA] Refactor large list of hard-coded compile-time constant arguments to operators. Add a new .CompileTimeConstInput() annotation on kernel registrations instead.
PiperOrigin-RevId: 180008567
35 files changed, 243 insertions, 208 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 2374620f58..13343967c4 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -215,7 +215,6 @@ cc_library( ":common", ":compilation_passes", "//tensorflow/compiler/jit/kernels:xla_launch_op", - "//tensorflow/compiler/tf2xla:const_analysis", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -245,7 +244,6 @@ cc_library( "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/jit/ops:parallel_check_op", "//tensorflow/compiler/jit/ops:xla_ops", - "//tensorflow/compiler/tf2xla:const_analysis", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:status_macros", diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 72fe6f5c7f..3c7dfef03d 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -120,6 +120,7 @@ cc_library( cc_library( name = "xla_compiler", srcs = [ + "const_analysis.cc", "graph_compiler.cc", "xla_compilation_device.cc", "xla_compiler.cc", @@ -133,6 +134,7 @@ cc_library( "xla_gpu_backend.cc", ]), hdrs = [ + "const_analysis.h", "graph_compiler.h", "xla_compilation_device.h", "xla_compiler.h", @@ -145,7 +147,6 @@ cc_library( visibility = [":friends"], deps = [ ":common", - ":const_analysis", ":dump_graph", ":functionalize_control_flow", ":sharding_util", @@ -355,28 +356,16 @@ tf_cc_test( ], ) -cc_library( - name = "const_analysis", - srcs = ["const_analysis.cc"], - hdrs = ["const_analysis.h"], - deps = [ - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - ], -) - tf_cc_test( name = "const_analysis_test", size = "small", srcs = ["const_analysis_test.cc"], deps = [ - ":const_analysis", + ":xla_compiler", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:ops", "//tensorflow/core:test", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index ab2f1e9a7a..0249500910 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -18,6 +18,7 @@ limitations under the License. #include <unordered_map> #include <unordered_set> +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/algorithm.h" @@ -27,96 +28,18 @@ namespace tensorflow { // compile-time constants. Status BackwardsConstAnalysis(const Graph& g, std::vector<bool>* compile_time_const_args) { - // TODO(phawkins): annotate these on the kernel registrations, rather than - // using a hard-coded list. - // (operator, argument) pairs that must be compile-time constants. - const std::unordered_multimap<string, string> compile_time_const_inputs = { - {"All", "reduction_indices"}, - {"Any", "reduction_indices"}, - {"ArgMin", "dimension"}, - {"ArgMax", "dimension"}, - {"AvgPoolGrad", "orig_input_shape"}, - {"AvgPool3DGrad", "orig_input_shape"}, - {"BatchToSpace", "crops"}, - {"BatchToSpaceND", "block_shape"}, - {"BatchToSpaceND", "crops"}, - {"BroadcastArgs", "s0"}, - {"BroadcastArgs", "s1"}, - {"BroadcastGradientArgs", "s0"}, - {"BroadcastGradientArgs", "s1"}, - {"Concat", "concat_dim"}, - {"ConcatV2", "axis"}, - {"ConcatOffset", "concat_dim"}, - {"ConcatOffset", "shape"}, - {"Conv2DBackpropFilter", "filter_sizes"}, - {"Conv2DBackpropInput", "input_sizes"}, - {"Conv3DBackpropFilterV2", "filter_sizes"}, - {"Conv3DBackpropInputV2", "input_sizes"}, - {"Cumprod", "axis"}, - {"Cumsum", "axis"}, - {"DepthwiseConv2dNativeBackpropFilter", "filter_sizes"}, - {"DepthwiseConv2dNativeBackpropInput", "input_sizes"}, - {"DynamicStitch", "indices"}, - {"ExpandDims", "dim"}, - {"Fill", "dims"}, - {"GatherV2", "axis"}, - {"InvertPermutation", "x"}, - {"LinSpace", "start"}, - {"LinSpace", "stop"}, - {"LinSpace", "num"}, - {"Max", "reduction_indices"}, - {"Mean", "reduction_indices"}, - {"Min", "reduction_indices"}, - {"OneHot", "depth"}, - {"Pad", "paddings"}, - {"PadV2", "paddings"}, - {"MirrorPad", "paddings"}, - {"Multinomial", "num_samples"}, - {"Prod", "reduction_indices"}, - {"RandomStandardNormal", "shape"}, - {"RandomUniform", "shape"}, - {"RandomUniformInt", "shape"}, - {"Range", "start"}, - {"Range", "limit"}, - {"Range", "delta"}, - {"Reshape", "shape"}, - {"ResizeBilinear", "size"}, - {"ResourceStridedSliceAssign", "begin"}, - {"ResourceStridedSliceAssign", "end"}, - {"ResourceStridedSliceAssign", "strides"}, - {"Reverse", "dims"}, - {"ReverseV2", "axis"}, - {"Slice", "begin"}, - {"Slice", "size"}, - {"SpaceToBatch", "paddings"}, - {"SpaceToBatchND", "block_shape"}, - {"SpaceToBatchND", "paddings"}, - {"Split", "split_dim"}, - {"SplitV", "split_dim"}, - {"SplitV", "size_splits"}, - {"StackV2", "max_size"}, - {"StridedSlice", "begin"}, - {"StridedSlice", "end"}, - {"StridedSlice", "strides"}, - {"StridedSliceGrad", "shape"}, - {"StridedSliceGrad", "begin"}, - {"StridedSliceGrad", "end"}, - {"StridedSliceGrad", "strides"}, - {"Sum", "reduction_indices"}, - {"TensorArrayV3", "size"}, - {"TensorArraySplitV3", "lengths"}, - {"Tile", "multiples"}, - {"Transpose", "perm"}}; - // Operators that don't look at the data of their inputs, just the shapes. const std::unordered_set<string> metadata_ops = { - "Rank", "Shape", "ShapeN", "Size", + "Rank", + "Shape", + "ShapeN", + "Size", }; Status status; std::unordered_set<Node*> must_be_const; - auto visit = [&status, &metadata_ops, &compile_time_const_inputs, - &must_be_const, compile_time_const_args](Node* node) { + auto visit = [&status, &metadata_ops, &must_be_const, + compile_time_const_args](Node* node) { if (!status.ok()) return; // If this is a metadata-only op, don't propagate the const requirement. @@ -139,16 +62,17 @@ Status BackwardsConstAnalysis(const Graph& g, } // Mark any compile-time constant operator arguments as const. - auto range = compile_time_const_inputs.equal_range(node->type_string()); - if (range.first == range.second) return; + const std::unordered_set<string>* const_inputs = + XlaOpRegistry::CompileTimeConstantInputs(node->type_string()); + if (!const_inputs) return; NameRangeMap input_name_ranges; status = NameRangesForNode(*node, node->op_def(), &input_name_ranges, nullptr); if (!status.ok()) return; - for (auto it = range.first; it != range.second; ++it) { - auto name_range = input_name_ranges.find(it->second); + for (const string& input : *const_inputs) { + auto name_range = input_name_ranges.find(input); if (name_range == input_name_ranges.end()) continue; for (Edge const* edge : node->in_edges()) { diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 21d3e64872..344a2ab2b6 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -159,7 +159,8 @@ class BatchToSpaceNDOp : public XlaOpKernel { block_shape, crops); } }; -REGISTER_XLA_OP(Name("BatchToSpaceND"), BatchToSpaceNDOp); +REGISTER_XLA_OP(Name("BatchToSpaceND").CompileTimeConstInput("crops"), + BatchToSpaceNDOp); class BatchToSpaceOp : public XlaOpKernel { public: @@ -181,7 +182,10 @@ class BatchToSpaceOp : public XlaOpKernel { private: int block_size_; }; -REGISTER_XLA_OP(Name("BatchToSpace"), BatchToSpaceOp); +REGISTER_XLA_OP(Name("BatchToSpace") + .CompileTimeConstInput("crops") + .CompileTimeConstInput("block_shape"), + BatchToSpaceOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index bb031b8c47..ee2c920453 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -65,7 +65,10 @@ class BCastArgsOp : public XlaOpKernel { private: TF_DISALLOW_COPY_AND_ASSIGN(BCastArgsOp); }; -REGISTER_XLA_OP(Name("BroadcastArgs"), BCastArgsOp); +REGISTER_XLA_OP(Name("BroadcastArgs") + .CompileTimeConstInput("s0") + .CompileTimeConstInput("s1"), + BCastArgsOp); // Given shapes of two tensors, computes the reduction indices for the // gradient computation. @@ -121,7 +124,10 @@ class BCastGradArgsOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp); }; -REGISTER_XLA_OP(Name("BroadcastGradientArgs"), BCastGradArgsOp); +REGISTER_XLA_OP(Name("BroadcastGradientArgs") + .CompileTimeConstInput("s0") + .CompileTimeConstInput("s1"), + BCastGradArgsOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 592f3ecc3c..545aa364f9 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -92,7 +92,8 @@ class CategoricalOp : public XlaOpKernel { }; // TODO(b/68769717): Rename this sampler to Categorical. -REGISTER_XLA_OP(Name("Multinomial"), CategoricalOp); +REGISTER_XLA_OP(Name("Multinomial").CompileTimeConstInput("num_samples"), + CategoricalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index 73a4740e29..1a246e8df9 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -84,8 +84,8 @@ class ConcatBaseOp : public XlaOpKernel { in_shape.dims() == input_dims || (input_is_scalar && in_is_scalar), errors::InvalidArgument( "ConcatOp : Ranks of all input tensors should match: shape[0] = ", - input_shape.DebugString(), " vs. shape[", i, "] = ", - in_shape.DebugString())); + input_shape.DebugString(), " vs. shape[", i, + "] = ", in_shape.DebugString())); if (in_shape.dims() == 0) { // Inputs that come in as scalars must be reshaped to 1-vectors. input_data.push_back(ctx->builder()->Reshape(handle, {1})); @@ -117,8 +117,11 @@ class ConcatV2Op : public ConcatBaseOp { : ConcatBaseOp(c, /* axis_index */ c->num_inputs() - 1) {} }; -REGISTER_XLA_OP(Name("Concat"), ConcatOp); -REGISTER_XLA_OP(Name("ConcatV2").TypeConstraint("Tidx", DT_INT32), ConcatV2Op); +REGISTER_XLA_OP(Name("Concat").CompileTimeConstInput("concat_dim"), ConcatOp); +REGISTER_XLA_OP(Name("ConcatV2") + .TypeConstraint("Tidx", DT_INT32) + .CompileTimeConstInput("axis"), + ConcatV2Op); class ConcatOffsetOp : public XlaOpKernel { public: @@ -189,10 +192,10 @@ class ConcatOffsetOp : public XlaOpKernel { } else { const int32 inp0_element = inp0_literal.Get<int>({j}); const int32 inp_element = inp_literal.Get<int>({j}); - OP_REQUIRES( - ctx, (inp0_element == inp_element), - errors::InvalidArgument("input[", i, ",", j, "] mismatch: ", - inp0_element, " vs. ", inp_element)); + OP_REQUIRES(ctx, (inp0_element == inp_element), + errors::InvalidArgument("input[", i, ",", j, + "] mismatch: ", inp0_element, + " vs. ", inp_element)); out_vec(j) = 0; } } @@ -202,7 +205,10 @@ class ConcatOffsetOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("ConcatOffset"), ConcatOffsetOp); +REGISTER_XLA_OP(Name("ConcatOffset") + .CompileTimeConstInput("concat_dim") + .CompileTimeConstInput("shape"), + ConcatOffsetOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index aaddbe811c..da9be68732 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -445,21 +445,26 @@ class Conv2DBackpropInputOp : public ConvBackpropInputOp { explicit Conv2DBackpropInputOp(OpKernelConstruction* ctx) : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {} }; -REGISTER_XLA_OP(Name("Conv2DBackpropInput"), Conv2DBackpropInputOp); +REGISTER_XLA_OP( + Name("Conv2DBackpropInput").CompileTimeConstInput("input_sizes"), + Conv2DBackpropInputOp); class Conv3DBackpropInputOp : public ConvBackpropInputOp { public: explicit Conv3DBackpropInputOp(OpKernelConstruction* ctx) : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {} }; -REGISTER_XLA_OP(Name("Conv3DBackpropInputV2"), Conv3DBackpropInputOp); +REGISTER_XLA_OP( + Name("Conv3DBackpropInputV2").CompileTimeConstInput("input_sizes"), + Conv3DBackpropInputOp); class DepthwiseConv2DBackpropInputOp : public ConvBackpropInputOp { public: explicit DepthwiseConv2DBackpropInputOp(OpKernelConstruction* ctx) : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} }; -REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput"), +REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput") + .CompileTimeConstInput("input_sizes"), DepthwiseConv2DBackpropInputOp); class ConvBackpropFilterOp : public XlaOpKernel { @@ -644,7 +649,9 @@ class Conv2DBackpropFilterOp : public ConvBackpropFilterOp { : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) { } }; -REGISTER_XLA_OP(Name("Conv2DBackpropFilter"), Conv2DBackpropFilterOp); +REGISTER_XLA_OP( + Name("Conv2DBackpropFilter").CompileTimeConstInput("filter_sizes"), + Conv2DBackpropFilterOp); class Conv3DBackpropFilterOp : public ConvBackpropFilterOp { public: @@ -652,14 +659,17 @@ class Conv3DBackpropFilterOp : public ConvBackpropFilterOp { : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) { } }; -REGISTER_XLA_OP(Name("Conv3DBackpropFilterV2"), Conv3DBackpropFilterOp); +REGISTER_XLA_OP( + Name("Conv3DBackpropFilterV2").CompileTimeConstInput("filter_sizes"), + Conv3DBackpropFilterOp); class DepthwiseConv2DBackpropFilterOp : public ConvBackpropFilterOp { public: explicit DepthwiseConv2DBackpropFilterOp(OpKernelConstruction* ctx) : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} }; -REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter"), +REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter") + .CompileTimeConstInput("filter_sizes"), DepthwiseConv2DBackpropFilterOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 7349dcb987..f2cd21ffb9 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -72,22 +72,24 @@ class DynamicStitchOp : public XlaOpKernel { XLAShapeToTensorShape(indices_input[input_num].shape(), &indices_shape)); const TensorShape& data_shape = data_shapes[input_num]; - OP_REQUIRES(ctx, TensorShapeUtils::StartsWith(data_shape, indices_shape), - errors::InvalidArgument( - "data[", input_num, "].shape = ", - data_shape.DebugString(), " does not start with indices[", - input_num, "].shape = ", indices_shape.DebugString())); - OP_REQUIRES(ctx, - input_num == 0 || SameExtraShape(data0_shape, indices0_shape, - data_shape, indices_shape), - errors::InvalidArgument( - "Need data[0].shape[", indices0_shape.dims(), - ":] = data[", input_num, "].shape[", indices_shape.dims(), - ":], got data[0].shape = ", data0_shape.DebugString(), - ", data[", input_num, "].shape = ", - data_shape.DebugString(), ", indices[0].shape = ", - indices0_shape.DebugString(), ", indices[", input_num, - "].shape = ", indices_shape.DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::StartsWith(data_shape, indices_shape), + errors::InvalidArgument("data[", input_num, + "].shape = ", data_shape.DebugString(), + " does not start with indices[", input_num, + "].shape = ", indices_shape.DebugString())); + OP_REQUIRES( + ctx, + input_num == 0 || SameExtraShape(data0_shape, indices0_shape, + data_shape, indices_shape), + errors::InvalidArgument( + "Need data[0].shape[", indices0_shape.dims(), ":] = data[", + input_num, "].shape[", indices_shape.dims(), + ":], got data[0].shape = ", data0_shape.DebugString(), ", data[", + input_num, "].shape = ", data_shape.DebugString(), + ", indices[0].shape = ", indices0_shape.DebugString(), + ", indices[", input_num, + "].shape = ", indices_shape.DebugString())); OP_REQUIRES_OK(ctx, XlaHelpers::ReshapeLiteral(indices_input[input_num], @@ -159,8 +161,8 @@ class DynamicStitchOp : public XlaOpKernel { indices0_shape.dims()); std::vector<int64> slice_limit(1 + data0_shape.dims() - indices0_shape.dims()); - std::vector<int64> stride(1 + data0_shape.dims() - - indices0_shape.dims(), 1); + std::vector<int64> stride(1 + data0_shape.dims() - indices0_shape.dims(), + 1); for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d); } @@ -198,8 +200,10 @@ class DynamicStitchOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("DynamicStitch"), DynamicStitchOp); -REGISTER_XLA_OP(Name("ParallelDynamicStitch"), DynamicStitchOp); +REGISTER_XLA_OP(Name("DynamicStitch").CompileTimeConstInput("indices"), + DynamicStitchOp); +REGISTER_XLA_OP(Name("ParallelDynamicStitch").CompileTimeConstInput("indices"), + DynamicStitchOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index 9e090fe01c..eaa13b8dfa 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -69,7 +69,7 @@ class FillOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Fill"), FillOp); +REGISTER_XLA_OP(Name("Fill").CompileTimeConstInput("dims"), FillOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index e420f21ca3..70192cb324 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -198,6 +198,7 @@ void GatherOpDynamicSlice::Compile(XlaOpKernelContext* context) { } REGISTER_XLA_OP(Name("Gather"), GatherOpDynamicSlice); -REGISTER_XLA_OP(Name("GatherV2"), GatherOpDynamicSlice); +REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstInput("axis"), + GatherOpDynamicSlice); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index d91ebb500b..c0b8f9c179 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -243,7 +243,8 @@ class ResizeBilinearOp : public XlaOpKernel { bool align_corners_; }; -REGISTER_XLA_OP(Name("ResizeBilinear"), ResizeBilinearOp); +REGISTER_XLA_OP(Name("ResizeBilinear").CompileTimeConstInput("size"), + ResizeBilinearOp); class ResizeBilinearGradOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index e0dc1870f2..7bf4b435f5 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -80,7 +80,10 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { XlaArgMaxOp::XlaArgMaxOp(OpKernelConstruction* ctx) : XlaArgMinMaxOp(ctx, /*is_min=*/false) {} -REGISTER_XLA_OP(Name("ArgMax").Device(DEVICE_GPU_XLA_JIT), XlaArgMaxOp); +REGISTER_XLA_OP(Name("ArgMax") + .Device(DEVICE_GPU_XLA_JIT) + .CompileTimeConstInput("dimension"), + XlaArgMaxOp); namespace { @@ -90,7 +93,7 @@ class XlaArgMinOp : public XlaArgMinMaxOp { }; XlaArgMinOp::XlaArgMinOp(OpKernelConstruction* ctx) : XlaArgMinMaxOp(ctx, /*is_min=*/true) {} -REGISTER_XLA_OP(Name("ArgMin"), XlaArgMinOp); +REGISTER_XLA_OP(Name("ArgMin").CompileTimeConstInput("dimension"), XlaArgMinOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 20946e247a..b1f3c3c298 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -56,10 +56,10 @@ class ArgMaxCustomCallOp : public XlaOpKernel { errors::InvalidArgument("dim must be < input rank (", input_shape.dims(), "), but got: ", dim)); const int64 dim_size = input_shape.dim_size(dim); - OP_REQUIRES( - ctx, dim_size > 0, - errors::InvalidArgument("Reduction axis ", dim, " is empty in shape: ", - input_shape.DebugString())); + OP_REQUIRES(ctx, dim_size > 0, + errors::InvalidArgument( + "Reduction axis ", dim, + " is empty in shape: ", input_shape.DebugString())); // The output shape is the input shape contracted along dim. TensorShape output_shape; @@ -113,9 +113,11 @@ class ArgMaxCustomCallOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(ArgMaxCustomCallOp); }; -REGISTER_XLA_OP( - Name("ArgMax").TypeConstraint("T", DT_FLOAT).Device(DEVICE_CPU_XLA_JIT), - ArgMaxCustomCallOp); +REGISTER_XLA_OP(Name("ArgMax") + .TypeConstraint("T", DT_FLOAT) + .Device(DEVICE_CPU_XLA_JIT) + .CompileTimeConstInput("dimension"), + ArgMaxCustomCallOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index bea1d1600b..05a36a031a 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -92,7 +92,8 @@ class MirrorPadOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(MirrorPadOp); }; -REGISTER_XLA_OP(Name("MirrorPad"), MirrorPadOp); +REGISTER_XLA_OP(Name("MirrorPad").CompileTimeConstInput("paddings"), + MirrorPadOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc index 2a9cfcb2eb..9f7c991380 100644 --- a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc @@ -76,7 +76,7 @@ class OneHotOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp); }; -REGISTER_XLA_OP(Name("OneHot"), OneHotOp); +REGISTER_XLA_OP(Name("OneHot").CompileTimeConstInput("depth"), OneHotOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index d841bd37b3..791351637a 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -83,8 +83,8 @@ class PadOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Pad"), PadOp); -REGISTER_XLA_OP(Name("PadV2"), PadOp); +REGISTER_XLA_OP(Name("Pad").CompileTimeConstInput("paddings"), PadOp); +REGISTER_XLA_OP(Name("PadV2").CompileTimeConstInput("paddings"), PadOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 2b6053d19d..0b5a38967a 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -455,14 +455,16 @@ class AvgPool2DGradOp : public AvgPoolGradOp { errors::InvalidArgument("Invalid data format")); } }; -REGISTER_XLA_OP(Name("AvgPoolGrad"), AvgPool2DGradOp); +REGISTER_XLA_OP(Name("AvgPoolGrad").CompileTimeConstInput("orig_input_shape"), + AvgPool2DGradOp); class AvgPool3DGradOp : public AvgPoolGradOp { public: explicit AvgPool3DGradOp(OpKernelConstruction* ctx) : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {} }; -REGISTER_XLA_OP(Name("AvgPool3DGrad"), AvgPool3DGradOp); +REGISTER_XLA_OP(Name("AvgPool3DGrad").CompileTimeConstInput("orig_input_shape"), + AvgPool3DGradOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 2421825ead..c0994c434b 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -52,7 +52,8 @@ class RandomUniformOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformOp); }; -REGISTER_XLA_OP(Name("RandomUniform"), RandomUniformOp); +REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstInput("shape"), + RandomUniformOp); class RandomUniformIntOp : public XlaOpKernel { public: @@ -83,7 +84,8 @@ class RandomUniformIntOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformIntOp); }; -REGISTER_XLA_OP(Name("RandomUniformInt"), RandomUniformIntOp); +REGISTER_XLA_OP(Name("RandomUniformInt").CompileTimeConstInput("shape"), + RandomUniformIntOp); class RandomStandardNormalOp : public XlaOpKernel { public: @@ -111,7 +113,8 @@ class RandomStandardNormalOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RandomStandardNormalOp); }; -REGISTER_XLA_OP(Name("RandomStandardNormal"), RandomStandardNormalOp); +REGISTER_XLA_OP(Name("RandomStandardNormal").CompileTimeConstInput("shape"), + RandomStandardNormalOp); class TruncatedNormalOp : public XlaOpKernel { public: @@ -183,7 +186,8 @@ class TruncatedNormalOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("TruncatedNormal"), TruncatedNormalOp); +REGISTER_XLA_OP(Name("TruncatedNormal").CompileTimeConstInput("shape"), + TruncatedNormalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 647b627408..03b13b2924 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -35,7 +35,7 @@ class SumOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Sum"), SumOp); +REGISTER_XLA_OP(Name("Sum").CompileTimeConstInput("reduction_indices"), SumOp); class ProdOp : public XlaReductionOp { public: @@ -53,7 +53,8 @@ class ProdOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Prod"), ProdOp); +REGISTER_XLA_OP(Name("Prod").CompileTimeConstInput("reduction_indices"), + ProdOp); class MinOp : public XlaReductionOp { public: @@ -73,7 +74,7 @@ class MinOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Min"), MinOp); +REGISTER_XLA_OP(Name("Min").CompileTimeConstInput("reduction_indices"), MinOp); class MaxOp : public XlaReductionOp { public: @@ -93,7 +94,7 @@ class MaxOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Max"), MaxOp); +REGISTER_XLA_OP(Name("Max").CompileTimeConstInput("reduction_indices"), MaxOp); class MeanOp : public XlaReductionOp { public: @@ -115,7 +116,8 @@ class MeanOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Mean"), MeanOp); +REGISTER_XLA_OP(Name("Mean").CompileTimeConstInput("reduction_indices"), + MeanOp); class AllOp : public XlaReductionOp { public: @@ -133,7 +135,7 @@ class AllOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("All"), AllOp); +REGISTER_XLA_OP(Name("All").CompileTimeConstInput("reduction_indices"), AllOp); class AnyOp : public XlaReductionOp { public: @@ -151,7 +153,7 @@ class AnyOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Any"), AnyOp); +REGISTER_XLA_OP(Name("Any").CompileTimeConstInput("reduction_indices"), AnyOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index 5952e75272..af4d64b159 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -95,7 +95,7 @@ class ReshapeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Reshape"), ReshapeOp); +REGISTER_XLA_OP(Name("Reshape").CompileTimeConstInput("shape"), ReshapeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index bdfd066f01..17a345fc94 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -65,7 +65,7 @@ class ReverseOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Reverse"), ReverseOp); +REGISTER_XLA_OP(Name("Reverse").CompileTimeConstInput("dims"), ReverseOp); class ReverseV2Op : public XlaOpKernel { public: @@ -103,7 +103,7 @@ class ReverseV2Op : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("ReverseV2"), ReverseV2Op); +REGISTER_XLA_OP(Name("ReverseV2").CompileTimeConstInput("axis"), ReverseV2Op); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 650f8c7dc8..ee4a94164c 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -129,13 +129,19 @@ class CumsumOp : public ScanOp { public: explicit CumsumOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/true) {} }; -REGISTER_XLA_OP(Name("Cumsum").TypeConstraint("T", kScanOpTypes), CumsumOp); +REGISTER_XLA_OP(Name("Cumsum") + .TypeConstraint("T", kScanOpTypes) + .CompileTimeConstInput("axis"), + CumsumOp); class CumprodOp : public ScanOp { public: explicit CumprodOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/false) {} }; -REGISTER_XLA_OP(Name("Cumprod").TypeConstraint("T", kScanOpTypes), CumprodOp); +REGISTER_XLA_OP(Name("Cumprod") + .TypeConstraint("T", kScanOpTypes) + .CompileTimeConstInput("axis"), + CumprodOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index c2b0e1bb4c..2c31f8d908 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -138,7 +138,11 @@ class RangeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Range"), RangeOp); +REGISTER_XLA_OP(Name("Range") + .CompileTimeConstInput("start") + .CompileTimeConstInput("limit") + .CompileTimeConstInput("delta"), + RangeOp); class LinSpaceOp : public XlaOpKernel { public: @@ -207,7 +211,11 @@ class LinSpaceOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("LinSpace"), LinSpaceOp); +REGISTER_XLA_OP(Name("LinSpace") + .CompileTimeConstInput("start") + .CompileTimeConstInput("stop") + .CompileTimeConstInput("num"), + LinSpaceOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index e205fadd2b..8fb7a74310 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -150,7 +150,7 @@ class ExpandDimsOp : public XlaOpKernel { ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape)); } }; -REGISTER_XLA_OP(Name("ExpandDims"), ExpandDimsOp); +REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstInput("dim"), ExpandDimsOp); class SqueezeOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index fbe8c78d8f..be1e97bf26 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -112,7 +112,9 @@ class SliceOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Slice"), SliceOp); +REGISTER_XLA_OP( + Name("Slice").CompileTimeConstInput("begin").CompileTimeConstInput("size"), + SliceOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 83a87f19a7..01b46e160d 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -162,7 +162,10 @@ class SpaceToBatchNDOp : public XlaOpKernel { block_shape, paddings); } }; -REGISTER_XLA_OP(Name("SpaceToBatchND"), SpaceToBatchNDOp); +REGISTER_XLA_OP(Name("SpaceToBatchND") + .CompileTimeConstInput("paddings") + .CompileTimeConstInput("block_shape"), + SpaceToBatchNDOp); class SpaceToBatchOp : public XlaOpKernel { public: @@ -184,7 +187,8 @@ class SpaceToBatchOp : public XlaOpKernel { private: int block_size_; }; -REGISTER_XLA_OP(Name("SpaceToBatch"), SpaceToBatchOp); +REGISTER_XLA_OP(Name("SpaceToBatch").CompileTimeConstInput("paddings"), + SpaceToBatchOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 795eb1794f..79c435c90a 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -103,7 +103,7 @@ class SplitOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Split"), SplitOp); +REGISTER_XLA_OP(Name("Split").CompileTimeConstInput("split_dim"), SplitOp); class SplitVOp : public XlaOpKernel { public: @@ -142,8 +142,9 @@ class SplitVOp : public XlaOpKernel { int neg_one_dim = -1; std::vector<int64> split_sizes_vec(num_split, -1); const TensorShape split_size_shape = ctx->InputShape(1); - OP_REQUIRES(ctx, split_size_shape.dims() == 1 && - split_size_shape.num_elements() == num_split, + OP_REQUIRES(ctx, + split_size_shape.dims() == 1 && + split_size_shape.num_elements() == num_split, errors::InvalidArgument( "shape of tensor describing " " the output must have dimension 1 and the same " @@ -171,10 +172,11 @@ class SplitVOp : public XlaOpKernel { } OP_REQUIRES( - ctx, (neg_one_dim == -1 && - total_split_size == input_shape.dim_size(split_dim)) || - (neg_one_dim >= 0 && - total_split_size <= input_shape.dim_size(split_dim)), + ctx, + (neg_one_dim == -1 && + total_split_size == input_shape.dim_size(split_dim)) || + (neg_one_dim >= 0 && + total_split_size <= input_shape.dim_size(split_dim)), errors::InvalidArgument("Determined shape must either match " "input shape along split_dim exactly if " "fully specified, or be less than the size of " @@ -206,7 +208,10 @@ class SplitVOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("SplitV"), SplitVOp); +REGISTER_XLA_OP(Name("SplitV") + .CompileTimeConstInput("split_dim") + .CompileTimeConstInput("size_splits"), + SplitVOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 8013ece861..c912876e65 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -129,7 +129,7 @@ class StackOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackOp); }; -REGISTER_XLA_OP(Name("StackV2"), StackOp); +REGISTER_XLA_OP(Name("StackV2").CompileTimeConstInput("max_size"), StackOp); class StackPushOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 6af4bd0496..f0525a5fb8 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -106,7 +106,11 @@ class StridedSliceOp : public XlaOpKernel { DataType index_type_; }; -REGISTER_XLA_OP(Name("StridedSlice"), StridedSliceOp); +REGISTER_XLA_OP(Name("StridedSlice") + .CompileTimeConstInput("begin") + .CompileTimeConstInput("end") + .CompileTimeConstInput("strides"), + StridedSliceOp); class StridedSliceGradOp : public XlaOpKernel { public: @@ -211,7 +215,12 @@ class StridedSliceGradOp : public XlaOpKernel { DataType index_type_; }; -REGISTER_XLA_OP(Name("StridedSliceGrad"), StridedSliceGradOp); +REGISTER_XLA_OP(Name("StridedSliceGrad") + .CompileTimeConstInput("shape") + .CompileTimeConstInput("begin") + .CompileTimeConstInput("end") + .CompileTimeConstInput("strides"), + StridedSliceGradOp); class StridedSliceAssignOp : public XlaOpKernel { public: @@ -320,7 +329,11 @@ class StridedSliceAssignOp : public XlaOpKernel { DataType index_type_; }; -REGISTER_XLA_OP(Name("ResourceStridedSliceAssign"), StridedSliceAssignOp); +REGISTER_XLA_OP(Name("ResourceStridedSliceAssign") + .CompileTimeConstInput("begin") + .CompileTimeConstInput("end") + .CompileTimeConstInput("strides"), + StridedSliceAssignOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 8a742ff11c..9224072a3c 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -192,7 +192,8 @@ class TensorArrayOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp); }; -REGISTER_XLA_OP(Name("TensorArrayV3"), TensorArrayOp); +REGISTER_XLA_OP(Name("TensorArrayV3").CompileTimeConstInput("size"), + TensorArrayOp); class TensorArrayWriteOp : public XlaOpKernel { public: @@ -414,8 +415,8 @@ class TensorArrayScatterOp : public XlaOpKernel { // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto index = b->Slice(indices, {i}, {i + 1}, {1}); auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); } } @@ -537,7 +538,8 @@ class TensorArraySplitOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySplitOp); }; -REGISTER_XLA_OP(Name("TensorArraySplitV3"), TensorArraySplitOp); +REGISTER_XLA_OP(Name("TensorArraySplitV3").CompileTimeConstInput("lengths"), + TensorArraySplitOp); class TensorArraySizeOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 9ee6bd8925..9aefcd4fc7 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -122,7 +122,7 @@ class TileOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TileOp); }; -REGISTER_XLA_OP(Name("Tile"), TileOp); +REGISTER_XLA_OP(Name("Tile").CompileTimeConstInput("multiples"), TileOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 2fc5d40d10..5c17b7fbf0 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -72,8 +72,9 @@ class TransposeOp : public XlaOpKernel { } } for (int i = 0; i < dims; ++i) { - OP_REQUIRES(ctx, bits[i], errors::InvalidArgument( - i, " is missing from 'perm' argument.")); + OP_REQUIRES( + ctx, bits[i], + errors::InvalidArgument(i, " is missing from 'perm' argument.")); } // 0-D, 1-D, and identity transposes do nothing. @@ -87,7 +88,7 @@ class TransposeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Transpose"), TransposeOp); +REGISTER_XLA_OP(Name("Transpose").CompileTimeConstInput("perm"), TransposeOp); // InvertPermutation frequently forms part of the gradient of Transpose. // @@ -103,8 +104,9 @@ class InvertPermutationOp : public XlaOpKernel { explicit InvertPermutationOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - OP_REQUIRES(ctx, FastBoundsCheck(ctx->InputShape(0).num_elements(), - std::numeric_limits<int32>::max()), + OP_REQUIRES(ctx, + FastBoundsCheck(ctx->InputShape(0).num_elements(), + std::numeric_limits<int32>::max()), errors::InvalidArgument("permutation of nonnegative int32s " "must have <= int32 max elements")); @@ -128,7 +130,9 @@ class InvertPermutationOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("InvertPermutation").TypeConstraint("T", DT_INT32), +REGISTER_XLA_OP(Name("InvertPermutation") + .TypeConstraint("T", DT_INT32) + .CompileTimeConstInput("x"), InvertPermutationOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index faf47434b5..97bb100fb1 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -83,6 +83,11 @@ XlaOpRegistry::~XlaOpRegistry() = default; return false; } } + if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) { + LOG(WARNING) << "Registrations of " << x.name + << " have incompatible compile time constant inputs."; + return false; + } return true; } @@ -263,6 +268,17 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels( return kernels; } +/* static */ const std::unordered_set<string>* +XlaOpRegistry::CompileTimeConstantInputs(const string& op) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + auto it = registry.ops_.find(op); + if (it == registry.ops_.end()) { + return nullptr; + } + return &it->second->compile_time_constant_inputs; +} + std::vector<string> XlaOpRegistry::BackendNames() { std::vector<string> names; XlaOpRegistry& registry = Instance(); @@ -337,6 +353,12 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( return *this; } +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput( + StringPiece input_name) { + registration_->compile_time_constant_inputs.insert(input_name.ToString()); + return *this; +} + std::unique_ptr<XlaOpRegistry::OpRegistration> XlaOpRegistrationBuilder::Build( XlaOpRegistry::Factory factory) { registration_->factory = factory; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 8bfd9758f7..ff7453194a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -128,6 +128,11 @@ class XlaOpRegistry { const string& compilation_device_name, bool include_compilation_only_kernels); + // Returns the set of compile-time constant inputs to 'op'. Returns nullptr + // if the op is not registered. + static const std::unordered_set<string>* CompileTimeConstantInputs( + const string& op); + private: friend class XlaBackendRegistrar; friend class XlaOpRegistrar; @@ -181,6 +186,9 @@ class XlaOpRegistry { bool has_device_whitelist = false; std::unordered_set<string> device_whitelist; + // Names of arguments that must be compile-time constants. + std::unordered_set<string> compile_time_constant_inputs; + // Factory used to build OpKernels that perform symbolic execution. Factory factory; }; @@ -242,6 +250,9 @@ class XlaOpRegistrationBuilder { // Allow DT_RESOURCE types for type parameters. XlaOpRegistrationBuilder& AllowResourceTypes(); + // Mark 'input_name' as an argument whose value must be known at compile-time. + XlaOpRegistrationBuilder& CompileTimeConstInput(StringPiece input_name); + std::unique_ptr<XlaOpRegistry::OpRegistration> Build( XlaOpRegistry::Factory factory); |