diff options
author | Peter Hawkins <phawkins@google.com> | 2017-05-09 09:14:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-10 16:29:47 -0700 |
commit | 1d0b8c007b8bc7f77dd63c74f02d87185071f038 (patch) | |
tree | b72df4064224d66c62bb4a126efb06fa214fa439 /tensorflow/core | |
parent | b9845c6d0d5dc601fb3b58206a7070aa8937af4f (diff) |
Remove unnecessary copies of value parameters.
PiperOrigin-RevId: 155511618
Diffstat (limited to 'tensorflow/core')
43 files changed, 146 insertions, 102 deletions
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index ed5b87f2f2..ec0c9405dd 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -1213,7 +1213,8 @@ class ExecutorState { GUARDED_BY(mu_); // The unique name of a frame. - inline string MakeFrameName(FrameState* frame, int64 iter_id, string name) { + inline string MakeFrameName(FrameState* frame, int64 iter_id, + const string& name) { return strings::StrCat(frame->frame_name, ";", iter_id, ";", name); } diff --git a/tensorflow/core/common_runtime/gpu/gpu_tracer_test.cc b/tensorflow/core/common_runtime/gpu/gpu_tracer_test.cc index b1be278ab4..aaa25ad345 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_tracer_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_tracer_test.cc @@ -82,7 +82,7 @@ class GPUTracerTest : public ::testing::Test { } protected: - void ExpectFailure(Status status, error::Code code) { + void ExpectFailure(const Status& status, error::Code code) { EXPECT_FALSE(status.ok()); if (!status.ok()) { LOG(INFO) << "Status message: " << status.error_message(); diff --git a/tensorflow/core/common_runtime/memory_types.cc b/tensorflow/core/common_runtime/memory_types.cc index 80c483e70b..db053dd2fa 100644 --- a/tensorflow/core/common_runtime/memory_types.cc +++ b/tensorflow/core/common_runtime/memory_types.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/memory_types.h" +#include <utility> + #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/node_builder.h" @@ -43,8 +45,8 @@ struct EndpointEq { }; static Status ProcessMemoryTypes( - DeviceType device_type, const Graph* g, - std::function<Status(const Edge*, MemoryType, MemoryType)> fn) { + const DeviceType& device_type, const Graph* g, + const std::function<Status(const Edge*, MemoryType, MemoryType)>& fn) { if (device_type != DEVICE_GPU) { // On non-GPU devices, HOST_MEMORY and DEVICE_MEMORY are always // compatible. @@ -88,17 +90,18 @@ static Status ProcessMemoryTypes( return Status::OK(); } -Status ValidateMemoryTypes(DeviceType device_type, const Graph* g) { - return ProcessMemoryTypes(device_type, g, [g](const Edge* e, MemoryType sm, - MemoryType dm) { - if (sm == dm) { - return Status::OK(); - } - return errors::Internal( - "Memory type mismatch (", sm, " ", dm, ") between :", e->src()->id(), - ":", e->src_output(), " and ", e->dst()->id(), ":", e->dst_input(), - " : from ", e->src()->DebugString(), " to ", e->dst()->DebugString()); - }); +Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g) { + return ProcessMemoryTypes( + device_type, g, [g](const Edge* e, MemoryType sm, MemoryType dm) { + if (sm == dm) { + return Status::OK(); + } + return errors::Internal( + "Memory type mismatch (", sm, " ", dm, + ") between :", e->src()->id(), ":", e->src_output(), " and ", + e->dst()->id(), ":", e->dst_input(), " : from ", + e->src()->DebugString(), " to ", e->dst()->DebugString()); + }); } static Node* Send(Graph* g, const string& device_name, bool host, @@ -132,8 +135,8 @@ static Node* Recv(Graph* g, const string& device_name, bool host, return ret; } -Status EnsureMemoryTypes(DeviceType device_type, const string& device_name, - Graph* g) { +Status EnsureMemoryTypes(const DeviceType& device_type, + const string& device_name, Graph* g) { struct Item { const Edge* edge; MemoryType sm; @@ -185,7 +188,7 @@ Status EnsureMemoryTypes(DeviceType device_type, const string& device_name, return ValidateMemoryTypes(device_type, g); } -Status MemoryTypeForOutput(DeviceType device_type, const Graph* g, +Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g, const Node* n, int index, MemoryType* memory_type) { MemoryTypeVector inp_mvec; MemoryTypeVector out_mvec; diff --git a/tensorflow/core/common_runtime/memory_types.h b/tensorflow/core/common_runtime/memory_types.h index ccbb8cffb1..fa0a7595f3 100644 --- a/tensorflow/core/common_runtime/memory_types.h +++ b/tensorflow/core/common_runtime/memory_types.h @@ -24,7 +24,7 @@ namespace tensorflow { // Returns an error iff *g running on a single device of 'device_type' // has memory type mismatch for any edge's source and destination. -Status ValidateMemoryTypes(DeviceType device_type, const Graph* g); +Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g); // Updates '*g' so that every edge's source and destination has // compatible memory types by inserting proper HostSend/Recv and @@ -35,12 +35,12 @@ Status ValidateMemoryTypes(DeviceType device_type, const Graph* g); // Returns OK if '*g' is updated properly (ValidateMemoryTypes(g) must // be OK). Otherwise, returns an error and '*g' may be in an // invalidate state and the caller should discard it. -Status EnsureMemoryTypes(DeviceType device_type, const string& device_name, - Graph* g); +Status EnsureMemoryTypes(const DeviceType& device_type, + const string& device_name, Graph* g); // Get the memory type for 'index'th output of node 'n' in graph 'g', when // running on 'device_type'. -Status MemoryTypeForOutput(DeviceType device_type, const Graph* g, +Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g, const Node* n, int index, MemoryType* memory_type); } // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc index 24f27af5f1..69ed58b33c 100644 --- a/tensorflow/core/common_runtime/simple_placer_test.cc +++ b/tensorflow/core/common_runtime/simple_placer_test.cc @@ -237,7 +237,7 @@ class SimplePlacerTest : public ::testing::Test { Status ReferenceTestHelper(const string& variable_op_type, const string& assign_op_type, - DeviceType expected_device_type); + const DeviceType& expected_device_type); }; #define EXPECT_COLOCATED(g, name_a, name_b) \ @@ -500,9 +500,9 @@ TEST_F(SimplePlacerTest, TestAssignedGpuDeviceToCpuDevice) { // Build a graph containing a Variable op of "variable_op_type" and an // Assign op of "assign_op_type", and expect all of the ops to be // placed on a device of type "expected_device_type". -Status SimplePlacerTest::ReferenceTestHelper(const string& variable_op_type, - const string& assign_op_type, - DeviceType expected_device_type) { +Status SimplePlacerTest::ReferenceTestHelper( + const string& variable_op_type, const string& assign_op_type, + const DeviceType& expected_device_type) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); diff --git a/tensorflow/core/debug/debug_gateway.cc b/tensorflow/core/debug/debug_gateway.cc index 24b9dd799a..1031ea843e 100644 --- a/tensorflow/core/debug/debug_gateway.cc +++ b/tensorflow/core/debug/debug_gateway.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/debug/debug_gateway.h" +#include <utility> + #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/session_factory.h" #include "tensorflow/core/framework/tensor.h" @@ -56,11 +58,11 @@ DebugGateway::~DebugGateway() { } void DebugGateway::SetNodeCompletionCallback(NodeCompletionCallback callback) { - comp_cb_ = callback; + comp_cb_ = std::move(callback); } void DebugGateway::SetNodeValueCallback(NodeValueCallback callback) { - val_cb_ = callback; + val_cb_ = std::move(callback); } void DebugGateway::CopyTensor(const string& node_name, const int output_slot, diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 50c5d90fc9..73758ade03 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -63,7 +63,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts, std::unique_ptr<SimpleClientGraph> cg, const SessionOptions& session_opts, - StatsPublisherFactory stats_publisher_factory, + const StatsPublisherFactory& stats_publisher_factory, SimpleGraphExecutionState* execution_state, bool is_partial, WorkerCacheInterface* worker_cache) : session_handle_(handle), diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc index c3b76ed31b..bf72d9a7fc 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h" +#include <utility> + #include "tensorflow/core/distributed_runtime/call_options.h" #include "tensorflow/core/distributed_runtime/master_interface.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h" @@ -29,7 +31,7 @@ namespace tensorflow { // that uses gRPC to talk to the Master service. class GrpcRemoteMaster : public MasterInterface { public: - explicit GrpcRemoteMaster(SharedGrpcChannelPtr client_channel) + explicit GrpcRemoteMaster(const SharedGrpcChannelPtr& client_channel) : stub_(grpc::MasterService::NewStub(client_channel)) {} ~GrpcRemoteMaster() override {} @@ -106,7 +108,7 @@ class GrpcRemoteMaster : public MasterInterface { } }; -MasterInterface* NewGrpcMaster(SharedGrpcChannelPtr channel) { +MasterInterface* NewGrpcMaster(const SharedGrpcChannelPtr& channel) { return new GrpcRemoteMaster(channel); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h index 881a6b10e3..d661caaa60 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h @@ -21,7 +21,7 @@ limitations under the License. namespace tensorflow { // Returns a MasterInterface wrapped around the gRPC channel `channel`. -MasterInterface* NewGrpcMaster(SharedGrpcChannelPtr channel); +MasterInterface* NewGrpcMaster(const SharedGrpcChannelPtr& channel); } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc index 36626e1a33..2b1a47a93f 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h" +#include <utility> + #include "grpc++/grpc++.h" #include "tensorflow/core/common_runtime/process_util.h" @@ -37,7 +39,7 @@ class GrpcRemoteWorker : public WorkerInterface { explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel, ::grpc::CompletionQueue* completion_queue, WorkerCacheLogger* logger) - : channel_(channel), + : channel_(std::move(channel)), cq_(completion_queue), getstatus_(Method(GrpcWorkerMethod::kGetStatus)), createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)), @@ -272,7 +274,7 @@ class GrpcRemoteWorker : public WorkerInterface { WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel, ::grpc::CompletionQueue* completion_queue, WorkerCacheLogger* logger) { - return new GrpcRemoteWorker(channel, completion_queue, logger); + return new GrpcRemoteWorker(std::move(channel), completion_queue, logger); } } // namespace tensorflow diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index a387d49613..9c47c1da2f 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include <unordered_map> +#include <utility> #include <vector> #include "tensorflow/core/framework/function.pb_text.h" @@ -601,7 +602,8 @@ string Print(const GraphDef& gdef) { return out; } -Status AddDefaultAttrs(const string& op, GetFunctionSignature get_function, +Status AddDefaultAttrs(const string& op, + const GetFunctionSignature& get_function, InstantiateAttrValueMap* attrs) { const OpDef* op_def = nullptr; TF_RETURN_IF_ERROR(get_function(op, &op_def)); @@ -987,7 +989,7 @@ Status InstantiateFunction(const FunctionDef& fdef, for (const auto& aval : attr_values) { m.insert({aval.first, aval.second.proto}); } - return InstantiateFunction(fdef, m, get_function, result); + return InstantiateFunction(fdef, m, std::move(get_function), result); } string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs) { diff --git a/tensorflow/core/framework/graph_def_util_test.cc b/tensorflow/core/framework/graph_def_util_test.cc index 8c76a74a4a..1ac322e48e 100644 --- a/tensorflow/core/framework/graph_def_util_test.cc +++ b/tensorflow/core/framework/graph_def_util_test.cc @@ -28,7 +28,7 @@ limitations under the License. namespace tensorflow { namespace { -Status FinalizeOpDef(OpDefBuilder b, OpDef* op_def) { +Status FinalizeOpDef(const OpDefBuilder& b, OpDef* op_def) { OpRegistrationData op_reg_data; const Status s = b.Finalize(&op_reg_data); *op_def = op_reg_data.op_def; diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc index 14d8d91490..c1dde1504a 100644 --- a/tensorflow/core/framework/memory_types.cc +++ b/tensorflow/core/framework/memory_types.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/framework/memory_types.h" +#include <utility> + #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" @@ -64,7 +66,7 @@ MemoryType MTypeFromDType(const DataType dtype) { } // namespace Status MemoryTypesForNode(const OpRegistryInterface* op_registry, - DeviceType device_type, const NodeDef& ndef, + const DeviceType& device_type, const NodeDef& ndef, MemoryTypeVector* inp_mtypes, MemoryTypeVector* out_mtypes) { // Look up the Op registered for this op name. diff --git a/tensorflow/core/framework/memory_types.h b/tensorflow/core/framework/memory_types.h index 3d4ca7597a..e35e22f590 100644 --- a/tensorflow/core/framework/memory_types.h +++ b/tensorflow/core/framework/memory_types.h @@ -28,7 +28,7 @@ namespace tensorflow { // REQUIRES: * '*_memory_types' is not nullptr. // * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). Status MemoryTypesForNode(const OpRegistryInterface* op_registry, - DeviceType device_type, const NodeDef& ndef, + const DeviceType& device_type, const NodeDef& ndef, MemoryTypeVector* input_memory_types, MemoryTypeVector* output_memory_types); diff --git a/tensorflow/core/framework/op_def_builder_test.cc b/tensorflow/core/framework/op_def_builder_test.cc index a6ffd5c596..bde5bb2c39 100644 --- a/tensorflow/core/framework/op_def_builder_test.cc +++ b/tensorflow/core/framework/op_def_builder_test.cc @@ -73,7 +73,7 @@ class OpDefBuilderTest : public ::testing::Test { } } - void ExpectFailure(const OpDefBuilder& builder, string error) { + void ExpectFailure(const OpDefBuilder& builder, const string& error) { OpRegistrationData op_reg_data; Status status = builder.Finalize(&op_reg_data); EXPECT_FALSE(status.ok()); diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index 6fad379b76..b53daeed0b 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include <unordered_map> +#include <utility> #include <vector> #include "tensorflow/core/framework/attr_value_util.h" @@ -807,7 +808,7 @@ static KernelRegistry* GlobalKernelRegistryTyped() { return reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry()); } -static string Key(StringPiece op_type, DeviceType device_type, +static string Key(StringPiece op_type, const DeviceType& device_type, StringPiece label) { return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":", label); @@ -892,7 +893,8 @@ Status AttrsMatch(const NodeDef& node_def, const KernelDef& kernel_def, return Status::OK(); } -Status FindKernelRegistration(DeviceType device_type, const NodeDef& node_def, +Status FindKernelRegistration(const DeviceType& device_type, + const NodeDef& node_def, const KernelRegistration** reg, bool* was_attr_mismatch) { *reg = nullptr; @@ -924,7 +926,7 @@ Status FindKernelRegistration(DeviceType device_type, const NodeDef& node_def, } // namespace -Status FindKernelDef(DeviceType device_type, const NodeDef& node_def, +Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def, const KernelDef** def, string* kernel_class_name) { const KernelRegistration* reg = nullptr; bool was_attr_mismatch; @@ -1006,8 +1008,8 @@ std::unique_ptr<OpKernel> CreateOpKernel( DeviceType device_type, DeviceBase* device, Allocator* allocator, const NodeDef& node_def, int graph_def_version, Status* status) { OpKernel* kernel = nullptr; - *status = CreateOpKernel(device_type, device, allocator, nullptr, node_def, - graph_def_version, &kernel); + *status = CreateOpKernel(std::move(device_type), device, allocator, nullptr, + node_def, graph_def_version, &kernel); return std::unique_ptr<OpKernel>(kernel); } diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index d926d7db19..969142ca23 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -18,6 +18,7 @@ limitations under the License. #include <functional> +#include <utility> #include <vector> #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/cancellation.h" @@ -227,7 +228,7 @@ class OpKernelConstruction { const DataTypeSlice& output_types, const MemoryTypeSlice& output_memory_types, int graph_def_version, Status* status) - : device_type_(device_type), + : device_type_(std::move(device_type)), device_(device), allocator_(allocator), def_(node_def), @@ -1253,7 +1254,7 @@ void* GlobalKernelRegistry(); // If node_def has a corresponding kernel registered on device_type, // returns OK and fill in the kernel def and kernel_class_name. <def> and // <kernel_class_name> may be null. -Status FindKernelDef(DeviceType device_type, const NodeDef& node_def, +Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def, const KernelDef** def, string* kernel_class_name); // Writes a list of all registered kernels to LOG(INFO), to help users debug diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index 46d4dbd86a..e8e931b52e 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include <memory> +#include <utility> #include <vector> #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/attr_value_util.h" @@ -133,8 +134,8 @@ class OpKernelTest : public ::testing::Test { const DataTypeVector& outputs) { Status status; std::unique_ptr<OpKernel> op(CreateOpKernel( - device_type, &device_, cpu_allocator(), CreateNodeDef(op_type, inputs), - TF_GRAPH_DEF_VERSION, &status)); + std::move(device_type), &device_, cpu_allocator(), + CreateNodeDef(op_type, inputs), TF_GRAPH_DEF_VERSION, &status)); EXPECT_TRUE(status.ok()) << status; EXPECT_TRUE(op != nullptr); if (op != nullptr) { @@ -148,9 +149,9 @@ class OpKernelTest : public ::testing::Test { NodeDef node_def; protobuf::TextFormat::ParseFromString(ascii_node_def, &node_def); Status status; - std::unique_ptr<OpKernel> op(CreateOpKernel(device_type, &device_, - cpu_allocator(), node_def, - TF_GRAPH_DEF_VERSION, &status)); + std::unique_ptr<OpKernel> op( + CreateOpKernel(std::move(device_type), &device_, cpu_allocator(), + node_def, TF_GRAPH_DEF_VERSION, &status)); EXPECT_TRUE(op == nullptr); EXPECT_FALSE(status.ok()); if (!status.ok()) { @@ -384,7 +385,7 @@ class OpKernelBuilderTest : public ::testing::Test { } std::unique_ptr<OpKernel> ExpectSuccess(const string& op_type, - DeviceType device_type, + const DeviceType& device_type, const std::vector<string>& attrs, DataTypeSlice input_types = {}) { Status status; @@ -423,7 +424,7 @@ class OpKernelBuilderTest : public ::testing::Test { return op; } - void ExpectFailure(const string& op_type, DeviceType device_type, + void ExpectFailure(const string& op_type, const DeviceType& device_type, const std::vector<string>& attrs, error::Code code) { Status status; const NodeDef def = CreateNodeDef(op_type, attrs); diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc index c907bbb69f..424b4579d3 100644 --- a/tensorflow/core/framework/tensor_test.cc +++ b/tensorflow/core/framework/tensor_test.cc @@ -810,7 +810,8 @@ TEST(Tensor, Slice_Basic) { namespace { template <typename T> -Tensor MkTensor(DataType dt, TensorShape shape, std::vector<T> init_values) { +Tensor MkTensor(DataType dt, const TensorShape& shape, + std::vector<T> init_values) { Tensor x(dt, shape); const int limit = x.NumElements(); int vi = 0; diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc index a374f848a1..dc396e468a 100644 --- a/tensorflow/core/framework/types.cc +++ b/tensorflow/core/framework/types.cc @@ -169,7 +169,9 @@ bool DataTypeFromString(StringPiece sp, DataType* dt) { return false; } -string DeviceTypeString(DeviceType device_type) { return device_type.type(); } +string DeviceTypeString(const DeviceType& device_type) { + return device_type.type(); +} string DataTypeSliceString(const DataTypeSlice types) { string out; diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h index 932d788f23..0a81b1cb9f 100644 --- a/tensorflow/core/framework/types.h +++ b/tensorflow/core/framework/types.h @@ -82,7 +82,7 @@ typedef gtl::InlinedVector<DeviceType, 4> DeviceTypeVector; // Convert the enums to strings for errors: string DataTypeString(DataType dtype); -string DeviceTypeString(DeviceType device_type); +string DeviceTypeString(const DeviceType& device_type); string DataTypeSliceString(const DataTypeSlice dtypes); inline string DataTypeVectorString(const DataTypeVector& dtypes) { return DataTypeSliceString(dtypes); diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc index 502b7b26da..4afc878f76 100644 --- a/tensorflow/core/graph/graph_test.cc +++ b/tensorflow/core/graph/graph_test.cc @@ -51,8 +51,8 @@ class GraphTest : public ::testing::Test { GraphTest() : graph_(OpRegistry::Global()) {} ~GraphTest() override {} - static void VerifyNodes(Node* node, std::vector<Node*> expected_in, - std::vector<Node*> expected_out) { + static void VerifyNodes(Node* node, const std::vector<Node*>& expected_in, + const std::vector<Node*>& expected_out) { std::vector<Node*> in; for (const Edge* e : node->in_edges()) { in.push_back(e->src()); diff --git a/tensorflow/core/graph/optimizer_cse.cc b/tensorflow/core/graph/optimizer_cse.cc index 59dff60ea3..a679eac0e7 100644 --- a/tensorflow/core/graph/optimizer_cse.cc +++ b/tensorflow/core/graph/optimizer_cse.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/core/graph/optimizer_cse.h" #include <unordered_map> +#include <utility> #include <vector> #include "tensorflow/core/graph/algorithm.h" @@ -52,7 +53,7 @@ class OptimizerCSE { public: explicit OptimizerCSE(Graph* g) : g_(g) {} - bool Optimize(std::function<bool(const Node*)> consider_fn); + bool Optimize(const std::function<bool(const Node*)>& consider_fn); private: struct Scratch; @@ -180,7 +181,8 @@ bool OptimizerCSE::Equivalent(const Node* a, const Node* b, Scratch* scratch) { return true; } -bool OptimizerCSE::Optimize(std::function<bool(const Node*)> consider_fn) { +bool OptimizerCSE::Optimize( + const std::function<bool(const Node*)>& consider_fn) { // This very simple implementation works if the whole graph is one // giant basic block (because we just traverse nodes in a // topological order). This simple implementation works well @@ -232,7 +234,8 @@ bool OptimizerCSE::Optimize(std::function<bool(const Node*)> consider_fn) { return changed; } -bool OptimizeCSE(Graph* g, std::function<bool(const Node*)> consider_fn) { +bool OptimizeCSE(Graph* g, + const std::function<bool(const Node*)>& consider_fn) { OptimizerCSE opt(g); return opt.Optimize(consider_fn); } diff --git a/tensorflow/core/graph/optimizer_cse.h b/tensorflow/core/graph/optimizer_cse.h index 24ec5658d8..b8f3230c70 100644 --- a/tensorflow/core/graph/optimizer_cse.h +++ b/tensorflow/core/graph/optimizer_cse.h @@ -29,7 +29,8 @@ namespace tensorflow { // during the common subexpression elimination. // // Returns true if and only if 'g' is mutated. -extern bool OptimizeCSE(Graph* g, std::function<bool(const Node*)> consider_fn); +extern bool OptimizeCSE(Graph* g, + const std::function<bool(const Node*)>& consider_fn); } // namespace tensorflow diff --git a/tensorflow/core/graph/optimizer_cse_test.cc b/tensorflow/core/graph/optimizer_cse_test.cc index 1091af4e45..94250240eb 100644 --- a/tensorflow/core/graph/optimizer_cse_test.cc +++ b/tensorflow/core/graph/optimizer_cse_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/graph/optimizer_cse.h" +#include <utility> #include <vector> #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/core/kernels/barrier_ops.cc b/tensorflow/core/kernels/barrier_ops.cc index 03880b9827..83633a1dd9 100644 --- a/tensorflow/core/kernels/barrier_ops.cc +++ b/tensorflow/core/kernels/barrier_ops.cc @@ -88,7 +88,7 @@ class Barrier : public ResourceBase { template <typename T> void TryInsertMany(const Tensor& keys, int component_index, const Tensor& values, OpKernelContext* ctx, - DoneCallback callback) { + const DoneCallback& callback) { TensorShape element_shape = values.shape(); OP_REQUIRES_ASYNC( ctx, keys.NumElements() == 0 || element_shape.num_elements() > 0, @@ -195,7 +195,8 @@ class Barrier : public ResourceBase { } void TryTakeMany(int num_elements, bool allow_small_batch, int64 timeout, - OpKernelContext* ctx, IndicesKeysValuesCallback callback) { + OpKernelContext* ctx, + const IndicesKeysValuesCallback& callback) { int num_elements_to_deliver = num_elements; { mutex_lock lock(mu_); @@ -247,7 +248,7 @@ class Barrier : public ResourceBase { } void Close(OpKernelContext* ctx, bool cancel_pending_enqueues, - DoneCallback callback) { + const DoneCallback& callback) { mutex_lock lock(mu_); // We're allowed to close twice if the first close wasn't a // cancel but the second one is. @@ -399,7 +400,8 @@ class Barrier : public ResourceBase { } void CloseQueueLocked(OpKernelContext* ctx, bool cancel_pending_enqueues, - DoneCallback callback) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + const DoneCallback& callback) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { // CloseQueueLocked may only be called with mu_ held. if (!cancel_pending_enqueues && queue_closed_) { callback(); diff --git a/tensorflow/core/kernels/conv_ops_fused.cc b/tensorflow/core/kernels/conv_ops_fused.cc index 219e6d5e97..f7348f1077 100644 --- a/tensorflow/core/kernels/conv_ops_fused.cc +++ b/tensorflow/core/kernels/conv_ops_fused.cc @@ -74,8 +74,9 @@ enum SamplingMode { // my_vector[current] *= 10.0f; // } // }); -void FusedConvParallelFor(OpKernelContext* context, int64 begin, int64 end, - std::function<void(int64, int64)> task_function) { +void FusedConvParallelFor( + OpKernelContext* context, int64 begin, int64 end, + const std::function<void(int64, int64)>& task_function) { // On iOS, the thread management imposes a very big performance penalty, so // just call the function directly with no multithreading. #if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM) diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc index b122e7f0e8..cd9aa4a53e 100644 --- a/tensorflow/core/kernels/conv_ops_test.cc +++ b/tensorflow/core/kernels/conv_ops_test.cc @@ -116,8 +116,9 @@ class FusedResizePadConvOpTest : public OpsTestBase { int input_depth, int resize_width, int resize_height, int y_padding, int x_padding, int filter_size, int filter_count, - bool resize_align_corners, string pad_mode, - int stride, string padding) { + bool resize_align_corners, + const string& pad_mode, int stride, + const string& padding) { auto root = tensorflow::Scope::NewRootScope(); using namespace ::tensorflow::ops; // NOLINT(build/namespaces) @@ -170,8 +171,8 @@ class FusedResizePadConvOpTest : public OpsTestBase { void CompareFusedPadOnlyAndSeparate(int input_width, int input_height, int input_depth, int y_padding, int x_padding, int filter_size, - int filter_count, string pad_mode, - int stride, string padding) { + int filter_count, const string& pad_mode, + int stride, const string& padding) { auto root = tensorflow::Scope::NewRootScope(); using namespace ::tensorflow::ops; // NOLINT(build/namespaces) diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc index 6a748d3462..c5470f81eb 100644 --- a/tensorflow/core/kernels/crop_and_resize_op.cc +++ b/tensorflow/core/kernels/crop_and_resize_op.cc @@ -79,13 +79,13 @@ static inline Status ParseAndCheckBoxSizes(const Tensor& boxes, template <typename Device> inline void RunIfBoxIndexIsValid( OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index, - int batch_size, Callback compute, Callback done); + int batch_size, const Callback& compute, const Callback& done); // Specialization of CheckValidBoxIndex for a CPUDevice. template <> inline void RunIfBoxIndexIsValid<CPUDevice>( OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index, - int batch_size, Callback compute, Callback done) { + int batch_size, const Callback& compute, const Callback& done) { const int num_boxes = box_index.dimension(0); for (int b = 0; b < num_boxes; ++b) { OP_REQUIRES_ASYNC( @@ -690,7 +690,7 @@ namespace { template <> inline void RunIfBoxIndexIsValid<GPUDevice>( OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index, - int batch_size, Callback compute, Callback done) { + int batch_size, const Callback& compute, const Callback& done) { const int num_boxes = box_index.dimension(0); if (num_boxes == 0) { compute(); diff --git a/tensorflow/core/kernels/debug_ops_test.cc b/tensorflow/core/kernels/debug_ops_test.cc index 495db92ef6..487f045cc8 100644 --- a/tensorflow/core/kernels/debug_ops_test.cc +++ b/tensorflow/core/kernels/debug_ops_test.cc @@ -36,7 +36,7 @@ namespace tensorflow { class DebugIdentityOpTest : public OpsTestBase { protected: - Status Init(DataType input_type, const std::vector<string> debug_urls) { + Status Init(DataType input_type, const std::vector<string>& debug_urls) { env_ = Env::Default(); TF_CHECK_OK(NodeDefBuilder("op", "DebugIdentity") diff --git a/tensorflow/core/kernels/reverse_op_test.cc b/tensorflow/core/kernels/reverse_op_test.cc index 19e25b887d..c6193f378d 100644 --- a/tensorflow/core/kernels/reverse_op_test.cc +++ b/tensorflow/core/kernels/reverse_op_test.cc @@ -120,7 +120,7 @@ static SessionOptions GetOptions(int intra_threads) { // Creates a Graph which "reduce"s a 3D float tensor of "num" elements // into a scalar. -static Graph* Reverse(TensorShape shape, int reverse_axis) { +static Graph* Reverse(const TensorShape& shape, int reverse_axis) { Graph* g = new Graph(OpRegistry::Global()); Tensor data(DT_FLOAT, shape); data.flat<float>().setRandom(); diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc index 006ef988b5..80d4901740 100644 --- a/tensorflow/core/kernels/save_restore_tensor.cc +++ b/tensorflow/core/kernels/save_restore_tensor.cc @@ -15,6 +15,7 @@ limitations under the License. #include <unordered_map> +#include <utility> #include <vector> #include "tensorflow/core/kernels/save_restore_tensor.h" @@ -79,7 +80,7 @@ void SaveTensors( VLOG(1) << "About to save tensors to file " << filename_t.flat<string>()(0) << "..."; checkpoint::TensorSliceWriter writer(filename_t.flat<string>()(0), - builder_func); + std::move(builder_func)); Status s; auto tensor_names_flat = tensor_names_t.flat<string>(); diff --git a/tensorflow/core/kernels/segment_reduction_ops_test.cc b/tensorflow/core/kernels/segment_reduction_ops_test.cc index 0a281835a4..bdf3c12ff9 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_test.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_test.cc @@ -40,8 +40,9 @@ limitations under the License. namespace tensorflow { template <typename Index> -static void BM_SegmentReduction(int iters, string reduction, Index num_rows, - Index num_cols, Index segment_size) { +static void BM_SegmentReduction(int iters, const string& reduction, + Index num_rows, Index num_cols, + Index segment_size) { testing::StopTiming(); std::unique_ptr<Device> device( DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc index 2b10ebeaf7..c8e514df80 100644 --- a/tensorflow/core/lib/core/threadpool.cc +++ b/tensorflow/core/lib/core/threadpool.cc @@ -47,7 +47,7 @@ struct EigenEnvironment { const string& name) : env_(env), thread_options_(thread_options), name_(name) {} - EnvThread* CreateThread(std::function<void()> f) { + EnvThread* CreateThread(const std::function<void()>& f) { return env_->StartThread(thread_options_, name_, [=]() { // Set the processor flag to flush denormals to zero. port::ScopedFlushDenormal flush; diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.cc b/tensorflow/core/lib/jpeg/jpeg_mem.cc index f9846968af..e27904ea12 100644 --- a/tensorflow/core/lib/jpeg/jpeg_mem.cc +++ b/tensorflow/core/lib/jpeg/jpeg_mem.cc @@ -337,7 +337,8 @@ uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) { uint8* Uncompress(const void* srcdata, int datasize, const UncompressFlags& flags, int64* nwarn, std::function<uint8*(int, int, int)> allocate_output) { - FewerArgsForCompiler argball(datasize, flags, nwarn, allocate_output); + FewerArgsForCompiler argball(datasize, flags, nwarn, + std::move(allocate_output)); uint8* const dstdata = UncompressLow(srcdata, &argball); const float fraction_read = diff --git a/tensorflow/core/lib/random/random_distributions_test.cc b/tensorflow/core/lib/random/random_distributions_test.cc index 531ed78109..28ff5bf6e8 100644 --- a/tensorflow/core/lib/random/random_distributions_test.cc +++ b/tensorflow/core/lib/random/random_distributions_test.cc @@ -70,7 +70,7 @@ void FillRandomsWithSingles(PhiloxRandom gen, // z_limit: the maximum z-test we would consider the test to pass; template <typename T> bool CheckSamplesMoments(const std::vector<T>& samples, - std::function<double(int)> theoretical_moments, + const std::function<double(int)>& theoretical_moments, int max_moments, int stride, T z_limit) { const T* const samples_data = &samples[0]; const int samples_size = samples.size(); diff --git a/tensorflow/core/ops/training_ops_test.cc b/tensorflow/core/ops/training_ops_test.cc index 9c3489211c..da66fbe4ba 100644 --- a/tensorflow/core/ops/training_ops_test.cc +++ b/tensorflow/core/ops/training_ops_test.cc @@ -21,9 +21,9 @@ limitations under the License. namespace tensorflow { // Used for testing the grad+indices handling for SparseApplyXYZ tests. -static void TestGradAndIndicesErrorHandling(ShapeInferenceTestOp op, +static void TestGradAndIndicesErrorHandling(const ShapeInferenceTestOp& op, string shape_spec_middle, - string shape_spec_end = "") { + const string& shape_spec_end = "") { auto shape_spec = [&shape_spec_middle, shape_spec_end]( const char* var_spec, const char* grad_indices_spec) { return strings::StrCat(var_spec, ";", shape_spec_middle, ";", diff --git a/tensorflow/core/platform/cloud/oauth_client.cc b/tensorflow/core/platform/cloud/oauth_client.cc index a3d5b9a6e4..97d6617a04 100644 --- a/tensorflow/core/platform/cloud/oauth_client.cc +++ b/tensorflow/core/platform/cloud/oauth_client.cc @@ -43,7 +43,8 @@ constexpr char kJwtType[] = "JWT"; constexpr char kGrantType[] = "urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer"; -Status ReadJsonValue(Json::Value json, const string& name, Json::Value* value) { +Status ReadJsonValue(const Json::Value& json, const string& name, + Json::Value* value) { if (!value) { return errors::FailedPrecondition("'value' cannot be nullptr."); } @@ -55,7 +56,8 @@ Status ReadJsonValue(Json::Value json, const string& name, Json::Value* value) { return Status::OK(); } -Status ReadJsonString(Json::Value json, const string& name, string* value) { +Status ReadJsonString(const Json::Value& json, const string& name, + string* value) { Json::Value json_value; TF_RETURN_IF_ERROR(ReadJsonValue(json, name, &json_value)); if (!json_value.isString()) { @@ -66,7 +68,7 @@ Status ReadJsonString(Json::Value json, const string& name, string* value) { return Status::OK(); } -Status ReadJsonInt(Json::Value json, const string& name, int64* value) { +Status ReadJsonInt(const Json::Value& json, const string& name, int64* value) { Json::Value json_value; TF_RETURN_IF_ERROR(ReadJsonValue(json, name, &json_value)); if (!json_value.isIntegral()) { diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc index d729963616..2fdd989c9b 100644 --- a/tensorflow/core/platform/env.cc +++ b/tensorflow/core/platform/env.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include <deque> +#include <utility> #include <vector> #if defined(__APPLE__) #include <mach-o/dyld.h> @@ -95,7 +96,7 @@ Status Env::GetRegisteredFileSystemSchemes(std::vector<string>* schemes) { Status Env::RegisterFileSystem(const string& scheme, FileSystemRegistry::Factory factory) { - return file_system_registry_->Register(scheme, factory); + return file_system_registry_->Register(scheme, std::move(factory)); } Status Env::NewRandomAccessFile(const string& fname, diff --git a/tensorflow/core/platform/file_system.cc b/tensorflow/core/platform/file_system.cc index 3d7553e6da..2abda45714 100644 --- a/tensorflow/core/platform/file_system.cc +++ b/tensorflow/core/platform/file_system.cc @@ -37,7 +37,7 @@ constexpr int kNumThreads = 8; // Run a function in parallel using a ThreadPool, but skip the ThreadPool // on the iOS platform due to its problems with more than a few threads. -void ForEach(int first, int last, std::function<void(int)> f) { +void ForEach(int first, int last, const std::function<void(int)>& f) { #if TARGET_OS_IPHONE for (int i = first; i < last; i++) { f(i); diff --git a/tensorflow/core/util/tensor_slice_reader.cc b/tensorflow/core/util/tensor_slice_reader.cc index e750b130b9..cd49034719 100644 --- a/tensorflow/core/util/tensor_slice_reader.cc +++ b/tensorflow/core/util/tensor_slice_reader.cc @@ -102,7 +102,8 @@ TensorSliceReader::TensorSliceReader(const string& filepattern) TensorSliceReader::TensorSliceReader(const string& filepattern, OpenTableFunction open_function) - : TensorSliceReader(filepattern, open_function, kLoadAllShards) {} + : TensorSliceReader(filepattern, std::move(open_function), kLoadAllShards) { +} TensorSliceReader::TensorSliceReader(const string& filepattern, OpenTableFunction open_function, diff --git a/tensorflow/core/util/tensor_slice_reader_cache.cc b/tensorflow/core/util/tensor_slice_reader_cache.cc index cbd2922f54..0f009d7de5 100644 --- a/tensorflow/core/util/tensor_slice_reader_cache.cc +++ b/tensorflow/core/util/tensor_slice_reader_cache.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/util/tensor_slice_reader_cache.h" +#include <utility> + #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/platform/logging.h" @@ -36,7 +38,8 @@ const TensorSliceReader* TensorSliceReaderCacheWrapper::GetReader( if (!cache_) { cache_ = new TensorSliceReaderCache; } - return cache_->GetReader(filepattern, open_function, preferred_shard); + return cache_->GetReader(filepattern, std::move(open_function), + preferred_shard); } TensorSliceReaderCache::TensorSliceReaderCache() {} diff --git a/tensorflow/core/util/tensor_slice_reader_test.cc b/tensorflow/core/util/tensor_slice_reader_test.cc index 8545697886..f4859262e1 100644 --- a/tensorflow/core/util/tensor_slice_reader_test.cc +++ b/tensorflow/core/util/tensor_slice_reader_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <utility> + #include "tensorflow/core/util/tensor_slice_reader.h" #include "tensorflow/core/framework/types.h" @@ -48,8 +50,9 @@ namespace { // // We assume this is a row-major matrix. -void SimpleFloatHelper(TensorSliceWriter::CreateBuilderFunction create_function, - TensorSliceReader::OpenTableFunction open_function) { +void SimpleFloatHelper( + const TensorSliceWriter::CreateBuilderFunction& create_function, + TensorSliceReader::OpenTableFunction open_function) { const string fname_base = io::JoinPath(testing::TmpDir(), "float_checkpoint"); TensorShape shape({4, 5}); @@ -108,7 +111,7 @@ void SimpleFloatHelper(TensorSliceWriter::CreateBuilderFunction create_function, // Now we need to read the tensor slices const string filepattern = strings::StrCat(fname_base, "_*"); - TensorSliceReader reader(filepattern, open_function); + TensorSliceReader reader(filepattern, std::move(open_function)); TF_EXPECT_OK(reader.status()); EXPECT_EQ(2, reader.num_files()); @@ -171,9 +174,10 @@ TEST(TensorSliceReaderTest, SimpleFloat) { } template <typename T, typename U> -void SimpleIntXHelper(TensorSliceWriter::CreateBuilderFunction create_function, - TensorSliceReader::OpenTableFunction open_function, - const string& checkpoint_file) { +void SimpleIntXHelper( + const TensorSliceWriter::CreateBuilderFunction& create_function, + TensorSliceReader::OpenTableFunction open_function, + const string& checkpoint_file) { const string fname_base = io::JoinPath(testing::TmpDir(), checkpoint_file); TensorShape shape({4, 5}); @@ -232,7 +236,7 @@ void SimpleIntXHelper(TensorSliceWriter::CreateBuilderFunction create_function, // Now we need to read the tensor slices const string filepattern = strings::StrCat(fname_base, "_*"); - TensorSliceReader reader(filepattern, open_function); + TensorSliceReader reader(filepattern, std::move(open_function)); TF_EXPECT_OK(reader.status()); EXPECT_EQ(2, reader.num_files()); @@ -304,8 +308,8 @@ TEST_SIMPLE_INT(int8, int32) TEST_SIMPLE_INT(uint8, int32) void CachedTensorSliceReaderTesterHelper( - TensorSliceWriter::CreateBuilderFunction create_function, - TensorSliceReader::OpenTableFunction open_function) { + const TensorSliceWriter::CreateBuilderFunction& create_function, + const TensorSliceReader::OpenTableFunction& open_function) { const string fname_base = io::JoinPath(testing::TmpDir(), "float_checkpoint"); TensorShape shape({4, 5}); |