aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/ops/const_op.cc
blob: e428e4f35ece1cc1b4bb26d12c7c23d1df37d14a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include "tensorflow/cc/ops/const_op.h"

#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"

namespace tensorflow {
namespace ops {

namespace {
const string& OpName() {
  static const string kOpName = "Const";
  return kOpName;
}
}  // namespace

#define DEFINE_CONST_SCALAR(TYPE)                                         \
  Node* Const(TYPE s, const GraphDefBuilder::Options& options) {          \
    return Const(gtl::ArraySlice<TYPE>(&s, 1), TensorShape({}), options); \
  }

#define DEFINE_CONST_VECTOR(TYPE)                                          \
  Node* Const(gtl::ArraySlice<TYPE> v,                                     \
              const GraphDefBuilder::Options& options) {                   \
    return Const(v, TensorShape({static_cast<int64>(v.size())}), options); \
  }

#define DEFINE_CONST_TENSOR(TYPE, ...)                                         \
  Node* Const(gtl::ArraySlice<TYPE> t, const TensorShape& shape,               \
              const GraphDefBuilder::Options& options) {                       \
    if (options.HaveError()) return nullptr;                                   \
    NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(),         \
                             options.op_registry());                           \
    const DataType dt = DataTypeToEnum<TYPE>::v();                             \
    if (t.size() == 1) {                                                       \
      TensorProto proto;                                                       \
      proto.set_dtype(dt);                                                     \
      shape.AsProto(proto.mutable_tensor_shape());                             \
      __VA_ARGS__;                                                             \
      node_builder.Attr("dtype", dt).Attr("value", proto);                     \
    } else {                                                                   \
      Tensor tensor(dt, shape);                                                \
      if (tensor.NumElements() != static_cast<int64>(t.size())) {              \
        options.UpdateStatus(errors::InvalidArgument(                          \
            t.size(), " values provided to Const() != ", tensor.NumElements(), \
            " elements for shape ", shape.ShortDebugString()));                \
      } else {                                                                 \
        std::copy_n(t.data(), t.size(), tensor.flat<TYPE>().data());           \
        node_builder.Attr("dtype", dt).Attr("value", tensor);                  \
      }                                                                        \
    }                                                                          \
    return options.FinalizeBuilder(&node_builder);                             \
  }

#define DEFINE_CONST_IMPL(TYPE, ...) \
  DEFINE_CONST_SCALAR(TYPE)          \
  DEFINE_CONST_VECTOR(TYPE)          \
  DEFINE_CONST_TENSOR(TYPE, __VA_ARGS__)

#define DEFINE_CONST(TYPE, FIELD) \
  DEFINE_CONST_IMPL(TYPE, proto.add_##FIELD(*t.begin());)

DEFINE_CONST(float, float_val);
DEFINE_CONST(double, double_val);
DEFINE_CONST(int32, int_val);
DEFINE_CONST(uint8, int_val);
DEFINE_CONST(int16, int_val);
DEFINE_CONST(int8, int_val);
DEFINE_CONST(int64, int64_val);
DEFINE_CONST(bool, bool_val);

DEFINE_CONST_IMPL(complex64, proto.add_scomplex_val(t.begin()->real());
                  proto.add_scomplex_val(t.begin()->imag()););

Node* Const(StringPiece s, const GraphDefBuilder::Options& options) {
  if (options.HaveError()) return nullptr;
  NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(),
                           options.op_registry());
  TensorProto proto;
  proto.set_dtype(DT_STRING);
  TensorShape({}).AsProto(proto.mutable_tensor_shape());
  proto.add_string_val(s.data(), s.size());
  node_builder.Attr("dtype", DT_STRING).Attr("value", proto);
  return options.FinalizeBuilder(&node_builder);
}

DEFINE_CONST_VECTOR(string)
DEFINE_CONST_TENSOR(string, proto.add_string_val(*t.begin());)

#undef DEFINE_CONST
#undef DEFINE_CONST_IMPL
#undef DEFINE_CONST_TENSOR
#undef DEFINE_CONST_VECTOR
#undef DEFINE_CONST_SCALAR

Node* Const(const Tensor& t, const GraphDefBuilder::Options& options) {
  if (options.HaveError()) return nullptr;
  NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(),
                           options.op_registry());
  node_builder.Attr("dtype", t.dtype()).Attr("value", t);
  return options.FinalizeBuilder(&node_builder);
}

Node* Const(const TensorProto& proto, const GraphDefBuilder::Options& options) {
  if (options.HaveError()) return nullptr;
  NodeBuilder node_builder(options.GetNameForOp(OpName()), OpName(),
                           options.op_registry());
  node_builder.Attr("dtype", proto.dtype()).Attr("value", proto);
  return options.FinalizeBuilder(&node_builder);
}

}  // namespace ops
}  // namespace tensorflow