aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/ops/infeed_ops.cc
blob: c12e83137aa8f4f74d63f3dac1bb490cd0977648 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

namespace tensorflow {

using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;

REGISTER_OP("InfeedDequeue")
    .Output("output: dtype")
    .Attr("dtype: type")
    .Attr("shape: shape")
    .SetIsStateful()
    .SetShapeFn([](InferenceContext* c) {
      PartialTensorShape shape;
      TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
      ShapeHandle out;
      TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
      c->set_output(0, out);
      return Status::OK();
    })
    .Doc(R"doc(
A placeholder op for a value that will be fed into the computation.

output: A tensor that will be provided using the infeed mechanism.
dtype: The type of elements in the tensor.
shape: The shape of the tensor.
)doc");

REGISTER_OP("InfeedEnqueue")
    .Input("input: dtype")
    .Attr("dtype: type")
    .Attr("shape: shape = {}")
    .Attr("device_ordinal: int = -1")
    .SetIsStateful()
    .Doc(R"doc(
An op which feeds a single Tensor value into the computation.

input: A tensor that will be provided using the infeed mechanism.
dtype: The type of elements in the tensor.
shape: The shape of the tensor.
device_ordinal: The TPU device to use. This should be -1 when the Op
is running on a TPU device, and >= 0 when the Op is running on the CPU
device.
)doc");

REGISTER_OP("InfeedEnqueueTuple")
    .Input("inputs: dtypes")
    .Attr("dtypes: list(type)")
    .Attr("shapes: list(shape)")
    .Attr("device_ordinal: int = -1")
    .SetIsStateful()
    .Doc(R"doc(
An op which feeds multiple Tensor values into the computation as an XLA tuple.

inputs: A list of tensors that will be provided using the infeed mechanism.
dtypes: The element types of each element in `inputs`.
shapes: The shapes of each tensor in `inputs`.
device_ordinal: The TPU device to use. This should be -1 when the Op
is running on a TPU device, and >= 0 when the Op is running on the CPU
device.
)doc");

REGISTER_OP("InfeedDequeueTuple")
    .Output("outputs: dtypes")
    .Attr("dtypes: list(type)")
    .Attr("shapes: list(shape)")
    .SetIsStateful()
    .SetShapeFn([](InferenceContext* c) {
      std::vector<PartialTensorShape> shapes;
      TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
      for (int i = 0; i < shapes.size(); ++i) {
        ShapeHandle out;
        TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &out));
        c->set_output(i, out);
      }
      return Status::OK();
    })
    .Doc(R"doc(
A placeholder op for multiple values that will be fed into the computation
simultaneously as an XLA tuple.

outputs: A list of tensors that will be provided using the infeed mechanism.
dtypes: The element types of each element in `outputs`.
shapes: The shapes of each tensor in `outputs`.
)doc");

}  // namespace tensorflow