aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <geoffreyi@google.com>2017-06-28 10:17:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-28 10:26:21 -0700
commit1ac185fcf6210c6fecaf2fc7f76f6289138dc474 (patch)
tree5c59d185bf8df88b63b8988c473926a409afe39a
parent6233d8a75c17bedc2b082f6ef32fb150ac2b28b9 (diff)
Don't include node_def.proto.h in node_def_util.h
The goal is to make kernels mostly independent of proto headers, which will let us lock down our .so imports. RELNOTES: n/a PiperOrigin-RevId: 160422647
-rw-r--r--tensorflow/core/framework/node_def_util.cc53
-rw-r--r--tensorflow/core/framework/node_def_util.h67
2 files changed, 102 insertions, 18 deletions
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index 9b737e1f72..79feb20d53 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/graph.pb_text.h"
+#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb_text.h"
#include "tensorflow/core/framework/op_def_util.h"
@@ -611,4 +612,56 @@ Status AttachDef(const Status& status, const Node& node) {
return AttachDef(status, node.def());
}
+void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
+ node_def->mutable_attr()->insert(
+ AttrValueMap::value_type(name.ToString(), value));
+}
+
+#define ADD_NODE_ATTR(T) \
+ void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
+ AttrValue attr_value; \
+ SetAttrValue(value, &attr_value); \
+ AddNodeAttr(name, attr_value, node_def); \
+ }
+ADD_NODE_ATTR(StringPiece)
+ADD_NODE_ATTR(const char*)
+ADD_NODE_ATTR(int32)
+ADD_NODE_ATTR(int64)
+ADD_NODE_ATTR(float)
+ADD_NODE_ATTR(double)
+ADD_NODE_ATTR(bool)
+ADD_NODE_ATTR(DataType)
+ADD_NODE_ATTR(const PartialTensorShape&)
+ADD_NODE_ATTR(const Tensor&)
+ADD_NODE_ATTR(const TensorProto&)
+ADD_NODE_ATTR(const NameAttrList&)
+ADD_NODE_ATTR(gtl::ArraySlice<StringPiece>)
+ADD_NODE_ATTR(gtl::ArraySlice<const char*>)
+ADD_NODE_ATTR(gtl::ArraySlice<string>)
+ADD_NODE_ATTR(gtl::ArraySlice<int32>)
+ADD_NODE_ATTR(gtl::ArraySlice<int64>)
+ADD_NODE_ATTR(gtl::ArraySlice<float>)
+ADD_NODE_ATTR(gtl::ArraySlice<bool>)
+ADD_NODE_ATTR(const std::vector<bool>&)
+ADD_NODE_ATTR(gtl::ArraySlice<DataType>)
+ADD_NODE_ATTR(gtl::ArraySlice<TensorShape>)
+ADD_NODE_ATTR(gtl::ArraySlice<PartialTensorShape>)
+ADD_NODE_ATTR(gtl::ArraySlice<TensorShapeProto>)
+ADD_NODE_ATTR(gtl::ArraySlice<Tensor>)
+ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>)
+#undef ADD_NODE_ATTR
+
+void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) {
+ map->insert(AttrValueMap::value_type(name.ToString(), value));
+}
+
+#define ADD_ATTR(T) \
+ void AddAttr(StringPiece name, T value, AttrValueMap* map) { \
+ AttrValue attr_value; \
+ SetAttrValue(value, &attr_value); \
+ AddAttr(name, attr_value, map); \
+ }
+ADD_ATTR(bool)
+#undef ADD_ATTR
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 1438abdec6..5d4864db66 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -21,7 +21,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/attr_value_util.h"
-#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -31,6 +30,9 @@ namespace tensorflow {
class Node;
+// We forward declare NodeDef so that kernels don't need to depend on protos
+class NodeDef;
+
// Name of the attribute used to encode node colocation constraints.
//
// Nodes can be co-located on the same device. Desire for explicit co-location
@@ -50,32 +52,61 @@ typedef protobuf::Map<string, AttrValue> AttrValueMap;
// Adds an attr with name <name> and value <value> to *node_def.
// The type of the attr is based on the type of value.
-template <class T>
-void AddNodeAttr(StringPiece name, T&& value, NodeDef* node_def) {
- AttrValue attr_value;
- SetAttrValue(std::forward<T>(value), &attr_value);
- node_def->mutable_attr()->insert(
- AttrValueMap::value_type(name.ToString(), attr_value));
-}
+void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def);
+void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def);
+void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def);
+void AddNodeAttr(StringPiece name, int32 value, NodeDef* node_def);
+void AddNodeAttr(StringPiece name, int64 value, NodeDef* node_def);
+void AddNodeAttr(StringPiece name, float value, NodeDef* node_def);
+void AddNodeAttr(StringPiece name, double value, NodeDef* node_def);
+void AddNodeAttr(StringPiece name, bool value, NodeDef* node_def);
+void AddNodeAttr(StringPiece name, DataType value, NodeDef* node_def);
+void AddNodeAttr(StringPiece name, const PartialTensorShape& value,
+ NodeDef* node_def);
+void AddNodeAttr(StringPiece name, const Tensor& value, NodeDef* node_def);
+void AddNodeAttr(StringPiece name, const TensorProto& value, NodeDef* node_def);
+void AddNodeAttr(StringPiece name, const NameAttrList& value,
+ NodeDef* node_def);
+void AddNodeAttr(StringPiece name, gtl::ArraySlice<StringPiece> value,
+ NodeDef* node_def);
+void AddNodeAttr(StringPiece name, gtl::ArraySlice<const char*> value,
+ NodeDef* node_def);
+void AddNodeAttr(StringPiece name, gtl::ArraySlice<string> value,
+ NodeDef* node_def);
+void AddNodeAttr(StringPiece name, gtl::ArraySlice<int32> value,
+ NodeDef* node_def);
+void AddNodeAttr(StringPiece name, gtl::ArraySlice<int64> value,
+ NodeDef* node_def);
+void AddNodeAttr(StringPiece name, gtl::ArraySlice<float> value,
+ NodeDef* node_def);
+void AddNodeAttr(StringPiece name, gtl::ArraySlice<bool> value,
+ NodeDef* node_def);
+void AddNodeAttr(StringPiece name, const std::vector<bool>& value,
+ NodeDef* node_def);
+void AddNodeAttr(StringPiece name, gtl::ArraySlice<DataType> value,
+ NodeDef* node_def);
+void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShape> value,
+ NodeDef* node_def);
+void AddNodeAttr(StringPiece name, gtl::ArraySlice<PartialTensorShape> value,
+ NodeDef* node_def);
+void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShapeProto> value,
+ NodeDef* node_def);
+void AddNodeAttr(StringPiece name, gtl::ArraySlice<Tensor> value,
+ NodeDef* node_def);
+void AddNodeAttr(StringPiece name, gtl::ArraySlice<NameAttrList> value,
+ NodeDef* node_def);
// Version to workaround C++'s "perfect" forwarding not being able to
// forward {...} initialization.
template <class T>
void AddNodeAttr(StringPiece name, std::initializer_list<T> value,
NodeDef* node_def) {
- AttrValue attr_value;
- SetAttrValue(value, &attr_value);
- node_def->mutable_attr()->insert(
- AttrValueMap::value_type(name.ToString(), attr_value));
+ AddNodeAttr(name, gtl::ArraySlice<T>(value), node_def);
}
// Adds an attr to an attr value map.
-template <class T>
-void AddAttr(StringPiece name, T&& value, AttrValueMap* map) {
- AttrValue attr_value;
- SetAttrValue(value, &attr_value);
- map->insert(AttrValueMap::value_type(name.ToString(), attr_value));
-}
+void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map);
+void AddAttr(StringPiece name, bool value, AttrValueMap* map);
class AttrSlice {
public: