1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
|
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace ops {
namespace {
const string& OpName() {
static const string kOpName = "Const";
return kOpName;
}
} // namespace
#define DEFINE_CONST_SCALAR(TYPE) \
Node* Const(TYPE s, const GraphDefBuilder::Options& options) { \
return Const(gtl::ArraySlice<TYPE>(&s, 1), TensorShape({}), options); \
}
#define DEFINE_CONST_VECTOR(TYPE) \
Node* Const(gtl::ArraySlice<TYPE> v, \
const GraphDefBuilder::Options& options) { \
return Const(v, TensorShape({static_cast<int64>(v.size())}), options); \
}
#define DEFINE_CONST_TENSOR(TYPE, ...) \
Node* Const(gtl::ArraySlice<TYPE> t, const TensorShape& shape, \
const GraphDefBuilder::Options& options) { \
if (options.HaveError()) return nullptr; \
NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(), \
options.op_registry()); \
const DataType dt = DataTypeToEnum<TYPE>::v(); \
if (t.size() == 1) { \
TensorProto proto; \
proto.set_dtype(dt); \
shape.AsProto(proto.mutable_tensor_shape()); \
__VA_ARGS__; \
node_builder.Attr("dtype", dt).Attr("value", proto); \
} else { \
Tensor tensor(dt, shape); \
if (tensor.NumElements() != static_cast<int64>(t.size())) { \
options.UpdateStatus(errors::InvalidArgument( \
t.size(), " values provided to Const() != ", tensor.NumElements(), \
" elements for shape ", shape.ShortDebugString())); \
} else { \
std::copy_n(t.data(), t.size(), tensor.flat<TYPE>().data()); \
node_builder.Attr("dtype", dt).Attr("value", tensor); \
} \
} \
return options.FinalizeBuilder(&node_builder); \
}
#define DEFINE_CONST_IMPL(TYPE, ...) \
DEFINE_CONST_SCALAR(TYPE) \
DEFINE_CONST_VECTOR(TYPE) \
DEFINE_CONST_TENSOR(TYPE, __VA_ARGS__)
#define DEFINE_CONST(TYPE, FIELD) \
DEFINE_CONST_IMPL(TYPE, proto.add_##FIELD(*t.begin());)
DEFINE_CONST(float, float_val);
DEFINE_CONST(double, double_val);
DEFINE_CONST(int32, int_val);
DEFINE_CONST(uint8, int_val);
DEFINE_CONST(int16, int_val);
DEFINE_CONST(int8, int_val);
DEFINE_CONST(int64, int64_val);
DEFINE_CONST(bool, bool_val);
DEFINE_CONST_IMPL(complex64, proto.add_scomplex_val(t.begin()->real());
proto.add_scomplex_val(t.begin()->imag()););
Node* Const(StringPiece s, const GraphDefBuilder::Options& options) {
if (options.HaveError()) return nullptr;
NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(),
options.op_registry());
TensorProto proto;
proto.set_dtype(DT_STRING);
TensorShape({}).AsProto(proto.mutable_tensor_shape());
proto.add_string_val(s.data(), s.size());
node_builder.Attr("dtype", DT_STRING).Attr("value", proto);
return options.FinalizeBuilder(&node_builder);
}
DEFINE_CONST_VECTOR(string)
DEFINE_CONST_TENSOR(string, proto.add_string_val(*t.begin());)
#undef DEFINE_CONST
#undef DEFINE_CONST_IMPL
#undef DEFINE_CONST_TENSOR
#undef DEFINE_CONST_VECTOR
#undef DEFINE_CONST_SCALAR
Node* Const(const Tensor& t, const GraphDefBuilder::Options& options) {
if (options.HaveError()) return nullptr;
NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(),
options.op_registry());
node_builder.Attr("dtype", t.dtype()).Attr("value", t);
return options.FinalizeBuilder(&node_builder);
}
Node* Const(const TensorProto& proto, const GraphDefBuilder::Options& options) {
if (options.HaveError()) return nullptr;
NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(),
options.op_registry());
node_builder.Attr("dtype", proto.dtype()).Attr("value", proto);
return options.FinalizeBuilder(&node_builder);
}
} // namespace ops
} // namespace tensorflow
|