aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/framework/node_def_builder.cc49
-rw-r--r--tensorflow/core/framework/node_def_builder.h50
-rw-r--r--tensorflow/core/framework/node_def_util.cc16
-rw-r--r--tensorflow/core/framework/node_def_util.h60
-rw-r--r--tensorflow/core/framework/op_kernel.cc40
-rw-r--r--tensorflow/core/framework/op_kernel.h34
-rw-r--r--tensorflow/core/framework/tensor.cc17
-rw-r--r--tensorflow/core/framework/tensor.h55
-rw-r--r--tensorflow/core/framework/tensor_shape.cc10
-rw-r--r--tensorflow/core/framework/tensor_shape.h9
-rw-r--r--tensorflow/core/graph/node_builder.cc10
-rw-r--r--tensorflow/core/graph/node_builder.h17
-rw-r--r--tensorflow/core/lib/strings/strcat.cc2
-rw-r--r--tensorflow/core/lib/strings/strcat.h7
-rw-r--r--tensorflow/core/util/mirror_pad_mode.cc2
-rw-r--r--tensorflow/core/util/mirror_pad_mode.h2
-rw-r--r--tensorflow/core/util/padding.cc2
-rw-r--r--tensorflow/core/util/padding.h2
18 files changed, 218 insertions, 166 deletions
diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc
index b6f5838528..f3091ad286 100644
--- a/tensorflow/core/framework/node_def_builder.cc
+++ b/tensorflow/core/framework/node_def_builder.cc
@@ -22,11 +22,24 @@ limitations under the License.
namespace tensorflow {
-NodeDefBuilder::NodeDefBuilder(const string& name, const string& op_name,
+NodeDefBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType dt)
+ : node(n.ToString()), index(i), data_type(dt) {}
+
+NodeDefBuilder::NodeOut::NodeOut() {
+ // uninitialized, call Reset() before use.
+}
+
+void NodeDefBuilder::NodeOut::Reset(StringPiece n, int i, DataType dt) {
+ node = n.ToString();
+ index = i;
+ data_type = dt;
+}
+
+NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
const OpRegistryInterface* op_registry) {
- node_def_.set_name(name);
+ node_def_.set_name(name.ToString());
Status status;
- op_def_ = op_registry->LookUp(op_name, &status);
+ op_def_ = op_registry->LookUp(op_name.ToString(), &status);
if (op_def_ == nullptr) {
errors_.push_back(status.error_message());
inputs_specified_ = 0;
@@ -35,9 +48,9 @@ NodeDefBuilder::NodeDefBuilder(const string& name, const string& op_name,
}
}
-NodeDefBuilder::NodeDefBuilder(const string& name, const OpDef* op_def)
+NodeDefBuilder::NodeDefBuilder(StringPiece name, const OpDef* op_def)
: op_def_(op_def) {
- node_def_.set_name(name);
+ node_def_.set_name(name.ToString());
Initialize();
}
@@ -72,7 +85,7 @@ NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) {
}
void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg,
- const string& src_node, int src_index,
+ StringPiece src_node, int src_index,
DataType dt) {
AddInput(src_node, src_index);
@@ -129,7 +142,7 @@ void NodeDefBuilder::ListInput(const OpDef::ArgDef* input_arg,
}
}
-void NodeDefBuilder::AddInput(const string& src_node, int src_index) {
+void NodeDefBuilder::AddInput(StringPiece src_node, int src_index) {
if (src_node.empty()) {
errors_.push_back("Empty input node name");
} else if (src_node[0] == '^') {
@@ -138,7 +151,7 @@ void NodeDefBuilder::AddInput(const string& src_node, int src_index) {
} else if (src_index > 0) {
node_def_.add_input(strings::StrCat(src_node, ":", src_index));
} else {
- node_def_.add_input(src_node);
+ node_def_.add_input(src_node.ToString());
}
}
@@ -160,6 +173,16 @@ void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg,
}
}
+NodeDefBuilder& NodeDefBuilder::ControlInput(StringPiece src_node) {
+ control_inputs_.push_back(src_node.ToString());
+ return *this;
+}
+
+NodeDefBuilder& NodeDefBuilder::Device(StringPiece device_spec) {
+ node_def_.set_device(device_spec.ToString());
+ return *this;
+}
+
Status NodeDefBuilder::Finalize(NodeDef* node_def) const {
const std::vector<string>* errors_ptr = &errors_;
std::vector<string> errors_storage;
@@ -206,4 +229,14 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def) const {
}
}
+void NodeDefBuilder::CheckInconsistency(StringPiece attr_name,
+ const AttrValue& found,
+ const AttrValue& attr_value) {
+ if (!AreAttrValuesEqual(found, attr_value)) {
+ errors_.push_back(strings::StrCat(
+ "Inconsistent values for attr '", attr_name, "' ",
+ SummarizeAttrValue(found), " vs. ", SummarizeAttrValue(attr_value)));
+ }
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/node_def_builder.h b/tensorflow/core/framework/node_def_builder.h
index 72ad37ab93..f8b3a8441a 100644
--- a/tensorflow/core/framework/node_def_builder.h
+++ b/tensorflow/core/framework/node_def_builder.h
@@ -32,7 +32,8 @@ namespace tensorflow {
class NodeDefBuilder;
typedef std::function<Status(const OpDef&, int, const NodeDef&,
- NodeDefBuilder*)> FakeInputFunctor;
+ NodeDefBuilder*)>
+ FakeInputFunctor;
// This is a helper for creating a NodeDef. Automatically sets attrs
// that can be inferred from the inputs, and uses default values
@@ -49,14 +50,9 @@ class NodeDefBuilder {
public:
// To specify an output to be consumed by one of the Input() methods below.
struct NodeOut {
- NodeOut(const string& n, int i, DataType dt)
- : node(n), index(i), data_type(dt) {}
- NodeOut() {} // uninitialized, call Reset() before use.
- void Reset(const string& n, int i, DataType dt) {
- node = n;
- index = i;
- data_type = dt;
- }
+ NodeOut(StringPiece n, int i, DataType dt);
+ NodeOut(); // uninitialized, call Reset() before use.
+ void Reset(StringPiece n, int i, DataType dt);
string node;
int index;
DataType data_type;
@@ -66,16 +62,16 @@ class NodeDefBuilder {
// the Op plus a registry) for the NodeDef. Other fields are
// specified by calling the methods below.
// REQUIRES: The OpDef must satisfy ValidateOpDef().
- NodeDefBuilder(const string& name, const string& op_name,
+ NodeDefBuilder(StringPiece name, StringPiece op_name,
const OpRegistryInterface* op_registry = OpRegistry::Global());
// REQUIRES: in addition, *op_def must outlive *this.
- NodeDefBuilder(const string& name, const OpDef* op_def);
+ NodeDefBuilder(StringPiece name, const OpDef* op_def);
// You must call one Input() function per input_arg in the Op,
// *and in the same order as the input_args appear in the OpDef.*
// For inputs that take a single tensor.
- NodeDefBuilder& Input(const string& src_node, int src_index, DataType dt) {
+ NodeDefBuilder& Input(StringPiece src_node, int src_index, DataType dt) {
const OpDef::ArgDef* arg = NextArgDef();
if (arg != nullptr) SingleInput(arg, src_node, src_index, dt);
return *this;
@@ -96,25 +92,18 @@ class NodeDefBuilder {
NodeDefBuilder& Input(FakeInputFunctor fake_input);
// Specify that this node must only run after src_node.
- NodeDefBuilder& ControlInput(const string& src_node) {
- control_inputs_.push_back(src_node);
- return *this;
- }
+ NodeDefBuilder& ControlInput(StringPiece src_node);
// Constrains what devices this node may be scheduled on.
- NodeDefBuilder& Device(const string& device_spec) {
- node_def_.set_device(device_spec);
- return *this;
- }
+ NodeDefBuilder& Device(StringPiece device_spec);
// Sets the attr, if not already set. If already set with a different
// value, an error will be returned from Finalize().
template <class T>
- NodeDefBuilder& Attr(const string& attr_name, T&& value);
+ NodeDefBuilder& Attr(StringPiece attr_name, T&& value);
// Note: overload needed to allow {...} expressions for value.
template <class T>
- NodeDefBuilder& Attr(const string& attr_name,
- std::initializer_list<T> value) {
+ NodeDefBuilder& Attr(StringPiece attr_name, std::initializer_list<T> value) {
Attr<std::initializer_list<T>>(attr_name, std::move(value));
return *this;
}
@@ -141,13 +130,13 @@ class NodeDefBuilder {
bool NextArgAvailable();
// These do the main work of the Input() methods.
- void SingleInput(const OpDef::ArgDef* input_arg, const string& src_node,
+ void SingleInput(const OpDef::ArgDef* input_arg, StringPiece src_node,
int src_index, DataType dt);
void ListInput(const OpDef::ArgDef* input_arg,
gtl::ArraySlice<NodeOut> src_list);
// Add "src_node:src_index" to the list of inputs in the node_def_.
- void AddInput(const string& src_node, int src_index);
+ void AddInput(StringPiece src_node, int src_index);
// Generate an error if you can't pass dt when expected is expected.
void VerifyInputType(const OpDef::ArgDef* input_arg, DataType expected,
@@ -161,6 +150,9 @@ class NodeDefBuilder {
return input_arg->is_ref() ? MakeRefType(dt) : dt;
}
+ void CheckInconsistency(StringPiece attr_name, const AttrValue& found,
+ const AttrValue& attr_value);
+
const OpDef* op_def_;
NodeDef node_def_;
int inputs_specified_;
@@ -171,18 +163,14 @@ class NodeDefBuilder {
// IMPLEMENTATION -------------------------------------------------------------
template <class T>
-NodeDefBuilder& NodeDefBuilder::Attr(const string& attr_name, T&& value) {
+NodeDefBuilder& NodeDefBuilder::Attr(StringPiece attr_name, T&& value) {
const AttrValue* found = AttrSlice(node_def_).Find(attr_name);
if (found == nullptr) {
AddNodeAttr(attr_name, std::forward<T>(value), &node_def_);
} else {
AttrValue attr_value;
SetAttrValue(std::forward<T>(value), &attr_value);
- if (!AreAttrValuesEqual(*found, attr_value)) {
- errors_.push_back(strings::StrCat(
- "Inconsistent values for attr '", attr_name, "' ",
- SummarizeAttrValue(*found), " vs. ", SummarizeAttrValue(attr_value)));
- }
+ CheckInconsistency(attr_name, *found, attr_value);
}
return *this;
}
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index 641411892d..002adfc250 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -72,13 +72,13 @@ string SummarizeNodeDef(const NodeDef& node_def) {
return ret;
}
-const AttrValue* AttrSlice::Find(const string& attr_name) const {
- auto iter = attrs_->find(attr_name);
+const AttrValue* AttrSlice::Find(StringPiece attr_name) const {
+ auto iter = attrs_->find(attr_name.ToString());
if (iter == attrs_->end()) return nullptr;
return &iter->second;
}
-Status AttrSlice::Find(const string& attr_name,
+Status AttrSlice::Find(StringPiece attr_name,
const AttrValue** attr_value) const {
*attr_value = Find(attr_name);
if (*attr_value != nullptr) {
@@ -97,7 +97,7 @@ Status AttrSlice::Find(const string& attr_name,
// The ... is to allow the caller to inject some value validation code. Use
// just ; if no additional validation code is needed.
#define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
- Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, \
+ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
TYPE* value) { \
const AttrValue* attr_value; \
TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
@@ -107,7 +107,7 @@ Status AttrSlice::Find(const string& attr_name,
*value = CAST; \
return Status::OK(); \
} \
- Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, \
+ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
std::vector<TYPE>* value) { \
const AttrValue* attr_value; \
TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
@@ -149,7 +149,7 @@ DEFINE_GET_ATTR(Tensor, tensor, "tensor", emplace_back, t, Tensor t;
#undef DEFINE_GET_ATTR
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
DataTypeVector* value) {
const AttrValue* attr_value;
TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
@@ -160,7 +160,7 @@ Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
return Status::OK();
}
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
const TensorProto** value) {
const AttrValue* attr_value;
TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
@@ -169,7 +169,7 @@ Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
return Status::OK();
}
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
const NameAttrList** value) {
const AttrValue* attr_value;
TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index d70bb6dd37..db7c98149a 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -38,20 +38,22 @@ typedef protobuf::Map<string, AttrValue> AttrValueMap;
// Adds an attr with name <name> and value <value> to *node_def.
// The type of the attr is based on the type of value.
template <class T>
-void AddNodeAttr(const string& name, T&& value, NodeDef* node_def) {
+void AddNodeAttr(StringPiece name, T&& value, NodeDef* node_def) {
AttrValue attr_value;
SetAttrValue(std::forward<T>(value), &attr_value);
- node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value));
+ node_def->mutable_attr()->insert(
+ AttrValueMap::value_type(name.ToString(), attr_value));
}
// Version to workaround C++'s "perfect" forwarding not being able to
// forward {...} initialization.
template <class T>
-void AddNodeAttr(const string& name, std::initializer_list<T> value,
+void AddNodeAttr(StringPiece name, std::initializer_list<T> value,
NodeDef* node_def) {
AttrValue attr_value;
SetAttrValue(value, &attr_value);
- node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value));
+ node_def->mutable_attr()->insert(
+ AttrValueMap::value_type(name.ToString(), attr_value));
}
class AttrSlice {
@@ -62,11 +64,11 @@ class AttrSlice {
// Returns the attr with attr_name if found. Otherwise, returns
// nullptr.
- const AttrValue* Find(const string& attr_name) const;
+ const AttrValue* Find(StringPiece attr_name) const;
// Returns the attr_value for attr_name if found. Otherwise, returns a
// NotFound status.
- Status Find(const string& attr_name, const AttrValue** attr_value) const;
+ Status Find(StringPiece attr_name, const AttrValue** attr_value) const;
private:
const NodeDef* ndef_;
@@ -76,58 +78,58 @@ class AttrSlice {
// Look up the attr with name attr_name and set *value to its value. If no
// attr with attr_name is found in node_def, or the attr does not have
// a matching type, a non-ok status will be returned.
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
string* value); // type: "string"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
int64* value); // type: "int"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
int32* value); // type: "int"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
float* value); // type: "float"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
bool* value); // type: "bool"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
DataType* value); // type: "type"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
TensorShapeProto* value); // type: "shape"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
TensorShape* value); // type: "shape"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
PartialTensorShape* value); // type: "shape"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
Tensor* value); // type: "tensor"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<string>* value); // type "list(string)"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<int64>* value); // type "list(int)"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<int32>* value); // type "list(int)"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<float>* value); // type "list(float)"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<bool>* value); // type "list(bool)"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<DataType>* value); // type "list(type)"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
DataTypeVector* value); // type "list(type)"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<TensorShapeProto>* value); // type "list(shape)"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<TensorShape>* value); // type "list(shape)"
Status GetNodeAttr(
- const AttrSlice& attrs, const string& attr_name,
+ const AttrSlice& attrs, StringPiece attr_name,
std::vector<PartialTensorShape>* value); // type "list(shape)"
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<Tensor>* value); // type: "list(tensor)"
// This version avoids copying the TensorProto.
// REQUIRES: Must not use *value beyond the lifetime of node_def.
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
const TensorProto** value); // type: "tensor"
// This version avoids copying the NameAttrList.
// REQUIRES: Must not use *value beyond the lifetime of node_def.
-Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
const NameAttrList** value); // type: "func"
// Computes the input and output types for a specific node.
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index c86c259651..c984c8dc93 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -93,9 +93,9 @@ OpKernel::OpKernel(OpKernelConstruction* context)
OpKernel::~OpKernel() {}
-Status OpKernel::InputRange(const string& input_name, int* start,
+Status OpKernel::InputRange(StringPiece input_name, int* start,
int* stop) const {
- const auto result = input_name_map_.find(input_name);
+ const auto result = input_name_map_.find(input_name.ToString());
if (result == input_name_map_.end()) {
return errors::InvalidArgument("Unknown input name: ", input_name);
} else {
@@ -105,9 +105,9 @@ Status OpKernel::InputRange(const string& input_name, int* start,
}
}
-Status OpKernel::OutputRange(const string& output_name, int* start,
+Status OpKernel::OutputRange(StringPiece output_name, int* start,
int* stop) const {
- const auto result = output_name_map_.find(output_name);
+ const auto result = output_name_map_.find(output_name.ToString());
if (result == output_name_map_.end()) {
return errors::InvalidArgument("Unknown output name: ", output_name);
} else {
@@ -236,7 +236,7 @@ void OpKernelContext::really_record_tensor_reference(const Tensor& tensor) {
referenced_tensors_.Add(tensor);
}
-Status OpKernelContext::input(const string& name, const Tensor** tensor) {
+Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
@@ -254,7 +254,7 @@ Status OpKernelContext::input(const string& name, const Tensor** tensor) {
return Status::OK();
}
-Status OpKernelContext::input_ref_mutex(const string& name, mutex** out_mutex) {
+Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
@@ -329,7 +329,7 @@ void OpKernelContext::delete_ref_input(int index, bool lock_held) {
}
}
-Status OpKernelContext::mutable_input(const string& name, Tensor* tensor,
+Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
bool lock_held) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
@@ -353,7 +353,7 @@ Status OpKernelContext::mutable_input(const string& name, Tensor* tensor,
return Status::OK();
}
-Status OpKernelContext::replace_ref_input(const string& name,
+Status OpKernelContext::replace_ref_input(StringPiece name,
const Tensor& tensor,
bool lock_held) {
int start, stop;
@@ -371,14 +371,14 @@ Status OpKernelContext::replace_ref_input(const string& name,
return Status::OK();
}
-Status OpKernelContext::input_list(const string& name, OpInputList* list) {
+Status OpKernelContext::input_list(StringPiece name, OpInputList* list) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
*list = OpInputList(this, start, stop);
return Status::OK();
}
-Status OpKernelContext::mutable_input_list(const string& name,
+Status OpKernelContext::mutable_input_list(StringPiece name,
OpMutableInputList* list) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
@@ -386,7 +386,7 @@ Status OpKernelContext::mutable_input_list(const string& name,
return Status::OK();
}
-Status OpKernelContext::output_list(const string& name, OpOutputList* list) {
+Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
*list = OpOutputList(this, start, stop);
@@ -401,7 +401,7 @@ Status OpKernelContext::allocate_output(int index, const TensorShape& shape,
return allocate_output(index, shape, output, attr);
}
-Status OpKernelContext::allocate_output(const string& name,
+Status OpKernelContext::allocate_output(StringPiece name,
const TensorShape& shape,
Tensor** tensor) {
int start, stop;
@@ -415,7 +415,7 @@ Status OpKernelContext::allocate_output(const string& name,
return allocate_output(start, shape, tensor);
}
-Status OpKernelContext::allocate_output(const string& name,
+Status OpKernelContext::allocate_output(StringPiece name,
const TensorShape& shape,
Tensor** tensor,
AllocatorAttributes attr) {
@@ -494,7 +494,7 @@ Status OpKernelContext::allocate_persistent(DataType type,
return s;
}
-Status OpKernelContext::set_output(const string& name, const Tensor& tensor) {
+Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
@@ -525,7 +525,7 @@ void OpKernelContext::set_output_ref(int index, mutex* mu,
outputs_[index] = TensorValue(mu, tensor_for_ref);
}
-Status OpKernelContext::set_output_ref(const string& name, mutex* mu,
+Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu,
Tensor* tensor_for_ref) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
@@ -539,7 +539,7 @@ Status OpKernelContext::set_output_ref(const string& name, mutex* mu,
return Status::OK();
}
-Status OpKernelContext::mutable_output(const string& name, Tensor** tensor) {
+Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
@@ -552,7 +552,7 @@ Status OpKernelContext::mutable_output(const string& name, Tensor** tensor) {
return Status::OK();
}
-Status OpKernelContext::release_output(const string& name, TensorValue* value) {
+Status OpKernelContext::release_output(StringPiece name, TensorValue* value) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
@@ -615,8 +615,8 @@ static KernelRegistry* GlobalKernelRegistryTyped() {
return reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
}
-static string Key(const string& op_type, DeviceType device_type,
- const string& label) {
+static string Key(StringPiece op_type, DeviceType device_type,
+ StringPiece label) {
return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":",
label);
}
@@ -832,7 +832,7 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
namespace {
-bool FindArgInOp(const string& arg_name,
+bool FindArgInOp(StringPiece arg_name,
const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
for (const auto& arg : args) {
if (arg_name == arg.name()) {
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 46bae0a0b6..e2cabb4b68 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -122,8 +122,8 @@ class OpKernel {
return output_memory_types_;
}
- Status InputRange(const string& input_name, int* start, int* stop) const;
- Status OutputRange(const string& output_name, int* start, int* stop) const;
+ Status InputRange(StringPiece input_name, int* start, int* stop) const;
+ Status OutputRange(StringPiece output_name, int* start, int* stop) const;
// We allow legacy scalars within Google up until GraphDef version 6.
// TODO(irving): Remove when we can drop support for GraphDef version 5.
@@ -295,7 +295,7 @@ class OpKernelConstruction {
// attr with attr_name is found in def(), or the attr does not have
// a matching type, a non-ok status will be returned.
template <class T>
- Status GetAttr(const string& attr_name, T* value) const;
+ Status GetAttr(StringPiece attr_name, T* value) const;
// May be used, e.g., to get GPU handles, etc.
// TODO(tucker): Add example usage.
@@ -558,14 +558,14 @@ class OpKernelContext {
// use mutable_input below.
// REQUIRES: !IsRefType(input_dtype(index))
// REQUIRES: the named input must not be a list.
- Status input(const string& name, const Tensor** tensor);
+ Status input(StringPiece name, const Tensor** tensor);
// Returns the named list-valued immutable input in "list", as
// defined in the OpDef. If the named output is not list-valued,
// returns a one-element list. May only be used for non-Ref
// inputs. For Ref inputs use mutable_input below.
// REQUIRES: !IsRefType(input_dtype(index))
- Status input_list(const string& name, OpInputList* list);
+ Status input_list(StringPiece name, OpInputList* list);
// For mutable inputs, use the following together to make sure there
// is no concurrent access to mutable_input(), e.g.:
@@ -577,7 +577,7 @@ class OpKernelContext {
// REQUIRES: IsRefType(input_dtype(index))
// TODO(mrry): Convert this to return Status.
mutex* input_ref_mutex(int index);
- Status input_ref_mutex(const string& name, mutex** out_mutex);
+ Status input_ref_mutex(StringPiece name, mutex** out_mutex);
// Returns a mutable input tensor. Must be used to access Ref
// inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may
@@ -596,7 +596,7 @@ class OpKernelContext {
// the input mutex will be acquired before returning the Tensor.
// REQUIRES: the named input must not be a list.
// REQUIRES: the named input must be a ref tensor.
- Status mutable_input(const string& name, Tensor* tensor, bool lock_held);
+ Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held);
// Returns the named list-valued mutable input in "list", as defined
// in the OpDef. If the named input is not list-valued, returns a
@@ -604,7 +604,7 @@ class OpKernelContext {
// stored in the Tensor buffer may be modified, and modifications
// will be visible to other Ops reading the same ref tensor.
// REQUIRES: the named input must be a ref tensor.
- Status mutable_input_list(const string& name, OpMutableInputList* list);
+ Status mutable_input_list(StringPiece name, OpMutableInputList* list);
// Replace the corresponding Ref Input to use the storage buffer
// used by tensor. If !lock_held the input mutex will be acquired
@@ -616,7 +616,7 @@ class OpKernelContext {
// buffer used by tensor. If !lock_held the input mutex will be
// acquired before returning the Tensor.
// REQUIRES: IsRefType(input_dtype(index)).
- Status replace_ref_input(const string& name, const Tensor& tensor,
+ Status replace_ref_input(StringPiece name, const Tensor& tensor,
bool lock_held);
// Set the output Ref Tensor at output_index to be an alias of the
@@ -647,7 +647,7 @@ class OpKernelContext {
// Returns the named list-valued output in "list", as defined in the OpDef.
// If the named output is not list-valued, returns a one-element list.
- Status output_list(const string& name, OpOutputList* list);
+ Status output_list(StringPiece name, OpOutputList* list);
// If output_required(index) returns true, the OpKernel's Compute() method
// should call allocate_output(index, ...), set_output(index, ...),
@@ -712,7 +712,7 @@ class OpKernelContext {
// REQUIRES: !IsRefType(expected_output_dtype(index))
Status allocate_output(int index, const TensorShape& shape,
Tensor** tensor) TF_MUST_USE_RESULT;
- Status allocate_output(const string& name, const TensorShape& shape,
+ Status allocate_output(StringPiece name, const TensorShape& shape,
Tensor** tensor) TF_MUST_USE_RESULT;
// The following methods use the supplied attributes instead of
// those in output_attr_array. The caller is responsible for
@@ -721,7 +721,7 @@ class OpKernelContext {
// device. See comment above.
Status allocate_output(int index, const TensorShape& shape, Tensor** tensor,
AllocatorAttributes attr) TF_MUST_USE_RESULT;
- Status allocate_output(const string& name, const TensorShape& shape,
+ Status allocate_output(StringPiece name, const TensorShape& shape,
Tensor** tensor,
AllocatorAttributes attr) TF_MUST_USE_RESULT;
@@ -766,19 +766,19 @@ class OpKernelContext {
// output_memory_types[index]. See comment above.
// TODO(mrry): Convert this to return Status.
void set_output(int index, const Tensor& tensor);
- Status set_output(const string& name, const Tensor& tensor);
+ Status set_output(StringPiece name, const Tensor& tensor);
// To output a reference. Caller retains ownership of mu and tensor_for_ref,
// and they must outlive all uses within the step. See comment above.
// REQUIRES: IsRefType(expected_output_dtype(index))
// TODO(mrry): Convert this to return Status.
void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
- Status set_output_ref(const string& name, mutex* mu, Tensor* tensor_for_ref);
+ Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref);
// Returns nullptr if allocate_output() or set_output() have not been called.
// TODO(mrry): Convert this to return Status.
Tensor* mutable_output(int index);
- Status mutable_output(const string& name, Tensor** tensor);
+ Status mutable_output(StringPiece name, Tensor** tensor);
// Transfers ownership of an output tensor to the caller.
// NOTE: For non-reference outputs, the caller takes responsibility
@@ -786,7 +786,7 @@ class OpKernelContext {
// responsibility for deletion.
// TODO(mrry): Convert this to return Status.
TensorValue release_output(int index);
- Status release_output(const string& name, TensorValue* value);
+ Status release_output(StringPiece name, TensorValue* value);
// Records device specific state about how the input tensors were
// computed.
@@ -1074,7 +1074,7 @@ class OpKernelRegistrar {
// Template and inline method implementations, please ignore
template <class T>
-Status OpKernelConstruction::GetAttr(const string& attr_name, T* value) const {
+Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const {
return GetNodeAttr(def(), attr_name, value);
}
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 5420a0d976..77273d9ba3 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -369,6 +369,20 @@ bool Tensor::IsInitialized() const {
return buf_ != nullptr && buf_->data() != nullptr;
}
+void Tensor::CheckType(DataType expected_dtype) const {
+ CHECK_EQ(dtype(), expected_dtype);
+}
+
+void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const {
+ CHECK_EQ(dtype(), expected_dtype);
+ CHECK(IsAligned());
+}
+
+void Tensor::CheckIsAlignedAndSingleElement() const {
+ CHECK(IsAligned());
+ CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
+}
+
Tensor::~Tensor() { UnrefIfNonNull(buf_); }
void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) {
@@ -551,7 +565,8 @@ bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
set_dtype(proto.dtype());
UnrefIfNonNull(buf_);
buf_ = p;
- // TODO(misard) add tracking of which kernels and steps are calling FromProto.
+ // TODO(misard) add tracking of which kernels and steps are calling
+ // FromProto.
if (IsInitialized() && LogMemory::IsEnabled()) {
LogMemory::RecordTensorAllocation("Unknown (from Proto)",
LogMemory::UNKNOWN_STEP_ID, *this);
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index a6143e95c0..708c98f409 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -370,7 +370,15 @@ class Tensor {
void UnsafeCopyFromInternal(const Tensor&, const TensorShape&);
private:
+ void CheckType(DataType expected_dtype) const;
+ void CheckTypeAndIsAligned(DataType expected_dtype) const;
+ void CheckIsAlignedAndSingleElement() const;
void set_dtype(DataType t) { shape_.set_data_type(t); }
+ template <size_t NDIMS>
+ void FillDimsAndValidateCompatibleShape(
+ gtl::ArraySlice<int64> new_sizes,
+ Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const;
+
TensorShape shape_;
TensorBuffer* buf_;
@@ -440,48 +448,46 @@ T* Tensor::base() const {
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor Tensor::tensor() {
- CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
- CHECK(IsAligned());
+ CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
return typename TTypes<T, NDIMS>::Tensor(base<T>(),
shape().AsEigenDSizes<NDIMS>());
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor Tensor::tensor() const {
- CHECK(IsAligned());
- CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
+ CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
return typename TTypes<T, NDIMS>::ConstTensor(base<const T>(),
shape().AsEigenDSizes<NDIMS>());
}
-template <typename T, size_t NDIMS>
-typename TTypes<T, NDIMS>::Tensor Tensor::shaped(
- gtl::ArraySlice<int64> new_sizes) {
- CHECK(IsAligned());
- CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
+template <size_t NDIMS>
+void Tensor::FillDimsAndValidateCompatibleShape(
+ gtl::ArraySlice<int64> new_sizes,
+ Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const {
CHECK_EQ(NDIMS, new_sizes.size());
int64 new_num_elements = 1;
- Eigen::array<Eigen::DenseIndex, NDIMS> dims;
for (size_t d = 0; d < NDIMS; d++) {
new_num_elements *= new_sizes[d];
- dims[d] = new_sizes[d];
+ (*dims)[d] = new_sizes[d];
}
CHECK_EQ(new_num_elements, NumElements());
+}
+
+template <typename T, size_t NDIMS>
+typename TTypes<T, NDIMS>::Tensor Tensor::shaped(
+ gtl::ArraySlice<int64> new_sizes) {
+ CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
+ Eigen::array<Eigen::DenseIndex, NDIMS> dims;
+ FillDimsAndValidateCompatibleShape<NDIMS>(new_sizes, &dims);
return typename TTypes<T, NDIMS>::Tensor(base<T>(), dims);
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::UnalignedTensor Tensor::unaligned_shaped(
gtl::ArraySlice<int64> new_sizes) {
- CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
- CHECK_EQ(NDIMS, new_sizes.size());
- int64 new_num_elements = 1;
+ CheckType(DataTypeToEnum<T>::v());
Eigen::array<Eigen::DenseIndex, NDIMS> dims;
- for (size_t d = 0; d < NDIMS; d++) {
- new_num_elements *= new_sizes[d];
- dims[d] = new_sizes[d];
- }
- CHECK_EQ(new_num_elements, NumElements());
+ FillDimsAndValidateCompatibleShape<NDIMS>(new_sizes, &dims);
return typename TTypes<T, NDIMS>::UnalignedTensor(base<T>(), dims);
}
@@ -501,8 +507,7 @@ void Tensor::FillDimsAndValidateCompatibleShape(
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor Tensor::shaped(
gtl::ArraySlice<int64> new_sizes) const {
- CHECK(IsAligned());
- CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
+ CheckTypeAndIsAligned(DataTypeToEnum<T>::v());
Eigen::array<Eigen::DenseIndex, NDIMS> dims;
FillDimsAndValidateCompatibleShape(&dims, new_sizes);
return typename TTypes<T, NDIMS>::ConstTensor(base<T>(), dims);
@@ -511,7 +516,7 @@ typename TTypes<T, NDIMS>::ConstTensor Tensor::shaped(
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::UnalignedConstTensor Tensor::unaligned_shaped(
gtl::ArraySlice<int64> new_sizes) const {
- CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
+ CheckType(DataTypeToEnum<T>::v());
Eigen::array<Eigen::DenseIndex, NDIMS> dims;
FillDimsAndValidateCompatibleShape(&dims, new_sizes);
return typename TTypes<T, NDIMS>::UnalignedConstTensor(base<T>(), dims);
@@ -519,15 +524,13 @@ typename TTypes<T, NDIMS>::UnalignedConstTensor Tensor::unaligned_shaped(
template <typename T>
typename TTypes<T>::Scalar Tensor::scalar() {
- CHECK(IsAligned());
- CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
+ CheckIsAlignedAndSingleElement();
return typename TTypes<T>::Scalar(base<T>());
}
template <typename T>
typename TTypes<T>::ConstScalar Tensor::scalar() const {
- CHECK(IsAligned());
- CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
+ CheckIsAlignedAndSingleElement();
return typename TTypes<T>::ConstScalar(base<T>());
}
diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc
index 7cc98513c8..ee59a79d38 100644
--- a/tensorflow/core/framework/tensor_shape.cc
+++ b/tensorflow/core/framework/tensor_shape.cc
@@ -32,6 +32,16 @@ static void AppendTo(const TensorShape& s, gtl::InlinedVector<int64, 8>* vals) {
}
}
+void TensorShape::CheckDimsEqual(int NDIMS) const {
+ CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS
+ << " for a tensor of " << dims() << " dimensions";
+}
+
+void TensorShape::CheckDimsAtLeast(int NDIMS) const {
+ CHECK_GE(NDIMS, dims()) << "Asking for tensor of at least " << NDIMS
+ << " for a tensor of " << dims() << " dimensions";
+}
+
bool TensorShape::IsValid(const TensorShapeProto& proto) {
int64 num_elements = 1;
for (const auto& d : proto.dim()) {
diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h
index bd80215849..e341ceddfb 100644
--- a/tensorflow/core/framework/tensor_shape.h
+++ b/tensorflow/core/framework/tensor_shape.h
@@ -143,6 +143,9 @@ class TensorShape {
void RecomputeNumElements();
+ void CheckDimsEqual(int NDIMS) const;
+ void CheckDimsAtLeast(int NDIMS) const;
+
// We use 16 bytes to represent a TensorShape. Because we need to
// be able to support full 64-bit dimension sizes and an arbitrary
// number of dimensions for a Tensor, but most tensor dimensions are
@@ -266,16 +269,14 @@ class TensorShapeUtils {
template <int NDIMS>
Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizes() const {
- CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS
- << " for a tensor of " << dims() << " dimensions";
+ CheckDimsEqual(NDIMS);
return AsEigenDSizesWithPadding<NDIMS>();
}
template <int NDIMS>
Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizesWithPadding()
const {
- CHECK_GE(NDIMS, dims()) << "Asking for tensor of " << NDIMS
- << " for a tensor of " << dims() << " dimensions";
+ CheckDimsAtLeast(NDIMS);
Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes;
for (int d = 0; d < dims(); d++) {
dsizes[d] = dim_size(d);
diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc
index 8de02cc2ca..df59950a16 100644
--- a/tensorflow/core/graph/node_builder.cc
+++ b/tensorflow/core/graph/node_builder.cc
@@ -28,17 +28,17 @@ NodeBuilder::NodeOut::NodeOut(Node* n, int i) // NOLINT(runtime/explicit)
index(i),
dt(SafeGetOutput(node, i, &error)) {}
-NodeBuilder::NodeOut::NodeOut(const string& name, int i, DataType t)
- : node(nullptr), error(false), name(name), index(i), dt(t) {}
+NodeBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType t)
+ : node(nullptr), error(false), name(n.ToString()), index(i), dt(t) {}
NodeBuilder::NodeOut::NodeOut()
: node(nullptr), error(true), index(0), dt(DT_FLOAT) {}
-NodeBuilder::NodeBuilder(const string& name, const string& op_name,
+NodeBuilder::NodeBuilder(StringPiece name, StringPiece op_name,
const OpRegistryInterface* op_registry)
: def_builder_(name, op_name, op_registry) {}
-NodeBuilder::NodeBuilder(const string& name, const OpDef* op_def)
+NodeBuilder::NodeBuilder(StringPiece name, const OpDef* op_def)
: def_builder_(name, op_def) {}
NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) {
@@ -90,7 +90,7 @@ NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice<Node*> src_nodes) {
return *this;
}
-NodeBuilder& NodeBuilder::Device(const string& device_spec) {
+NodeBuilder& NodeBuilder::Device(StringPiece device_spec) {
def_builder_.Device(device_spec);
return *this;
}
diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h
index 147311f320..50c41e2221 100644
--- a/tensorflow/core/graph/node_builder.h
+++ b/tensorflow/core/graph/node_builder.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
@@ -54,7 +55,7 @@ class NodeBuilder {
// useful when preparing a graph for ExtendSession or creating a
// back edge to a node that hasn't been added to the graph yet,
// but will be.
- NodeOut(const string& name, int i, DataType t);
+ NodeOut(StringPiece name, int i, DataType t);
// Default constructor for std::vector<NodeOut>.
NodeOut();
@@ -74,9 +75,9 @@ class NodeBuilder {
// the Op plus a registry) for the Node. Other fields are
// specified by calling the methods below.
// REQUIRES: The OpDef must satisfy ValidateOpDef().
- NodeBuilder(const string& name, const string& op_name,
+ NodeBuilder(StringPiece name, StringPiece op_name,
const OpRegistryInterface* op_registry = OpRegistry::Global());
- NodeBuilder(const string& name, const OpDef* op_def);
+ NodeBuilder(StringPiece name, const OpDef* op_def);
// You must call one Input() function per input_arg in the Op,
// *and in the same order as the input_args appear in the OpDef.*
@@ -94,7 +95,7 @@ class NodeBuilder {
// Sets the "requested device spec" in the NodeDef (not the
// "assigned device" in the Node).
- NodeBuilder& Device(const string& device_spec);
+ NodeBuilder& Device(StringPiece device_spec);
// Set the value of an attr. attr_name must match the name of one of
// attrs defined by the Op, and value must have the corresponding type
@@ -102,9 +103,9 @@ class NodeBuilder {
// types for value). Note that attrs will be set automatically if
// they can be determined by the inputs.
template <class T>
- NodeBuilder& Attr(const string& attr_name, T&& value);
+ NodeBuilder& Attr(StringPiece attr_name, T&& value);
template <class T>
- NodeBuilder& Attr(const string& attr_name, std::initializer_list<T> value);
+ NodeBuilder& Attr(StringPiece attr_name, std::initializer_list<T> value);
// Validates the described node and adds it to *graph, adding edges
// for all (non-back) inputs. If created_node is not nullptr,
@@ -138,13 +139,13 @@ class NodeBuilder {
// IMPLEMENTATION -------------------------------------------------------------
template <class T>
-NodeBuilder& NodeBuilder::Attr(const string& attr_name, T&& value) {
+NodeBuilder& NodeBuilder::Attr(StringPiece attr_name, T&& value) {
def_builder_.Attr(attr_name, std::forward<T>(value));
return *this;
}
template <class T>
-NodeBuilder& NodeBuilder::Attr(const string& attr_name,
+NodeBuilder& NodeBuilder::Attr(StringPiece attr_name,
std::initializer_list<T> value) {
def_builder_.Attr(attr_name, value);
return *this;
diff --git a/tensorflow/core/lib/strings/strcat.cc b/tensorflow/core/lib/strings/strcat.cc
index 4050727dc0..0c659e236c 100644
--- a/tensorflow/core/lib/strings/strcat.cc
+++ b/tensorflow/core/lib/strings/strcat.cc
@@ -84,6 +84,8 @@ static char *Append4(char *out, const AlphaNum &x1, const AlphaNum &x2,
return out + x4.size();
}
+string StrCat(const AlphaNum &a) { return string(a.data(), a.size()); }
+
string StrCat(const AlphaNum &a, const AlphaNum &b) {
string result;
gtl::STLStringResizeUninitialized(&result, a.size() + b.size());
diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h
index 33b6028153..d7d5352a88 100644
--- a/tensorflow/core/lib/strings/strcat.h
+++ b/tensorflow/core/lib/strings/strcat.h
@@ -172,9 +172,6 @@ string StrCat(const AlphaNum &a, const AlphaNum &b,
string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c,
const AlphaNum &d) TF_MUST_USE_RESULT;
-// inline definitions must be duplicated due to TF_MUST_USE_RESULT
-inline string StrCat(const AlphaNum &a) { return string(a.data(), a.size()); }
-
namespace internal {
// Do not call directly - this is not part of the public API.
@@ -190,8 +187,8 @@ string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c,
const AV &... args) TF_MUST_USE_RESULT;
template <typename... AV>
-inline string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c,
- const AlphaNum &d, const AlphaNum &e, const AV &... args) {
+string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c,
+ const AlphaNum &d, const AlphaNum &e, const AV &... args) {
return internal::CatPieces({a.Piece(), b.Piece(), c.Piece(), d.Piece(),
e.Piece(),
static_cast<const AlphaNum &>(args).Piece()...});
diff --git a/tensorflow/core/util/mirror_pad_mode.cc b/tensorflow/core/util/mirror_pad_mode.cc
index 5d38c9b1de..a20c1fba74 100644
--- a/tensorflow/core/util/mirror_pad_mode.cc
+++ b/tensorflow/core/util/mirror_pad_mode.cc
@@ -21,7 +21,7 @@ limitations under the License.
namespace tensorflow {
-Status GetNodeAttr(const NodeDef& node_def, const string& attr_name,
+Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name,
MirrorPadMode* value) {
string str_value;
TF_RETURN_IF_ERROR(GetNodeAttr(node_def, attr_name, &str_value));
diff --git a/tensorflow/core/util/mirror_pad_mode.h b/tensorflow/core/util/mirror_pad_mode.h
index ae96fdff73..c865d35e28 100644
--- a/tensorflow/core/util/mirror_pad_mode.h
+++ b/tensorflow/core/util/mirror_pad_mode.h
@@ -44,7 +44,7 @@ string GetMirrorPadModeAttrString();
class NodeDef;
// Specialization to parse an attribute directly into a MirrorPadMode enum.
-Status GetNodeAttr(const NodeDef& node_def, const string& attr_name,
+Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name,
MirrorPadMode* value);
} // end namespace tensorflow
diff --git a/tensorflow/core/util/padding.cc b/tensorflow/core/util/padding.cc
index b3064d6420..ce36d4cc78 100644
--- a/tensorflow/core/util/padding.cc
+++ b/tensorflow/core/util/padding.cc
@@ -20,7 +20,7 @@ limitations under the License.
namespace tensorflow {
-Status GetNodeAttr(const NodeDef& node_def, const string& attr_name,
+Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name,
Padding* value) {
string str_value;
TF_RETURN_IF_ERROR(GetNodeAttr(node_def, attr_name, &str_value));
diff --git a/tensorflow/core/util/padding.h b/tensorflow/core/util/padding.h
index d6d970d27a..989a71b24f 100644
--- a/tensorflow/core/util/padding.h
+++ b/tensorflow/core/util/padding.h
@@ -44,7 +44,7 @@ enum Padding {
string GetPaddingAttrString();
// Specialization to parse an attribute directly into a Padding enum.
-Status GetNodeAttr(const NodeDef& node_def, const string& attr_name,
+Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name,
Padding* value);
} // end namespace tensorflow