aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Vincent Vanhoucke <vanhoucke@google.com>2016-08-31 16:01:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-31 17:17:24 -0700
commit7a1210bdbdade7210d48db287065ecac950338aa (patch)
treec39be641a1072944866b16d3daed4204c9bb5543 /tensorflow/core
parent62c159ffe847eeb788550a32b8be572e41055022 (diff)
Fix ~63 ClangTidy - Performance findings in TensorFlow.
Change: 131891101
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.cc2
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc2
-rw-r--r--tensorflow/core/common_runtime/copy_tensor.cc5
-rw-r--r--tensorflow/core/common_runtime/device_set.cc4
-rw-r--r--tensorflow/core/common_runtime/function_test.cc2
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc2
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc2
-rw-r--r--tensorflow/core/common_runtime/gpu/pool_allocator.cc11
-rw-r--r--tensorflow/core/common_runtime/simple_placer.cc8
-rw-r--r--tensorflow/core/framework/function.cc4
-rw-r--r--tensorflow/core/framework/function_testlib.cc6
-rw-r--r--tensorflow/core/framework/op_def_util.cc2
-rw-r--r--tensorflow/core/framework/op_kernel_test.cc4
-rw-r--r--tensorflow/core/graph/optimizer_cse_test.cc2
-rw-r--r--tensorflow/core/graph/quantize_training.cc2
-rw-r--r--tensorflow/core/kernels/argmax_op.cc2
-rw-r--r--tensorflow/core/kernels/attention_ops.cc2
-rw-r--r--tensorflow/core/kernels/candidate_sampler_ops.cc2
-rw-r--r--tensorflow/core/kernels/gather_nd_op.cc4
-rw-r--r--tensorflow/core/kernels/maxpooling_op.cc2
-rw-r--r--tensorflow/core/kernels/scan_ops.cc2
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_mem.cc3
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc6
-rw-r--r--tensorflow/core/util/tensor_slice_reader.cc3
-rw-r--r--tensorflow/core/util/tensor_slice_set.cc10
-rw-r--r--tensorflow/core/util/tensor_slice_writer.cc4
26 files changed, 53 insertions, 45 deletions
diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc
index 70b01d6485..f525d1d981 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/bfc_allocator.cc
@@ -157,7 +157,7 @@ bool BFCAllocator::Extend(size_t rounded_bytes) {
InsertFreeChunkIntoBin(h);
// Invoke visitors on newly allocated region.
- for (auto visitor : region_visitors_) {
+ for (const auto& visitor : region_visitors_) {
visitor(mem_addr, bytes);
}
return true;
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index 9bd162b72f..6a49c940b3 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -279,7 +279,7 @@ bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device,
edges_to_remove.push_back(out_edge);
}
}
- string node_name = n->name();
+ const string& node_name = n->name();
Node* constant_node;
auto builder = NodeDefBuilder(strings::StrCat(graph->NewName(node_name),
"__cf__", UniqueConstantId()),
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index 5dc8c33b2a..e55ef7d5ba 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include <atomic>
+#include <utility>
#include <vector>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
@@ -26,7 +27,9 @@ namespace {
struct RegistrationInfo {
RegistrationInfo(DeviceType s, DeviceType r, CopyTensor::CopyFunction cf)
- : sender_device_type(s), receiver_device_type(r), copy_function(cf) {}
+ : sender_device_type(std::move(s)),
+ receiver_device_type(r),
+ copy_function(cf) {}
DeviceType sender_device_type;
DeviceType receiver_device_type;
CopyTensor::CopyFunction copy_function;
diff --git a/tensorflow/core/common_runtime/device_set.cc b/tensorflow/core/common_runtime/device_set.cc
index 98c6c3843c..8ff93760d4 100644
--- a/tensorflow/core/common_runtime/device_set.cc
+++ b/tensorflow/core/common_runtime/device_set.cc
@@ -71,9 +71,9 @@ std::vector<DeviceType> DeviceSet::PrioritizedDeviceTypeList() const {
std::vector<DeviceType> result;
std::set<string> seen;
for (Device* d : devices_) {
- auto t = d->device_type();
+ const auto& t = d->device_type();
if (seen.insert(t).second) {
- result.emplace_back(DeviceType(t));
+ result.emplace_back(t);
}
}
std::sort(result.begin(), result.end(), DeviceTypeComparator);
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 2f5507a0c5..e263e62bd8 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -144,7 +144,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
void Init(const std::vector<FunctionDef>& flib) {
FunctionDefLibrary proto;
- for (auto fdef : flib) *(proto.add_function()) = fdef;
+ for (const auto& fdef : flib) *(proto.add_function()) = fdef;
delete lib_def_;
lib_def_ = new FunctionLibraryDefinition(OpRegistry::Global(), proto);
delete lib_;
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
index 7506e35ff3..f18ee5efd8 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
@@ -95,7 +95,7 @@ void EventMgr::ThenDeleteTensors(perftools::gputools::Stream* stream,
FlushAccumulatedTensors();
}
accumulated_stream_ = stream;
- for (auto t : tensors) {
+ for (const auto& t : tensors) {
// accumulated_tensors_ takes over ownership of the reference to "t"
accumulated_tensors_->push_back(t);
accumulated_tensor_bytes_ += t.TotalBytes();
diff --git a/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc b/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc
index 5b4812bb34..3aaaf87e79 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_stream_util_test.cc
@@ -129,7 +129,7 @@ TEST_F(GpuStreamUtilTest, StreamOverrides) {
// Nodes should be assigned to streams by op type.
for (const auto& it : node_to_stream_id) {
Node* n = g.FindNodeId(it.first);
- const string op = n->type_string();
+ const string& op = n->type_string();
const int stream = it.second;
if (op == "Const") {
EXPECT_EQ(stream, 90);
diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator.cc b/tensorflow/core/common_runtime/gpu/pool_allocator.cc
index b44108d1ac..e0362b38e6 100644
--- a/tensorflow/core/common_runtime/gpu/pool_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/pool_allocator.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <sys/mman.h> // for munmap
#include <map>
+#include <utility>
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logging.h"
@@ -31,7 +32,7 @@ namespace tensorflow {
PoolAllocator::PoolAllocator(size_t pool_size_limit, bool auto_resize,
SubAllocator* allocator,
RoundUpInterface* size_rounder, string name)
- : name_(name),
+ : name_(std::move(name)),
has_size_limit_(pool_size_limit > 0),
auto_resize_(auto_resize),
pool_size_limit_(pool_size_limit),
@@ -125,7 +126,7 @@ void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
return PrepareChunk(r, alignment, num_bytes);
} else {
void* ptr = allocator_->Alloc(kPoolAlignment, num_bytes);
- for (auto v : alloc_visitors_) {
+ for (const auto& v : alloc_visitors_) {
v(ptr, num_bytes);
}
return PrepareChunk(ptr, alignment, num_bytes);
@@ -137,7 +138,7 @@ void PoolAllocator::DeallocateRaw(void* ptr) {
ChunkPrefix* cp = FindPrefix(ptr);
CHECK_LE((void*)cp, (void*)ptr);
if (!has_size_limit_ && !auto_resize_) {
- for (auto v : free_visitors_) {
+ for (const auto& v : free_visitors_) {
v(cp, cp->num_bytes);
}
allocator_->Free(cp, cp->num_bytes);
@@ -160,7 +161,7 @@ void PoolAllocator::Clear() {
mutex_lock lock(mutex_);
for (auto iter : pool_) {
PtrRecord* pr = iter.second;
- for (auto v : free_visitors_) {
+ for (const auto& v : free_visitors_) {
v(pr->ptr, pr->num_bytes);
}
allocator_->Free(pr->ptr, pr->num_bytes);
@@ -217,7 +218,7 @@ void PoolAllocator::EvictOne() {
DCHECK(iter != pool_.end());
}
pool_.erase(iter);
- for (auto v : free_visitors_) {
+ for (const auto& v : free_visitors_) {
v(prec->ptr, prec->num_bytes);
}
allocator_->Free(prec->ptr, prec->num_bytes);
diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc
index 6e177da57f..2a9e0fa196 100644
--- a/tensorflow/core/common_runtime/simple_placer.cc
+++ b/tensorflow/core/common_runtime/simple_placer.cc
@@ -42,7 +42,7 @@ std::vector<Device*> FilterSupportedDevices(
const std::vector<Device*>& devices,
const DeviceTypeVector& supported_device_types) {
std::vector<Device*> filtered_devices;
- for (DeviceType d : supported_device_types) {
+ for (const DeviceType& d : supported_device_types) {
for (Device* device : devices) {
if (DeviceType(device->attributes().device_type()) == d) {
filtered_devices.emplace_back(device);
@@ -495,7 +495,7 @@ class ColocationGraph {
"' does not match any device");
}
- for (DeviceType d : member->supported_device_types) {
+ for (const DeviceType& d : member->supported_device_types) {
if (DeviceType(assigned_device->attributes().device_type()) == d) {
return Status::OK();
}
@@ -545,9 +545,9 @@ class ColocationGraph {
target->clear();
// Iterate in priority order.
- for (DeviceType device_type : temp) {
+ for (const DeviceType& device_type : temp) {
bool found = false;
- for (DeviceType other_device_type : other) {
+ for (const DeviceType& other_device_type : other) {
if (device_type == other_device_type) {
found = true;
break;
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 83676a90c5..bedc85ab4e 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -861,11 +861,11 @@ string DebugString(const GraphDef& instantiated_func_def) {
string DebugStringWhole(const GraphDef& gdef) {
string ret;
- for (auto fdef : gdef.library().function()) {
+ for (const auto& fdef : gdef.library().function()) {
strings::StrAppend(&ret, Print(fdef));
}
strings::StrAppend(&ret, "\n");
- for (auto ndef : gdef.node()) {
+ for (const auto& ndef : gdef.node()) {
strings::StrAppend(&ret, Print(ndef), "\n");
}
return ret;
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
index 900ceed1a5..47db0f0339 100644
--- a/tensorflow/core/framework/function_testlib.cc
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -31,11 +31,11 @@ GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
VersionDef* versions = g.mutable_versions();
versions->set_producer(TF_GRAPH_DEF_VERSION);
versions->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
- for (auto n : nodes) {
+ for (const auto& n : nodes) {
*(g.add_node()) = n;
}
auto lib = g.mutable_library();
- for (auto f : funcs) {
+ for (const auto& f : funcs) {
*(lib->add_function()) = f;
}
return g;
@@ -49,7 +49,7 @@ NodeDef NDef(const string& name, const string& op,
NodeDef n;
n.set_name(name);
n.set_op(op);
- for (auto in : inputs) n.add_input(in);
+ for (const auto& in : inputs) n.add_input(in);
n.set_device(device);
for (auto na : attrs) n.mutable_attr()->insert({na.first, na.second.proto});
return n;
diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc
index 5717488b1c..c36e6dd653 100644
--- a/tensorflow/core/framework/op_def_util.cc
+++ b/tensorflow/core/framework/op_def_util.cc
@@ -60,7 +60,7 @@ Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) {
Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) {
const AttrValue& allowed_values(attr.allowed_values());
- for (auto allowed : allowed_values.list().s()) {
+ for (const auto& allowed : allowed_values.list().s()) {
if (str == allowed) {
return Status::OK();
}
diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc
index db4b6037ef..b4556c9272 100644
--- a/tensorflow/core/framework/op_kernel_test.cc
+++ b/tensorflow/core/framework/op_kernel_test.cc
@@ -381,7 +381,7 @@ class OpKernelBuilderTest : public ::testing::Test {
DeviceTypeVector devices;
TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
bool found = false;
- for (DeviceType dt : devices) {
+ for (const DeviceType& dt : devices) {
if (dt == device_type) {
found = true;
}
@@ -414,7 +414,7 @@ class OpKernelBuilderTest : public ::testing::Test {
DeviceTypeVector devices;
if (errors::IsNotFound(status)) {
TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
- for (DeviceType dt : devices) {
+ for (const DeviceType& dt : devices) {
EXPECT_NE(dt, device_type);
}
} else {
diff --git a/tensorflow/core/graph/optimizer_cse_test.cc b/tensorflow/core/graph/optimizer_cse_test.cc
index 0841bac93c..1091af4e45 100644
--- a/tensorflow/core/graph/optimizer_cse_test.cc
+++ b/tensorflow/core/graph/optimizer_cse_test.cc
@@ -326,7 +326,7 @@ TEST_F(OptimizerCSETest, Constant_Dedup) {
// A graph contains a bunch of constants.
Graph g(OpRegistry::Global());
- for (auto val : {a, b, c, d, d, c, b, a}) {
+ for (const auto& val : {a, b, c, d, d, c, b, a}) {
test::graph::Constant(&g, val); // Node name is n/_0, n/_1, ...
}
GraphDef gdef;
diff --git a/tensorflow/core/graph/quantize_training.cc b/tensorflow/core/graph/quantize_training.cc
index 8521dff6fa..930d7bd15f 100644
--- a/tensorflow/core/graph/quantize_training.cc
+++ b/tensorflow/core/graph/quantize_training.cc
@@ -74,7 +74,7 @@ inline bool IsGradientNode(const Graph* graph, const Node* node) {
// Returns true if the root tensor op type is known, false otherwise.
bool FindType(const Graph* graph, const Node* node, bool* signed_input,
bool* range_given, float* input_min, float* input_max) {
- const string src_op = node->type_string();
+ const string& src_op = node->type_string();
if (src_op == "Const" || src_op == "Variable") {
*signed_input = true;
*range_given = false;
diff --git a/tensorflow/core/kernels/argmax_op.cc b/tensorflow/core/kernels/argmax_op.cc
index 595bd7bd5e..2f92a2da9f 100644
--- a/tensorflow/core/kernels/argmax_op.cc
+++ b/tensorflow/core/kernels/argmax_op.cc
@@ -67,7 +67,7 @@ class ArgOp : public OpKernel {
input.shape().DebugString()));
TensorShape output_shape;
- TensorShape input_shape = input.shape();
+ const TensorShape& input_shape = input.shape();
for (int d = 0; d < input_dims - 1; ++d) {
output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1));
}
diff --git a/tensorflow/core/kernels/attention_ops.cc b/tensorflow/core/kernels/attention_ops.cc
index 695068d315..cc8f122cab 100644
--- a/tensorflow/core/kernels/attention_ops.cc
+++ b/tensorflow/core/kernels/attention_ops.cc
@@ -41,7 +41,7 @@ class ExtractGlimpseOp : public OpKernel {
// depth).
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
- const TensorShape input_shape = input.shape();
+ const TensorShape& input_shape = input.shape();
const int32 num_dims = input_shape.dims();
OP_REQUIRES(
context, num_dims == 4,
diff --git a/tensorflow/core/kernels/candidate_sampler_ops.cc b/tensorflow/core/kernels/candidate_sampler_ops.cc
index d64dca3d0b..6aa9059dc7 100644
--- a/tensorflow/core/kernels/candidate_sampler_ops.cc
+++ b/tensorflow/core/kernels/candidate_sampler_ops.cc
@@ -190,7 +190,7 @@ class ComputeAccidentalHitsOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& in_true_candidates = context->input(0);
- TensorShape in_true_candidates_shape = in_true_candidates.shape();
+ const TensorShape& in_true_candidates_shape = in_true_candidates.shape();
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(in_true_candidates_shape) &&
in_true_candidates_shape.dim_size(1) == num_true_,
errors::InvalidArgument(
diff --git a/tensorflow/core/kernels/gather_nd_op.cc b/tensorflow/core/kernels/gather_nd_op.cc
index c2a5192efb..73f30cdae3 100644
--- a/tensorflow/core/kernels/gather_nd_op.cc
+++ b/tensorflow/core/kernels/gather_nd_op.cc
@@ -53,7 +53,7 @@ class GatherNdOp : public OpKernel {
"index innermost dimension length must be <= params rank; saw: ",
indices.dim_size(indices.dims() - 1), " vs. ", params.dims()));
- TensorShape indices_shape(indices.shape());
+ const TensorShape& indices_shape(indices.shape());
const int64 indices_nd = indices_shape.dim_size(indices_shape.dims() - 1);
// Check that we have enough index space
@@ -79,7 +79,7 @@ class GatherNdOp : public OpKernel {
N_result *= indices_shape.dim_size(i);
}
- TensorShape params_shape(params.shape());
+ const TensorShape& params_shape(params.shape());
Index total_nd = params_shape.dims();
TensorShape result_shape(indices_shape);
diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc
index 97e2bfcad5..27888d3a31 100644
--- a/tensorflow/core/kernels/maxpooling_op.cc
+++ b/tensorflow/core/kernels/maxpooling_op.cc
@@ -272,7 +272,7 @@ class MaxPoolingGradOp : public OpKernel {
OP_REQUIRES(context, out_backprop.dims() == 4,
errors::InvalidArgument("out_backprop must be 4-dimensional"));
- TensorShape output_shape = tensor_in.shape();
+ const TensorShape& output_shape = tensor_in.shape();
Tensor tensor_out_dup;
OP_REQUIRES_OK(context,
diff --git a/tensorflow/core/kernels/scan_ops.cc b/tensorflow/core/kernels/scan_ops.cc
index 604e712b0f..2604b73844 100644
--- a/tensorflow/core/kernels/scan_ops.cc
+++ b/tensorflow/core/kernels/scan_ops.cc
@@ -58,7 +58,7 @@ public:
errors::InvalidArgument("ScanOp: Expected scan axis in the range [", 0,
", ", input.dims(), "), but got ", axis));
- TensorShape output_shape = input.shape();
+ const TensorShape& output_shape = input.shape();
Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output));
diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.cc b/tensorflow/core/lib/jpeg/jpeg_mem.cc
index 9a317f1fd2..ac12798322 100644
--- a/tensorflow/core/lib/jpeg/jpeg_mem.cc
+++ b/tensorflow/core/lib/jpeg/jpeg_mem.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <algorithm>
#include <memory>
#include <string>
+#include <utility>
#include "tensorflow/core/lib/jpeg/jpeg_handle.h"
#include "tensorflow/core/platform/logging.h"
@@ -52,7 +53,7 @@ class FewerArgsForCompiler {
: datasize_(datasize),
flags_(flags),
pnwarn_(nwarn),
- allocate_output_(allocate_output),
+ allocate_output_(std::move(allocate_output)),
height_read_(0),
height_(0),
stride_(0) {
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 7426006cec..fc35c293d2 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -92,7 +92,7 @@ class GcsRandomAccessFile : public RandomAccessFile {
: bucket_(bucket),
object_(object),
auth_provider_(auth_provider),
- http_request_factory_(std::move(http_request_factory)),
+ http_request_factory_(http_request_factory),
read_ahead_bytes_(read_ahead_bytes) {}
/// The implementation of reads with a read-ahead buffer.
@@ -189,7 +189,7 @@ class GcsWritableFile : public WritableFile {
: bucket_(bucket),
object_(object),
auth_provider_(auth_provider),
- http_request_factory_(std::move(http_request_factory)) {
+ http_request_factory_(http_request_factory) {
if (GetTmpFilename(&tmp_content_filename_).ok()) {
outfile_.open(tmp_content_filename_,
std::ofstream::binary | std::ofstream::app);
@@ -208,7 +208,7 @@ class GcsWritableFile : public WritableFile {
: bucket_(bucket),
object_(object),
auth_provider_(auth_provider),
- http_request_factory_(std::move(http_request_factory)) {
+ http_request_factory_(http_request_factory) {
tmp_content_filename_ = tmp_content_filename;
outfile_.open(tmp_content_filename_,
std::ofstream::binary | std::ofstream::app);
diff --git a/tensorflow/core/util/tensor_slice_reader.cc b/tensorflow/core/util/tensor_slice_reader.cc
index 9ab81af43b..b40f5e7736 100644
--- a/tensorflow/core/util/tensor_slice_reader.cc
+++ b/tensorflow/core/util/tensor_slice_reader.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/util/tensor_slice_reader.h"
+#include <utility>
#include <vector>
#include "tensorflow/core/framework/types.pb_text.h"
#include "tensorflow/core/framework/versions.h"
@@ -107,7 +108,7 @@ TensorSliceReader::TensorSliceReader(const string& filepattern,
TensorSliceReader::TensorSliceReader(const string& filepattern,
OpenTableFunction open_function,
int preferred_shard)
- : filepattern_(filepattern), open_function_(open_function) {
+ : filepattern_(filepattern), open_function_(std::move(open_function)) {
VLOG(1) << "TensorSliceReader for " << filepattern;
Status s = io::GetMatchingFiles(Env::Default(), filepattern, &fnames_);
if (!s.ok()) {
diff --git a/tensorflow/core/util/tensor_slice_set.cc b/tensorflow/core/util/tensor_slice_set.cc
index d4b9a4087c..4217df90ca 100644
--- a/tensorflow/core/util/tensor_slice_set.cc
+++ b/tensorflow/core/util/tensor_slice_set.cc
@@ -42,7 +42,7 @@ Status TensorSliceSet::Register(const TensorSlice& slice, const string& tag,
// We check if there is any intersection between this slice and any of the
// registered slices.
if (slices_hull_.Overlaps(slice)) {
- for (const auto x : slices_) {
+ for (const auto& x : slices_) {
if (slice.Overlaps(x.second.slice)) {
return errors::Internal("Overlapping slices: existing slice = ",
x.first, ", new slice = ", str);
@@ -89,7 +89,7 @@ bool TensorSliceSet::Query(const TensorSlice& slice, float* data) const {
int64 overlap_size = 0;
TensorSlice intersection;
TensorShape inter_shape;
- for (const auto x : slices_) {
+ for (const auto& x : slices_) {
if (slice.Intersect(x.second.slice, &intersection)) {
s = intersection.SliceTensorShape(shape_, &inter_shape);
if (!s.ok()) {
@@ -103,7 +103,7 @@ bool TensorSliceSet::Query(const TensorSlice& slice, float* data) const {
// We have it!
// Now we need to copy the data to "data"
if (data) {
- for (const auto x : slices_) {
+ for (const auto& x : slices_) {
CopyDataFromTensorSliceToTensorSlice(shape_, x.second.slice, slice,
x.second.data, data);
}
@@ -146,7 +146,7 @@ bool TensorSliceSet::QueryMeta(
int64 overlap_size = 0;
TensorSlice intersection;
TensorShape inter_shape;
- for (const auto x : slices_) {
+ for (const auto& x : slices_) {
if (slice.Intersect(x.second.slice, &intersection)) {
s = intersection.SliceTensorShape(shape_, &inter_shape);
if (!s.ok()) {
@@ -180,7 +180,7 @@ Status RegisterTensorSlice(
tensor_slices->insert(std::make_pair(name, tss));
} else {
// Check if the shapes match
- TensorShape tss_shape(tss->shape());
+ const TensorShape& tss_shape(tss->shape());
if (!shape.IsSameSize(tss_shape)) {
return errors::Internal("Incompatible tensor shapes detected for tensor ",
name, ": existing = ", tss_shape.DebugString(),
diff --git a/tensorflow/core/util/tensor_slice_writer.cc b/tensorflow/core/util/tensor_slice_writer.cc
index 8907aa6522..928d6fe72c 100644
--- a/tensorflow/core/util/tensor_slice_writer.cc
+++ b/tensorflow/core/util/tensor_slice_writer.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/core/util/tensor_slice_writer.h"
+#include <utility>
+
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/table_builder.h"
#include "tensorflow/core/lib/random/random.h"
@@ -81,7 +83,7 @@ Status CreateTableTensorSliceBuilder(const string& name,
TensorSliceWriter::TensorSliceWriter(const string& filename,
CreateBuilderFunction create_builder)
: filename_(filename),
- create_builder_(create_builder),
+ create_builder_(std::move(create_builder)),
tmpname_(strings::StrCat(filename, ".tempstate", random::New64())),
slices_(0) {
VersionDef* versions = sts_.mutable_meta()->mutable_versions();