diff options
author | Jojy George Varghese <jojy.varghese@gmail.com> | 2017-03-01 16:43:46 -0800 |
---|---|---|
committer | Andrew Harp <andrewharp@users.noreply.github.com> | 2017-03-01 19:43:46 -0500 |
commit | d97f45ba949b11d721aa5b81c12d655a8487aa38 (patch) | |
tree | bdcd93060dd06532123966fc246d8b65b9797bea | |
parent | 51331365b60dd042ebc0849ea4e1ab6a396b60fd (diff) |
Fix compile warnings (#7752)
* Fixed compile warnings.
Most of the warnings are related to comparison of signed vs unsigned types.
* Addressed code review.
22 files changed, 78 insertions, 50 deletions
diff --git a/tensorflow/compiler/jit/graph_to_functiondef.cc b/tensorflow/compiler/jit/graph_to_functiondef.cc index d4482c87ff..ce943471fb 100644 --- a/tensorflow/compiler/jit/graph_to_functiondef.cc +++ b/tensorflow/compiler/jit/graph_to_functiondef.cc @@ -185,7 +185,7 @@ Status GraphToFunctionDef(const Graph& graph, const string& name, } // Add regular inputs - for (int i = 0; i < in_edges.size(); ++i) { + for (std::vector<const Edge*>::size_type i = 0; i < in_edges.size(); ++i) { const Edge* edge = in_edges[i]; if (edge == nullptr) { return errors::InvalidArgument( diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc index 87d5de09d1..2139ffed4b 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc @@ -76,7 +76,7 @@ struct GraphCycles::Rep { GraphCycles::GraphCycles() : rep_(new Rep) {} GraphCycles::~GraphCycles() { - for (int i = 0; i < rep_->nodes_.size(); i++) { + for (Vec<Node*>::size_type i = 0; i < rep_->nodes_.size(); i++) { delete rep_->nodes_[i]; } delete rep_; @@ -85,7 +85,7 @@ GraphCycles::~GraphCycles() { bool GraphCycles::CheckInvariants() const { Rep* r = rep_; NodeSet ranks; // Set of ranks seen so far. - for (int32 x = 0; x < r->nodes_.size(); x++) { + for (Vec<Node*>::size_type x = 0; x < r->nodes_.size(); x++) { Node* nx = r->nodes_[x]; if (nx->visited) { LOG(FATAL) << "Did not clear visited marker on node " << x; @@ -259,7 +259,7 @@ static void Reorder(GraphCycles::Rep* r) { r->deltaf_.end(), r->merged_.begin()); // Assign the ranks in order to the collected list. - for (int32 i = 0; i < r->list_.size(); i++) { + for (Vec<int32>::size_type i = 0; i < r->list_.size(); i++) { r->nodes_[r->list_[i]]->rank = r->merged_[i]; } } @@ -277,7 +277,7 @@ static void Sort(const Vec<Node*>& nodes, Vec<int32>* delta) { } static void MoveToList(GraphCycles::Rep* r, Vec<int32>* src, Vec<int32>* dst) { - for (int32 i = 0; i < src->size(); i++) { + for (Vec<int32>::size_type i = 0; i < src->size(); i++) { int32 w = (*src)[i]; (*src)[i] = r->nodes_[w]->rank; // Replace src entry with its rank r->nodes_[w]->visited = false; // Prepare for future DFS calls @@ -286,7 +286,7 @@ static void MoveToList(GraphCycles::Rep* r, Vec<int32>* src, Vec<int32>* dst) { } static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32>& nodes) { - for (int32 i = 0; i < nodes.size(); i++) { + for (Vec<int32>::size_type i = 0; i < nodes.size(); i++) { r->nodes_[nodes[i]]->visited = false; } } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index efc8dfce93..45f3420cc1 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -225,7 +225,8 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args, parameters.reserve(args.size()); variables.reserve(args.size()); - for (int i = 0; i < args.size(); ++i) { + for (std::vector<XlaCompiler::Argument>::size_typ i = 0; + i < args.size(); ++i) { XlaContext::Argument& context_arg = (*context_args)[i]; context_arg.name = args[i].name; context_arg.value.constant_value = args[i].constant_value; @@ -262,7 +263,7 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args, input_shapes->resize(parameters.size()); input_mapping->resize(parameters.size()); - for (int i = 0; i < input_shapes->size(); ++i) { + for (std::vector<int>::size_type i = 0; i < input_shapes->size(); ++i) { const XlaCompiler::Argument& arg = args[parameters[i]]; // Computes the shapes of non-constant arguments. xla::PrimitiveType type; @@ -276,12 +277,12 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args, xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(*input_shapes); xla::ComputationDataHandle tuple = builder->Parameter(0, tuple_shape, "arg_tuple"); - for (int i = 0; i < input_shapes->size(); ++i) { + for (std::vector<int>::size_type i = 0; i < input_shapes->size(); ++i) { (*context_args)[parameters[i]].value.handle = builder->GetTupleElement(tuple, i); } } else { - for (int i = 0; i < input_shapes->size(); ++i) { + for (std::vector<int>::size_type i = 0; i < input_shapes->size(); ++i) { (*context_args)[parameters[i]].value.handle = builder->Parameter(i, (*input_shapes)[i], strings::StrCat("arg", i)); } @@ -413,7 +414,8 @@ Status XlaCompiler::CompileGraph(string const& name, VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; result->outputs.resize(context->retvals().size()); - for (int i = 0; i < context->retvals().size(); ++i) { + for (std::vector<XlaContext::HandleOrConstant>::size_type i = 0; + i < context->retvals().size(); ++i) { const XlaContext::HandleOrConstant& retval = context->retvals()[i]; if (retval.is_constant) { OutputDescription& output = result->outputs[i]; @@ -457,7 +459,8 @@ Status XlaCompiler::CompileGraph(string const& name, // Converts the output shapes to TensorShapes. int computation_output = 0; - for (int i = 0; i < context->retvals().size(); ++i) { + for (std::vector<XlaContext::HandleOrConstant>::size_type i = 0; + i < context->retvals().size(); ++i) { const XlaContext::HandleOrConstant& retval = context->retvals()[i]; if (!retval.is_constant) { CHECK_LT(computation_output, num_nonconst_outputs); @@ -474,7 +477,8 @@ Status XlaCompiler::CompileGraph(string const& name, } } - for (int i = 0; i < result->variable_updates.size(); ++i) { + for (std::vector<XlaCompiler::VariableWrite>::size_type i = 0; + i < result->variable_updates.size(); ++i) { if (num_computation_outputs > 1) { result->variable_updates[i].shape = XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape( diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index c4430dab65..2a8a4b321a 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -301,7 +301,8 @@ StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel( } std::vector<std::unique_ptr<GlobalData>> outputs; - for (int64 i = 0; i < computations.size(); ++i) { + for (tensorflow::gtl::ArraySlice<ComputationInstance>::size_type i = 0; + i < computations.size(); ++i) { outputs.push_back( MakeUnique<GlobalData>(stub_, response.responses(i).output())); if (computations[i].execution_profile != nullptr) { diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 2ee6b51c18..2cb86a2514 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -106,7 +106,9 @@ bool ComputationBuilder::MakeWindow( tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, tensorflow::gtl::ArraySlice<int64> lhs_dilation, tensorflow::gtl::ArraySlice<int64> rhs_dilation, Window* window) { - const auto verify_size = [&](const int64 x, const char* x_name) { + const auto verify_size = [&]( + const tensorflow::gtl::ArraySlice<int64>::size_type x, + const char* x_name) { if (x == 0 || x == window_dimensions.size()) { return true; } else { @@ -440,7 +442,8 @@ ComputationDataHandle ComputationBuilder::Collapse( // Don't support out-of-order collapse here. // Checks that the collapsed dimensions are in order and consecutive. - for (int i = 1; i < dims_to_collapse.size(); ++i) { + for (tensorflow::gtl::ArraySlice<int64>::size_type i = 1; + i < dims_to_collapse.size(); ++i) { if (dims_to_collapse[i] - 1 != dims_to_collapse[i - 1]) { NoteError(InvalidArgument( "Collapsed dimensions are not in order and consecutive.")); @@ -681,14 +684,15 @@ ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions( std::vector<int64> base_area_dimensions( dimension_numbers.spatial_dimensions_size()); - for (int i = 0; i < base_area_dimensions.size(); ++i) { + for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size() ; + ++i) { base_area_dimensions[i] = lhs_shape->dimensions(dimension_numbers.spatial_dimensions(i)); } std::vector<int64> window_dimensions( dimension_numbers.kernel_spatial_dimensions_size()); - for (int i = 0; i < window_dimensions.size(); ++i) { + for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) { window_dimensions[i] = rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); } @@ -740,7 +744,7 @@ ComputationDataHandle ComputationBuilder::ConvGeneralDilated( std::vector<int64> window_dimensions( dimension_numbers.kernel_spatial_dimensions_size()); - for (int i = 0; i < window_dimensions.size(); ++i) { + for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) { window_dimensions[i] = rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); } diff --git a/tensorflow/compiler/xla/client/padding.cc b/tensorflow/compiler/xla/client/padding.cc index 281fa10408..8d75815711 100644 --- a/tensorflow/compiler/xla/client/padding.cc +++ b/tensorflow/compiler/xla/client/padding.cc @@ -35,7 +35,8 @@ std::vector<std::pair<int64, int64>> MakePadding( return low_high_padding; case Padding::kSame: - for (int64 i = 0; i < input_dimensions.size(); ++i) { + for (tensorflow::gtl::ArraySlice<int64>::size_type i = 0; + i < input_dimensions.size(); ++i) { int64 input_dimension = input_dimensions[i]; int64 window_dimension = window_dimensions[i]; int64 window_stride = window_strides[i]; diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 4edbfd2482..eb937c3614 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -32,7 +32,8 @@ namespace xla { // Padding and nested layouts not supported yet. DCHECK_EQ(0, shape.layout().padded_dimensions_size()); - for (int i = 0; i < multi_index.size(); ++i) { + for (tensorflow::gtl::ArraySlice<int64>::size_type i = 0; + i < multi_index.size(); ++i) { DCHECK_GE(multi_index[i], 0); DCHECK_LT(multi_index[i], shape.dimensions(i)) << "indexing beyond extent in dimension " << i << ":" diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index a123213401..9bc336badd 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -136,7 +136,8 @@ tensorflow::Status AllocationTracker::DeallocateShape( TF_RET_CHECK(ShapeUtil::TupleElementCount(shape) == elements.size()) << "tuple has unexpected number of elements: " << elements.size() << " != " << ShapeUtil::TupleElementCount(shape); - for (int i = 0; i < elements.size(); ++i) { + for (std::vector<se::DeviceMemoryBase>::size_type i = 0; + i < elements.size(); ++i) { VLOG(2) << "recursing onto the tuple elements"; TF_RETURN_IF_ERROR(DeallocateShape(backend, device_ordinal, &elements[i], shape.tuple_shapes(i), @@ -170,7 +171,8 @@ StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple( executor, allocation->device_memory(), allocation->shape())); std::vector<GlobalDataHandle> element_handles; - for (int i = 0; i < element_bases.size(); ++i) { + for (std::vector<se::DeviceMemoryBase>::size_type i = 0; + i < element_bases.size(); ++i) { element_handles.push_back(RegisterInternal( allocation->backend(), allocation->device_ordinal(), element_bases[i], ShapeUtil::GetSubshape(allocation->shape(), {i}), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 08f417401d..14a92e0be6 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -517,9 +517,9 @@ CpuCompiler::CompileAheadOfTime( error.c_str()); } - llvm::Reloc::Model reloc_model; - llvm::PICLevel::Level pic_level; - llvm::PIELevel::Level pie_level; + llvm::Reloc::Model reloc_model = llvm::Reloc::Static; + llvm::PICLevel::Level pic_level = llvm::PICLevel::NotPIC; + llvm::PIELevel::Level pie_level = llvm::PIELevel::Default; switch (options.relocation_model()) { case CpuAotCompilationOptions::RelocationModel::Static: reloc_model = llvm::Reloc::Static; @@ -569,7 +569,8 @@ CpuCompiler::CompileAheadOfTime( } std::vector<std::unique_ptr<AotCompilationResult>> results; - for (int i = 0; i < hlo_modules.size(); ++i) { + for (std::vector<std::unique_ptr<HloModule>>::size_type i = 0; + i < hlo_modules.size(); ++i) { HloModule* hlo_module = hlo_modules[i].get(); HloModuleConfig* module_config = module_configs[i].get(); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index c71756c18b..7e47558e94 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -506,7 +506,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window, llvm_ir::IrArray::Index input_index(index.size()); llvm::Value* in_bounds_condition = nullptr; - for (int64 i = 0; i < index.size(); ++i) { + for (size_t i = 0; i < index.size(); ++i) { llvm::Value* strided_index = ir_builder_.CreateNSWMul( index[i], ir_builder_.getInt64(window.dimensions(i).stride())); input_index[i] = ir_builder_.CreateNSWSub( @@ -1111,7 +1111,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg, llvm_ir::IrArray::Index input_index = reduced_dims_index; llvm_ir::IrArray::Index::const_iterator it = index.begin(); - for (int64 i = 0; i < input_index.size(); ++i) { + for (auto i = 0; i < input_index.size(); ++i) { if (input_index[i] == nullptr) { input_index[i] = *it++; } @@ -1180,7 +1180,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { // output_index := edge_padding_low + operand_index * (interior_padding + 1) const PaddingConfig& padding_config = pad->padding_config(); llvm_ir::IrArray::Index output_index; - for (int64 i = 0; i < operand_index.size(); ++i) { + for (auto i = 0; i < operand_index.size(); ++i) { llvm::Value* offset = ir_builder_.CreateMul( operand_index[i], ir_builder_.getInt64(padding_config.dimensions(i).interior_padding() + @@ -1294,7 +1294,7 @@ Status IrEmitter::HandleCustomCall( llvm_ir::EmitAllocaAtFunctionEntryWithCount( i8_ptr_type, ir_builder_.getInt32(operands.size()), "cc_operands_alloca", &ir_builder_); - for (int i = 0; i < operands.size(); ++i) { + for (auto i = 0; i < operands.size(); ++i) { const HloInstruction* operand = operands[i]; llvm::Value* operand_as_i8ptr = ir_builder_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type); @@ -1659,7 +1659,7 @@ void IrEmitter::EmitArrayFunctionCallInto( ir_builder_.getInt32(parameter_addresses.size()), tensorflow::strings::StrCat(name, "_parameter_addresses"), &ir_builder_); - for (int i = 0; i < parameter_addresses.size(); ++i) { + for (auto i = 0; i < parameter_addresses.size(); ++i) { llvm::Value* parameter_as_i8ptr = ir_builder_.CreateBitCast( parameter_addresses[i], ir_builder_.getInt8PtrTy(), llvm_ir::AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 5b1a5a16d1..6ba65db570 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -40,7 +40,8 @@ Executable::ExecuteOnStreams( std::vector<perftools::gputools::DeviceMemoryBase> return_values( run_options.size()); - for (int64 i = 0; i < run_options.size(); ++i) { + for (tensorflow::gtl::ArraySlice<const ExecutableRunOptions>::size_type i = 0; + i < run_options.size(); ++i) { // We cannot BlockHostUntilDone() on the already-launched executions in case // of error, since if the executions communicate, the initially launched // executions may never complete if not all executions are running. diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index aa512f242a..3c3f1dbf57 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -118,7 +118,7 @@ GenericTransferManager::ShallowCopyTupleFromDevice( // Create a DeviceMemoryBase from each void* pointer. std::vector<se::DeviceMemoryBase> destination; - for (int i = 0; i < element_pointers.size(); ++i) { + for (std::vector<void*>::size_type i = 0; i < element_pointers.size(); ++i) { if (element_pointers[i] == nullptr && !ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) { return FailedPrecondition("tuple contains nullptr at element %d", i); diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 67c80bf93b..1cb03db4ee 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -276,7 +276,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( auto index = padded_index; llvm::Value* in_bounds = llvm::ConstantInt::get(ir_builder_->getInt1Ty(), 1); - for (int i = 0; i < index.size(); ++i) { + for (auto i = 0; i < index.size(); ++i) { auto index_typed_const = [=](int64 n) { return llvm::ConstantInt::get(index[i]->getType(), n); }; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index be441929bb..d17ef5a67e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -430,12 +430,12 @@ Status IrEmitter::HandleDot(HloInstruction* dot, // and lhs indexes with the reduction dimensions removed. The terms from the // rhs index are the lower dimensions in the index so we add them first. llvm_ir::IrArray::Index target_index; - for (int dimension = 0; dimension < lhs_index.size(); ++dimension) { + for (size_t dimension = 0; dimension < lhs_index.size(); ++dimension) { if (dimension != lhs_reduction_dimension) { target_index.push_back(lhs_index[dimension]); } } - for (int dimension = 0; dimension < rhs_index.size(); ++dimension) { + for (size_t dimension = 0; dimension < rhs_index.size(); ++dimension) { if (dimension != rhs_reduction_dimension) { target_index.push_back(rhs_index[dimension]); } @@ -513,7 +513,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg, llvm_ir::IrArray::Index input_index = reduced_dims_index; llvm_ir::IrArray::Index::const_iterator it = index.begin(); - for (int64 i = 0; i < input_index.size(); ++i) { + for (auto i = 0; i < input_index.size(); ++i) { if (input_index[i] == nullptr) { input_index[i] = *it++; } @@ -614,7 +614,7 @@ llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest( llvm_ir::IrArray::Index index = loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix); // Verify every dimension except the reduction dimension was set in the index. - for (int dimension = 0; dimension < index.size(); ++dimension) { + for (auto dimension = 0; dimension < index.size(); ++dimension) { if (dimension == reduction_dimension) { DCHECK_EQ(nullptr, index[dimension]); } else { diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index d7d8722ccc..19634cd33e 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -135,7 +135,9 @@ string InstructionSequenceGraph( std::vector<HloInstruction*> param_instructions; for (auto& instruction : instructions) { if (instruction->opcode() == HloOpcode::kParameter) { - int64 param_number = instruction->parameter_number(); + std::vector<HloInstruction*>::size_type param_number = + instruction->parameter_number(); + if (param_instructions.size() < param_number + 1) { param_instructions.resize(param_number + 1, nullptr); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 34a6bb8a52..c60e13fc9c 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -101,6 +101,8 @@ bool IsExpensive(const HloInstruction& instruction) { case HloOpcode::kRecv: return true; } + + return false; } bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer) { @@ -124,7 +126,8 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) { std::vector<HloInstruction*> post_order(post_order_list.begin(), post_order_list.end()); tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index; - for (int i = 0; i < post_order.size(); ++i) { + for (std::vector<HloInstruction*>::size_type i = 0; i < post_order.size(); + ++i) { InsertOrDie(&post_order_index, post_order[i], i); } diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 249e31cf26..4c2fe99e46 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -256,7 +256,8 @@ StatusOr<std::vector<const Allocation*>> Service::ResolveAndValidateArguments( tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments, const Backend* backend, int device_ordinal) { std::vector<const Allocation*> allocations; - for (int i = 0; i < arguments.size(); ++i) { + for (tensorflow::gtl::ArraySlice<const GlobalDataHandle*>::size_type i = 0; + i < arguments.size(); ++i) { auto allocation_status = allocation_tracker_.Resolve(*arguments[i]); if (!allocation_status.ok()) { return Status(allocation_status.status().code(), @@ -295,7 +296,8 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig( program_shape.parameters_size(), arguments.size()); } - for (int i = 0; i < arguments.size(); ++i) { + for (tensorflow::gtl::ArraySlice<const Allocation*>::size_type i = 0; + i < arguments.size(); ++i) { // Verify that shape of arguments matches the shape of the arguments in the // ProgramShape. if (!ShapeUtil::Compatible(arguments[i]->shape(), @@ -383,7 +385,8 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables( hlo_dumper, std::move(executors))); if (!other_directory_path.empty()) { - for (int64 i = 0; i < versioned_handles.size(); ++i) { + for (std::vector<VersionedComputationHandle>::size_type i = 0; + i < versioned_handles.size(); ++i) { executables[i]->set_session_module(std::move(session_modules[i])); } } @@ -523,7 +526,8 @@ Service::ExecuteParallelAndRegisterResult( // Asynchronously launch all executables. std::vector<GlobalDataHandle> result_handles; - for (int64 i = 0; i < executables.size(); i++) { + for (tensorflow::gtl::ArraySlice<Executable*>::size_type i = 0; + i < executables.size(); i++) { TF_ASSIGN_OR_RETURN( perftools::gputools::DeviceMemoryBase result, executables[i]->ExecuteAsyncOnStream(&run_options[i], arguments[i])); diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index a34c64d328..a758bb92aa 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -69,7 +69,10 @@ bool IsBinaryInstalled(const string& binary_name) { for (const string& dir : str_util::Split(path, ':')) { const string binary_path = io::JoinPath(dir, binary_name); char absolute_path[PATH_MAX + 1]; - ::realpath(binary_path.c_str(), absolute_path); + if (::realpath(binary_path.c_str(), absolute_path) == NULL) { + LOG(ERROR) << "Invalid binary path: " << binary_path; + return false; + } struct stat statinfo; int result = ::stat(absolute_path, &statinfo); if (result < 0) { diff --git a/tensorflow/core/kernels/hexagon/hexagon_rewriter_transform.cc b/tensorflow/core/kernels/hexagon/hexagon_rewriter_transform.cc index d90ff82eff..fb4378d737 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_rewriter_transform.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_rewriter_transform.cc @@ -47,7 +47,7 @@ Status RewriteQuantizedStrippedModelForHexagon( "graph execute op..."; std::vector<GraphTransferer::InputNodeInfo> inputs; std::vector<string> outputs; - for (int i = 0; i < context.input_names.size(); ++i) { + for (auto i = 0; i < context.input_names.size(); ++i) { const string& input_name = context.input_names.at(i); // Get input shape diff --git a/tensorflow/core/kernels/record_yielder.cc b/tensorflow/core/kernels/record_yielder.cc index e391752289..12ab794114 100644 --- a/tensorflow/core/kernels/record_yielder.cc +++ b/tensorflow/core/kernels/record_yielder.cc @@ -114,7 +114,7 @@ void RecordYielder::MainLoop() { std::shuffle(filenames.begin(), filenames.end(), shuffle_rnd); // Left-shift the filename list. - const int64 num = filenames.size(); + const std::vector<string>::size_type num = filenames.size(); int64 shift; if (0 <= opts_.file_shuffle_shift_ratio && opts_.file_shuffle_shift_ratio < 1) { @@ -129,7 +129,7 @@ void RecordYielder::MainLoop() { for (int i = 0; i < N; ++i) { Shard* shard = &shards[i]; shard->index = i; - for (int j = i; j < filenames.size(); j += N) { + for (std::vector<string>::size_type j = i; j < filenames.size(); j += N) { shard->filenames.push_back(filenames[j]); } thread_->Schedule([this, shard]() { ShardLoop(shard); }); diff --git a/tensorflow/core/kernels/record_yielder.h b/tensorflow/core/kernels/record_yielder.h index 503644f3b8..44f7c9511f 100644 --- a/tensorflow/core/kernels/record_yielder.h +++ b/tensorflow/core/kernels/record_yielder.h @@ -142,7 +142,7 @@ class RecordYielder { // any. return stop_ || !status_.ok() || (epoch_end_ && !buf_.empty()) || (!epoch_end_ && - buf_.size() >= std::max<int64>(1, opts_.bufsize / 2)); + buf_.size() >= std::max<uint64>(1, opts_.bufsize / 2)); } void MainLoop(); diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 44c6d70636..07ef8df1b5 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -157,6 +157,7 @@ string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) { case dnn::DepthToSpaceLayout::DepthHeightWidth: return "DepthToSpaceLayout::DepthHeightWidth"; } + return "unknown DepthToSpaceLayout"; } // Used together with PARAM to VLOG calls made to the stream. Intended |