aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/ops/infeed_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tpu/ops/infeed_ops.cc')
-rw-r--r--tensorflow/contrib/tpu/ops/infeed_ops.cc10
1 files changed, 2 insertions, 8 deletions
diff --git a/tensorflow/contrib/tpu/ops/infeed_ops.cc b/tensorflow/contrib/tpu/ops/infeed_ops.cc
index c12e83137a..849c4a1102 100644
--- a/tensorflow/contrib/tpu/ops/infeed_ops.cc
+++ b/tensorflow/contrib/tpu/ops/infeed_ops.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
@@ -26,14 +27,7 @@ REGISTER_OP("InfeedDequeue")
.Attr("dtype: type")
.Attr("shape: shape")
.SetIsStateful()
- .SetShapeFn([](InferenceContext* c) {
- PartialTensorShape shape;
- TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
- ShapeHandle out;
- TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
- c->set_output(0, out);
- return Status::OK();
- })
+ .SetShapeFn(shape_inference::ExplicitShape)
.Doc(R"doc(
A placeholder op for a value that will be fed into the computation.