aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-08-06 07:39:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-06 07:42:13 -0700
commit2b8df9f4064d5cf21786ebc9da4d800233d1afa6 (patch)
tree5968a7e86b720a951df13e5de52d50b5a96e2931
parente70f94ee089abbb9eb70361b5cdef55aa9beb18b (diff)
[XLA] Clean up clang tidy readability warnings in compiler/xla
* lambda capture 'builder' is not used * using decl 'Printf' is unused * lambda capture 'this' is not used (17 times) * lambda capture 'buffer_liveness' is not used * lambda capture 'computation' is not used * lambda capture 'operand_to_generator' is not used * lambda capture 'M' is not used * using decl 'InvalidParameterArgument' is unused * lambda capture 'sum' is not used * lambda capture 's' is not used * lambda capture 'epsilon' is not used PiperOrigin-RevId: 207542895
-rw-r--r--tensorflow/compiler/xla/client/lib/prng.cc2
-rw-r--r--tensorflow/compiler/xla/literal_util.cc1
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/allocation_tracker.cc10
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc10
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.cc3
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc2
-rw-r--r--tensorflow/compiler/xla/service/service.cc1
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc9
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc4
-rw-r--r--tensorflow/compiler/xla/shape_tree_test.cc2
-rw-r--r--tensorflow/compiler/xla/shape_util.cc3
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/prng_test.cc2
21 files changed, 37 insertions, 42 deletions
diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc
index 3a744148fb..6ef8168948 100644
--- a/tensorflow/compiler/xla/client/lib/prng.cc
+++ b/tensorflow/compiler/xla/client/lib/prng.cc
@@ -56,7 +56,7 @@ ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
// Performs a single round of the Threefry2x32 algorithm, with a rotation
// amount 'rotation'.
- auto round = [builder](ThreeFry2x32State v, int rotation) {
+ auto round = [](ThreeFry2x32State v, int rotation) {
v[0] = v[0] + v[1];
v[1] = RotateLeftS32(v[1], rotation);
v[1] = v[0] ^ v[1];
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 548fbe8a83..356f12ed78 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -36,7 +36,6 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
-using tensorflow::strings::Printf;
using tensorflow::strings::StrCat;
namespace xla {
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index ad14fe6f2c..862cbeeba6 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2006,7 +2006,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
// Builds a convolution from <options> and runs algebraic simplification on
// the computation. Returns a string description of the result of
// simplification.
- auto build_and_simplify = [&options, this]() -> string {
+ auto build_and_simplify = [&options]() -> string {
HloComputation::Builder b(TestName());
Window window;
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc
index 95b4cb6d2e..51ebc4763b 100644
--- a/tensorflow/compiler/xla/service/allocation_tracker.cc
+++ b/tensorflow/compiler/xla/service/allocation_tracker.cc
@@ -109,11 +109,11 @@ Status AllocationTracker::Unregister(const GlobalDataHandle& data) {
ResolveInternal(data));
for (const auto& shaped_buffer : replicated_buffers) {
std::vector<ShapeIndex> shape_indices;
- ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(),
- [this, &shape_indices](const Shape& /*subshape*/,
- const ShapeIndex& index) {
- shape_indices.push_back(index);
- });
+ ShapeUtil::ForEachSubshape(
+ shaped_buffer->on_device_shape(),
+ [&shape_indices](const Shape& /*subshape*/, const ShapeIndex& index) {
+ shape_indices.push_back(index);
+ });
for (const ShapeIndex& index : shape_indices) {
TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index),
shaped_buffer->device_ordinal()));
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index e4d2e73b99..118a11c8de 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -877,8 +877,8 @@ Status BufferAssigner::AssignBuffersForComputation(
// important reuse case where an elementwise instruction reuses one of its
// operand's buffer. This improves locality.
std::sort(sorted_buffers.begin(), sorted_buffers.end(),
- [this, has_sequential_order, &liveness, &post_order_position,
- assignment](const LogicalBuffer* a, const LogicalBuffer* b) {
+ [has_sequential_order, &liveness, &post_order_position, assignment](
+ const LogicalBuffer* a, const LogicalBuffer* b) {
// Primary sort is by decreasing buffer size.
const int64 a_size = assignment->buffer_size_(*a);
const int64 b_size = assignment->buffer_size_(*b);
@@ -1441,9 +1441,9 @@ void BufferAssigner::BuildColocatedBufferSets(
const HloInstruction* while_hlo = instruction;
ShapeUtil::ForEachSubshape(
while_hlo->shape(),
- [this, while_hlo, &points_to_analysis, &buffer_liveness,
- buffer_size, computation, colocated_buffer_sets](
- const Shape& /*subshape*/, const ShapeIndex& index) {
+ [this, while_hlo, &points_to_analysis, buffer_size,
+ colocated_buffer_sets](const Shape& /*subshape*/,
+ const ShapeIndex& index) {
std::vector<const LogicalBuffer*> colocated_set;
// Add while.init.
AddBufferToColocatedSet(while_hlo->operand(0), index,
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
index c433bddc84..c35569c661 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
@@ -220,7 +220,7 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// The body adds the reduced value of the Infeed data (first tuple element)
// to the previous accumulator, and returns the accumulator and the continue
// flag (second tuple element) as a tuple.
- const auto build_body = [this, &result_shape](const Shape& infeed_shape) {
+ const auto build_body = [&result_shape](const Shape& infeed_shape) {
XlaComputation body;
XlaBuilder builder("body");
auto prev = Parameter(&builder, 0, result_shape, "prev");
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index f883eb828c..574ae0c903 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -2134,7 +2134,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
return EmitElementalDot(hlo, operand_to_generator, dot_result_index);
};
default:
- return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
+ return [hlo](const IrArray::Index& index) {
return Unimplemented("Unhandled opcode for elemental IR emission: %s",
HloOpcodeString(hlo->opcode()).c_str());
};
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index bb71c79fd7..bb7736efa6 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -293,7 +293,7 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
// the respective location in ShapedBuffer.
std::set<se::DeviceMemoryBase> buffers_in_result;
TF_RETURN_IF_ERROR(shaped_buffer.buffers().ForEachMutableElementWithStatus(
- [&buffer_allocations, &buffers_in_result, &shaped_buffer, this](
+ [&buffer_allocations, &buffers_in_result, this](
const ShapeIndex& index, se::DeviceMemoryBase* device_memory) {
const auto& sources = this->GetRootPointsToSet().element(index);
// The points-to set is unambiguous so the set should be a
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
index c48fca355b..cf44458a2e 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
@@ -328,7 +328,7 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module,
if (linker.linkInModule(
std::move(libdevice_module), llvm::Linker::Flags::LinkOnlyNeeded,
[](Module& M, const StringSet<>& GVS) {
- internalizeModule(M, [&M, &GVS](const GlobalValue& GV) {
+ internalizeModule(M, [&GVS](const GlobalValue& GV) {
return !GV.hasName() || (GVS.count(GV.getName()) == 0);
});
})) {
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 995521aed0..a2cefd2621 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -533,12 +533,12 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) {
// TODO(b/33004697): Compute correct cost here, taking the actual number of
// replicas into account.
double flops = 0.0;
- ShapeUtil::ForEachSubshape(
- crs->shape(), [&, this](const Shape& subshape, const ShapeIndex&) {
- if (ShapeUtil::IsArray(subshape)) {
- flops += ShapeUtil::ElementsIn(subshape);
- }
- });
+ ShapeUtil::ForEachSubshape(crs->shape(),
+ [&](const Shape& subshape, const ShapeIndex&) {
+ if (ShapeUtil::IsArray(subshape)) {
+ flops += ShapeUtil::ElementsIn(subshape);
+ }
+ });
current_properties_[kFlopsKey] = flops;
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 2ec31a9148..4755c4a0cf 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -2365,7 +2365,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
- auto make_cond = [this, &data_shape]() {
+ auto make_cond = [&data_shape]() {
auto builder = HloComputation::Builder(TestName() + ".Cond");
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));
@@ -2374,7 +2374,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
return builder.Build();
};
- auto make_body = [this, &data_shape]() {
+ auto make_body = [&data_shape]() {
auto builder = HloComputation::Builder(TestName() + ".Body");
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
index cf9ceed5b2..9ec983c2bc 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
@@ -282,7 +282,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
TF_ASSERT_OK_AND_ASSIGN(
SequentialHloOrdering::HloModuleSequence sequence,
ScheduleComputationsInModule(*module,
- [&TUPLE_SIZE](const BufferValue& buffer) {
+ [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(
buffer.shape(), TUPLE_SIZE);
},
diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc
index 4e3c9df3a0..7fd99fc930 100644
--- a/tensorflow/compiler/xla/service/hlo_value.cc
+++ b/tensorflow/compiler/xla/service/hlo_value.cc
@@ -283,8 +283,7 @@ std::ostream& operator<<(std::ostream& out,
string InstructionValueSet::ToString() const {
string out =
StrCat("InstructionValueSet(", ShapeUtil::HumanString(shape()), ")\n");
- ForEachElement([this, &out](const ShapeIndex& index,
- const HloValueSet& value_set) {
+ ForEachElement([&out](const ShapeIndex& index, const HloValueSet& value_set) {
StrAppend(&out, " ", index.ToString(), " : ", value_set.ToString(), "\n");
});
return out;
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index c5df6bd223..b5a9d6e8e7 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -1228,7 +1228,7 @@ Status LayoutAssignment::PropagateUseConstraintToDefs(
const PointsToSet& points_to_set =
constraints->points_to_analysis().GetPointsToSet(instruction);
return points_to_set.ForEachElementWithStatus(
- [this, &shape_layout, constraints](
+ [&shape_layout, constraints](
const ShapeIndex& index,
const PointsToSet::BufferList& buffers) -> Status {
if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) {
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index ce070bc5b6..212db0643c 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -56,7 +56,6 @@ limitations under the License.
using ::tensorflow::strings::Printf;
using ::tensorflow::strings::StrCat;
-using ::xla::source_map_util::InvalidParameterArgument;
namespace xla {
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index 0effdc80a4..0447807a41 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -232,8 +232,7 @@ Status TuplePointsToAnalysis::HandleGetTupleElement(
// Copy the points-to set (and tuple sources) at index {element_index} of the
// operand to the points-to set for this GetTupleElement instruction.
points_to_set.ForEachMutableElement(
- [&, this](const ShapeIndex& target_index,
- PointsToSet::BufferList* points_to) {
+ [&](const ShapeIndex& target_index, PointsToSet::BufferList* points_to) {
// Construct an index into the operand by prepending element_index to
// the index for the GetTupleElement instruction's points-to set.
ShapeIndex src_index;
@@ -308,7 +307,7 @@ Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
// Recursively copy the points to set of the operand tuple {0} to the output
// element {0}.
points_to_set.ForEachMutableElement(
- [this, &points_to_set, &operand_points_to_set](
+ [&points_to_set, &operand_points_to_set](
const ShapeIndex& index, PointsToSet::BufferList* buffers) {
if (index.empty() || index[0] != 0) {
return;
@@ -517,7 +516,7 @@ Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction(
const HloInstruction* instruction,
TuplePointsToAnalysis::BufferDefinitionVector* buffers) {
GetPointsToSet(instruction)
- .ForEachElement([this, buffers, instruction](
+ .ForEachElement([buffers, instruction](
const ShapeIndex& index,
const PointsToSet::BufferList& source_buffers) {
// Add buffers which 'instruction' is the source of.
@@ -547,7 +546,7 @@ PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet(
PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction);
const PointsToSet& src_points_to_set = GetPointsToSet(src);
dst_points_to_set.ForEachMutableElement(
- [this, &dst_points_to_set, &src_points_to_set](
+ [&dst_points_to_set, &src_points_to_set](
const ShapeIndex& index, PointsToSet::BufferList* buffers) {
*buffers = src_points_to_set.element(index);
for (auto& tuple_source : src_points_to_set.tuple_sources(index)) {
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index 2e5f646804..10d382e8ab 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -1118,7 +1118,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
- auto make_cond = [this, &data_shape]() {
+ auto make_cond = [&data_shape]() {
auto builder = HloComputation::Builder(TestName() + ".Cond");
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));
@@ -1127,7 +1127,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
return builder.Build();
};
- auto make_body = [this, &data_shape]() {
+ auto make_body = [&data_shape]() {
auto builder = HloComputation::Builder(TestName() + ".Body");
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));
diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc
index 4391078b64..c4c958be4a 100644
--- a/tensorflow/compiler/xla/shape_tree_test.cc
+++ b/tensorflow/compiler/xla/shape_tree_test.cc
@@ -172,7 +172,7 @@ TEST_F(ShapeTreeTest, TupleShape) {
// Write zero to all data elements.
shape_tree.ForEachMutableElement(
- [&sum](const ShapeIndex& /*index*/, int* data) { *data = 0; });
+ [](const ShapeIndex& /*index*/, int* data) { *data = 0; });
EXPECT_EQ(0, shape_tree.element({}));
EXPECT_EQ(0, shape_tree.element({0}));
EXPECT_EQ(0, shape_tree.element({1}));
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 9ea41b7c92..34869cc507 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -596,8 +596,7 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
};
auto comma_list_to_int64s =
- [&s,
- string_to_int64](const string& input) -> StatusOr<std::vector<int64>> {
+ [string_to_int64](const string& input) -> StatusOr<std::vector<int64>> {
std::vector<int64> results;
for (const string& piece : tensorflow::str_util::Split(input, ',')) {
TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece));
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index d372d1ca43..24b17b7100 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -733,7 +733,7 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
var4D, [epsilon](float a) { return a + epsilon; });
auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
- var_add_epsilon, [epsilon](float a) { return 1 / std::sqrt(a); });
+ var_add_epsilon, [](float a) { return 1 / std::sqrt(a); });
auto grad_output_times_var =
*ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon,
diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc
index 029af69573..326e13b386 100644
--- a/tensorflow/compiler/xla/tests/prng_test.cc
+++ b/tensorflow/compiler/xla/tests/prng_test.cc
@@ -182,7 +182,7 @@ XLA_TEST_F(PrngTest, Uniformity256) {
XLA_TEST_F(PrngTest, MapUsingRng) {
// Build a x -> (x + U[0,1)) computation.
- auto build_sum_rng = [this](XlaBuilder& builder) {
+ auto build_sum_rng = [](XlaBuilder& builder) {
auto b = builder.CreateSubBuilder("sum_with_rng");
auto x = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "input");
Add(x,