aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-12-23 08:59:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-23 09:03:02 -0800
commitc975bc2b3fdc9674dd71a7ed89c74ce8ea2d46f0 (patch)
tree55db4d5d76c5acaa565feef0fe510644b162760e
parentc0c2775ce3de682f7913d1aeaf50bbc4d1521934 (diff)
[TF:XLA] Refactor large list of hard-coded compile-time constant arguments to operators. Add a new .CompileTimeConstInput() annotation on kernel registrations instead.
PiperOrigin-RevId: 180008567
-rw-r--r--tensorflow/compiler/jit/BUILD2
-rw-r--r--tensorflow/compiler/tf2xla/BUILD19
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.cc100
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/bcast_ops.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/categorical_op.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/concat_op.cc24
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_ops.cc22
-rw-r--r--tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc44
-rw-r--r--tensorflow/compiler/tf2xla/kernels/fill_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/gather_op.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops.cc7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc16
-rw-r--r--tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/one_hot_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pad_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pooling_ops.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/random_ops.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops.cc16
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reshape_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reverse_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/scan_ops.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sequence_ops.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/shape_op.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/slice_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/split_op.cc21
-rw-r--r--tensorflow/compiler/tf2xla/kernels/stack_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc19
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tile_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/transpose_op.cc16
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.cc22
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.h11
35 files changed, 243 insertions, 208 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 2374620f58..13343967c4 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -215,7 +215,6 @@ cc_library(
":common",
":compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_launch_op",
- "//tensorflow/compiler/tf2xla:const_analysis",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -245,7 +244,6 @@ cc_library(
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/compiler/jit/ops:parallel_check_op",
"//tensorflow/compiler/jit/ops:xla_ops",
- "//tensorflow/compiler/tf2xla:const_analysis",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 72fe6f5c7f..3c7dfef03d 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -120,6 +120,7 @@ cc_library(
cc_library(
name = "xla_compiler",
srcs = [
+ "const_analysis.cc",
"graph_compiler.cc",
"xla_compilation_device.cc",
"xla_compiler.cc",
@@ -133,6 +134,7 @@ cc_library(
"xla_gpu_backend.cc",
]),
hdrs = [
+ "const_analysis.h",
"graph_compiler.h",
"xla_compilation_device.h",
"xla_compiler.h",
@@ -145,7 +147,6 @@ cc_library(
visibility = [":friends"],
deps = [
":common",
- ":const_analysis",
":dump_graph",
":functionalize_control_flow",
":sharding_util",
@@ -355,28 +356,16 @@ tf_cc_test(
],
)
-cc_library(
- name = "const_analysis",
- srcs = ["const_analysis.cc"],
- hdrs = ["const_analysis.h"],
- deps = [
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:core_cpu_internal",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- ],
-)
-
tf_cc_test(
name = "const_analysis_test",
size = "small",
srcs = ["const_analysis_test.cc"],
deps = [
- ":const_analysis",
+ ":xla_compiler",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
+ "//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:ops",
"//tensorflow/core:test",
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index ab2f1e9a7a..0249500910 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/algorithm.h"
@@ -27,96 +28,18 @@ namespace tensorflow {
// compile-time constants.
Status BackwardsConstAnalysis(const Graph& g,
std::vector<bool>* compile_time_const_args) {
- // TODO(phawkins): annotate these on the kernel registrations, rather than
- // using a hard-coded list.
- // (operator, argument) pairs that must be compile-time constants.
- const std::unordered_multimap<string, string> compile_time_const_inputs = {
- {"All", "reduction_indices"},
- {"Any", "reduction_indices"},
- {"ArgMin", "dimension"},
- {"ArgMax", "dimension"},
- {"AvgPoolGrad", "orig_input_shape"},
- {"AvgPool3DGrad", "orig_input_shape"},
- {"BatchToSpace", "crops"},
- {"BatchToSpaceND", "block_shape"},
- {"BatchToSpaceND", "crops"},
- {"BroadcastArgs", "s0"},
- {"BroadcastArgs", "s1"},
- {"BroadcastGradientArgs", "s0"},
- {"BroadcastGradientArgs", "s1"},
- {"Concat", "concat_dim"},
- {"ConcatV2", "axis"},
- {"ConcatOffset", "concat_dim"},
- {"ConcatOffset", "shape"},
- {"Conv2DBackpropFilter", "filter_sizes"},
- {"Conv2DBackpropInput", "input_sizes"},
- {"Conv3DBackpropFilterV2", "filter_sizes"},
- {"Conv3DBackpropInputV2", "input_sizes"},
- {"Cumprod", "axis"},
- {"Cumsum", "axis"},
- {"DepthwiseConv2dNativeBackpropFilter", "filter_sizes"},
- {"DepthwiseConv2dNativeBackpropInput", "input_sizes"},
- {"DynamicStitch", "indices"},
- {"ExpandDims", "dim"},
- {"Fill", "dims"},
- {"GatherV2", "axis"},
- {"InvertPermutation", "x"},
- {"LinSpace", "start"},
- {"LinSpace", "stop"},
- {"LinSpace", "num"},
- {"Max", "reduction_indices"},
- {"Mean", "reduction_indices"},
- {"Min", "reduction_indices"},
- {"OneHot", "depth"},
- {"Pad", "paddings"},
- {"PadV2", "paddings"},
- {"MirrorPad", "paddings"},
- {"Multinomial", "num_samples"},
- {"Prod", "reduction_indices"},
- {"RandomStandardNormal", "shape"},
- {"RandomUniform", "shape"},
- {"RandomUniformInt", "shape"},
- {"Range", "start"},
- {"Range", "limit"},
- {"Range", "delta"},
- {"Reshape", "shape"},
- {"ResizeBilinear", "size"},
- {"ResourceStridedSliceAssign", "begin"},
- {"ResourceStridedSliceAssign", "end"},
- {"ResourceStridedSliceAssign", "strides"},
- {"Reverse", "dims"},
- {"ReverseV2", "axis"},
- {"Slice", "begin"},
- {"Slice", "size"},
- {"SpaceToBatch", "paddings"},
- {"SpaceToBatchND", "block_shape"},
- {"SpaceToBatchND", "paddings"},
- {"Split", "split_dim"},
- {"SplitV", "split_dim"},
- {"SplitV", "size_splits"},
- {"StackV2", "max_size"},
- {"StridedSlice", "begin"},
- {"StridedSlice", "end"},
- {"StridedSlice", "strides"},
- {"StridedSliceGrad", "shape"},
- {"StridedSliceGrad", "begin"},
- {"StridedSliceGrad", "end"},
- {"StridedSliceGrad", "strides"},
- {"Sum", "reduction_indices"},
- {"TensorArrayV3", "size"},
- {"TensorArraySplitV3", "lengths"},
- {"Tile", "multiples"},
- {"Transpose", "perm"}};
-
// Operators that don't look at the data of their inputs, just the shapes.
const std::unordered_set<string> metadata_ops = {
- "Rank", "Shape", "ShapeN", "Size",
+ "Rank",
+ "Shape",
+ "ShapeN",
+ "Size",
};
Status status;
std::unordered_set<Node*> must_be_const;
- auto visit = [&status, &metadata_ops, &compile_time_const_inputs,
- &must_be_const, compile_time_const_args](Node* node) {
+ auto visit = [&status, &metadata_ops, &must_be_const,
+ compile_time_const_args](Node* node) {
if (!status.ok()) return;
// If this is a metadata-only op, don't propagate the const requirement.
@@ -139,16 +62,17 @@ Status BackwardsConstAnalysis(const Graph& g,
}
// Mark any compile-time constant operator arguments as const.
- auto range = compile_time_const_inputs.equal_range(node->type_string());
- if (range.first == range.second) return;
+ const std::unordered_set<string>* const_inputs =
+ XlaOpRegistry::CompileTimeConstantInputs(node->type_string());
+ if (!const_inputs) return;
NameRangeMap input_name_ranges;
status =
NameRangesForNode(*node, node->op_def(), &input_name_ranges, nullptr);
if (!status.ok()) return;
- for (auto it = range.first; it != range.second; ++it) {
- auto name_range = input_name_ranges.find(it->second);
+ for (const string& input : *const_inputs) {
+ auto name_range = input_name_ranges.find(input);
if (name_range == input_name_ranges.end()) continue;
for (Edge const* edge : node->in_edges()) {
diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
index 21d3e64872..344a2ab2b6 100644
--- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
@@ -159,7 +159,8 @@ class BatchToSpaceNDOp : public XlaOpKernel {
block_shape, crops);
}
};
-REGISTER_XLA_OP(Name("BatchToSpaceND"), BatchToSpaceNDOp);
+REGISTER_XLA_OP(Name("BatchToSpaceND").CompileTimeConstInput("crops"),
+ BatchToSpaceNDOp);
class BatchToSpaceOp : public XlaOpKernel {
public:
@@ -181,7 +182,10 @@ class BatchToSpaceOp : public XlaOpKernel {
private:
int block_size_;
};
-REGISTER_XLA_OP(Name("BatchToSpace"), BatchToSpaceOp);
+REGISTER_XLA_OP(Name("BatchToSpace")
+ .CompileTimeConstInput("crops")
+ .CompileTimeConstInput("block_shape"),
+ BatchToSpaceOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
index bb031b8c47..ee2c920453 100644
--- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
@@ -65,7 +65,10 @@ class BCastArgsOp : public XlaOpKernel {
private:
TF_DISALLOW_COPY_AND_ASSIGN(BCastArgsOp);
};
-REGISTER_XLA_OP(Name("BroadcastArgs"), BCastArgsOp);
+REGISTER_XLA_OP(Name("BroadcastArgs")
+ .CompileTimeConstInput("s0")
+ .CompileTimeConstInput("s1"),
+ BCastArgsOp);
// Given shapes of two tensors, computes the reduction indices for the
// gradient computation.
@@ -121,7 +124,10 @@ class BCastGradArgsOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp);
};
-REGISTER_XLA_OP(Name("BroadcastGradientArgs"), BCastGradArgsOp);
+REGISTER_XLA_OP(Name("BroadcastGradientArgs")
+ .CompileTimeConstInput("s0")
+ .CompileTimeConstInput("s1"),
+ BCastGradArgsOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
index 592f3ecc3c..545aa364f9 100644
--- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
@@ -92,7 +92,8 @@ class CategoricalOp : public XlaOpKernel {
};
// TODO(b/68769717): Rename this sampler to Categorical.
-REGISTER_XLA_OP(Name("Multinomial"), CategoricalOp);
+REGISTER_XLA_OP(Name("Multinomial").CompileTimeConstInput("num_samples"),
+ CategoricalOp);
} // anonymous namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
index 73a4740e29..1a246e8df9 100644
--- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
@@ -84,8 +84,8 @@ class ConcatBaseOp : public XlaOpKernel {
in_shape.dims() == input_dims || (input_is_scalar && in_is_scalar),
errors::InvalidArgument(
"ConcatOp : Ranks of all input tensors should match: shape[0] = ",
- input_shape.DebugString(), " vs. shape[", i, "] = ",
- in_shape.DebugString()));
+ input_shape.DebugString(), " vs. shape[", i,
+ "] = ", in_shape.DebugString()));
if (in_shape.dims() == 0) {
// Inputs that come in as scalars must be reshaped to 1-vectors.
input_data.push_back(ctx->builder()->Reshape(handle, {1}));
@@ -117,8 +117,11 @@ class ConcatV2Op : public ConcatBaseOp {
: ConcatBaseOp(c, /* axis_index */ c->num_inputs() - 1) {}
};
-REGISTER_XLA_OP(Name("Concat"), ConcatOp);
-REGISTER_XLA_OP(Name("ConcatV2").TypeConstraint("Tidx", DT_INT32), ConcatV2Op);
+REGISTER_XLA_OP(Name("Concat").CompileTimeConstInput("concat_dim"), ConcatOp);
+REGISTER_XLA_OP(Name("ConcatV2")
+ .TypeConstraint("Tidx", DT_INT32)
+ .CompileTimeConstInput("axis"),
+ ConcatV2Op);
class ConcatOffsetOp : public XlaOpKernel {
public:
@@ -189,10 +192,10 @@ class ConcatOffsetOp : public XlaOpKernel {
} else {
const int32 inp0_element = inp0_literal.Get<int>({j});
const int32 inp_element = inp_literal.Get<int>({j});
- OP_REQUIRES(
- ctx, (inp0_element == inp_element),
- errors::InvalidArgument("input[", i, ",", j, "] mismatch: ",
- inp0_element, " vs. ", inp_element));
+ OP_REQUIRES(ctx, (inp0_element == inp_element),
+ errors::InvalidArgument("input[", i, ",", j,
+ "] mismatch: ", inp0_element,
+ " vs. ", inp_element));
out_vec(j) = 0;
}
}
@@ -202,7 +205,10 @@ class ConcatOffsetOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("ConcatOffset"), ConcatOffsetOp);
+REGISTER_XLA_OP(Name("ConcatOffset")
+ .CompileTimeConstInput("concat_dim")
+ .CompileTimeConstInput("shape"),
+ ConcatOffsetOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index aaddbe811c..da9be68732 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -445,21 +445,26 @@ class Conv2DBackpropInputOp : public ConvBackpropInputOp {
explicit Conv2DBackpropInputOp(OpKernelConstruction* ctx)
: ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {}
};
-REGISTER_XLA_OP(Name("Conv2DBackpropInput"), Conv2DBackpropInputOp);
+REGISTER_XLA_OP(
+ Name("Conv2DBackpropInput").CompileTimeConstInput("input_sizes"),
+ Conv2DBackpropInputOp);
class Conv3DBackpropInputOp : public ConvBackpropInputOp {
public:
explicit Conv3DBackpropInputOp(OpKernelConstruction* ctx)
: ConvBackpropInputOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {}
};
-REGISTER_XLA_OP(Name("Conv3DBackpropInputV2"), Conv3DBackpropInputOp);
+REGISTER_XLA_OP(
+ Name("Conv3DBackpropInputV2").CompileTimeConstInput("input_sizes"),
+ Conv3DBackpropInputOp);
class DepthwiseConv2DBackpropInputOp : public ConvBackpropInputOp {
public:
explicit DepthwiseConv2DBackpropInputOp(OpKernelConstruction* ctx)
: ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {}
};
-REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput"),
+REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput")
+ .CompileTimeConstInput("input_sizes"),
DepthwiseConv2DBackpropInputOp);
class ConvBackpropFilterOp : public XlaOpKernel {
@@ -644,7 +649,9 @@ class Conv2DBackpropFilterOp : public ConvBackpropFilterOp {
: ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {
}
};
-REGISTER_XLA_OP(Name("Conv2DBackpropFilter"), Conv2DBackpropFilterOp);
+REGISTER_XLA_OP(
+ Name("Conv2DBackpropFilter").CompileTimeConstInput("filter_sizes"),
+ Conv2DBackpropFilterOp);
class Conv3DBackpropFilterOp : public ConvBackpropFilterOp {
public:
@@ -652,14 +659,17 @@ class Conv3DBackpropFilterOp : public ConvBackpropFilterOp {
: ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {
}
};
-REGISTER_XLA_OP(Name("Conv3DBackpropFilterV2"), Conv3DBackpropFilterOp);
+REGISTER_XLA_OP(
+ Name("Conv3DBackpropFilterV2").CompileTimeConstInput("filter_sizes"),
+ Conv3DBackpropFilterOp);
class DepthwiseConv2DBackpropFilterOp : public ConvBackpropFilterOp {
public:
explicit DepthwiseConv2DBackpropFilterOp(OpKernelConstruction* ctx)
: ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {}
};
-REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter"),
+REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter")
+ .CompileTimeConstInput("filter_sizes"),
DepthwiseConv2DBackpropFilterOp);
} // namespace
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
index 7349dcb987..f2cd21ffb9 100644
--- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
@@ -72,22 +72,24 @@ class DynamicStitchOp : public XlaOpKernel {
XLAShapeToTensorShape(indices_input[input_num].shape(),
&indices_shape));
const TensorShape& data_shape = data_shapes[input_num];
- OP_REQUIRES(ctx, TensorShapeUtils::StartsWith(data_shape, indices_shape),
- errors::InvalidArgument(
- "data[", input_num, "].shape = ",
- data_shape.DebugString(), " does not start with indices[",
- input_num, "].shape = ", indices_shape.DebugString()));
- OP_REQUIRES(ctx,
- input_num == 0 || SameExtraShape(data0_shape, indices0_shape,
- data_shape, indices_shape),
- errors::InvalidArgument(
- "Need data[0].shape[", indices0_shape.dims(),
- ":] = data[", input_num, "].shape[", indices_shape.dims(),
- ":], got data[0].shape = ", data0_shape.DebugString(),
- ", data[", input_num, "].shape = ",
- data_shape.DebugString(), ", indices[0].shape = ",
- indices0_shape.DebugString(), ", indices[", input_num,
- "].shape = ", indices_shape.DebugString()));
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::StartsWith(data_shape, indices_shape),
+ errors::InvalidArgument("data[", input_num,
+ "].shape = ", data_shape.DebugString(),
+ " does not start with indices[", input_num,
+ "].shape = ", indices_shape.DebugString()));
+ OP_REQUIRES(
+ ctx,
+ input_num == 0 || SameExtraShape(data0_shape, indices0_shape,
+ data_shape, indices_shape),
+ errors::InvalidArgument(
+ "Need data[0].shape[", indices0_shape.dims(), ":] = data[",
+ input_num, "].shape[", indices_shape.dims(),
+ ":], got data[0].shape = ", data0_shape.DebugString(), ", data[",
+ input_num, "].shape = ", data_shape.DebugString(),
+ ", indices[0].shape = ", indices0_shape.DebugString(),
+ ", indices[", input_num,
+ "].shape = ", indices_shape.DebugString()));
OP_REQUIRES_OK(ctx,
XlaHelpers::ReshapeLiteral(indices_input[input_num],
@@ -159,8 +161,8 @@ class DynamicStitchOp : public XlaOpKernel {
indices0_shape.dims());
std::vector<int64> slice_limit(1 + data0_shape.dims() -
indices0_shape.dims());
- std::vector<int64> stride(1 + data0_shape.dims() -
- indices0_shape.dims(), 1);
+ std::vector<int64> stride(1 + data0_shape.dims() - indices0_shape.dims(),
+ 1);
for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) {
slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d);
}
@@ -198,8 +200,10 @@ class DynamicStitchOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("DynamicStitch"), DynamicStitchOp);
-REGISTER_XLA_OP(Name("ParallelDynamicStitch"), DynamicStitchOp);
+REGISTER_XLA_OP(Name("DynamicStitch").CompileTimeConstInput("indices"),
+ DynamicStitchOp);
+REGISTER_XLA_OP(Name("ParallelDynamicStitch").CompileTimeConstInput("indices"),
+ DynamicStitchOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
index 9e090fe01c..eaa13b8dfa 100644
--- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
@@ -69,7 +69,7 @@ class FillOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Fill"), FillOp);
+REGISTER_XLA_OP(Name("Fill").CompileTimeConstInput("dims"), FillOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
index e420f21ca3..70192cb324 100644
--- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
@@ -198,6 +198,7 @@ void GatherOpDynamicSlice::Compile(XlaOpKernelContext* context) {
}
REGISTER_XLA_OP(Name("Gather"), GatherOpDynamicSlice);
-REGISTER_XLA_OP(Name("GatherV2"), GatherOpDynamicSlice);
+REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstInput("axis"),
+ GatherOpDynamicSlice);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index d91ebb500b..c0b8f9c179 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -243,7 +243,8 @@ class ResizeBilinearOp : public XlaOpKernel {
bool align_corners_;
};
-REGISTER_XLA_OP(Name("ResizeBilinear"), ResizeBilinearOp);
+REGISTER_XLA_OP(Name("ResizeBilinear").CompileTimeConstInput("size"),
+ ResizeBilinearOp);
class ResizeBilinearGradOp : public XlaOpKernel {
public:
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
index e0dc1870f2..7bf4b435f5 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
@@ -80,7 +80,10 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
XlaArgMaxOp::XlaArgMaxOp(OpKernelConstruction* ctx)
: XlaArgMinMaxOp(ctx, /*is_min=*/false) {}
-REGISTER_XLA_OP(Name("ArgMax").Device(DEVICE_GPU_XLA_JIT), XlaArgMaxOp);
+REGISTER_XLA_OP(Name("ArgMax")
+ .Device(DEVICE_GPU_XLA_JIT)
+ .CompileTimeConstInput("dimension"),
+ XlaArgMaxOp);
namespace {
@@ -90,7 +93,7 @@ class XlaArgMinOp : public XlaArgMinMaxOp {
};
XlaArgMinOp::XlaArgMinOp(OpKernelConstruction* ctx)
: XlaArgMinMaxOp(ctx, /*is_min=*/true) {}
-REGISTER_XLA_OP(Name("ArgMin"), XlaArgMinOp);
+REGISTER_XLA_OP(Name("ArgMin").CompileTimeConstInput("dimension"), XlaArgMinOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
index 20946e247a..b1f3c3c298 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
@@ -56,10 +56,10 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
errors::InvalidArgument("dim must be < input rank (",
input_shape.dims(), "), but got: ", dim));
const int64 dim_size = input_shape.dim_size(dim);
- OP_REQUIRES(
- ctx, dim_size > 0,
- errors::InvalidArgument("Reduction axis ", dim, " is empty in shape: ",
- input_shape.DebugString()));
+ OP_REQUIRES(ctx, dim_size > 0,
+ errors::InvalidArgument(
+ "Reduction axis ", dim,
+ " is empty in shape: ", input_shape.DebugString()));
// The output shape is the input shape contracted along dim.
TensorShape output_shape;
@@ -113,9 +113,11 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(ArgMaxCustomCallOp);
};
-REGISTER_XLA_OP(
- Name("ArgMax").TypeConstraint("T", DT_FLOAT).Device(DEVICE_CPU_XLA_JIT),
- ArgMaxCustomCallOp);
+REGISTER_XLA_OP(Name("ArgMax")
+ .TypeConstraint("T", DT_FLOAT)
+ .Device(DEVICE_CPU_XLA_JIT)
+ .CompileTimeConstInput("dimension"),
+ ArgMaxCustomCallOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
index bea1d1600b..05a36a031a 100644
--- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
@@ -92,7 +92,8 @@ class MirrorPadOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(MirrorPadOp);
};
-REGISTER_XLA_OP(Name("MirrorPad"), MirrorPadOp);
+REGISTER_XLA_OP(Name("MirrorPad").CompileTimeConstInput("paddings"),
+ MirrorPadOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc
index 2a9cfcb2eb..9f7c991380 100644
--- a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc
@@ -76,7 +76,7 @@ class OneHotOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp);
};
-REGISTER_XLA_OP(Name("OneHot"), OneHotOp);
+REGISTER_XLA_OP(Name("OneHot").CompileTimeConstInput("depth"), OneHotOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
index d841bd37b3..791351637a 100644
--- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
@@ -83,8 +83,8 @@ class PadOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Pad"), PadOp);
-REGISTER_XLA_OP(Name("PadV2"), PadOp);
+REGISTER_XLA_OP(Name("Pad").CompileTimeConstInput("paddings"), PadOp);
+REGISTER_XLA_OP(Name("PadV2").CompileTimeConstInput("paddings"), PadOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
index 2b6053d19d..0b5a38967a 100644
--- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
@@ -455,14 +455,16 @@ class AvgPool2DGradOp : public AvgPoolGradOp {
errors::InvalidArgument("Invalid data format"));
}
};
-REGISTER_XLA_OP(Name("AvgPoolGrad"), AvgPool2DGradOp);
+REGISTER_XLA_OP(Name("AvgPoolGrad").CompileTimeConstInput("orig_input_shape"),
+ AvgPool2DGradOp);
class AvgPool3DGradOp : public AvgPoolGradOp {
public:
explicit AvgPool3DGradOp(OpKernelConstruction* ctx)
: AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {}
};
-REGISTER_XLA_OP(Name("AvgPool3DGrad"), AvgPool3DGradOp);
+REGISTER_XLA_OP(Name("AvgPool3DGrad").CompileTimeConstInput("orig_input_shape"),
+ AvgPool3DGradOp);
} // anonymous namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index 2421825ead..c0994c434b 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -52,7 +52,8 @@ class RandomUniformOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformOp);
};
-REGISTER_XLA_OP(Name("RandomUniform"), RandomUniformOp);
+REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstInput("shape"),
+ RandomUniformOp);
class RandomUniformIntOp : public XlaOpKernel {
public:
@@ -83,7 +84,8 @@ class RandomUniformIntOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformIntOp);
};
-REGISTER_XLA_OP(Name("RandomUniformInt"), RandomUniformIntOp);
+REGISTER_XLA_OP(Name("RandomUniformInt").CompileTimeConstInput("shape"),
+ RandomUniformIntOp);
class RandomStandardNormalOp : public XlaOpKernel {
public:
@@ -111,7 +113,8 @@ class RandomStandardNormalOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(RandomStandardNormalOp);
};
-REGISTER_XLA_OP(Name("RandomStandardNormal"), RandomStandardNormalOp);
+REGISTER_XLA_OP(Name("RandomStandardNormal").CompileTimeConstInput("shape"),
+ RandomStandardNormalOp);
class TruncatedNormalOp : public XlaOpKernel {
public:
@@ -183,7 +186,8 @@ class TruncatedNormalOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("TruncatedNormal"), TruncatedNormalOp);
+REGISTER_XLA_OP(Name("TruncatedNormal").CompileTimeConstInput("shape"),
+ TruncatedNormalOp);
} // anonymous namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
index 647b627408..03b13b2924 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
@@ -35,7 +35,7 @@ class SumOp : public XlaReductionOp {
}
};
-REGISTER_XLA_OP(Name("Sum"), SumOp);
+REGISTER_XLA_OP(Name("Sum").CompileTimeConstInput("reduction_indices"), SumOp);
class ProdOp : public XlaReductionOp {
public:
@@ -53,7 +53,8 @@ class ProdOp : public XlaReductionOp {
}
};
-REGISTER_XLA_OP(Name("Prod"), ProdOp);
+REGISTER_XLA_OP(Name("Prod").CompileTimeConstInput("reduction_indices"),
+ ProdOp);
class MinOp : public XlaReductionOp {
public:
@@ -73,7 +74,7 @@ class MinOp : public XlaReductionOp {
}
};
-REGISTER_XLA_OP(Name("Min"), MinOp);
+REGISTER_XLA_OP(Name("Min").CompileTimeConstInput("reduction_indices"), MinOp);
class MaxOp : public XlaReductionOp {
public:
@@ -93,7 +94,7 @@ class MaxOp : public XlaReductionOp {
}
};
-REGISTER_XLA_OP(Name("Max"), MaxOp);
+REGISTER_XLA_OP(Name("Max").CompileTimeConstInput("reduction_indices"), MaxOp);
class MeanOp : public XlaReductionOp {
public:
@@ -115,7 +116,8 @@ class MeanOp : public XlaReductionOp {
}
};
-REGISTER_XLA_OP(Name("Mean"), MeanOp);
+REGISTER_XLA_OP(Name("Mean").CompileTimeConstInput("reduction_indices"),
+ MeanOp);
class AllOp : public XlaReductionOp {
public:
@@ -133,7 +135,7 @@ class AllOp : public XlaReductionOp {
}
};
-REGISTER_XLA_OP(Name("All"), AllOp);
+REGISTER_XLA_OP(Name("All").CompileTimeConstInput("reduction_indices"), AllOp);
class AnyOp : public XlaReductionOp {
public:
@@ -151,7 +153,7 @@ class AnyOp : public XlaReductionOp {
}
};
-REGISTER_XLA_OP(Name("Any"), AnyOp);
+REGISTER_XLA_OP(Name("Any").CompileTimeConstInput("reduction_indices"), AnyOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
index 5952e75272..af4d64b159 100644
--- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
@@ -95,7 +95,7 @@ class ReshapeOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Reshape"), ReshapeOp);
+REGISTER_XLA_OP(Name("Reshape").CompileTimeConstInput("shape"), ReshapeOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
index bdfd066f01..17a345fc94 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
@@ -65,7 +65,7 @@ class ReverseOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Reverse"), ReverseOp);
+REGISTER_XLA_OP(Name("Reverse").CompileTimeConstInput("dims"), ReverseOp);
class ReverseV2Op : public XlaOpKernel {
public:
@@ -103,7 +103,7 @@ class ReverseV2Op : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("ReverseV2"), ReverseV2Op);
+REGISTER_XLA_OP(Name("ReverseV2").CompileTimeConstInput("axis"), ReverseV2Op);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
index 650f8c7dc8..ee4a94164c 100644
--- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
@@ -129,13 +129,19 @@ class CumsumOp : public ScanOp {
public:
explicit CumsumOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/true) {}
};
-REGISTER_XLA_OP(Name("Cumsum").TypeConstraint("T", kScanOpTypes), CumsumOp);
+REGISTER_XLA_OP(Name("Cumsum")
+ .TypeConstraint("T", kScanOpTypes)
+ .CompileTimeConstInput("axis"),
+ CumsumOp);
class CumprodOp : public ScanOp {
public:
explicit CumprodOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/false) {}
};
-REGISTER_XLA_OP(Name("Cumprod").TypeConstraint("T", kScanOpTypes), CumprodOp);
+REGISTER_XLA_OP(Name("Cumprod")
+ .TypeConstraint("T", kScanOpTypes)
+ .CompileTimeConstInput("axis"),
+ CumprodOp);
} // anonymous namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
index c2b0e1bb4c..2c31f8d908 100644
--- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
@@ -138,7 +138,11 @@ class RangeOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Range"), RangeOp);
+REGISTER_XLA_OP(Name("Range")
+ .CompileTimeConstInput("start")
+ .CompileTimeConstInput("limit")
+ .CompileTimeConstInput("delta"),
+ RangeOp);
class LinSpaceOp : public XlaOpKernel {
public:
@@ -207,7 +211,11 @@ class LinSpaceOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("LinSpace"), LinSpaceOp);
+REGISTER_XLA_OP(Name("LinSpace")
+ .CompileTimeConstInput("start")
+ .CompileTimeConstInput("stop")
+ .CompileTimeConstInput("num"),
+ LinSpaceOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
index e205fadd2b..8fb7a74310 100644
--- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
@@ -150,7 +150,7 @@ class ExpandDimsOp : public XlaOpKernel {
ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape));
}
};
-REGISTER_XLA_OP(Name("ExpandDims"), ExpandDimsOp);
+REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstInput("dim"), ExpandDimsOp);
class SqueezeOp : public XlaOpKernel {
public:
diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
index fbe8c78d8f..be1e97bf26 100644
--- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
@@ -112,7 +112,9 @@ class SliceOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Slice"), SliceOp);
+REGISTER_XLA_OP(
+ Name("Slice").CompileTimeConstInput("begin").CompileTimeConstInput("size"),
+ SliceOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
index 83a87f19a7..01b46e160d 100644
--- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
@@ -162,7 +162,10 @@ class SpaceToBatchNDOp : public XlaOpKernel {
block_shape, paddings);
}
};
-REGISTER_XLA_OP(Name("SpaceToBatchND"), SpaceToBatchNDOp);
+REGISTER_XLA_OP(Name("SpaceToBatchND")
+ .CompileTimeConstInput("paddings")
+ .CompileTimeConstInput("block_shape"),
+ SpaceToBatchNDOp);
class SpaceToBatchOp : public XlaOpKernel {
public:
@@ -184,7 +187,8 @@ class SpaceToBatchOp : public XlaOpKernel {
private:
int block_size_;
};
-REGISTER_XLA_OP(Name("SpaceToBatch"), SpaceToBatchOp);
+REGISTER_XLA_OP(Name("SpaceToBatch").CompileTimeConstInput("paddings"),
+ SpaceToBatchOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc
index 795eb1794f..79c435c90a 100644
--- a/tensorflow/compiler/tf2xla/kernels/split_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc
@@ -103,7 +103,7 @@ class SplitOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Split"), SplitOp);
+REGISTER_XLA_OP(Name("Split").CompileTimeConstInput("split_dim"), SplitOp);
class SplitVOp : public XlaOpKernel {
public:
@@ -142,8 +142,9 @@ class SplitVOp : public XlaOpKernel {
int neg_one_dim = -1;
std::vector<int64> split_sizes_vec(num_split, -1);
const TensorShape split_size_shape = ctx->InputShape(1);
- OP_REQUIRES(ctx, split_size_shape.dims() == 1 &&
- split_size_shape.num_elements() == num_split,
+ OP_REQUIRES(ctx,
+ split_size_shape.dims() == 1 &&
+ split_size_shape.num_elements() == num_split,
errors::InvalidArgument(
"shape of tensor describing "
" the output must have dimension 1 and the same "
@@ -171,10 +172,11 @@ class SplitVOp : public XlaOpKernel {
}
OP_REQUIRES(
- ctx, (neg_one_dim == -1 &&
- total_split_size == input_shape.dim_size(split_dim)) ||
- (neg_one_dim >= 0 &&
- total_split_size <= input_shape.dim_size(split_dim)),
+ ctx,
+ (neg_one_dim == -1 &&
+ total_split_size == input_shape.dim_size(split_dim)) ||
+ (neg_one_dim >= 0 &&
+ total_split_size <= input_shape.dim_size(split_dim)),
errors::InvalidArgument("Determined shape must either match "
"input shape along split_dim exactly if "
"fully specified, or be less than the size of "
@@ -206,7 +208,10 @@ class SplitVOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("SplitV"), SplitVOp);
+REGISTER_XLA_OP(Name("SplitV")
+ .CompileTimeConstInput("split_dim")
+ .CompileTimeConstInput("size_splits"),
+ SplitVOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
index 8013ece861..c912876e65 100644
--- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
@@ -129,7 +129,7 @@ class StackOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(StackOp);
};
-REGISTER_XLA_OP(Name("StackV2"), StackOp);
+REGISTER_XLA_OP(Name("StackV2").CompileTimeConstInput("max_size"), StackOp);
class StackPushOp : public XlaOpKernel {
public:
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index 6af4bd0496..f0525a5fb8 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -106,7 +106,11 @@ class StridedSliceOp : public XlaOpKernel {
DataType index_type_;
};
-REGISTER_XLA_OP(Name("StridedSlice"), StridedSliceOp);
+REGISTER_XLA_OP(Name("StridedSlice")
+ .CompileTimeConstInput("begin")
+ .CompileTimeConstInput("end")
+ .CompileTimeConstInput("strides"),
+ StridedSliceOp);
class StridedSliceGradOp : public XlaOpKernel {
public:
@@ -211,7 +215,12 @@ class StridedSliceGradOp : public XlaOpKernel {
DataType index_type_;
};
-REGISTER_XLA_OP(Name("StridedSliceGrad"), StridedSliceGradOp);
+REGISTER_XLA_OP(Name("StridedSliceGrad")
+ .CompileTimeConstInput("shape")
+ .CompileTimeConstInput("begin")
+ .CompileTimeConstInput("end")
+ .CompileTimeConstInput("strides"),
+ StridedSliceGradOp);
class StridedSliceAssignOp : public XlaOpKernel {
public:
@@ -320,7 +329,11 @@ class StridedSliceAssignOp : public XlaOpKernel {
DataType index_type_;
};
-REGISTER_XLA_OP(Name("ResourceStridedSliceAssign"), StridedSliceAssignOp);
+REGISTER_XLA_OP(Name("ResourceStridedSliceAssign")
+ .CompileTimeConstInput("begin")
+ .CompileTimeConstInput("end")
+ .CompileTimeConstInput("strides"),
+ StridedSliceAssignOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 8a742ff11c..9224072a3c 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -192,7 +192,8 @@ class TensorArrayOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp);
};
-REGISTER_XLA_OP(Name("TensorArrayV3"), TensorArrayOp);
+REGISTER_XLA_OP(Name("TensorArrayV3").CompileTimeConstInput("size"),
+ TensorArrayOp);
class TensorArrayWriteOp : public XlaOpKernel {
public:
@@ -414,8 +415,8 @@ class TensorArrayScatterOp : public XlaOpKernel {
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto index = b->Slice(indices, {i}, {i + 1}, {1});
auto start_indices =
- b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
- xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
+ b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
+ xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
}
}
@@ -537,7 +538,8 @@ class TensorArraySplitOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySplitOp);
};
-REGISTER_XLA_OP(Name("TensorArraySplitV3"), TensorArraySplitOp);
+REGISTER_XLA_OP(Name("TensorArraySplitV3").CompileTimeConstInput("lengths"),
+ TensorArraySplitOp);
class TensorArraySizeOp : public XlaOpKernel {
public:
diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
index 9ee6bd8925..9aefcd4fc7 100644
--- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
@@ -122,7 +122,7 @@ class TileOp : public XlaOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(TileOp);
};
-REGISTER_XLA_OP(Name("Tile"), TileOp);
+REGISTER_XLA_OP(Name("Tile").CompileTimeConstInput("multiples"), TileOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
index 2fc5d40d10..5c17b7fbf0 100644
--- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
@@ -72,8 +72,9 @@ class TransposeOp : public XlaOpKernel {
}
}
for (int i = 0; i < dims; ++i) {
- OP_REQUIRES(ctx, bits[i], errors::InvalidArgument(
- i, " is missing from 'perm' argument."));
+ OP_REQUIRES(
+ ctx, bits[i],
+ errors::InvalidArgument(i, " is missing from 'perm' argument."));
}
// 0-D, 1-D, and identity transposes do nothing.
@@ -87,7 +88,7 @@ class TransposeOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Transpose"), TransposeOp);
+REGISTER_XLA_OP(Name("Transpose").CompileTimeConstInput("perm"), TransposeOp);
// InvertPermutation frequently forms part of the gradient of Transpose.
//
@@ -103,8 +104,9 @@ class InvertPermutationOp : public XlaOpKernel {
explicit InvertPermutationOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- OP_REQUIRES(ctx, FastBoundsCheck(ctx->InputShape(0).num_elements(),
- std::numeric_limits<int32>::max()),
+ OP_REQUIRES(ctx,
+ FastBoundsCheck(ctx->InputShape(0).num_elements(),
+ std::numeric_limits<int32>::max()),
errors::InvalidArgument("permutation of nonnegative int32s "
"must have <= int32 max elements"));
@@ -128,7 +130,9 @@ class InvertPermutationOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("InvertPermutation").TypeConstraint("T", DT_INT32),
+REGISTER_XLA_OP(Name("InvertPermutation")
+ .TypeConstraint("T", DT_INT32)
+ .CompileTimeConstInput("x"),
InvertPermutationOp);
} // namespace
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index faf47434b5..97bb100fb1 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -83,6 +83,11 @@ XlaOpRegistry::~XlaOpRegistry() = default;
return false;
}
}
+ if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) {
+ LOG(WARNING) << "Registrations of " << x.name
+ << " have incompatible compile time constant inputs.";
+ return false;
+ }
return true;
}
@@ -263,6 +268,17 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
return kernels;
}
+/* static */ const std::unordered_set<string>*
+XlaOpRegistry::CompileTimeConstantInputs(const string& op) {
+ XlaOpRegistry& registry = Instance();
+ mutex_lock lock(registry.mutex_);
+ auto it = registry.ops_.find(op);
+ if (it == registry.ops_.end()) {
+ return nullptr;
+ }
+ return &it->second->compile_time_constant_inputs;
+}
+
std::vector<string> XlaOpRegistry::BackendNames() {
std::vector<string> names;
XlaOpRegistry& registry = Instance();
@@ -337,6 +353,12 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
return *this;
}
+XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput(
+ StringPiece input_name) {
+ registration_->compile_time_constant_inputs.insert(input_name.ToString());
+ return *this;
+}
+
std::unique_ptr<XlaOpRegistry::OpRegistration> XlaOpRegistrationBuilder::Build(
XlaOpRegistry::Factory factory) {
registration_->factory = factory;
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index 8bfd9758f7..ff7453194a 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -128,6 +128,11 @@ class XlaOpRegistry {
const string& compilation_device_name,
bool include_compilation_only_kernels);
+ // Returns the set of compile-time constant inputs to 'op'. Returns nullptr
+ // if the op is not registered.
+ static const std::unordered_set<string>* CompileTimeConstantInputs(
+ const string& op);
+
private:
friend class XlaBackendRegistrar;
friend class XlaOpRegistrar;
@@ -181,6 +186,9 @@ class XlaOpRegistry {
bool has_device_whitelist = false;
std::unordered_set<string> device_whitelist;
+ // Names of arguments that must be compile-time constants.
+ std::unordered_set<string> compile_time_constant_inputs;
+
// Factory used to build OpKernels that perform symbolic execution.
Factory factory;
};
@@ -242,6 +250,9 @@ class XlaOpRegistrationBuilder {
// Allow DT_RESOURCE types for type parameters.
XlaOpRegistrationBuilder& AllowResourceTypes();
+ // Mark 'input_name' as an argument whose value must be known at compile-time.
+ XlaOpRegistrationBuilder& CompileTimeConstInput(StringPiece input_name);
+
std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
XlaOpRegistry::Factory factory);