diff options
Diffstat (limited to 'tensorflow/core')
-rw-r--r-- | tensorflow/core/framework/node_def_builder.cc | 49 | ||||
-rw-r--r-- | tensorflow/core/framework/node_def_builder.h | 50 | ||||
-rw-r--r-- | tensorflow/core/framework/node_def_util.cc | 16 | ||||
-rw-r--r-- | tensorflow/core/framework/node_def_util.h | 60 | ||||
-rw-r--r-- | tensorflow/core/framework/op_kernel.cc | 40 | ||||
-rw-r--r-- | tensorflow/core/framework/op_kernel.h | 34 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor.cc | 17 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor.h | 55 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor_shape.cc | 10 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor_shape.h | 9 | ||||
-rw-r--r-- | tensorflow/core/graph/node_builder.cc | 10 | ||||
-rw-r--r-- | tensorflow/core/graph/node_builder.h | 17 | ||||
-rw-r--r-- | tensorflow/core/lib/strings/strcat.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/lib/strings/strcat.h | 7 | ||||
-rw-r--r-- | tensorflow/core/util/mirror_pad_mode.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/util/mirror_pad_mode.h | 2 | ||||
-rw-r--r-- | tensorflow/core/util/padding.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/util/padding.h | 2 |
18 files changed, 218 insertions, 166 deletions
diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc index b6f5838528..f3091ad286 100644 --- a/tensorflow/core/framework/node_def_builder.cc +++ b/tensorflow/core/framework/node_def_builder.cc @@ -22,11 +22,24 @@ limitations under the License. namespace tensorflow { -NodeDefBuilder::NodeDefBuilder(const string& name, const string& op_name, +NodeDefBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType dt) + : node(n.ToString()), index(i), data_type(dt) {} + +NodeDefBuilder::NodeOut::NodeOut() { + // uninitialized, call Reset() before use. +} + +void NodeDefBuilder::NodeOut::Reset(StringPiece n, int i, DataType dt) { + node = n.ToString(); + index = i; + data_type = dt; +} + +NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name, const OpRegistryInterface* op_registry) { - node_def_.set_name(name); + node_def_.set_name(name.ToString()); Status status; - op_def_ = op_registry->LookUp(op_name, &status); + op_def_ = op_registry->LookUp(op_name.ToString(), &status); if (op_def_ == nullptr) { errors_.push_back(status.error_message()); inputs_specified_ = 0; @@ -35,9 +48,9 @@ NodeDefBuilder::NodeDefBuilder(const string& name, const string& op_name, } } -NodeDefBuilder::NodeDefBuilder(const string& name, const OpDef* op_def) +NodeDefBuilder::NodeDefBuilder(StringPiece name, const OpDef* op_def) : op_def_(op_def) { - node_def_.set_name(name); + node_def_.set_name(name.ToString()); Initialize(); } @@ -72,7 +85,7 @@ NodeDefBuilder& NodeDefBuilder::Input(FakeInputFunctor fake_input) { } void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg, - const string& src_node, int src_index, + StringPiece src_node, int src_index, DataType dt) { AddInput(src_node, src_index); @@ -129,7 +142,7 @@ void NodeDefBuilder::ListInput(const OpDef::ArgDef* input_arg, } } -void NodeDefBuilder::AddInput(const string& src_node, int src_index) { +void NodeDefBuilder::AddInput(StringPiece src_node, int src_index) { if (src_node.empty()) { errors_.push_back("Empty input node name"); } else if (src_node[0] == '^') { @@ -138,7 +151,7 @@ void NodeDefBuilder::AddInput(const string& src_node, int src_index) { } else if (src_index > 0) { node_def_.add_input(strings::StrCat(src_node, ":", src_index)); } else { - node_def_.add_input(src_node); + node_def_.add_input(src_node.ToString()); } } @@ -160,6 +173,16 @@ void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg, } } +NodeDefBuilder& NodeDefBuilder::ControlInput(StringPiece src_node) { + control_inputs_.push_back(src_node.ToString()); + return *this; +} + +NodeDefBuilder& NodeDefBuilder::Device(StringPiece device_spec) { + node_def_.set_device(device_spec.ToString()); + return *this; +} + Status NodeDefBuilder::Finalize(NodeDef* node_def) const { const std::vector<string>* errors_ptr = &errors_; std::vector<string> errors_storage; @@ -206,4 +229,14 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def) const { } } +void NodeDefBuilder::CheckInconsistency(StringPiece attr_name, + const AttrValue& found, + const AttrValue& attr_value) { + if (!AreAttrValuesEqual(found, attr_value)) { + errors_.push_back(strings::StrCat( + "Inconsistent values for attr '", attr_name, "' ", + SummarizeAttrValue(found), " vs. ", SummarizeAttrValue(attr_value))); + } +} + } // namespace tensorflow diff --git a/tensorflow/core/framework/node_def_builder.h b/tensorflow/core/framework/node_def_builder.h index 72ad37ab93..f8b3a8441a 100644 --- a/tensorflow/core/framework/node_def_builder.h +++ b/tensorflow/core/framework/node_def_builder.h @@ -32,7 +32,8 @@ namespace tensorflow { class NodeDefBuilder; typedef std::function<Status(const OpDef&, int, const NodeDef&, - NodeDefBuilder*)> FakeInputFunctor; + NodeDefBuilder*)> + FakeInputFunctor; // This is a helper for creating a NodeDef. Automatically sets attrs // that can be inferred from the inputs, and uses default values @@ -49,14 +50,9 @@ class NodeDefBuilder { public: // To specify an output to be consumed by one of the Input() methods below. struct NodeOut { - NodeOut(const string& n, int i, DataType dt) - : node(n), index(i), data_type(dt) {} - NodeOut() {} // uninitialized, call Reset() before use. - void Reset(const string& n, int i, DataType dt) { - node = n; - index = i; - data_type = dt; - } + NodeOut(StringPiece n, int i, DataType dt); + NodeOut(); // uninitialized, call Reset() before use. + void Reset(StringPiece n, int i, DataType dt); string node; int index; DataType data_type; @@ -66,16 +62,16 @@ class NodeDefBuilder { // the Op plus a registry) for the NodeDef. Other fields are // specified by calling the methods below. // REQUIRES: The OpDef must satisfy ValidateOpDef(). - NodeDefBuilder(const string& name, const string& op_name, + NodeDefBuilder(StringPiece name, StringPiece op_name, const OpRegistryInterface* op_registry = OpRegistry::Global()); // REQUIRES: in addition, *op_def must outlive *this. - NodeDefBuilder(const string& name, const OpDef* op_def); + NodeDefBuilder(StringPiece name, const OpDef* op_def); // You must call one Input() function per input_arg in the Op, // *and in the same order as the input_args appear in the OpDef.* // For inputs that take a single tensor. - NodeDefBuilder& Input(const string& src_node, int src_index, DataType dt) { + NodeDefBuilder& Input(StringPiece src_node, int src_index, DataType dt) { const OpDef::ArgDef* arg = NextArgDef(); if (arg != nullptr) SingleInput(arg, src_node, src_index, dt); return *this; @@ -96,25 +92,18 @@ class NodeDefBuilder { NodeDefBuilder& Input(FakeInputFunctor fake_input); // Specify that this node must only run after src_node. - NodeDefBuilder& ControlInput(const string& src_node) { - control_inputs_.push_back(src_node); - return *this; - } + NodeDefBuilder& ControlInput(StringPiece src_node); // Constrains what devices this node may be scheduled on. - NodeDefBuilder& Device(const string& device_spec) { - node_def_.set_device(device_spec); - return *this; - } + NodeDefBuilder& Device(StringPiece device_spec); // Sets the attr, if not already set. If already set with a different // value, an error will be returned from Finalize(). template <class T> - NodeDefBuilder& Attr(const string& attr_name, T&& value); + NodeDefBuilder& Attr(StringPiece attr_name, T&& value); // Note: overload needed to allow {...} expressions for value. template <class T> - NodeDefBuilder& Attr(const string& attr_name, - std::initializer_list<T> value) { + NodeDefBuilder& Attr(StringPiece attr_name, std::initializer_list<T> value) { Attr<std::initializer_list<T>>(attr_name, std::move(value)); return *this; } @@ -141,13 +130,13 @@ class NodeDefBuilder { bool NextArgAvailable(); // These do the main work of the Input() methods. - void SingleInput(const OpDef::ArgDef* input_arg, const string& src_node, + void SingleInput(const OpDef::ArgDef* input_arg, StringPiece src_node, int src_index, DataType dt); void ListInput(const OpDef::ArgDef* input_arg, gtl::ArraySlice<NodeOut> src_list); // Add "src_node:src_index" to the list of inputs in the node_def_. - void AddInput(const string& src_node, int src_index); + void AddInput(StringPiece src_node, int src_index); // Generate an error if you can't pass dt when expected is expected. void VerifyInputType(const OpDef::ArgDef* input_arg, DataType expected, @@ -161,6 +150,9 @@ class NodeDefBuilder { return input_arg->is_ref() ? MakeRefType(dt) : dt; } + void CheckInconsistency(StringPiece attr_name, const AttrValue& found, + const AttrValue& attr_value); + const OpDef* op_def_; NodeDef node_def_; int inputs_specified_; @@ -171,18 +163,14 @@ class NodeDefBuilder { // IMPLEMENTATION ------------------------------------------------------------- template <class T> -NodeDefBuilder& NodeDefBuilder::Attr(const string& attr_name, T&& value) { +NodeDefBuilder& NodeDefBuilder::Attr(StringPiece attr_name, T&& value) { const AttrValue* found = AttrSlice(node_def_).Find(attr_name); if (found == nullptr) { AddNodeAttr(attr_name, std::forward<T>(value), &node_def_); } else { AttrValue attr_value; SetAttrValue(std::forward<T>(value), &attr_value); - if (!AreAttrValuesEqual(*found, attr_value)) { - errors_.push_back(strings::StrCat( - "Inconsistent values for attr '", attr_name, "' ", - SummarizeAttrValue(*found), " vs. ", SummarizeAttrValue(attr_value))); - } + CheckInconsistency(attr_name, *found, attr_value); } return *this; } diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index 641411892d..002adfc250 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -72,13 +72,13 @@ string SummarizeNodeDef(const NodeDef& node_def) { return ret; } -const AttrValue* AttrSlice::Find(const string& attr_name) const { - auto iter = attrs_->find(attr_name); +const AttrValue* AttrSlice::Find(StringPiece attr_name) const { + auto iter = attrs_->find(attr_name.ToString()); if (iter == attrs_->end()) return nullptr; return &iter->second; } -Status AttrSlice::Find(const string& attr_name, +Status AttrSlice::Find(StringPiece attr_name, const AttrValue** attr_value) const { *attr_value = Find(attr_name); if (*attr_value != nullptr) { @@ -97,7 +97,7 @@ Status AttrSlice::Find(const string& attr_name, // The ... is to allow the caller to inject some value validation code. Use // just ; if no additional validation code is needed. #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \ - Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, \ + Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \ TYPE* value) { \ const AttrValue* attr_value; \ TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \ @@ -107,7 +107,7 @@ Status AttrSlice::Find(const string& attr_name, *value = CAST; \ return Status::OK(); \ } \ - Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, \ + Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \ std::vector<TYPE>* value) { \ const AttrValue* attr_value; \ TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \ @@ -149,7 +149,7 @@ DEFINE_GET_ATTR(Tensor, tensor, "tensor", emplace_back, t, Tensor t; #undef DEFINE_GET_ATTR -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, DataTypeVector* value) { const AttrValue* attr_value; TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); @@ -160,7 +160,7 @@ Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, return Status::OK(); } -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, const TensorProto** value) { const AttrValue* attr_value; TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); @@ -169,7 +169,7 @@ Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, return Status::OK(); } -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, const NameAttrList** value) { const AttrValue* attr_value; TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h index d70bb6dd37..db7c98149a 100644 --- a/tensorflow/core/framework/node_def_util.h +++ b/tensorflow/core/framework/node_def_util.h @@ -38,20 +38,22 @@ typedef protobuf::Map<string, AttrValue> AttrValueMap; // Adds an attr with name <name> and value <value> to *node_def. // The type of the attr is based on the type of value. template <class T> -void AddNodeAttr(const string& name, T&& value, NodeDef* node_def) { +void AddNodeAttr(StringPiece name, T&& value, NodeDef* node_def) { AttrValue attr_value; SetAttrValue(std::forward<T>(value), &attr_value); - node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value)); + node_def->mutable_attr()->insert( + AttrValueMap::value_type(name.ToString(), attr_value)); } // Version to workaround C++'s "perfect" forwarding not being able to // forward {...} initialization. template <class T> -void AddNodeAttr(const string& name, std::initializer_list<T> value, +void AddNodeAttr(StringPiece name, std::initializer_list<T> value, NodeDef* node_def) { AttrValue attr_value; SetAttrValue(value, &attr_value); - node_def->mutable_attr()->insert(AttrValueMap::value_type(name, attr_value)); + node_def->mutable_attr()->insert( + AttrValueMap::value_type(name.ToString(), attr_value)); } class AttrSlice { @@ -62,11 +64,11 @@ class AttrSlice { // Returns the attr with attr_name if found. Otherwise, returns // nullptr. - const AttrValue* Find(const string& attr_name) const; + const AttrValue* Find(StringPiece attr_name) const; // Returns the attr_value for attr_name if found. Otherwise, returns a // NotFound status. - Status Find(const string& attr_name, const AttrValue** attr_value) const; + Status Find(StringPiece attr_name, const AttrValue** attr_value) const; private: const NodeDef* ndef_; @@ -76,58 +78,58 @@ class AttrSlice { // Look up the attr with name attr_name and set *value to its value. If no // attr with attr_name is found in node_def, or the attr does not have // a matching type, a non-ok status will be returned. -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, string* value); // type: "string" -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, int64* value); // type: "int" -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, int32* value); // type: "int" -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, float* value); // type: "float" -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, bool* value); // type: "bool" -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, DataType* value); // type: "type" -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, TensorShapeProto* value); // type: "shape" -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, TensorShape* value); // type: "shape" -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, PartialTensorShape* value); // type: "shape" -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, Tensor* value); // type: "tensor" -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, std::vector<string>* value); // type "list(string)" -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, std::vector<int64>* value); // type "list(int)" -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, std::vector<int32>* value); // type "list(int)" -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, std::vector<float>* value); // type "list(float)" -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, std::vector<bool>* value); // type "list(bool)" -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, std::vector<DataType>* value); // type "list(type)" -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, DataTypeVector* value); // type "list(type)" -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, std::vector<TensorShapeProto>* value); // type "list(shape)" -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, std::vector<TensorShape>* value); // type "list(shape)" Status GetNodeAttr( - const AttrSlice& attrs, const string& attr_name, + const AttrSlice& attrs, StringPiece attr_name, std::vector<PartialTensorShape>* value); // type "list(shape)" -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, std::vector<Tensor>* value); // type: "list(tensor)" // This version avoids copying the TensorProto. // REQUIRES: Must not use *value beyond the lifetime of node_def. -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, const TensorProto** value); // type: "tensor" // This version avoids copying the NameAttrList. // REQUIRES: Must not use *value beyond the lifetime of node_def. -Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, +Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, const NameAttrList** value); // type: "func" // Computes the input and output types for a specific node. diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index c86c259651..c984c8dc93 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -93,9 +93,9 @@ OpKernel::OpKernel(OpKernelConstruction* context) OpKernel::~OpKernel() {} -Status OpKernel::InputRange(const string& input_name, int* start, +Status OpKernel::InputRange(StringPiece input_name, int* start, int* stop) const { - const auto result = input_name_map_.find(input_name); + const auto result = input_name_map_.find(input_name.ToString()); if (result == input_name_map_.end()) { return errors::InvalidArgument("Unknown input name: ", input_name); } else { @@ -105,9 +105,9 @@ Status OpKernel::InputRange(const string& input_name, int* start, } } -Status OpKernel::OutputRange(const string& output_name, int* start, +Status OpKernel::OutputRange(StringPiece output_name, int* start, int* stop) const { - const auto result = output_name_map_.find(output_name); + const auto result = output_name_map_.find(output_name.ToString()); if (result == output_name_map_.end()) { return errors::InvalidArgument("Unknown output name: ", output_name); } else { @@ -236,7 +236,7 @@ void OpKernelContext::really_record_tensor_reference(const Tensor& tensor) { referenced_tensors_.Add(tensor); } -Status OpKernelContext::input(const string& name, const Tensor** tensor) { +Status OpKernelContext::input(StringPiece name, const Tensor** tensor) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); if (stop != start + 1) { @@ -254,7 +254,7 @@ Status OpKernelContext::input(const string& name, const Tensor** tensor) { return Status::OK(); } -Status OpKernelContext::input_ref_mutex(const string& name, mutex** out_mutex) { +Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); if (stop != start + 1) { @@ -329,7 +329,7 @@ void OpKernelContext::delete_ref_input(int index, bool lock_held) { } } -Status OpKernelContext::mutable_input(const string& name, Tensor* tensor, +Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor, bool lock_held) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); @@ -353,7 +353,7 @@ Status OpKernelContext::mutable_input(const string& name, Tensor* tensor, return Status::OK(); } -Status OpKernelContext::replace_ref_input(const string& name, +Status OpKernelContext::replace_ref_input(StringPiece name, const Tensor& tensor, bool lock_held) { int start, stop; @@ -371,14 +371,14 @@ Status OpKernelContext::replace_ref_input(const string& name, return Status::OK(); } -Status OpKernelContext::input_list(const string& name, OpInputList* list) { +Status OpKernelContext::input_list(StringPiece name, OpInputList* list) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); *list = OpInputList(this, start, stop); return Status::OK(); } -Status OpKernelContext::mutable_input_list(const string& name, +Status OpKernelContext::mutable_input_list(StringPiece name, OpMutableInputList* list) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); @@ -386,7 +386,7 @@ Status OpKernelContext::mutable_input_list(const string& name, return Status::OK(); } -Status OpKernelContext::output_list(const string& name, OpOutputList* list) { +Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); *list = OpOutputList(this, start, stop); @@ -401,7 +401,7 @@ Status OpKernelContext::allocate_output(int index, const TensorShape& shape, return allocate_output(index, shape, output, attr); } -Status OpKernelContext::allocate_output(const string& name, +Status OpKernelContext::allocate_output(StringPiece name, const TensorShape& shape, Tensor** tensor) { int start, stop; @@ -415,7 +415,7 @@ Status OpKernelContext::allocate_output(const string& name, return allocate_output(start, shape, tensor); } -Status OpKernelContext::allocate_output(const string& name, +Status OpKernelContext::allocate_output(StringPiece name, const TensorShape& shape, Tensor** tensor, AllocatorAttributes attr) { @@ -494,7 +494,7 @@ Status OpKernelContext::allocate_persistent(DataType type, return s; } -Status OpKernelContext::set_output(const string& name, const Tensor& tensor) { +Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); if (stop != start + 1) { @@ -525,7 +525,7 @@ void OpKernelContext::set_output_ref(int index, mutex* mu, outputs_[index] = TensorValue(mu, tensor_for_ref); } -Status OpKernelContext::set_output_ref(const string& name, mutex* mu, +Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); @@ -539,7 +539,7 @@ Status OpKernelContext::set_output_ref(const string& name, mutex* mu, return Status::OK(); } -Status OpKernelContext::mutable_output(const string& name, Tensor** tensor) { +Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); if (stop != start + 1) { @@ -552,7 +552,7 @@ Status OpKernelContext::mutable_output(const string& name, Tensor** tensor) { return Status::OK(); } -Status OpKernelContext::release_output(const string& name, TensorValue* value) { +Status OpKernelContext::release_output(StringPiece name, TensorValue* value) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); if (stop != start + 1) { @@ -615,8 +615,8 @@ static KernelRegistry* GlobalKernelRegistryTyped() { return reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry()); } -static string Key(const string& op_type, DeviceType device_type, - const string& label) { +static string Key(StringPiece op_type, DeviceType device_type, + StringPiece label) { return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":", label); } @@ -832,7 +832,7 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device, namespace { -bool FindArgInOp(const string& arg_name, +bool FindArgInOp(StringPiece arg_name, const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) { for (const auto& arg : args) { if (arg_name == arg.name()) { diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 46bae0a0b6..e2cabb4b68 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -122,8 +122,8 @@ class OpKernel { return output_memory_types_; } - Status InputRange(const string& input_name, int* start, int* stop) const; - Status OutputRange(const string& output_name, int* start, int* stop) const; + Status InputRange(StringPiece input_name, int* start, int* stop) const; + Status OutputRange(StringPiece output_name, int* start, int* stop) const; // We allow legacy scalars within Google up until GraphDef version 6. // TODO(irving): Remove when we can drop support for GraphDef version 5. @@ -295,7 +295,7 @@ class OpKernelConstruction { // attr with attr_name is found in def(), or the attr does not have // a matching type, a non-ok status will be returned. template <class T> - Status GetAttr(const string& attr_name, T* value) const; + Status GetAttr(StringPiece attr_name, T* value) const; // May be used, e.g., to get GPU handles, etc. // TODO(tucker): Add example usage. @@ -558,14 +558,14 @@ class OpKernelContext { // use mutable_input below. // REQUIRES: !IsRefType(input_dtype(index)) // REQUIRES: the named input must not be a list. - Status input(const string& name, const Tensor** tensor); + Status input(StringPiece name, const Tensor** tensor); // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. May only be used for non-Ref // inputs. For Ref inputs use mutable_input below. // REQUIRES: !IsRefType(input_dtype(index)) - Status input_list(const string& name, OpInputList* list); + Status input_list(StringPiece name, OpInputList* list); // For mutable inputs, use the following together to make sure there // is no concurrent access to mutable_input(), e.g.: @@ -577,7 +577,7 @@ class OpKernelContext { // REQUIRES: IsRefType(input_dtype(index)) // TODO(mrry): Convert this to return Status. mutex* input_ref_mutex(int index); - Status input_ref_mutex(const string& name, mutex** out_mutex); + Status input_ref_mutex(StringPiece name, mutex** out_mutex); // Returns a mutable input tensor. Must be used to access Ref // inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may @@ -596,7 +596,7 @@ class OpKernelContext { // the input mutex will be acquired before returning the Tensor. // REQUIRES: the named input must not be a list. // REQUIRES: the named input must be a ref tensor. - Status mutable_input(const string& name, Tensor* tensor, bool lock_held); + Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held); // Returns the named list-valued mutable input in "list", as defined // in the OpDef. If the named input is not list-valued, returns a @@ -604,7 +604,7 @@ class OpKernelContext { // stored in the Tensor buffer may be modified, and modifications // will be visible to other Ops reading the same ref tensor. // REQUIRES: the named input must be a ref tensor. - Status mutable_input_list(const string& name, OpMutableInputList* list); + Status mutable_input_list(StringPiece name, OpMutableInputList* list); // Replace the corresponding Ref Input to use the storage buffer // used by tensor. If !lock_held the input mutex will be acquired @@ -616,7 +616,7 @@ class OpKernelContext { // buffer used by tensor. If !lock_held the input mutex will be // acquired before returning the Tensor. // REQUIRES: IsRefType(input_dtype(index)). - Status replace_ref_input(const string& name, const Tensor& tensor, + Status replace_ref_input(StringPiece name, const Tensor& tensor, bool lock_held); // Set the output Ref Tensor at output_index to be an alias of the @@ -647,7 +647,7 @@ class OpKernelContext { // Returns the named list-valued output in "list", as defined in the OpDef. // If the named output is not list-valued, returns a one-element list. - Status output_list(const string& name, OpOutputList* list); + Status output_list(StringPiece name, OpOutputList* list); // If output_required(index) returns true, the OpKernel's Compute() method // should call allocate_output(index, ...), set_output(index, ...), @@ -712,7 +712,7 @@ class OpKernelContext { // REQUIRES: !IsRefType(expected_output_dtype(index)) Status allocate_output(int index, const TensorShape& shape, Tensor** tensor) TF_MUST_USE_RESULT; - Status allocate_output(const string& name, const TensorShape& shape, + Status allocate_output(StringPiece name, const TensorShape& shape, Tensor** tensor) TF_MUST_USE_RESULT; // The following methods use the supplied attributes instead of // those in output_attr_array. The caller is responsible for @@ -721,7 +721,7 @@ class OpKernelContext { // device. See comment above. Status allocate_output(int index, const TensorShape& shape, Tensor** tensor, AllocatorAttributes attr) TF_MUST_USE_RESULT; - Status allocate_output(const string& name, const TensorShape& shape, + Status allocate_output(StringPiece name, const TensorShape& shape, Tensor** tensor, AllocatorAttributes attr) TF_MUST_USE_RESULT; @@ -766,19 +766,19 @@ class OpKernelContext { // output_memory_types[index]. See comment above. // TODO(mrry): Convert this to return Status. void set_output(int index, const Tensor& tensor); - Status set_output(const string& name, const Tensor& tensor); + Status set_output(StringPiece name, const Tensor& tensor); // To output a reference. Caller retains ownership of mu and tensor_for_ref, // and they must outlive all uses within the step. See comment above. // REQUIRES: IsRefType(expected_output_dtype(index)) // TODO(mrry): Convert this to return Status. void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref); - Status set_output_ref(const string& name, mutex* mu, Tensor* tensor_for_ref); + Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref); // Returns nullptr if allocate_output() or set_output() have not been called. // TODO(mrry): Convert this to return Status. Tensor* mutable_output(int index); - Status mutable_output(const string& name, Tensor** tensor); + Status mutable_output(StringPiece name, Tensor** tensor); // Transfers ownership of an output tensor to the caller. // NOTE: For non-reference outputs, the caller takes responsibility @@ -786,7 +786,7 @@ class OpKernelContext { // responsibility for deletion. // TODO(mrry): Convert this to return Status. TensorValue release_output(int index); - Status release_output(const string& name, TensorValue* value); + Status release_output(StringPiece name, TensorValue* value); // Records device specific state about how the input tensors were // computed. @@ -1074,7 +1074,7 @@ class OpKernelRegistrar { // Template and inline method implementations, please ignore template <class T> -Status OpKernelConstruction::GetAttr(const string& attr_name, T* value) const { +Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const { return GetNodeAttr(def(), attr_name, value); } diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index 5420a0d976..77273d9ba3 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -369,6 +369,20 @@ bool Tensor::IsInitialized() const { return buf_ != nullptr && buf_->data() != nullptr; } +void Tensor::CheckType(DataType expected_dtype) const { + CHECK_EQ(dtype(), expected_dtype); +} + +void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const { + CHECK_EQ(dtype(), expected_dtype); + CHECK(IsAligned()); +} + +void Tensor::CheckIsAlignedAndSingleElement() const { + CHECK(IsAligned()); + CHECK_EQ(1, NumElements()) << "Must have a one element tensor"; +} + Tensor::~Tensor() { UnrefIfNonNull(buf_); } void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) { @@ -551,7 +565,8 @@ bool Tensor::FromProto(Allocator* a, const TensorProto& proto) { set_dtype(proto.dtype()); UnrefIfNonNull(buf_); buf_ = p; - // TODO(misard) add tracking of which kernels and steps are calling FromProto. + // TODO(misard) add tracking of which kernels and steps are calling + // FromProto. if (IsInitialized() && LogMemory::IsEnabled()) { LogMemory::RecordTensorAllocation("Unknown (from Proto)", LogMemory::UNKNOWN_STEP_ID, *this); diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index a6143e95c0..708c98f409 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -370,7 +370,15 @@ class Tensor { void UnsafeCopyFromInternal(const Tensor&, const TensorShape&); private: + void CheckType(DataType expected_dtype) const; + void CheckTypeAndIsAligned(DataType expected_dtype) const; + void CheckIsAlignedAndSingleElement() const; void set_dtype(DataType t) { shape_.set_data_type(t); } + template <size_t NDIMS> + void FillDimsAndValidateCompatibleShape( + gtl::ArraySlice<int64> new_sizes, + Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const; + TensorShape shape_; TensorBuffer* buf_; @@ -440,48 +448,46 @@ T* Tensor::base() const { template <typename T, size_t NDIMS> typename TTypes<T, NDIMS>::Tensor Tensor::tensor() { - CHECK_EQ(dtype(), DataTypeToEnum<T>::v()); - CHECK(IsAligned()); + CheckTypeAndIsAligned(DataTypeToEnum<T>::v()); return typename TTypes<T, NDIMS>::Tensor(base<T>(), shape().AsEigenDSizes<NDIMS>()); } template <typename T, size_t NDIMS> typename TTypes<T, NDIMS>::ConstTensor Tensor::tensor() const { - CHECK(IsAligned()); - CHECK_EQ(dtype(), DataTypeToEnum<T>::v()); + CheckTypeAndIsAligned(DataTypeToEnum<T>::v()); return typename TTypes<T, NDIMS>::ConstTensor(base<const T>(), shape().AsEigenDSizes<NDIMS>()); } -template <typename T, size_t NDIMS> -typename TTypes<T, NDIMS>::Tensor Tensor::shaped( - gtl::ArraySlice<int64> new_sizes) { - CHECK(IsAligned()); - CHECK_EQ(dtype(), DataTypeToEnum<T>::v()); +template <size_t NDIMS> +void Tensor::FillDimsAndValidateCompatibleShape( + gtl::ArraySlice<int64> new_sizes, + Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const { CHECK_EQ(NDIMS, new_sizes.size()); int64 new_num_elements = 1; - Eigen::array<Eigen::DenseIndex, NDIMS> dims; for (size_t d = 0; d < NDIMS; d++) { new_num_elements *= new_sizes[d]; - dims[d] = new_sizes[d]; + (*dims)[d] = new_sizes[d]; } CHECK_EQ(new_num_elements, NumElements()); +} + +template <typename T, size_t NDIMS> +typename TTypes<T, NDIMS>::Tensor Tensor::shaped( + gtl::ArraySlice<int64> new_sizes) { + CheckTypeAndIsAligned(DataTypeToEnum<T>::v()); + Eigen::array<Eigen::DenseIndex, NDIMS> dims; + FillDimsAndValidateCompatibleShape<NDIMS>(new_sizes, &dims); return typename TTypes<T, NDIMS>::Tensor(base<T>(), dims); } template <typename T, size_t NDIMS> typename TTypes<T, NDIMS>::UnalignedTensor Tensor::unaligned_shaped( gtl::ArraySlice<int64> new_sizes) { - CHECK_EQ(dtype(), DataTypeToEnum<T>::v()); - CHECK_EQ(NDIMS, new_sizes.size()); - int64 new_num_elements = 1; + CheckType(DataTypeToEnum<T>::v()); Eigen::array<Eigen::DenseIndex, NDIMS> dims; - for (size_t d = 0; d < NDIMS; d++) { - new_num_elements *= new_sizes[d]; - dims[d] = new_sizes[d]; - } - CHECK_EQ(new_num_elements, NumElements()); + FillDimsAndValidateCompatibleShape<NDIMS>(new_sizes, &dims); return typename TTypes<T, NDIMS>::UnalignedTensor(base<T>(), dims); } @@ -501,8 +507,7 @@ void Tensor::FillDimsAndValidateCompatibleShape( template <typename T, size_t NDIMS> typename TTypes<T, NDIMS>::ConstTensor Tensor::shaped( gtl::ArraySlice<int64> new_sizes) const { - CHECK(IsAligned()); - CHECK_EQ(dtype(), DataTypeToEnum<T>::v()); + CheckTypeAndIsAligned(DataTypeToEnum<T>::v()); Eigen::array<Eigen::DenseIndex, NDIMS> dims; FillDimsAndValidateCompatibleShape(&dims, new_sizes); return typename TTypes<T, NDIMS>::ConstTensor(base<T>(), dims); @@ -511,7 +516,7 @@ typename TTypes<T, NDIMS>::ConstTensor Tensor::shaped( template <typename T, size_t NDIMS> typename TTypes<T, NDIMS>::UnalignedConstTensor Tensor::unaligned_shaped( gtl::ArraySlice<int64> new_sizes) const { - CHECK_EQ(dtype(), DataTypeToEnum<T>::v()); + CheckType(DataTypeToEnum<T>::v()); Eigen::array<Eigen::DenseIndex, NDIMS> dims; FillDimsAndValidateCompatibleShape(&dims, new_sizes); return typename TTypes<T, NDIMS>::UnalignedConstTensor(base<T>(), dims); @@ -519,15 +524,13 @@ typename TTypes<T, NDIMS>::UnalignedConstTensor Tensor::unaligned_shaped( template <typename T> typename TTypes<T>::Scalar Tensor::scalar() { - CHECK(IsAligned()); - CHECK_EQ(1, NumElements()) << "Must have a one element tensor"; + CheckIsAlignedAndSingleElement(); return typename TTypes<T>::Scalar(base<T>()); } template <typename T> typename TTypes<T>::ConstScalar Tensor::scalar() const { - CHECK(IsAligned()); - CHECK_EQ(1, NumElements()) << "Must have a one element tensor"; + CheckIsAlignedAndSingleElement(); return typename TTypes<T>::ConstScalar(base<T>()); } diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc index 7cc98513c8..ee59a79d38 100644 --- a/tensorflow/core/framework/tensor_shape.cc +++ b/tensorflow/core/framework/tensor_shape.cc @@ -32,6 +32,16 @@ static void AppendTo(const TensorShape& s, gtl::InlinedVector<int64, 8>* vals) { } } +void TensorShape::CheckDimsEqual(int NDIMS) const { + CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS + << " for a tensor of " << dims() << " dimensions"; +} + +void TensorShape::CheckDimsAtLeast(int NDIMS) const { + CHECK_GE(NDIMS, dims()) << "Asking for tensor of at least " << NDIMS + << " for a tensor of " << dims() << " dimensions"; +} + bool TensorShape::IsValid(const TensorShapeProto& proto) { int64 num_elements = 1; for (const auto& d : proto.dim()) { diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h index bd80215849..e341ceddfb 100644 --- a/tensorflow/core/framework/tensor_shape.h +++ b/tensorflow/core/framework/tensor_shape.h @@ -143,6 +143,9 @@ class TensorShape { void RecomputeNumElements(); + void CheckDimsEqual(int NDIMS) const; + void CheckDimsAtLeast(int NDIMS) const; + // We use 16 bytes to represent a TensorShape. Because we need to // be able to support full 64-bit dimension sizes and an arbitrary // number of dimensions for a Tensor, but most tensor dimensions are @@ -266,16 +269,14 @@ class TensorShapeUtils { template <int NDIMS> Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizes() const { - CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS - << " for a tensor of " << dims() << " dimensions"; + CheckDimsEqual(NDIMS); return AsEigenDSizesWithPadding<NDIMS>(); } template <int NDIMS> Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizesWithPadding() const { - CHECK_GE(NDIMS, dims()) << "Asking for tensor of " << NDIMS - << " for a tensor of " << dims() << " dimensions"; + CheckDimsAtLeast(NDIMS); Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes; for (int d = 0; d < dims(); d++) { dsizes[d] = dim_size(d); diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc index 8de02cc2ca..df59950a16 100644 --- a/tensorflow/core/graph/node_builder.cc +++ b/tensorflow/core/graph/node_builder.cc @@ -28,17 +28,17 @@ NodeBuilder::NodeOut::NodeOut(Node* n, int i) // NOLINT(runtime/explicit) index(i), dt(SafeGetOutput(node, i, &error)) {} -NodeBuilder::NodeOut::NodeOut(const string& name, int i, DataType t) - : node(nullptr), error(false), name(name), index(i), dt(t) {} +NodeBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType t) + : node(nullptr), error(false), name(n.ToString()), index(i), dt(t) {} NodeBuilder::NodeOut::NodeOut() : node(nullptr), error(true), index(0), dt(DT_FLOAT) {} -NodeBuilder::NodeBuilder(const string& name, const string& op_name, +NodeBuilder::NodeBuilder(StringPiece name, StringPiece op_name, const OpRegistryInterface* op_registry) : def_builder_(name, op_name, op_registry) {} -NodeBuilder::NodeBuilder(const string& name, const OpDef* op_def) +NodeBuilder::NodeBuilder(StringPiece name, const OpDef* op_def) : def_builder_(name, op_def) {} NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) { @@ -90,7 +90,7 @@ NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice<Node*> src_nodes) { return *this; } -NodeBuilder& NodeBuilder::Device(const string& device_spec) { +NodeBuilder& NodeBuilder::Device(StringPiece device_spec) { def_builder_.Device(device_spec); return *this; } diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h index 147311f320..50c41e2221 100644 --- a/tensorflow/core/graph/node_builder.h +++ b/tensorflow/core/graph/node_builder.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -54,7 +55,7 @@ class NodeBuilder { // useful when preparing a graph for ExtendSession or creating a // back edge to a node that hasn't been added to the graph yet, // but will be. - NodeOut(const string& name, int i, DataType t); + NodeOut(StringPiece name, int i, DataType t); // Default constructor for std::vector<NodeOut>. NodeOut(); @@ -74,9 +75,9 @@ class NodeBuilder { // the Op plus a registry) for the Node. Other fields are // specified by calling the methods below. // REQUIRES: The OpDef must satisfy ValidateOpDef(). - NodeBuilder(const string& name, const string& op_name, + NodeBuilder(StringPiece name, StringPiece op_name, const OpRegistryInterface* op_registry = OpRegistry::Global()); - NodeBuilder(const string& name, const OpDef* op_def); + NodeBuilder(StringPiece name, const OpDef* op_def); // You must call one Input() function per input_arg in the Op, // *and in the same order as the input_args appear in the OpDef.* @@ -94,7 +95,7 @@ class NodeBuilder { // Sets the "requested device spec" in the NodeDef (not the // "assigned device" in the Node). - NodeBuilder& Device(const string& device_spec); + NodeBuilder& Device(StringPiece device_spec); // Set the value of an attr. attr_name must match the name of one of // attrs defined by the Op, and value must have the corresponding type @@ -102,9 +103,9 @@ class NodeBuilder { // types for value). Note that attrs will be set automatically if // they can be determined by the inputs. template <class T> - NodeBuilder& Attr(const string& attr_name, T&& value); + NodeBuilder& Attr(StringPiece attr_name, T&& value); template <class T> - NodeBuilder& Attr(const string& attr_name, std::initializer_list<T> value); + NodeBuilder& Attr(StringPiece attr_name, std::initializer_list<T> value); // Validates the described node and adds it to *graph, adding edges // for all (non-back) inputs. If created_node is not nullptr, @@ -138,13 +139,13 @@ class NodeBuilder { // IMPLEMENTATION ------------------------------------------------------------- template <class T> -NodeBuilder& NodeBuilder::Attr(const string& attr_name, T&& value) { +NodeBuilder& NodeBuilder::Attr(StringPiece attr_name, T&& value) { def_builder_.Attr(attr_name, std::forward<T>(value)); return *this; } template <class T> -NodeBuilder& NodeBuilder::Attr(const string& attr_name, +NodeBuilder& NodeBuilder::Attr(StringPiece attr_name, std::initializer_list<T> value) { def_builder_.Attr(attr_name, value); return *this; diff --git a/tensorflow/core/lib/strings/strcat.cc b/tensorflow/core/lib/strings/strcat.cc index 4050727dc0..0c659e236c 100644 --- a/tensorflow/core/lib/strings/strcat.cc +++ b/tensorflow/core/lib/strings/strcat.cc @@ -84,6 +84,8 @@ static char *Append4(char *out, const AlphaNum &x1, const AlphaNum &x2, return out + x4.size(); } +string StrCat(const AlphaNum &a) { return string(a.data(), a.size()); } + string StrCat(const AlphaNum &a, const AlphaNum &b) { string result; gtl::STLStringResizeUninitialized(&result, a.size() + b.size()); diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h index 33b6028153..d7d5352a88 100644 --- a/tensorflow/core/lib/strings/strcat.h +++ b/tensorflow/core/lib/strings/strcat.h @@ -172,9 +172,6 @@ string StrCat(const AlphaNum &a, const AlphaNum &b, string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, const AlphaNum &d) TF_MUST_USE_RESULT; -// inline definitions must be duplicated due to TF_MUST_USE_RESULT -inline string StrCat(const AlphaNum &a) { return string(a.data(), a.size()); } - namespace internal { // Do not call directly - this is not part of the public API. @@ -190,8 +187,8 @@ string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, const AV &... args) TF_MUST_USE_RESULT; template <typename... AV> -inline string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, - const AlphaNum &d, const AlphaNum &e, const AV &... args) { +string StrCat(const AlphaNum &a, const AlphaNum &b, const AlphaNum &c, + const AlphaNum &d, const AlphaNum &e, const AV &... args) { return internal::CatPieces({a.Piece(), b.Piece(), c.Piece(), d.Piece(), e.Piece(), static_cast<const AlphaNum &>(args).Piece()...}); diff --git a/tensorflow/core/util/mirror_pad_mode.cc b/tensorflow/core/util/mirror_pad_mode.cc index 5d38c9b1de..a20c1fba74 100644 --- a/tensorflow/core/util/mirror_pad_mode.cc +++ b/tensorflow/core/util/mirror_pad_mode.cc @@ -21,7 +21,7 @@ limitations under the License. namespace tensorflow { -Status GetNodeAttr(const NodeDef& node_def, const string& attr_name, +Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name, MirrorPadMode* value) { string str_value; TF_RETURN_IF_ERROR(GetNodeAttr(node_def, attr_name, &str_value)); diff --git a/tensorflow/core/util/mirror_pad_mode.h b/tensorflow/core/util/mirror_pad_mode.h index ae96fdff73..c865d35e28 100644 --- a/tensorflow/core/util/mirror_pad_mode.h +++ b/tensorflow/core/util/mirror_pad_mode.h @@ -44,7 +44,7 @@ string GetMirrorPadModeAttrString(); class NodeDef; // Specialization to parse an attribute directly into a MirrorPadMode enum. -Status GetNodeAttr(const NodeDef& node_def, const string& attr_name, +Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name, MirrorPadMode* value); } // end namespace tensorflow diff --git a/tensorflow/core/util/padding.cc b/tensorflow/core/util/padding.cc index b3064d6420..ce36d4cc78 100644 --- a/tensorflow/core/util/padding.cc +++ b/tensorflow/core/util/padding.cc @@ -20,7 +20,7 @@ limitations under the License. namespace tensorflow { -Status GetNodeAttr(const NodeDef& node_def, const string& attr_name, +Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name, Padding* value) { string str_value; TF_RETURN_IF_ERROR(GetNodeAttr(node_def, attr_name, &str_value)); diff --git a/tensorflow/core/util/padding.h b/tensorflow/core/util/padding.h index d6d970d27a..989a71b24f 100644 --- a/tensorflow/core/util/padding.h +++ b/tensorflow/core/util/padding.h @@ -44,7 +44,7 @@ enum Padding { string GetPaddingAttrString(); // Specialization to parse an attribute directly into a Padding enum. -Status GetNodeAttr(const NodeDef& node_def, const string& attr_name, +Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name, Padding* value); } // end namespace tensorflow |