aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc6
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h4
-rw-r--r--tensorflow/compiler/xla/literal_util.cc10
-rw-r--r--tensorflow/compiler/xla/literal_util_test.cc57
-rw-r--r--tensorflow/compiler/xla/service/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc7
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.h2
-rw-r--r--tensorflow/compiler/xla/service/computation_placer.cc9
-rw-r--r--tensorflow/compiler/xla/service/computation_placer.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc9
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc17
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h13
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc23
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc28
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h3
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc33
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc6
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/multidimensional_slice_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/params_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc51
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc3
-rw-r--r--tensorflow/compiler/xla/util.h18
-rw-r--r--tensorflow/compiler/xla/xla_data.proto3
34 files changed, 115 insertions, 270 deletions
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index dcc313707b..735a69d596 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -256,8 +256,7 @@ void ComputationBuilder::CheckSameShape(const ComputationDataHandle& lhs,
ComputationDataHandle ComputationBuilder::Slice(
const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> stride) {
+ tensorflow::gtl::ArraySlice<int64> limit_indices) {
if (!first_error_.ok() || !PrepareComputation().ok()) {
return ComputationDataHandle();
}
@@ -270,9 +269,6 @@ ComputationDataHandle ComputationBuilder::Slice(
for (int64 index : limit_indices) {
request.add_limit_indices(index);
}
- for (int64 index : stride) {
- request.add_stride(index);
- }
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_slice_request() = request;
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index b411346459..5dceb03281 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -211,11 +211,9 @@ class ComputationBuilder {
//
// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
// range notation.
- // The stride parameter determines the stride over the slice
ComputationDataHandle Slice(const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> stride);
+ tensorflow::gtl::ArraySlice<int64> limit_indices);
// Enqueues a slice operation onto the computation that slices the 'operand'
// from dynamic start indices which are passed in 'start_indices'.
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index b6bd1158d2..1b125e3596 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -1205,7 +1205,11 @@ void Literal::Resize<double>(int64 num_elements, double value) {
template <>
void Literal::Resize<half>(int64 num_elements, half value) {
CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
- mutable_f16s()->resize(num_elements, value);
+ mutable_f16s()->resize(num_elements * sizeof(half));
+ auto data = GetMutableArraySlice<half>();
+ for (int i = 0; i < num_elements; i++) {
+ data[i] = value;
+ }
}
template <typename RepeatedFieldT, typename NativeT>
@@ -1248,7 +1252,7 @@ LiteralProto Literal::ToProto() const {
case F16:
*proto.mutable_f16s() =
string(reinterpret_cast<const char*>(f16s_.data()),
- f16s_.size() * sizeof(half));
+ f16s_.size() / sizeof(half));
break;
case F32:
CopyToRepeatedField(proto.mutable_f32s(), f32s());
@@ -1304,7 +1308,7 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
const string& s(literal_proto.f16s());
CHECK_EQ(0, s.size() % sizeof(half));
f16s_ = std::vector<half>(s.size() / sizeof(half));
- memcpy(f16s_.data(), s.data(), s.size());
+ memcpy(f16s_.data(), s.data(), s.size() / sizeof(half));
break;
}
case F32:
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc
index 5a550ef4c6..ffae623b0c 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_util_test.cc
@@ -939,62 +939,5 @@ TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
}
}
-// Note that f16 is currently stored in a byte array in little endian byte order
-TEST_F(LiteralUtilTest, ToProto_f16) {
- half h1(1.0f);
- half h2(2.0f);
-
- auto m = Literal::CreateR2<half>({{h1, h2}, {h2, h1}});
- Literal* l = m.get();
- EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape()));
- EXPECT_EQ(4, l->f16s().size());
- EXPECT_EQ(4, l->f16s_size());
-
- LiteralProto p = l->ToProto();
- EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape()));
- EXPECT_EQ(8, p.f16s().size());
- const char* d = p.f16s().data();
- EXPECT_EQ(d[0], 0);
- EXPECT_EQ(d[1], 0x3C);
- EXPECT_EQ(d[2], 0);
- EXPECT_EQ(d[3], 0x40);
- EXPECT_EQ(d[4], 0);
- EXPECT_EQ(d[5], 0x40);
- EXPECT_EQ(d[6], 0);
- EXPECT_EQ(d[7], 0x3C);
-}
-
-// Note that f16 is currently stored in a byte array in little endian byte order
-TEST_F(LiteralUtilTest, CopyFromProto_f16) {
- half h1(1.0f);
- half h2(2.0f);
-
- const char half_vals[8] = {
- 0x00, 0x3C, 0x00, 0x40, 0x00, 0x40, 0x00, 0x3C
- };
- LiteralProto p;
- p.mutable_shape()->set_element_type(F16);
- p.mutable_shape()->clear_dimensions();
- p.mutable_shape()->add_dimensions(4);
- p.clear_f16s();
- p.set_f16s(half_vals, 8);
-
-
- Literal literal(p);
- ASSERT_EQ(4, literal.f16s_size());
- ASSERT_EQ(h1, literal.f16s(0));
- ASSERT_EQ(h2, literal.f16s(1));
- ASSERT_EQ(h2, literal.f16s(2));
- ASSERT_EQ(h1, literal.f16s(3));
-
- const std::vector<half>& r = literal.f16s();
- ASSERT_EQ(4, r.size());
- ASSERT_EQ(h1, r[0]);
- ASSERT_EQ(h2, r[1]);
- ASSERT_EQ(h2, r[2]);
- ASSERT_EQ(h1, r[3]);
-}
-
-
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 99b1337b11..718a2d798c 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -90,6 +90,8 @@ cc_library(
":hlo_query",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status",
+ "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 5709ac3067..0187c09d7b 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -855,7 +855,6 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
// Second, construct the slice instruction to perform the negative padding.
std::vector<int64> start_indices;
std::vector<int64> end_indices;
- std::vector<int64> strides;
for (int64 i = 0; i < pad->padding_config().dimensions_size(); ++i) {
const PaddingConfig::PaddingConfigDimension& padding_dimension =
pad->padding_config().dimensions(i);
@@ -869,18 +868,16 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
}
start_indices.push_back(start);
end_indices.push_back(end);
- strides.push_back(1);
}
// Verify that the slice shape matches the pad shape.
TF_ASSIGN_OR_RETURN(Shape inferred_slice_shape,
ShapeInference::InferSliceShape(
- nonzero_pad_shape, start_indices, end_indices,
- strides));
+ nonzero_pad_shape, start_indices, end_indices));
TF_RET_CHECK(ShapeUtil::Compatible(inferred_slice_shape, pad->shape()));
std::unique_ptr<HloInstruction> slice = HloInstruction::CreateSlice(
- pad->shape(), nonzero_pad, start_indices, end_indices, strides);
+ pad->shape(), nonzero_pad, start_indices, end_indices);
return ReplaceWithNewInstruction(pad, std::move(slice));
}
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 7e52c8fb0c..0792006ddb 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -520,7 +520,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
HloInstruction* empty_slice =
builder.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1}));
+ ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}));
Shape result_shape = ShapeUtil::MakeShape(F32, {3 * kParamLength});
builder.AddInstruction(HloInstruction::CreateConcatenate(
result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0));
@@ -551,7 +551,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
HloInstruction* empty_slice =
builder.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1}));
+ ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}));
Shape result_shape = ShapeUtil::MakeShape(F32, {0});
builder.AddInstruction(HloInstruction::CreateConcatenate(
result_shape, {empty_literal, empty_slice}, 0));
@@ -1132,7 +1132,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
builder.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0},
- /*limit_indices=*/{dim0, dim1}, /*slices=*/{1, 1}));
+ /*limit_indices=*/{dim0, dim1}));
HloModule module(TestName());
HloComputation* computation = module.AddEntryComputation(builder.Build());
@@ -1537,7 +1537,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3});
HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
- slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1}));
+ slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}));
HloModule module(TestName());
auto computation = module.AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 56568fd446..c498b86dd4 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -731,7 +731,7 @@ TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) {
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
auto slice = builder.AddInstruction(
- HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
+ HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
@@ -763,7 +763,7 @@ TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) {
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
auto slice = builder.AddInstruction(
- HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
+ HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast}));
@@ -800,7 +800,7 @@ TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) {
auto tuple_element = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(f32vec100_, tuple, 0));
auto slice = builder.AddInstruction(
- HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10}, {1}));
+ HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10}));
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast}));
@@ -835,7 +835,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) {
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
// Slice output is 10 elements.
auto slice = builder.AddInstruction(
- HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
+ HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
// Broadcast output is 40 elements.
auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
@@ -867,7 +867,7 @@ TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) {
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
auto slice = builder.AddInstruction(
- HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
+ HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
// Broadcast output is 40 elements.
auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {10, 10}), slice, {0}));
@@ -904,7 +904,7 @@ TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) {
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
// Slice output is 10 elements.
auto slice = builder.AddInstruction(
- HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}, {1}));
+ HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
// Broadcast output is 40 elements.
auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
index a5f7cc0aeb..a31e9b1782 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -588,7 +588,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
if (update_uses_tuple_element1) {
// Create a slice instruction as an additional user of 'gte1'.
slice = builder.AddInstruction(
- HloInstruction::CreateSlice(update_shape, gte1, {0}, {3}, {1}));
+ HloInstruction::CreateSlice(update_shape, gte1, {0}, {3}));
update = builder.AddInstruction(HloInstruction::CreateBinary(
update_shape, HloOpcode::kAdd, update, slice));
}
diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h
index 0a1911cbd1..dd00c58240 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.h
+++ b/tensorflow/compiler/xla/service/compile_only_service.h
@@ -55,7 +55,7 @@ class CompileOnlyService : public Service {
// Override Service methods that require or imply the existence of an
// execute backend. Note that this does not include TransferToClient, as
- // computing constants produces global data that we may wish to transfer.
+ // computing contants produces global data that we may wish to transfer.
tensorflow::Status Execute(const ExecuteRequest* arg,
ExecuteResponse* result) override {
return Unimplemented("CompileOnlyService does not support execution.");
diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc
index cdfa30dd9a..cdf277581f 100644
--- a/tensorflow/compiler/xla/service/computation_placer.cc
+++ b/tensorflow/compiler/xla/service/computation_placer.cc
@@ -49,18 +49,17 @@ Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const {
return Status::OK();
}
-/* static */ StatusOr<std::unique_ptr<DeviceAssignment>>
-DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) {
+/* static */ StatusOr<DeviceAssignment> DeviceAssignment::Deserialize(
+ const DeviceAssignmentProto& proto) {
TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count());
- auto assignment = MakeUnique<DeviceAssignment>(proto.replica_count(),
- proto.computation_count());
+ DeviceAssignment assignment(proto.replica_count(), proto.computation_count());
for (int computation = 0; computation < proto.computation_count();
++computation) {
const auto& computation_device = proto.computation_devices(computation);
TF_RET_CHECK(computation_device.replica_device_ids_size() ==
proto.replica_count());
for (int replica = 0; replica < proto.replica_count(); ++replica) {
- (*assignment)(replica, computation) =
+ assignment(replica, computation) =
computation_device.replica_device_ids(replica);
}
}
diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h
index 7d9abcd100..4d26d6bb85 100644
--- a/tensorflow/compiler/xla/service/computation_placer.h
+++ b/tensorflow/compiler/xla/service/computation_placer.h
@@ -49,11 +49,7 @@ class DeviceAssignment : public Array2D<int> {
// Protocol buffer serialization and deserialization.
Status Serialize(DeviceAssignmentProto* proto) const;
-
- // Return a std::unique_ptr<DeviceAssignment> instead of a DeviceAssignment
- // directly because one of the supported TF platforms (mac) does not compile
- // due to a StatusOr of an incomplete type (DeviceAssignment).
- static StatusOr<std::unique_ptr<DeviceAssignment>> Deserialize(
+ static StatusOr<DeviceAssignment> Deserialize(
const DeviceAssignmentProto& proto);
};
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 759d27e1f3..da8d983e1a 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -359,6 +359,7 @@ Status AppendIRToFile(const string& file_name, const string& ir_module_string) {
StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
std::unique_ptr<HloModule> module, HloDumper dump_hlo,
se::StreamExecutor* stream_exec) {
+ VLOG(1) << "Compiling: " << module->name();
TF_RET_CHECK(stream_exec != nullptr);
std::call_once(llvm_command_line_options_initialized,
&InitializeLLVMCommandLineOptions, module->config());
@@ -403,6 +404,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
module->config().debug_options().xla_dump_debug_json_to();
if (CpuParallelBackendRequested(module->config())) {
+ VLOG(1) << "Using parallel cpu backend";
+
// Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation.
// DependencyHloOrdering is used for the parallel emitter because the order
@@ -497,6 +500,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
.set_ir_module_string(ir_module_string);
}
} else {
+ VLOG(1) << "Using sequential cpu backend";
+
// Select an order for emitting the HLO instructions for each
// computation. Using this sequence enables tighter buffer liveness analysis
// and reduced memory usage (as compared to using DependencyHloOrdering).
@@ -562,6 +567,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
}
}
+ VLOG(1) << "Compilation finished";
return std::move(cpu_executable);
}
@@ -663,6 +669,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
std::vector<std::unique_ptr<AotCompilationResult>> results;
for (size_t i = 0; i < modules.size(); ++i) {
HloModule* module = modules[i].get();
+ VLOG(1) << "Compiling ahead-of-time: " << module->name();
TF_RETURN_IF_ERROR(RunHloPasses(module, dump_hlo));
@@ -741,6 +748,8 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
std::move(object_file_data), std::move(buffer_sizes),
result_slice.index()));
}
+
+ VLOG(1) << "Compilation finished";
return std::move(results);
}
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index db0a8b36cd..5b21ae3d2a 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -949,20 +949,9 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
IrArray::Index sliced_index(index.size());
for (int i = 0; i < index.size(); ++i) {
- int64 stride = hlo->slice_stride(i);
- if (stride != 1) {
- sliced_index[i] = ir_builder_->CreateAdd(
- ir_builder_->CreateMul(
- index[i], llvm::ConstantInt::get(index[i]->getType(),
- stride)),
- llvm::ConstantInt::get(index[i]->getType(),
- hlo->slice_starts(i)));
- } else {
- sliced_index[i] = ir_builder_->CreateAdd(
- index[i],
- llvm::ConstantInt::get(index[i]->getType(),
- hlo->slice_starts(i)));
- }
+ sliced_index[i] = ir_builder_->CreateAdd(
+ index[i], llvm::ConstantInt::get(index[i]->getType(),
+ hlo->slice_starts(i)));
}
return operand_to_generator.at(hlo->operand(0))(sliced_index);
};
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index b8c6162084..4e130de311 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -80,7 +80,6 @@ HloInstruction* MaybePaddedAndSlicedInput(
std::vector<int64> start_indices(input->shape().dimensions_size(), 0);
std::vector<int64> limit_indices(input->shape().dimensions().begin(),
input->shape().dimensions().end());
- std::vector<int64> strides(input->shape().dimensions_size(), 1);
for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) {
int64 dim = conv_dnums.spatial_dimensions(i);
// If dimension "dim" has negative padding, increase the start index or
@@ -93,9 +92,9 @@ HloInstruction* MaybePaddedAndSlicedInput(
input = computation->AddInstruction(HloInstruction::CreateSlice(
ShapeInference::InferSliceShape(input->shape(), start_indices,
- limit_indices, strides)
+ limit_indices)
.ConsumeValueOrDie(),
- input, start_indices, limit_indices, strides));
+ input, start_indices, limit_indices));
}
return input;
@@ -355,8 +354,6 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
std::vector<int64> limit_indices(
new_backward_conv->shape().dimensions().begin(),
new_backward_conv->shape().dimensions().end());
- std::vector<int64> strides(new_backward_conv->shape().dimensions_size(),
- 1LL);
for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
int64 padding_low = backward_conv->window().dimensions(i).padding_low();
int64 padding_high = backward_conv->window().dimensions(i).padding_high();
@@ -376,13 +373,13 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
// Replace the old backward convolution with the slice.
CHECK(ShapeUtil::Compatible(
ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices,
- limit_indices, strides)
+ limit_indices)
.ConsumeValueOrDie(),
backward_conv->shape()));
TF_CHECK_OK(computation->ReplaceWithNewInstruction(
backward_conv,
HloInstruction::CreateSlice(backward_conv->shape(), new_backward_conv,
- start_indices, limit_indices, strides)));
+ start_indices, limit_indices)));
return true;
}
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index 1c60b06ddd..a643bc4076 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -147,7 +147,6 @@ TEST_F(HloConstantFoldingTest, Slice) {
const int64 dimensions[] = {11, 8, 7, 5, 9};
const int64 slice_start[] = {4, 2, 3, 1, 5};
const int64 slice_limits[] = {10, 8, 6, 5, 9};
- const int64 slice_strides[] = {1, 1, 1, 1, 1};
TF_ASSIGN_OR_ASSERT_OK(auto literal,
LiteralTestUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
@@ -155,7 +154,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
HloInstruction::CreateConstant(std::move(literal)));
Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4});
builder.AddInstruction(HloInstruction::CreateSlice(
- shape, literal_instruction, slice_start, slice_limits, slice_strides));
+ shape, literal_instruction, slice_start, slice_limits));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 9117ab9653..99b73dea29 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -306,13 +306,11 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides) {
+ tensorflow::gtl::ArraySlice<int64> limit_indices) {
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSlice, shape));
instruction->AppendOperand(operand);
instruction->slice_starts_.assign(start_indices.begin(), start_indices.end());
instruction->slice_limits_.assign(limit_indices.begin(), limit_indices.end());
- instruction->slice_strides_.assign(strides.begin(), strides.end());
return instruction;
}
@@ -854,8 +852,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
return CreateReshape(shape, new_operands[0]);
case HloOpcode::kSlice:
CHECK_EQ(new_operands.size(), 1);
- return CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_,
- slice_strides_);
+ return CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_);
case HloOpcode::kDynamicSlice:
return CreateDynamicSlice(shape, new_operands[0], new_operands[1],
dynamic_slice_sizes_);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index d29c0935fc..37cbb0b769 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -174,8 +174,7 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateSlice(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices,
- tensorflow::gtl::ArraySlice<int64> strides);
+ tensorflow::gtl::ArraySlice<int64> limit_indices);
// Creates a slice instruction, where the first operand is sliced by
// start indices specified in the second operand, and by size specfied in
@@ -663,15 +662,6 @@ class HloInstruction {
return slice_limits_;
}
- // Returns the stride in the given dimension for a slice node.
- //
- // Precondition: opcode() == HloOpcode::kSlice
- int64 slice_stride(int64 dimension) const {
- CHECK_EQ(HloOpcode::kSlice, opcode_);
- return slice_strides_[dimension];
- }
- const std::vector<int64>& slice_strides() const { return slice_strides_; }
-
// Returns the size of the slice in the given dimension for a dynamic
// slice node.
//
@@ -917,7 +907,6 @@ class HloInstruction {
// Describes the [begin, end) index range for a slice.
std::vector<int64> slice_starts_;
std::vector<int64> slice_limits_;
- std::vector<int64> slice_strides_;
// The bit sizes for a reduce-precision operation.
int32 exponent_bits_;
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index 1a861cd16b..8a1e705711 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -67,8 +67,7 @@ class HloRematerializationTest : public HloTestBase {
/*dimension=*/0));
auto slice_1 = builder.AddInstruction(HloInstruction::CreateSlice(
vec1_shape_, concat_1, /*start_indices=*/{0},
- /*limit_indices=*/{1},
- /*strides=*/{1}));
+ /*limit_indices=*/{1}));
auto concat_2 = builder.AddInstruction(HloInstruction::CreateConcatenate(
ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, slice_1},
/*dimension=*/0));
@@ -76,8 +75,7 @@ class HloRematerializationTest : public HloTestBase {
// which is necessary to use this computation in a while.
builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat_2,
/*start_indices=*/{0},
- /*limit_indices=*/{1},
- /*strides=*/{1}));
+ /*limit_indices=*/{1}));
return builder.Build();
}
@@ -105,8 +103,7 @@ class HloRematerializationTest : public HloTestBase {
HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
auto slice_1 = builder.AddInstruction(
HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0},
- /*limit_indices=*/{1},
- /*strides=*/{1}));
+ /*limit_indices=*/{1}));
auto while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
vec1_shape_, while_cond, while_body, slice_1));
auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate(
@@ -114,8 +111,7 @@ class HloRematerializationTest : public HloTestBase {
/*dimension=*/0));
builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat,
/*start_indices=*/{0},
- /*limit_indices=*/{1},
- /*strides=*/{1}));
+ /*limit_indices=*/{1}));
return builder.Build();
}
@@ -357,7 +353,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
/*dimension=*/0));
builder.AddInstruction(HloInstruction::CreateSlice(
vec1024_shape_, concat, /*start_indices=*/{0},
- /*limit_indices=*/{1024}, /*slices=*/{1}));
+ /*limit_indices=*/{1024}));
subcomputation = module->AddEmbeddedComputation(builder.Build());
}
@@ -473,7 +469,7 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
/*dimension=*/0));
builder.AddInstruction(HloInstruction::CreateSlice(
vec1024_shape_, concat, /*start_indices=*/{0},
- /*limit_indices=*/{1024}, /*slices=*/{1}));
+ /*limit_indices=*/{1024}));
subcomputation = module->AddEmbeddedComputation(builder.Build());
}
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index bcc9418d59..e348511c62 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -356,26 +356,9 @@ void EmitLogging(const char* tag, llvm::Value* value,
void SetTbaaForInstruction(llvm::Instruction* instruction, Shape shape,
bool is_pointer_to) {
- llvm::MDBuilder metadata_builder(instruction->getContext());
- llvm::MDNode* root = metadata_builder.createTBAARoot("XLA TBAA");
- string type_name;
- if (is_pointer_to) {
- type_name += "pointer-to ";
- }
- // Scalars do not have layout which makes it permissible to omit an explicit
- // layout. To make sure that equivalent scalar shapes have the same TBAA,
- // remove the (meaningless) explicit layout if one is present.
- if (!ShapeUtil::IsArray(shape) || ShapeUtil::IsScalar(shape)) {
- LayoutUtil::ClearLayout(&shape);
- } else {
- CHECK(shape.has_layout());
- }
- type_name += shape.ShortDebugString();
- llvm::MDNode* tbaa_node =
- metadata_builder.createTBAANode(llvm_ir::AsStringRef(type_name), root);
- instruction->setMetadata(llvm::LLVMContext::MD_tbaa,
- metadata_builder.createTBAAStructTagNode(
- tbaa_node, tbaa_node, /*Offset=*/0));
+ // TODO(b/62903316): TBAA metadata causes LLVM to miscompile generated code,
+ // most likely because the generated metadata is incorrect. Disable TBAA
+ // metadata while we resolve this.
}
void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) {
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 5e4df9ddd6..b332709995 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1135,8 +1135,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> limits,
- tensorflow::gtl::ArraySlice<int64> strides) {
+ tensorflow::gtl::ArraySlice<int64> limits) {
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice"));
VLOG(2) << tensorflow::strings::Printf(
"slicing shape %s starts={%s} limits={%s}",
@@ -1159,13 +1158,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
for (int64 dimension = 0; dimension < starts.size(); ++dimension) {
int64 start_index = starts[dimension];
int64 limit_index = limits[dimension];
- int64 stride = strides[dimension];
if (start_index < 0) {
return InvalidArgument("negative start index to slice: %lld",
start_index);
}
- if (stride == 0) {
- return InvalidArgument("Zero stride");
+ if (limit_index < 0) {
+ return InvalidArgument("negative limit index to slice: %lld",
+ limit_index);
}
if (limit_index > arg.dimensions(dimension)) {
return InvalidArgument(
@@ -1173,21 +1172,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
"size (%lld)",
limit_index, arg.dimensions(dimension));
}
+ if (start_index > limit_index) {
+ return InvalidArgument(
+ "limit index (%lld) must be greater or equal to "
+ "start index (%lld) in slice",
+ limit_index, start_index);
+ }
VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension,
start_index);
VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension,
limit_index);
- if (stride > 0) {
- if (start_index > limit_index) {
- return InvalidArgument(
- "limit index (%lld) must be greater or equal to "
- "start index (%lld) in slice with positive stride",
- limit_index, start_index);
- }
- sizes.push_back((limit_index - start_index + stride - 1) / stride);
- } else {
- return InvalidArgument("Negative strides not supported");
- }
+
+ sizes.push_back(limits[dimension] - starts[dimension]);
}
return ShapeUtil::MakeShape(arg.element_type(), sizes);
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 42e4c7d39d..55c60e149d 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -116,8 +116,7 @@ class ShapeInference {
// e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16]
static StatusOr<Shape> InferSliceShape(
const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
- tensorflow::gtl::ArraySlice<int64> limits,
- tensorflow::gtl::ArraySlice<int64> strides);
+ tensorflow::gtl::ArraySlice<int64> limits);
// Infers the shape produced by a dynamic slice operation of size specified
// in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'.
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 8c731ae297..7cff042a48 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -682,43 +682,16 @@ TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
TEST_F(ShapeInferenceTest, InferSliceShapeRank2) {
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
auto inferred_status =
- ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {1, 1});
+ ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64});
ASSERT_IS_OK(inferred_status.status());
Shape inferred = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred));
}
-TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) {
- Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
- auto inferred_status =
- ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {2, 4});
- ASSERT_IS_OK(inferred_status.status());
- Shape inferred = inferred_status.ValueOrDie();
- ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred));
-}
-
-TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) {
- Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
- auto inferred_status =
- ShapeInference::InferSliceShape(matrix_shape, {15, 0}, {20, 13}, {2, 4});
- ASSERT_IS_OK(inferred_status.status());
- Shape inferred = inferred_status.ValueOrDie();
- ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), inferred));
-}
-
-TEST_F(ShapeInferenceTest, InferInvalidStride) {
- Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
- auto inferred_status =
- ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {0, 1});
- ASSERT_FALSE(inferred_status.ok());
- ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
- inferred_status.status().code());
-}
-
TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) {
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
auto inferred_status =
- ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {1, 1});
+ ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2});
ASSERT_FALSE(inferred_status.ok());
ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
inferred_status.status().code());
@@ -727,7 +700,7 @@ TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) {
TEST_F(ShapeInferenceTest, InferSliceShapeRank1) {
Shape vector_shape = ShapeUtil::MakeShape(F32, {17});
auto inferred_status =
- ShapeInference::InferSliceShape(vector_shape, {2}, {4}, {1});
+ ShapeInference::InferSliceShape(vector_shape, {2}, {4});
ASSERT_TRUE(inferred_status.ok());
Shape inferred = inferred_status.ValueOrDie();
ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2})));
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index cd79e63caf..d25e5adee3 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -584,7 +584,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest {
if (add_additional_gte0_user) {
// Create 'slice' as an additional user of 'input'.
auto slice = builder.AddInstruction(
- HloInstruction::CreateSlice(update_shape, input, {0}, {3}, {1}));
+ HloInstruction::CreateSlice(update_shape, input, {0}, {3}));
// Modify 'update' to take 'slice' output.
update = builder.AddInstruction(HloInstruction::CreateBinary(
update_shape, HloOpcode::kAdd, update, slice));
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index 92b8c7bb21..1f6e789379 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -744,8 +744,7 @@ StatusOr<ComputationDataHandle> UserComputation::AddSliceInstruction(
Shape new_shape,
ShapeInference::InferSliceShape(
operand->output_shape(), AsInt64Slice(slice_request.start_indices()),
- AsInt64Slice(slice_request.limit_indices()),
- AsInt64Slice(slice_request.stride())));
+ AsInt64Slice(slice_request.limit_indices())));
ComputationDataHandle handle = CreateComputationDataHandle();
@@ -2394,8 +2393,7 @@ void ComputationLowerer::Visit(
hlo_instruction = add_instruction(HloInstruction::CreateSlice(
request.output_shape(), operand,
AsInt64Slice(slice_request.start_indices()),
- AsInt64Slice(slice_request.limit_indices()),
- AsInt64Slice(slice_request.stride())));
+ AsInt64Slice(slice_request.limit_indices())));
break;
}
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 024988743c..bb7fbad000 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -1853,7 +1853,7 @@ TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
auto x = builder.Parameter(0, x_literal->shape(), "x");
auto y = builder.Parameter(1, y_literal->shape(), "y");
- auto slice = builder.Slice(x, {1}, {2}, {1});
+ auto slice = builder.Slice(x, {1}, {2});
builder.Sub(slice, y);
ComputeAndCompareR1<float>(&builder, {-2, -3}, {x_data.get(), y_data.get()},
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 63a630f9e5..7abef6a27b 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -365,9 +365,9 @@ XLA_TEST_F(DotOperationTest, BatchMatMul) {
std::vector<xla::ComputationDataHandle> out_slices;
for (int i = 0; i < 4; ++i) {
// Slice off individual matrices and reshape to 2D tensors.
- auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
+ auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2});
x_slice = builder.Reshape(x_slice, {0, 1, 2}, {2, 2});
- auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
+ auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2});
y_slice = builder.Reshape(y_slice, {0, 1, 2}, {2, 2});
auto out = builder.Dot(x_slice, y_slice);
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index 7803d234fd..c8b91eafc7 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -210,7 +210,7 @@ XLA_TEST_F(FusionTest, Test) {
HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}),
HloOpcode::kSelect, const10, add8, const9));
auto slice12 = builder.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2}, {1, 1}));
+ ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2}));
// CreateFusionInstruction needs the `instructions_to_fuse` argument in
// reverse topological order, so the first element in `instructions_to_fuse`
// must be the root.
diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc
index 56c15e5ff7..df3d4fa21d 100644
--- a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc
+++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc
@@ -36,7 +36,7 @@ XLA_TEST_F(SliceTest, Slice2D) {
ComputationBuilder builder(client_, "slice_2d");
auto original = builder.ConstantR2<float>(
{{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}, {10.0, 11.0, 12.0}});
- builder.Slice(original, {2, 1}, {4, 3}, {1, 1});
+ builder.Slice(original, {2, 1}, {4, 3});
Array2D<float> expected({{8.0f, 9.0f}, {11.0f, 12.0f}});
ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
@@ -47,7 +47,7 @@ XLA_TEST_F(SliceTest, Slice3D) {
Array3D<float> array_3d(
{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}});
auto original = builder.ConstantR3FromArray3D<float>(array_3d);
- builder.Slice(original, {0, 0, 1}, {2, 1, 2}, {1, 1, 1});
+ builder.Slice(original, {0, 0, 1}, {2, 1, 2});
Array3D<float> expected_3d({{{2.0f}}, {{6.0f}}});
ComputeAndCompareR3<float>(&builder, expected_3d, {}, ErrorSpec(0.000001));
diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc
index a7692fceb4..2065e9e813 100644
--- a/tensorflow/compiler/xla/tests/params_test.cc
+++ b/tensorflow/compiler/xla/tests/params_test.cc
@@ -325,7 +325,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
ComputationBuilder builder(client_, TestName());
auto input = builder.Parameter(0, original, "input");
// Use the slice operator to get an off-diagonal element.
- builder.Slice(input, {0, 1}, {1, 2}, {1, 1});
+ builder.Slice(input, {0, 1}, {1, 2});
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index 5e7d475662..97120df0c5 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -44,7 +44,7 @@ class SliceTest : public ClientLibraryTestBase {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR1<NativeT>(constant);
- builder.Slice(original, {2}, {4}, {1});
+ builder.Slice(original, {2}, {4});
const std::vector<NativeT> expected = {static_cast<NativeT>(2),
static_cast<NativeT>(3)};
@@ -55,7 +55,7 @@ class SliceTest : public ClientLibraryTestBase {
XLA_TEST_F(SliceTest, SliceZeroToZeroF32) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR1<float>({});
- builder.Slice(original, {0}, {0}, {1});
+ builder.Slice(original, {0}, {0});
ComputeAndCompareR1<float>(&builder, {}, {});
}
@@ -64,7 +64,7 @@ XLA_TEST_F(SliceTest, SliceTenToZeroF32) {
ComputationBuilder builder(client_, TestName());
std::vector<float> constant(10, 0.3);
auto original = builder.ConstantR1<float>(constant);
- builder.Slice(original, {7}, {7}, {1});
+ builder.Slice(original, {7}, {7});
ComputeAndCompareR1<float>(&builder, {}, {});
}
@@ -87,7 +87,7 @@ TEST_F(SliceTest, SliceTenToTen) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR1<float>(values);
- builder.Slice(original, {0}, {10}, {1});
+ builder.Slice(original, {0}, {10});
ComputeAndCompareR1<float>(&builder, values, {}, ErrorSpec(0.000001));
}
@@ -98,7 +98,7 @@ TEST_F(SliceTest, SliceLastFourOf1024) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR1<float>(values);
- builder.Slice(original, {1024 - 4}, {1024}, {1});
+ builder.Slice(original, {1024 - 4}, {1024});
const std::vector<float> expected = {1020, 1021, 1022, 1023};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.000001));
@@ -112,7 +112,7 @@ TEST_F(SliceTest, DISABLED_SliceUnaligned1024In4096Values) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR1<float>(values);
- builder.Slice(original, {7}, {7 + 1024}, {1});
+ builder.Slice(original, {7}, {7 + 1024});
std::vector<float> expected(1024);
std::iota(values.begin(), values.end(), 7.0);
@@ -122,7 +122,7 @@ TEST_F(SliceTest, DISABLED_SliceUnaligned1024In4096Values) {
XLA_TEST_F(SliceTest, Slice0x0to0x0F32) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 0));
- builder.Slice(original, {0, 0}, {0, 0}, {1, 1});
+ builder.Slice(original, {0, 0}, {0, 0});
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {});
}
@@ -130,7 +130,7 @@ XLA_TEST_F(SliceTest, Slice0x0to0x0F32) {
XLA_TEST_F(SliceTest, Slice0x20to0x5F32) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 20));
- builder.Slice(original, {0, 15}, {0, 20}, {1, 1});
+ builder.Slice(original, {0, 15}, {0, 20});
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 5), {});
}
@@ -138,7 +138,7 @@ XLA_TEST_F(SliceTest, Slice0x20to0x5F32) {
XLA_TEST_F(SliceTest, Slice3x0to2x0F32) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(3, 0));
- builder.Slice(original, {1, 0}, {3, 0}, {1, 1});
+ builder.Slice(original, {1, 0}, {3, 0});
ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 0), {});
}
@@ -153,7 +153,7 @@ XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR2FromArray2D<float>(values);
- builder.Slice(original, {128, 128}, {256, 256}, {1, 1});
+ builder.Slice(original, {128, 128}, {256, 256});
Array2D<float> expected(128, 128);
for (int row = 0; row < 128; ++row) {
@@ -171,7 +171,7 @@ TEST_F(SliceTest, Slice_1x4096_To_1x1024) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR2FromArray2D<float>(values);
- builder.Slice(original, {0, 3072}, {1, 4096}, {1, 1});
+ builder.Slice(original, {0, 3072}, {1, 4096});
Array2D<float> expected(1, 1024);
std::iota(expected.data(), expected.data() + 1024, 3072.0);
@@ -192,7 +192,7 @@ TEST_F(SliceTest, Slice_16x4_To_16x2) {
}
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR2FromArray2D<float>(values);
- builder.Slice(original, {0, 0}, {16, 2}, {1, 1});
+ builder.Slice(original, {0, 0}, {16, 2});
ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
}
@@ -204,7 +204,7 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) {
ReferenceUtil::Slice4D(values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}});
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR4FromArray4D(values);
- builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1});
+ builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128});
ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001));
}
@@ -213,7 +213,6 @@ struct R2Spec {
int64 input_dim1;
std::array<int64, 2> slice_starts;
std::array<int64, 2> slice_limits;
- std::array<int64, 2> slice_strides;
Layout layout;
};
@@ -229,7 +228,7 @@ TEST_P(SliceR2Test, DoIt) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR2FromArray2D<int32>(input);
- builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides);
+ builder.Slice(a, spec.slice_starts, spec.slice_limits);
std::unique_ptr<Array2D<int32>> expected =
ReferenceUtil::Slice2D(input, spec.slice_starts, spec.slice_limits);
@@ -240,23 +239,19 @@ TEST_P(SliceR2Test, DoIt) {
INSTANTIATE_TEST_CASE_P(
SliceR2TestInstantiation, SliceR2Test,
::testing::Values(
- R2Spec {4, 12, {{0, 3}}, {{4, 6}}, {{1, 1}},
- LayoutUtil::MakeLayout({0, 1})},
- R2Spec {4, 12, {{0, 3}}, {{4, 6}}, {{1, 1}},
+ R2Spec {4, 12, {{0, 3}}, {{4, 6}}, LayoutUtil::MakeLayout({0, 1})},
+ R2Spec {4, 12, {{0, 3}}, {{4, 6}}, LayoutUtil::MakeLayout({1, 0})},
+ R2Spec {16, 4, {{0, 2}}, {{16, 4}}, LayoutUtil::MakeLayout({0, 1})},
+ R2Spec {16, 4, {{0, 2}}, {{16, 4}}, LayoutUtil::MakeLayout({1, 0})},
+ R2Spec {256, 400, {{0, 300}}, {{256, 400}},
LayoutUtil::MakeLayout({1, 0})},
- R2Spec {16, 4, {{0, 2}}, {{16, 4}}, {{1, 1}},
- LayoutUtil::MakeLayout({0, 1})},
- R2Spec {16, 4, {{0, 2}}, {{16, 4}}, {{1, 1}},
+ R2Spec {500, 400, {{111, 123}}, {{300, 257}},
LayoutUtil::MakeLayout({1, 0})},
- R2Spec {256, 400, {{0, 300}}, {{256, 400}}, {{1, 1}},
+ R2Spec {500, 400, {{111, 123}}, {{300, 400}},
LayoutUtil::MakeLayout({1, 0})},
- R2Spec {500, 400, {{111, 123}}, {{300, 257}}, {{1, 1}},
+ R2Spec {384, 512, {{128, 256}}, {{256, 384}},
LayoutUtil::MakeLayout({1, 0})},
- R2Spec {500, 400, {{111, 123}}, {{300, 400}}, {{1, 1}},
- LayoutUtil::MakeLayout({1, 0})},
- R2Spec {384, 512, {{128, 256}}, {{256, 384}}, {{1, 1}},
- LayoutUtil::MakeLayout({1, 0})},
- R2Spec {357, 512, {{111, 256}}, {{301, 384}}, {{1, 1}},
+ R2Spec {357, 512, {{111, 256}}, {{301, 384}},
LayoutUtil::MakeLayout({1, 0})}
)
);
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index afa7d871c0..ccd2a95658 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -666,8 +666,7 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) {
auto build_condition = [this, v6s32](int count) {
ComputationBuilder builder(client_, TestName());
auto prev = builder.Reshape(
- builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}, {1}), {0},
- {});
+ builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}), {0}, {});
builder.Gt(builder.ConstantR0<int32>(count), prev);
return builder.Build().ConsumeValueOrDie();
};
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index 31f0c3147e..42d5c1d155 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -195,24 +195,16 @@ bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank);
// 2. permutation.size() == input.size().
template <template <typename...> class C, typename T>
std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
- C<T> input) {
- tensorflow::gtl::ArraySlice<T> data(input);
- CHECK(IsPermutation(permutation, data.size()));
- std::vector<T> output(data.size());
+ C<T> input_) {
+ tensorflow::gtl::ArraySlice<T> input(input_);
+ CHECK(IsPermutation(permutation, input.size()));
+ std::vector<T> output(input.size());
for (size_t i = 0; i < permutation.size(); ++i) {
- output[permutation[i]] = data[i];
+ output[permutation[i]] = input[i];
}
return output;
}
-// Override of the above that works around compile failures with gcc 7.1.1.
-// For details see https://github.com/tensorflow/tensorflow/issues/10843
-template <typename T>
-std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
- const std::vector<T>& input) {
- return Permute<std::vector, T>(permutation, input);
-}
-
// Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i.
std::vector<int64> InversePermutation(
tensorflow::gtl::ArraySlice<int64> input_permutation);
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 86c72b3449..95c1f0995b 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -200,7 +200,7 @@ message OpMetadata {
string op_name = 2;
// Indicate a file and line that this op is associated to in a user's program.
//
- // e.g. it could be the file and line of user code that generated the op.
+ // e.g. it could be be the file and line of user code that generated the op.
string source_file = 3;
int32 source_line = 4;
}
@@ -369,7 +369,6 @@ message SliceRequest {
ComputationDataHandle operand = 2;
repeated int64 start_indices = 3;
repeated int64 limit_indices = 4;
- repeated int64 stride = 5;
}
message DynamicSliceRequest {