blob: 4897102974a18f61f2e35aef1f7d3b68da7e6bfd (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
|
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/queue_interface.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/public/tensor.h"
#include "tensorflow/core/public/tensor_shape.h"
namespace tensorflow {
// Functionality common to QueueInterface implementations.
class QueueBase : public QueueInterface {
public:
// As a possible value of 'capacity'.
static const int32 kUnbounded = INT_MAX;
// Args:
// component_dtypes: The types of each component in a queue-element tuple.
// component_shapes: The shapes of each component in a queue-element tuple,
// which must either be empty (if the shapes are not specified) or
// or have the same size as component_dtypes.
// name: A name to use for the queue.
QueueBase(const DataTypeVector& component_dtypes,
const std::vector<TensorShape>& component_shapes,
const string& name);
// Implementations of QueueInterface methods --------------------------------
const DataTypeVector& component_dtypes() const override {
return component_dtypes_;
}
// Other public methods -----------------------------------------------------
const std::vector<TensorShape>& component_shapes() const {
return component_shapes_;
}
protected:
// Returns the number of components in a queue-element tuple.
int32 num_components() const { return component_dtypes_.size(); }
// True if shapes were specified. If so, inputs will be validated
// against them, etc.
bool specified_shapes() const { return component_shapes_.size() > 0; }
// Code common to Validate*Tuple().
Status ValidateTupleCommon(const Tuple& tuple) const;
// Copies the index^th slice (in the first dimension) of parent into element.
static Status CopySliceToElement(const Tensor& parent, Tensor* element,
int index);
// Copies element into the index^th slice (in the first dimension) of parent.
static Status CopyElementToSlice(const Tensor& element, Tensor* parent,
int index);
~QueueBase() override {}
// Helpers for implementing MatchesNodeDef().
static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes);
Status MatchesNodeDefOp(const NodeDef& node_def, const string& op) const;
Status MatchesNodeDefCapacity(const NodeDef& node_def, int32 capacity) const;
Status MatchesNodeDefTypes(const NodeDef& node_def) const;
Status MatchesNodeDefShapes(const NodeDef& node_def) const;
const DataTypeVector component_dtypes_;
const std::vector<TensorShape> component_shapes_;
const string name_;
TF_DISALLOW_COPY_AND_ASSIGN(QueueBase);
};
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
|