aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-03-24 08:44:27 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-24 10:02:49 -0700
commit821920c1f25968f5dfcd2f8999b293ebedf85957 (patch)
treeb4dce0285bac385630f658582bf27558f80cac51
parent5b5b8412f0684a548e1e9001421e5d095cda0142 (diff)
Several changes to reduce TensorFlow code size (especially important
for mobile apps): (1) Change many interfaces in node_def_builder.h, node_def_util.h, op_kernel.h, node_builder.h, mirror_pad_mode.h, padding.h to use 'StringPiece', rather than 'const string&'. The interfaces that were changed tend to be heavily used in the registration of ops and kernels, and often caused extra string construction code to be emitted in the macro expansion of each op or kernel registration. (2) Move some repetitive CHECK operations into non-inlined routines in tensor.cc, rather than having them in inlined or templated routines in tensor.h (new Tensor::CheckDataType, Tensor::CheckTypeAndIsAligned, and Tensor::CheckIsAlignedAndSingleElement routines) (3) Factored out internal template<size_t NDIMS> Tensor::FillDimsAndValidateCompatibleShape routine, to be shared across more specialized templated routines (typically specialized on both DataType and NDIMS). (4) Added new non-inlined TensorShape::CheckDimsMatch(int NDIMS) routine in tensor_shape.cc, that can be called from various TensorShape routines templated on NDIMS. (5) Don't inline single-argument StrCat, since it involves a string creation, etc. (6) Remove inline keyword from template <typename... AV> StrCat version that handles 5 or more arguments. Reduces text size for third_party/tensorflow/core/libandroid_tensorflow_lib.so built in Google build environment by 1.43%, as measured by: % blaze build -c opt --config=android_arm \ third_party/tensorflow/core:android_tensorflow_lib % size blaze-bin/third_party/tensorflow/core/libandroid_tensorflow_lib.so Change: 118036659
-rw-r--r--tensorflow/core/framework/node_def_builder.cc49
-rw-r--r--tensorflow/core/framework/node_def_builder.h50
-rw-r--r--tensorflow/core/framework/node_def_util.cc16
-rw-r--r--tensorflow/core/framework/node_def_util.h60
-rw-r--r--tensorflow/core/framework/op_kernel.cc40
-rw-r--r--tensorflow/core/framework/op_kernel.h34
-rw-r--r--tensorflow/core/framework/tensor.cc17
-rw-r--r--tensorflow/core/framework/tensor.h55
-rw-r--r--tensorflow/core/framework/tensor_shape.cc10
-rw-r--r--tensorflow/core/framework/tensor_shape.h9
-rw-r--r--tensorflow/core/graph/node_builder.cc10
-rw-r--r--tensorflow/core/graph/node_builder.h17
-rw-r--r--tensorflow/core/lib/strings/strcat.cc2
-rw-r--r--tensorflow/core/lib/strings/strcat.h7
-rw-r--r--tensorflow/core/util/mirror_pad_mode.cc2
-rw-r--r--tensorflow/core/util/mirror_pad_mode.h2
-rw-r--r--tensorflow/core/util/padding.cc2
-rw-r--r--tensorflow/core/util/padding.h2
18 files changed, 218 insertions, 166 deletions
diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc
index b6f5838528..f3091ad286 100644
--- a/tensorflow/core/framework/node_def_builder.cc
+++ b/tensorflow/core/framework/node_def_builder.cc
@@ -22,11 +22,24 @@ limitations under the License.
namespace tensorflow {
-NodeDefBuilder::NodeDefBuilder(const string& name, const string& op_name,
+NodeDefBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType dt)
+ : node(n.ToString()), index(i), data_type(dt) {}
+
+NodeDefBuilder::NodeOut::NodeOut() {
+ // uninitialized, call Reset() before use.
+}
+
+void NodeDefBuilder::NodeOut::Reset(StringPiece n, int i, DataType dt) {
+ node = n.ToString();
+ index = i;
+ data_type = dt;
+}
+
+NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
const OpRegistryInterface* op_registry) {
- node_def_.set_name(name);
+ node_def_.set_name(name.ToString());
Status status;
- op_def_ = op_registry->LookUp(op_name, &status);
+ op_def_ = op_registry->LookUp(op_name.ToString(), &status);
if (op_def_ == nullptr) {
errors_.push_back(status.error_message());
inputs_specified_ = 0;
@@ -35,9 +48,9 @@ NodeDefBuilder::NodeDefBuilder(const string& name, const string& op_name,
}
}
-NodeDefBuilder::NodeDefBuilder(const string& name, const OpDef* op_def)
+NodeDefBuilder::NodeDefBuilder(StringPiece name, const OpDef* op_def)
: op_def_(op_def) {
- node_def_.set_name(name);
+ node_def_.set_name(name.ToString());
Initialize();
}
@@ -72,7 +85,7 @@ NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) {
}
void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg,
- const string& src_node, int src_index,
+ StringPiece src_node, int src_index,
DataType dt) {
AddInput(src_node, src_index);
@@ -129,7 +142,7 @@ void NodeDefBuilder::ListInput(const OpDef::ArgDef* input_arg,
}
}
-void NodeDefBuilder::AddInput(const string& src_node, int src_index) {
+void NodeDefBuilder::AddInput(StringPiece src_node, int src_index) {
if (src_node.empty()) {
errors_.push_back("Empty input node name");
} else if (src_node[0] == '^') {
@@ -138,7 +151,7 @@ void NodeDefBuilder::AddInput(const string& src_node, int src_index) {
} else if (src_index > 0) {
node_def_.add_input(strings::StrCat(src_node, ":", src_index));
} else {
- node_def_.add_input(src_node);
+ node_def_.add_input(src_node.ToString());
}
}
@@ -160,6 +173,16 @@ void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg,
}
}
+NodeDefBuilder& NodeDefBuilder::ControlInput(StringPiece src_node) {
+ control_inputs_.push_back(src_node.ToString());
+ return *this;
+}
+
+NodeDefBuilder& NodeDefBuilder::Device(StringPiece device_spec) {
+ node_def_.set_device(device_spec.ToString());
+ return *this;
+}
+
Status NodeDefBuilder::Finalize(NodeDef* node_def) const {
const std::vector<string>* errors_ptr = &errors_;
std::vector<string> errors_storage;
@@ -206,4 +229,14 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def) const {
}
}
+void NodeDefBuilder::CheckInconsistency(StringPiece attr_name,
+ const AttrValue& found,
+ const AttrValue& attr_value) {
+ if (!AreAttrValuesEqual(found, attr_value)) {
+ errors_.push_back(strings::StrCat(
+ "Inconsistent values for attr '", attr_name, "' ",
+ SummarizeAttrValue(found), " vs. ", SummarizeAttrValue(attr_value)));
+ }
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/node_def_builder.h b/tensorflow/core/framework/node_def_builder.h
index 72ad37ab93..f8b3a8441a 100644
--- a/tensorflow/core/framework/node_def_builder.h
+++ b/tensorflow/core/framework/node_def_builder.h
@@ -32,7 +32,8 @@ namespace tensorflow {
class NodeDefBuilder;
typedef std::function<Status(const OpDef&, int, const NodeDef&,
- NodeDefBuilder*)> FakeInputFunctor;
+ NodeDefBuilder*)>
+ FakeInputFunctor;
// This is a helper for creating a NodeDef. Automatically sets attrs
// that can be inferred from the inputs, and uses default values
@@ -49,14 +50,9 @@ class NodeDefBuilder {
public:
// To specify an output to be consumed by one of the Input() methods below.
struct NodeOut {
- NodeOut(const string& n, int i, DataType dt)
- : node(n), index(i), data_type(dt) {}
- NodeOut() {} // uninitialized, call Reset() before use.
- void Reset(const string& n, int i, DataType dt) {
- node = n;
- index = i;
- data_type = dt;
- }
+ NodeOut(StringPiece n, int i, DataType dt);
+ NodeOut(); // uninitialized, call Reset() before use.
+ void Reset(StringPiece n, int i, DataType dt);
string node;
int index;
DataType data_type;
@@ -66,16 +62,16 @@ class NodeDefBuilder {
// the Op plus a registry) for the NodeDef. Other fields are
// specified by calling the methods below.
// REQUIRES: The OpDef must satisfy ValidateOpDef().
- NodeDefBuilder(const string& name, const string& op_name,
+ NodeDefBuilder(StringPiece name, StringPiece op_name,
const OpRegistryInterface* op_registry = OpRegistry::Global());
// REQUIRES: in addition, *op_def must outlive *this.
- NodeDefBuilder(const string& name, const OpDef* op_def);
+ NodeDefBuilder(StringPiece name, const OpDef* op_def);
// You must call one Input() function per input_arg in the Op,
// *and in the same order as the input_args appear in the OpDef.*
// For inputs that take a single tensor.
- NodeDefBuilder& Input(const string& src_node, int src_index, DataType dt) {
+ NodeDefBuilder& Input(StringPiece src_node, int src_index, DataType dt) {
const OpDef::ArgDef* arg = NextArgDef();
if (arg != nullptr) SingleInput(arg, src_node, src_index, dt);
return *this;
@@ -96,25 +92,18 @@ class NodeDefBuilder {
NodeDefBuilder& Input(FakeInputFunctor fake_input);
// Specify that this node must only run after src_node.
- NodeDefBuilder& ControlInput(const string& src_node) {
- control_inputs_.push_back(src_node);
- return *this;
- }
+ NodeDefBuilder& ControlInput(StringPiece src_node);
// Constrains what devices this node may be scheduled on.
- NodeDefBuilder& Device(const string& device_spec) {
- node_def_.set_device(device_spec);
- return *this;
- }
+ NodeDefBuilder& Device(StringPiece device_spec);
// Sets the attr, if not already set. If already set with a different
// value, an error will be returned from Finalize().
template <class T>
- NodeDefBuilder& Attr(const string& attr_name, T&& value);
+ NodeDefBuilder& Attr(StringPiece attr_name, T&& value);
// Note: overload needed to allow {...} expressions for value.
template <class T>
- NodeDefBuilder& Attr(const string& attr_name,
- std::initializer_list<T> value) {
+ NodeDefBuilder& Attr(StringPiece attr_name, std::initializer_list<T> value) {
Attr<std::initializer_list<T>>(attr_name, std::move(value));
return *this;
}
@@ -141,13 +130,13 @@ class NodeDefBuilder {
bool NextArgAvailable();
// These do the main work of the Input() methods.
- void SingleInput(const OpDef::ArgDef* input_arg, const string& src_node,
+ void SingleInput(const OpDef::ArgDef* input_arg, StringPiece src_node,
int src_index, DataType dt);
void ListInput(const OpDef::ArgDef* input_arg,
gtl::ArraySlice<NodeOut> src_list);
// Add "src_node:src_index" to the list of inputs in the node_def_.
- void AddInput(const string& src_node, int src_index);
+ void AddInput(StringPiece src_node, int src_index);
// Generate an error if you can't pass dt when expected is expected.
void VerifyInputType(const OpDef::ArgDef* input_arg, DataType expected,
@@ -161,6 +150,9 @@ class NodeDefBuilder {
return input_arg->is_ref() ? MakeRefType(dt) : dt;
}
+ void CheckInconsistency(StringPiece attr_name, const AttrValue& found,
+ const AttrValue& attr_value);
+
const OpDef* op_def_;
NodeDef node_def_;
int inputs_specified_;
@@ -171,18 +163,14 @@ class NodeDefBuilder {
// IMPLEMENTATION -------------------------------------------------------------
template <class T>
-NodeDefBuilder& NodeDefBuilder::Attr(const string& attr_name, T&& value) {
+NodeDefBuilder& NodeDefBuilder::Attr(StringPiece attr_name, T&& value) {
const AttrValue* found = AttrSlice(node_def_).Find(attr_name);
if (found == nullptr) {
AddNodeAttr(attr_name, std::forward<T>(value), &node_def_);
} else {
AttrValue attr_value;
SetAttrValue(std::forward<T>(value), &attr_value);
- if (!AreAttrValuesEqual(*found, attr_value)) {
- errors_.push_back(strings::StrCat(
- "Inconsistent values for attr '", attr_name, "' ",
- SummarizeAttrValue(*found), " vs. ", SummarizeAttrValue(attr_value)));
- }
+ CheckInconsistency(attr_name, *found, attr_value);
}
return *this;
}
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index 641411892d..002adfc250 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -72,13 +72,13 @@ string SummarizeNodeDef(const NodeDef& node_def) {
return ret;
}
-const AttrValue* AttrSlice::Find(const string& attr_name) const {
- auto iter = attrs_->find(attr_name);
+const AttrValue* AttrSlice::Find(StringPiece attr_name) const {
+ auto iter = attrs_->find(attr_name.ToString());
if (iter == attrs_->end()) return nullptr;
return &iter->second;
}
-Status AttrSlice::Find(const string& attr_name,
+Status AttrSlice::Find(StringPiece attr_name,
const AttrValue** attr_value) const {
*attr_value = Find(attr_name);
if (*attr_value != nullptr) {
@@ -97,7 +97,7 @@ Status AttrSlice::Find(const string& attr_name,
// The ... is to allow the caller to inject some value validation code. Use
// just ; if no additional validation code is needed.
#define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
- Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, \
+ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
TYPE* value) { \
const AttrValue* attr_value; \
TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
@@ -107,7 +107,7 @@ Status AttrSlice::Find(const string& attr_name,
*value = CAST; \
return Status::OK(); \
} \
- Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, \
+ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
std::vector<TYPE>* value) { \
const AttrValue* attr_value; \
TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
@@ -149,7 +149,7 @@ DEFINE_GET_ATTR(Tensor, tensor, "tensor", emplace_back, t, Tensor t;
#undef DEFINE_GET_ATTR
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
DataTypeVector* value) {
const AttrValue* attr_value;
TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
@@ -160,7 +160,7 @@ Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
return Status::OK();
}
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
const TensorProto** value) {
const AttrValue* attr_value;
TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
@@ -169,7 +169,7 @@ Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
return Status::OK();
}
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
const NameAttrList** value) {
const AttrValue* attr_value;
TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index d70bb6dd37..db7c98149a 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -38,20 +38,22 @@ typedef protobuf::Map<string, AttrValue> AttrValueMap;
// Adds an attr with name <name> and value <value> to *node_def.
// The type of the attr is based on the type of value.
template <class T>
-void AddNodeAttr(const string& name, T&& value, NodeDef* node_def) {
+void AddNodeAttr(StringPiece name, T&& value, NodeDef* node_def) {
AttrValue attr_value;
SetAttrValue(std::forward<T>(value), &attr_value);
- node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value));
+ node_def->mutable_attr()->insert(
+ AttrValueMap::value_type(name.ToString(), attr_value));
}
// Version to workaround C++'s "perfect" forwarding not being able to
// forward {...} initialization.
template <class T>
-void AddNodeAttr(const string& name, std::initializer_list<T> value,
+void AddNodeAttr(StringPiece name, std::initializer_list<T> value,
NodeDef* node_def) {
AttrValue attr_value;
SetAttrValue(value, &attr_value);
- node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value));
+ node_def->mutable_attr()->insert(
+ AttrValueMap::value_type(name.ToString(), attr_value));
}
class AttrSlice {
@@ -62,11 +64,11 @@ class AttrSlice {
// Returns the attr with attr_name if found. Otherwise, returns
// nullptr.
- const AttrValue* Find(const string& attr_name) const;
+ const AttrValue* Find(StringPiece attr_name) const;
// Returns the attr_value for attr_name if found. Otherwise, returns a
// NotFound status.
- Status Find(const string& attr_name, const AttrValue** attr_value) const;
+ Status Find(StringPiece attr_name, const AttrValue** attr_value) const;
private:
const NodeDef* ndef_;
@@ -76,58 +78,58 @@ class AttrSlice {
// Look up the attr with name attr_name and set *value to its value. If no
// attr with attr_name is found in node_def, or the attr does not have
// a matching type, a non-ok status will be returned.
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
string* value); // type: "string"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
int64* value); // type: "int"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
int32* value); // type: "int"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
float* value); // type: "float"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
bool* value); // type: "bool"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
DataType* value); // type: "type"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
TensorShapeProto* value); // type: "shape"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
TensorShape* value); // type: "shape"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
PartialTensorShape* value); // type: "shape"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
Tensor* value); // type: "tensor"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<string>* value); // type "list(string)"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<int64>* value); // type "list(int)"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<int32>* value); // type "list(int)"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<float>* value); // type "list(float)"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<bool>* value); // type "list(bool)"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<DataType>* value); // type "list(type)"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
DataTypeVector* value); // type "list(type)"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<TensorShapeProto>* value); // type "list(shape)"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<TensorShape>* value); // type "list(shape)"
Status GetNodeAttr(
- const AttrSlice& attrs, const string& attr_name,
+ const AttrSlice& attrs, StringPiece attr_name,
std::vector<PartialTensorShape>* value); // type "list(shape)"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<Tensor>* value); // type: "list(tensor)"
// This version avoids copying the TensorProto.
// REQUIRES: Must not use *value beyond the lifetime of node_def.
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
const TensorProto** value); // type: "tensor"
// This version avoids copying the NameAttrList.
// REQUIRES: Must not use *value beyond the lifetime of node_def.
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
const NameAttrList** value); // type: "func"
// Computes the input and output types for a specific node.
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index c86c259651..c984c8dc93 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -93,9 +93,9 @@ OpKernel::OpKernel(OpKernelConstruction* context)
OpKernel::~OpKernel() {}
-Status OpKernel::InputRange(const string& input_name, int* start,
+Status OpKernel::InputRange(StringPiece input_name, int* start,
int* stop) const {
- const auto result = input_name_map_.find(input_name);
+ const auto result = input_name_map_.find(input_name.ToString());
if (result == input_name_map_.end()) {
return errors::InvalidArgument("Unknown input name: ", input_name);
} else {
@@ -105,9 +105,9 @@ Status OpKernel::InputRange(const string& input_name, int* start,
}
}
-Status OpKernel::OutputRange(const string& output_name, int* start,
+Status OpKernel::OutputRange(StringPiece output_name, int* start,
int* stop) const {
- const auto result = output_name_map_.find(output_name);
+ const auto result = output_name_map_.find(output_name.ToString());
if (result == output_name_map_.end()) {
return errors::InvalidArgument("Unknown output name: ", output_name);
} else {
@@ -236,7 +236,7 @@ void OpKernelContext::really_record_tensor_reference(const Tensor& tensor) {
referenced_tensors_.Add(tensor);
}
-Status OpKernelContext::input(const string& name, const Tensor** tensor) {
+Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
@@ -254,7 +254,7 @@ Status OpKernelContext::input(const string& name, const Tensor** tensor) {
return Status::OK();
}
-Status OpKernelContext::input_ref_mutex(const string& name, mutex** out_mutex) {
+Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
@@ -329,7 +329,7 @@ void OpKernelContext::delete_ref_input(int index, bool lock_held) {
}
}
-Status OpKernelContext::mutable_input(const string& name, Tensor* tensor,
+Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
bool lock_held) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
@@ -353,7 +353,7 @@ Status OpKernelContext::mutable_input(const string& name, Tensor* tensor,
return Status::OK();
}
-Status OpKernelContext::replace_ref_input(const string& name,
+Status OpKernelContext::replace_ref_input(StringPiece name,
const Tensor& tensor,
bool lock_held) {
int start, stop;
@@ -371,14 +371,14 @@ Status OpKernelContext::replace_ref_input(const string& name,
return Status::OK();
}
-Status OpKernelContext::input_list(const string& name, OpInputList* list) {
+Status OpKernelContext::input_list(StringPiece name, OpInputList* list) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
*list = OpInputList(this, start, stop);
return Status::OK();
}
-Status OpKernelContext::mutable_input_list(const string& name,
+Status OpKernelContext::mutable_input_list(StringPiece name,
OpMutableInputList* list) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
@@ -386,7 +386,7 @@ Status OpKernelContext::mutable_input_list(const string& name,
return Status::OK();
}
-Status OpKernelContext::output_list(const string& name, OpOutputList* list) {
+Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
*list = OpOutputList(this, start, stop);
@@ -401,7 +401,7 @@ Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
return allocate_output(index, shape, output, attr);
}
-Status OpKernelContext::allocate_output(const string& name,
+Status OpKernelContext::allocate_output(StringPiece name,
const TensorShape& shape,
Tensor** tensor) {
int start, stop;
@@ -415,7 +415,7 @@ Status OpKernelContext::allocate_output(const string& name,
return allocate_output(start, shape, tensor);
}
-Status OpKernelContext::allocate_output(const string& name,
+Status OpKernelContext::allocate_output(StringPiece name,
const TensorShape& shape,
Tensor** tensor,
AllocatorAttributes attr) {
@@ -494,7 +494,7 @@ Status OpKernelContext::allocate_persistent(DataType type,
return s;
}
-Status OpKernelContext::set_output(const string& name, const Tensor& tensor) {
+Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
@@ -525,7 +525,7 @@ void OpKernelContext::set_output_ref(int index, mutex* mu,
outputs_[index] = TensorValue(mu, tensor_for_ref);
}
-Status OpKernelContext::set_output_ref(const string& name, mutex* mu,
+Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu,
Tensor* tensor_for_ref) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
@@ -539,7 +539,7 @@ Status OpKernelContext::set_output_ref(const string& name, mutex* mu,
return Status::OK();
}
-Status OpKernelContext::mutable_output(const string& name, Tensor** tensor) {
+Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
@@ -552,7 +552,7 @@ Status OpKernelContext::mutable_output(const string& name, Tensor** tensor) {
return Status::OK();
}
-Status OpKernelContext::release_output(const string& name, TensorValue* value) {
+Status OpKernelContext::release_output(StringPiece name, TensorValue* value) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
@@ -615,8 +615,8 @@ static KernelRegistry* GlobalKernelRegistryTyped() {
return reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
}
-static string Key(const string& op_type, DeviceType device_type,
- const string& label) {
+static string Key(StringPiece op_type, DeviceType device_type,
+ StringPiece label) {
return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":",
label);
}
@@ -832,7 +832,7 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
namespace {
-bool FindArgInOp(const string& arg_name,
+bool FindArgInOp(StringPiece arg_name,
const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
for (const auto& arg : args) {
if (arg_name == arg.name()) {
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 46bae0a0b6..e2cabb4b68 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -122,8 +122,8 @@ class OpKernel {
return output_memory_types_;
}
- Status InputRange(const string& input_name, int* start, int* stop) const;
- Status OutputRange(const string& output_name, int* start, int* stop) const;
+ Status InputRange(StringPiece input_name, int* start, int* stop) const;
+ Status OutputRange(StringPiece output_name, int* start, int* stop) const;
// We allow legacy scalars within Google up until GraphDef version 6.
// TODO(irving): Remove when we can drop support for GraphDef version 5.
@@ -295,7 +295,7 @@ class OpKernelConstruction {
// attr with attr_name is found in def(), or the attr does not have
// a matching type, a non-ok status will be returned.
template <class T>
- Status GetAttr(const string& attr_name, T* value) const;
+ Status GetAttr(StringPiece attr_name, T* value) const;
// May be used, e.g., to get GPU handles, etc.
// TODO(tucker): Add example usage.
@@ -558,14 +558,14 @@ class OpKernelContext {
// use mutable_input below.
// REQUIRES: !IsRefType(input_dtype(index))
// REQUIRES: the named input must not be a list.
- Status input(const string& name, const Tensor** tensor);
+ Status input(StringPiece name, const Tensor** tensor);
// Returns the named list-valued immutable input in "list", as
// defined in the OpDef. If the named output is not list-valued,
// returns a one-element list. May only be used for non-Ref
// inputs. For Ref inputs use mutable_input below.
// REQUIRES: !IsRefType(input_dtype(index))
- Status input_list(const string& name, OpInputList* list);
+ Status input_list(StringPiece name, OpInputList* list);
// For mutable inputs, use the following together to make sure there
// is no concurrent access to mutable_input(), e.g.:
@@ -577,7 +577,7 @@ class OpKernelContext {
// REQUIRES: IsRefType(input_dtype(index))
// TODO(mrry): Convert this to return Status.
mutex* input_ref_mutex(int index);
- Status input_ref_mutex(const string& name, mutex** out_mutex);
+ Status input_ref_mutex(StringPiece name, mutex** out_mutex);
// Returns a mutable input tensor. Must be used to access Ref
// inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may
@@ -596,7 +596,7 @@ class OpKernelContext {
// the input mutex will be acquired before returning the Tensor.
// REQUIRES: the named input must not be a list.
// REQUIRES: the named input must be a ref tensor.
- Status mutable_input(const string& name, Tensor* tensor, bool lock_held);
+ Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);
// Returns the named list-valued mutable input in "list", as defined
// in the OpDef. If the named input is not list-valued, returns a
@@ -604,7 +604,7 @@ class OpKernelContext {
// stored in the Tensor buffer may be modified, and modifications
// will be visible to other Ops reading the same ref tensor.
// REQUIRES: the named input must be a ref tensor.
- Status mutable_input_list(const string& name, OpMutableInputList* list);
+ Status mutable_input_list(StringPiece name, OpMutableInputList* list);
// Replace the corresponding Ref Input to use the storage buffer
// used by tensor. If !lock_held the input mutex will be acquired
@@ -616,7 +616,7 @@ class OpKernelContext {
// buffer used by tensor. If !lock_held the input mutex will be
// acquired before returning the Tensor.
// REQUIRES: IsRefType(input_dtype(index)).
- Status replace_ref_input(const string& name, const Tensor& tensor,
+ Status replace_ref_input(StringPiece name, const Tensor& tensor,
bool lock_held);
// Set the output Ref Tensor at output_index to be an alias of the
@@ -647,7 +647,7 @@ class OpKernelContext {
// Returns the named list-valued output in "list", as defined in the OpDef.
// If the named output is not list-valued, returns a one-element list.
- Status output_list(const string& name, OpOutputList* list);
+ Status output_list(StringPiece name, OpOutputList* list);
// If output_required(index) returns true, the OpKernel's Compute() method
// should call allocate_output(index, ...), set_output(index, ...),
@@ -712,7 +712,7 @@ class OpKernelContext {
// REQUIRES: !IsRefType(expected_output_dtype(index))
Status allocate_output(int index, const TensorShape& shape,
Tensor** tensor) TF_MUST_USE_RESULT;
- Status allocate_output(const string& name, const TensorShape& shape,
+ Status allocate_output(StringPiece name, const TensorShape& shape,
Tensor** tensor) TF_MUST_USE_RESULT;
// The following methods use the supplied attributes instead of
// those in output_attr_array. The caller is responsible for
@@ -721,7 +721,7 @@ class OpKernelContext {
// device. See comment above.
Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
AllocatorAttributes attr) TF_MUST_USE_RESULT;
- Status allocate_output(const string& name, const TensorShape& shape,
+ Status allocate_output(StringPiece name, const TensorShape& shape,
Tensor** tensor,
AllocatorAttributes attr) TF_MUST_USE_RESULT;
@@ -766,19 +766,19 @@ class OpKernelContext {
// output_memory_types[index]. See comment above.
// TODO(mrry): Convert this to return Status.
void set_output(int index, const Tensor& tensor);
- Status set_output(const string& name, const Tensor& tensor);
+ Status set_output(StringPiece name, const Tensor& tensor);
// To output a reference. Caller retains ownership of mu and tensor_for_ref,
// and they must outlive all uses within the step. See comment above.
// REQUIRES: IsRefType(expected_output_dtype(index))
// TODO(mrry): Convert this to return Status.
void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
- Status set_output_ref(const string& name, mutex* mu, Tensor* tensor_for_ref);
+ Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref);
// Returns nullptr if allocate_output() or set_output() have not been called.
// TODO(mrry): Convert this to return Status.
Tensor* mutable_output(int index);
- Status mutable_output(const string& name, Tensor** tensor);
+ Status mutable_output(StringPiece name, Tensor** tensor);
// Transfers ownership of an output tensor to the caller.
// NOTE: For non-reference outputs, the caller takes responsibility
@@ -786,7 +786,7 @@ class OpKernelContext {
// responsibility for deletion.
// TODO(mrry): Convert this to return Status.
TensorValue release_output(int index);
- Status release_output(const string& name, TensorValue* value);
+ Status release_output(StringPiece name, TensorValue* value);
// Records device specific state about how the input tensors were
// computed.
@@ -1074,7 +1074,7 @@ class OpKernelRegistrar {
// Template and inline method implementations, please ignore
template <class T>
-Status OpKernelConstruction::GetAttr(const string& attr_name, T* value) const {
+Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const {
return GetNodeAttr(def(), attr_name, value);
}
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 5420a0d976..77273d9ba3 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -369,6 +369,20 @@ bool Tensor::IsInitialized() const {
return buf_ != nullptr && buf_->data() != nullptr;
}
+void Tensor::CheckType(DataType expected_dtype) const {
+ CHECK_EQ(dtype(), expected_dtype);
+}
+
+void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const {
+ CHECK_EQ(dtype(), expected_dtype);
+ CHECK(IsAligned());
+}
+
+void Tensor::CheckIsAlignedAndSingleElement() const {
+ CHECK(IsAligned());
+ CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
+}
+
Tensor::~Tensor() { UnrefIfNonNull(buf_); }
void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) {
@@ -551,7 +565,8 @@ bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
set_dtype(proto.dtype());
UnrefIfNonNull(buf_);
buf_ = p;
- // TODO(misard) add tracking of which kernels and steps are calling FromProto.
+ // TODO(misard) add tracking of which kernels and steps are calling
+ // FromProto.
if (IsInitialized() && LogMemory::IsEnabled()) {
LogMemory::RecordTensorAllocation("Unknown (from Proto)",
LogMemory::UNKNOWN_STEP_ID, *this);
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index a6143e95c0..708c98f409 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -370,7 +370,15 @@ class Tensor {
void UnsafeCopyFromInternal(const Tensor&, const TensorShape&);
private:
+ void CheckType(DataType expected_dtype) const;
+ void CheckTypeAndIsAligned(DataType expected_dtype) const;
+ void CheckIsAlignedAndSingleElement() const;
void set_dtype(DataType t) { shape_.set_data_type(t); }
+ template <size_t NDIMS>
+ void FillDimsAndValidateCompatibleShape(
+ gtl::ArraySlice<int64> new_sizes,
+ Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const;
+
TensorShape shape_;
TensorBuffer* buf_;
@@ -440,48 +448,46 @@ T* Tensor::base() const {
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor Tensor::tensor() {
- CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
- CHECK(IsAligned());
+ CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
return typename TTypes<T, NDIMS>::Tensor(base<T>(),
shape().AsEigenDSizes<NDIMS>());
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor Tensor::tensor() const {
- CHECK(IsAligned());
- CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
+ CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
return typename TTypes<T, NDIMS>::ConstTensor(base<const T>(),
shape().AsEigenDSizes<NDIMS>());
}
-template <typename T, size_t NDIMS>
-typename TTypes<T, NDIMS>::Tensor Tensor::shaped(
- gtl::ArraySlice<int64> new_sizes) {
- CHECK(IsAligned());
- CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
+template <size_t NDIMS>
+void Tensor::FillDimsAndValidateCompatibleShape(
+ gtl::ArraySlice<int64> new_sizes,
+ Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const {
CHECK_EQ(NDIMS, new_sizes.size());
int64 new_num_elements = 1;
- Eigen::array<Eigen::DenseIndex, NDIMS> dims;
for (size_t d = 0; d < NDIMS; d++) {
new_num_elements *= new_sizes[d];
- dims[d] = new_sizes[d];
+ (*dims)[d] = new_sizes[d];
}
CHECK_EQ(new_num_elements, NumElements());
+}
+
+template <typename T, size_t NDIMS>
+typename TTypes<T, NDIMS>::Tensor Tensor::shaped(
+ gtl::ArraySlice<int64> new_sizes) {
+ CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
+ Eigen::array<Eigen::DenseIndex, NDIMS> dims;
+ FillDimsAndValidateCompatibleShape<NDIMS>(new_sizes, &dims);
return typename TTypes<T, NDIMS>::Tensor(base<T>(), dims);
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::UnalignedTensor Tensor::unaligned_shaped(
gtl::ArraySlice<int64> new_sizes) {
- CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
- CHECK_EQ(NDIMS, new_sizes.size());
- int64 new_num_elements = 1;
+ CheckType(DataTypeToEnum<T>::v());
Eigen::array<Eigen::DenseIndex, NDIMS> dims;
- for (size_t d = 0; d < NDIMS; d++) {
- new_num_elements *= new_sizes[d];
- dims[d] = new_sizes[d];
- }
- CHECK_EQ(new_num_elements, NumElements());
+ FillDimsAndValidateCompatibleShape<NDIMS>(new_sizes, &dims);
return typename TTypes<T, NDIMS>::UnalignedTensor(base<T>(), dims);
}
@@ -501,8 +507,7 @@ void Tensor::FillDimsAndValidateCompatibleShape(
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor Tensor::shaped(
gtl::ArraySlice<int64> new_sizes) const {
- CHECK(IsAligned());
- CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
+ CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
Eigen::array<Eigen::DenseIndex, NDIMS> dims;
FillDimsAndValidateCompatibleShape(&dims, new_sizes);
return typename TTypes<T, NDIMS>::ConstTensor(base<T>(), dims);
@@ -511,7 +516,7 @@ typename TTypes<T, NDIMS>::ConstTensor Tensor::shaped(
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::UnalignedConstTensor Tensor::unaligned_shaped(
gtl::ArraySlice<int64> new_sizes) const {
- CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
+ CheckType(DataTypeToEnum<T>::v());
Eigen::array<Eigen::DenseIndex, NDIMS> dims;
FillDimsAndValidateCompatibleShape(&dims, new_sizes);
return typename TTypes<T, NDIMS>::UnalignedConstTensor(base<T>(), dims);
@@ -519,15 +524,13 @@ typename TTypes<T, NDIMS>::UnalignedConstTensor Tensor::unaligned_shaped(
template <typename T>
typename TTypes<T>::Scalar Tensor::scalar() {
- CHECK(IsAligned());
- CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
+ CheckIsAlignedAndSingleElement();
return typename TTypes<T>::Scalar(base<T>());
}
template <typename T>
typename TTypes<T>::ConstScalar Tensor::scalar() const {
- CHECK(IsAligned());
- CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
+ CheckIsAlignedAndSingleElement();
return typename TTypes<T>::ConstScalar(base<T>());
}
diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc
index 7cc98513c8..ee59a79d38 100644
--- a/tensorflow/core/framework/tensor_shape.cc
+++ b/tensorflow/core/framework/tensor_shape.cc
@@ -32,6 +32,16 @@ static void AppendTo(const TensorShape& s, gtl::InlinedVector<int64, 8>* vals) {
}
}
+void TensorShape::CheckDimsEqual(int NDIMS) const {
+ CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS
+ << " for a tensor of " << dims() << " dimensions";
+}
+
+void TensorShape::CheckDimsAtLeast(int NDIMS) const {
+ CHECK_GE(NDIMS, dims()) << "Asking for tensor of at least " << NDIMS
+ << " for a tensor of " << dims() << " dimensions";
+}
+
bool TensorShape::IsValid(const TensorShapeProto& proto) {
int64 num_elements = 1;
for (const auto& d : proto.dim()) {
diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h
index bd80215849..e341ceddfb 100644
--- a/tensorflow/core/framework/tensor_shape.h
+++ b/tensorflow/core/framework/tensor_shape.h
@@ -143,6 +143,9 @@ class TensorShape {
void RecomputeNumElements();
+ void CheckDimsEqual(int NDIMS) const;
+ void CheckDimsAtLeast(int NDIMS) const;
+
// We use 16 bytes to represent a TensorShape. Because we need to
// be able to support full 64-bit dimension sizes and an arbitrary
// number of dimensions for a Tensor, but most tensor dimensions are
@@ -266,16 +269,14 @@ class TensorShapeUtils {
template <int NDIMS>
Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizes() const {
- CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS
- << " for a tensor of " << dims() << " dimensions";
+ CheckDimsEqual(NDIMS);
return AsEigenDSizesWithPadding<NDIMS>();
}
template <int NDIMS>
Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizesWithPadding()
const {
- CHECK_GE(NDIMS, dims()) << "Asking for tensor of " << NDIMS
- << " for a tensor of " << dims() << " dimensions";
+ CheckDimsAtLeast(NDIMS);
Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes;
for (int d = 0; d < dims(); d++) {
dsizes[d] = dim_size(d);
diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc
index 8de02cc2ca..df59950a16 100644
--- a/tensorflow/core/graph/node_builder.cc
+++ b/tensorflow/core/graph/node_builder.cc
@@ -28,17 +28,17 @@ NodeBuilder::NodeOut::NodeOut(Node* n, int i) // NOLINT(runtime/explicit)
index(i),
dt(SafeGetOutput(node, i, &error)) {}
-NodeBuilder::NodeOut::NodeOut(const string& name, int i, DataType t)
- : node(nullptr), error(false), name(name), index(i), dt(t) {}
+NodeBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType t)
+ : node(nullptr), error(false), name(n.ToString()), index(i), dt(t) {}
NodeBuilder::NodeOut::NodeOut()
: node(nullptr), error(true), index(0), dt(DT_FLOAT) {}
-NodeBuilder::NodeBuilder(const string& name, const string& op_name,
+NodeBuilder::NodeBuilder(StringPiece name, StringPiece op_name,
const OpRegistryInterface* op_registry)
: def_builder_(name, op_name, op_registry) {}
-NodeBuilder::NodeBuilder(const string& name, const OpDef* op_def)
+NodeBuilder::NodeBuilder(StringPiece name, const OpDef* op_def)
: def_builder_(name, op_def) {}
NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) {
@@ -90,7 +90,7 @@ NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice<Node*> src_nodes) {
return *this;
}
-NodeBuilder& NodeBuilder::Device(const string& device_spec) {
+NodeBuilder& NodeBuilder::Device(StringPiece device_spec) {
def_builder_.Device(device_spec);
return *this;
}
diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h
index 147311f320..50c41e2221 100644
--- a/tensorflow/core/graph/node_builder.h
+++ b/tensorflow/core/graph/node_builder.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
@@ -54,7 +55,7 @@ class NodeBuilder {
// useful when preparing a graph for ExtendSession or creating a
// back edge to a node that hasn't been added to the graph yet,
// but will be.
- NodeOut(const string& name, int i, DataType t);
+ NodeOut(StringPiece name, int i, DataType t);
// Default constructor for std::vector<NodeOut>.
NodeOut();
@@ -74,9 +75,9 @@ class NodeBuilder {
// the Op plus a registry) for the Node. Other fields are
// specified by calling the methods below.
// REQUIRES: The OpDef must satisfy ValidateOpDef().
- NodeBuilder(const string& name, const string& op_name,
+ NodeBuilder(StringPiece name, StringPiece op_name,
const OpRegistryInterface* op_registry = OpRegistry::Global());
- NodeBuilder(const string& name, const OpDef* op_def);
+ NodeBuilder(StringPiece name, const OpDef* op_def);
// You must call one Input() function per input_arg in the Op,
// *and in the same order as the input_args appear in the OpDef.*
@@ -94,7 +95,7 @@ class NodeBuilder {
// Sets the "requested device spec" in the NodeDef (not the
// "assigned device" in the Node).
- NodeBuilder& Device(const string& device_spec);
+ NodeBuilder& Device(StringPiece device_spec);
// Set the value of an attr. attr_name must match the name of one of
// attrs defined by the Op, and value must have the corresponding type
@@ -102,9 +103,9 @@ class NodeBuilder {
// types for value). Note that attrs will be set automatically if
// they can be determined by the inputs.
template <class T>
- NodeBuilder& Attr(const string& attr_name, T&& value);
+ NodeBuilder& Attr(StringPiece attr_name, T&& value);
template <class T>
- NodeBuilder& Attr(const string& attr_name, std::initializer_list<T> value);
+ NodeBuilder& Attr(StringPiece attr_name, std::initializer_list<T> value);
// Validates the described node and adds it to *graph, adding edges
// for all (non-back) inputs. If created_node is not nullptr,
@@ -138,13 +139,13 @@ class NodeBuilder {
// IMPLEMENTATION -------------------------------------------------------------
template <class T>
-NodeBuilder& NodeBuilder::Attr(const string& attr_name, T&& value) {
+NodeBuilder& NodeBuilder::Attr(StringPiece attr_name, T&& value) {
def_builder_.Attr(attr_name, std::forward<T>(value));
return *this;
}
template <class T>
-NodeBuilder& NodeBuilder::Attr(const string& attr_name,
+NodeBuilder& NodeBuilder::Attr(StringPiece attr_name,
std::initializer_list<T> value) {
def_builder_.Attr(attr_name, value);
return *this;
diff --git a/tensorflow/core/lib/strings/strcat.cc b/tensorflow/core/lib/strings/strcat.cc
index 4050727dc0..0c659e236c 100644
--- a/tensorflow/core/lib/strings/strcat.cc
+++ b/tensorflow/core/lib/strings/strcat.cc
@@ -84,6 +84,8 @@ static char *Append4(char *out, const AlphaNum &x1, const AlphaNum &x2,
return out + x4.size();
}
+string StrCat(const AlphaNum &a) { return string(a.data(), a.size()); }
+
string StrCat(const AlphaNum &a, const AlphaNum &b) {
string result;
gtl::STLStringResizeUninitialized(&result, a.size() + b.size());
diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h
index 33b6028153..d7d5352a88 100644
--- a/tensorflow/core/lib/strings/strcat.h
+++ b/tensorflow/core/lib/strings/strcat.h
@@ -172,9 +172,6 @@ string StrCat(const AlphaNum &a, const AlphaNum &b,
string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c,
const AlphaNum &d) TF_MUST_USE_RESULT;
-// inline definitions must be duplicated due to TF_MUST_USE_RESULT
-inline string StrCat(const AlphaNum &a) { return string(a.data(), a.size()); }
-
namespace internal {
// Do not call directly - this is not part of the public API.
@@ -190,8 +187,8 @@ string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c,
const AV &... args) TF_MUST_USE_RESULT;
template <typename... AV>
-inline string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c,
- const AlphaNum &d, const AlphaNum &e, const AV &... args) {
+string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c,
+ const AlphaNum &d, const AlphaNum &e, const AV &... args) {
return internal::CatPieces({a.Piece(), b.Piece(), c.Piece(), d.Piece(),
e.Piece(),
static_cast<const AlphaNum &>(args).Piece()...});
diff --git a/tensorflow/core/util/mirror_pad_mode.cc b/tensorflow/core/util/mirror_pad_mode.cc
index 5d38c9b1de..a20c1fba74 100644
--- a/tensorflow/core/util/mirror_pad_mode.cc
+++ b/tensorflow/core/util/mirror_pad_mode.cc
@@ -21,7 +21,7 @@ limitations under the License.
namespace tensorflow {
-Status GetNodeAttr(const NodeDef& node_def, const string& attr_name,
+Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name,
MirrorPadMode* value) {
string str_value;
TF_RETURN_IF_ERROR(GetNodeAttr(node_def, attr_name, &str_value));
diff --git a/tensorflow/core/util/mirror_pad_mode.h b/tensorflow/core/util/mirror_pad_mode.h
index ae96fdff73..c865d35e28 100644
--- a/tensorflow/core/util/mirror_pad_mode.h
+++ b/tensorflow/core/util/mirror_pad_mode.h
@@ -44,7 +44,7 @@ string GetMirrorPadModeAttrString();
class NodeDef;
// Specialization to parse an attribute directly into a MirrorPadMode enum.
-Status GetNodeAttr(const NodeDef& node_def, const string& attr_name,
+Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name,
MirrorPadMode* value);
} // end namespace tensorflow
diff --git a/tensorflow/core/util/padding.cc b/tensorflow/core/util/padding.cc
index b3064d6420..ce36d4cc78 100644
--- a/tensorflow/core/util/padding.cc
+++ b/tensorflow/core/util/padding.cc
@@ -20,7 +20,7 @@ limitations under the License.
namespace tensorflow {
-Status GetNodeAttr(const NodeDef& node_def, const string& attr_name,
+Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name,
Padding* value) {
string str_value;
TF_RETURN_IF_ERROR(GetNodeAttr(node_def, attr_name, &str_value));
diff --git a/tensorflow/core/util/padding.h b/tensorflow/core/util/padding.h
index d6d970d27a..989a71b24f 100644
--- a/tensorflow/core/util/padding.h
+++ b/tensorflow/core/util/padding.h
@@ -44,7 +44,7 @@ enum Padding {
string GetPaddingAttrString();
// Specialization to parse an attribute directly into a Padding enum.
-Status GetNodeAttr(const NodeDef& node_def, const string& attr_name,
+Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name,
Padding* value);
} // end namespace tensorflow