diff options
author | Geoffrey Irving <geoffreyi@google.com> | 2017-06-28 10:17:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-28 10:26:21 -0700 |
commit | 1ac185fcf6210c6fecaf2fc7f76f6289138dc474 (patch) | |
tree | 5c59d185bf8df88b63b8988c473926a409afe39a | |
parent | 6233d8a75c17bedc2b082f6ef32fb150ac2b28b9 (diff) |
Don't include node_def.proto.h in node_def_util.h
The goal is to make kernels mostly independent of proto headers, which will let us lock down our .so imports.
RELNOTES: n/a
PiperOrigin-RevId: 160422647
-rw-r--r-- | tensorflow/core/framework/node_def_util.cc | 53 | ||||
-rw-r--r-- | tensorflow/core/framework/node_def_util.h | 67 |
2 files changed, 102 insertions, 18 deletions
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index 9b737e1f72..79feb20d53 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/graph.pb_text.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb_text.h" #include "tensorflow/core/framework/op_def_util.h" @@ -611,4 +612,56 @@ Status AttachDef(const Status& status, const Node& node) { return AttachDef(status, node.def()); } +void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) { + node_def->mutable_attr()->insert( + AttrValueMap::value_type(name.ToString(), value)); +} + +#define ADD_NODE_ATTR(T) \ + void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \ + AttrValue attr_value; \ + SetAttrValue(value, &attr_value); \ + AddNodeAttr(name, attr_value, node_def); \ + } +ADD_NODE_ATTR(StringPiece) +ADD_NODE_ATTR(const char*) +ADD_NODE_ATTR(int32) +ADD_NODE_ATTR(int64) +ADD_NODE_ATTR(float) +ADD_NODE_ATTR(double) +ADD_NODE_ATTR(bool) +ADD_NODE_ATTR(DataType) +ADD_NODE_ATTR(const PartialTensorShape&) +ADD_NODE_ATTR(const Tensor&) +ADD_NODE_ATTR(const TensorProto&) +ADD_NODE_ATTR(const NameAttrList&) +ADD_NODE_ATTR(gtl::ArraySlice<StringPiece>) +ADD_NODE_ATTR(gtl::ArraySlice<const char*>) +ADD_NODE_ATTR(gtl::ArraySlice<string>) +ADD_NODE_ATTR(gtl::ArraySlice<int32>) +ADD_NODE_ATTR(gtl::ArraySlice<int64>) +ADD_NODE_ATTR(gtl::ArraySlice<float>) +ADD_NODE_ATTR(gtl::ArraySlice<bool>) +ADD_NODE_ATTR(const std::vector<bool>&) +ADD_NODE_ATTR(gtl::ArraySlice<DataType>) +ADD_NODE_ATTR(gtl::ArraySlice<TensorShape>) +ADD_NODE_ATTR(gtl::ArraySlice<PartialTensorShape>) +ADD_NODE_ATTR(gtl::ArraySlice<TensorShapeProto>) +ADD_NODE_ATTR(gtl::ArraySlice<Tensor>) +ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>) +#undef ADD_NODE_ATTR + +void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) { + map->insert(AttrValueMap::value_type(name.ToString(), value)); +} + +#define ADD_ATTR(T) \ + void AddAttr(StringPiece name, T value, AttrValueMap* map) { \ + AttrValue attr_value; \ + SetAttrValue(value, &attr_value); \ + AddAttr(name, attr_value, map); \ + } +ADD_ATTR(bool) +#undef ADD_ATTR + } // namespace tensorflow diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h index 1438abdec6..5d4864db66 100644 --- a/tensorflow/core/framework/node_def_util.h +++ b/tensorflow/core/framework/node_def_util.h @@ -21,7 +21,6 @@ limitations under the License. #include <vector> #include "tensorflow/core/framework/attr_value_util.h" -#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -31,6 +30,9 @@ namespace tensorflow { class Node; +// We forward declare NodeDef so that kernels don't need to depend on protos +class NodeDef; + // Name of the attribute used to encode node colocation constraints. // // Nodes can be co-located on the same device. Desire for explicit co-location @@ -50,32 +52,61 @@ typedef protobuf::Map<string, AttrValue> AttrValueMap; // Adds an attr with name <name> and value <value> to *node_def. // The type of the attr is based on the type of value. -template <class T> -void AddNodeAttr(StringPiece name, T&& value, NodeDef* node_def) { - AttrValue attr_value; - SetAttrValue(std::forward<T>(value), &attr_value); - node_def->mutable_attr()->insert( - AttrValueMap::value_type(name.ToString(), attr_value)); -} +void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, int32 value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, int64 value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, float value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, double value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, bool value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, DataType value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, const PartialTensorShape& value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, const Tensor& value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, const TensorProto& value, NodeDef* node_def); +void AddNodeAttr(StringPiece name, const NameAttrList& value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice<StringPiece> value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice<const char*> value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice<string> value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice<int32> value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice<int64> value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice<float> value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice<bool> value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, const std::vector<bool>& value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice<DataType> value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShape> value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice<PartialTensorShape> value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShapeProto> value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice<Tensor> value, + NodeDef* node_def); +void AddNodeAttr(StringPiece name, gtl::ArraySlice<NameAttrList> value, + NodeDef* node_def); // Version to workaround C++'s "perfect" forwarding not being able to // forward {...} initialization. template <class T> void AddNodeAttr(StringPiece name, std::initializer_list<T> value, NodeDef* node_def) { - AttrValue attr_value; - SetAttrValue(value, &attr_value); - node_def->mutable_attr()->insert( - AttrValueMap::value_type(name.ToString(), attr_value)); + AddNodeAttr(name, gtl::ArraySlice<T>(value), node_def); } // Adds an attr to an attr value map. -template <class T> -void AddAttr(StringPiece name, T&& value, AttrValueMap* map) { - AttrValue attr_value; - SetAttrValue(value, &attr_value); - map->insert(AttrValueMap::value_type(name.ToString(), attr_value)); -} +void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map); +void AddAttr(StringPiece name, bool value, AttrValueMap* map); class AttrSlice { public: |