aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/ops/const_op.h
blob: 2588fe1aff3e75f5e5990914ebbd02a780a13eca (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
/* Copyright 2015 Google Inc. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CC_OPS_CONST_OP_H_
#define TENSORFLOW_CC_OPS_CONST_OP_H_

#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/public/tensor.h"

namespace tensorflow {
namespace ops {

// If a shape is specified, you may either provide the same number of values,
// or a single value and that value will be duplicated to fill out the Tensor.
#define DECLARE_CONST(TYPE)                                                  \
  Node* Const(TYPE s, const GraphDefBuilder::Options& options); /* Scalar */ \
  Node* Const(gtl::ArraySlice<TYPE> v,                                       \
              const GraphDefBuilder::Options& options); /* Vector */         \
  Node* Const(gtl::ArraySlice<TYPE> t, const TensorShape& shape,             \
              const GraphDefBuilder::Options& options); /* Tensor */         \
  inline Node* Const(std::initializer_list<TYPE> v, /* Vector using {...} */ \
                     const GraphDefBuilder::Options& options) {              \
    return Const(gtl::ArraySlice<TYPE>(v), options);                         \
  }                                                                          \
  inline Node* Const(std::initializer_list<TYPE> t, /* Tensor using {...} */ \
                     const TensorShape& shape,                               \
                     const GraphDefBuilder::Options& options) {              \
    return Const(gtl::ArraySlice<TYPE>(t), shape, options);                  \
  }

DECLARE_CONST(float);
DECLARE_CONST(double);
DECLARE_CONST(int32);
DECLARE_CONST(uint8);
DECLARE_CONST(int16);
DECLARE_CONST(int8);
DECLARE_CONST(complex64);
DECLARE_CONST(int64);
DECLARE_CONST(bool);

#undef DECLARE_CONST

// String
Node* Const(StringPiece s, const GraphDefBuilder::Options& options);
Node* Const(gtl::ArraySlice<string> v, const GraphDefBuilder::Options& options);
Node* Const(gtl::ArraySlice<string> t, const TensorShape& shape,
            const GraphDefBuilder::Options& options);
inline Node* Const(std::initializer_list<string> v,
                   const GraphDefBuilder::Options& options) {
  return Const(gtl::ArraySlice<string>(v), options);
}
inline Node* Const(std::initializer_list<string> t, const TensorShape& shape,
                   const GraphDefBuilder::Options& options) {
  return Const(gtl::ArraySlice<string>(t), shape, options);
}

// A Tensor of any type.
Node* Const(const Tensor& t, const GraphDefBuilder::Options& options);
Node* Const(const TensorProto& proto, const GraphDefBuilder::Options& options);

template <class T>
Node* EmptyConst(const GraphDefBuilder::Options& options) {
  return Const(gtl::ArraySlice<T>(), options);
}

// TODO(josh11b): Support other types (e.g. quantized ints, float16).

}  // namespace ops
}  // namespace tensorflow

#endif  // TENSORFLOW_CC_OPS_CONST_OP_H_