aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-30 15:48:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-30 15:51:50 -0800
commit186caed810c0e9a9ee9a3f1e0f8bea50764ce5df (patch)
tree6518391420b13090f64a6decfe3ae1a433d246f1
parent0472116d163eeb77d51cabdc5fc67be917048870 (diff)
Add int64 support to XLA Shape op.
PiperOrigin-RevId: 177519992
-rw-r--r--tensorflow/compiler/tf2xla/kernels/shape_op.cc76
1 files changed, 45 insertions, 31 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
index 24a99f253d..06838d1625 100644
--- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
@@ -25,58 +25,72 @@ limitations under the License.
namespace tensorflow {
namespace {
+// Converts a TensorShape to a constant Tensor.
+//
+// The input TensorShape input_shape is used to populate the elements of
+// shape_constant, which is modified in place.
+Status TensorShapeToConstant(const TensorShape& input_shape,
+ Tensor* shape_constant) {
+ const int dims = input_shape.dims();
+ if (shape_constant->dtype() == DT_INT32) {
+ auto vec = shape_constant->vec<int32>();
+ for (int i = 0; i < dims; ++i) {
+ int64 dim_size = input_shape.dim_size(i);
+ if (!FastBoundsCheck(dim_size, std::numeric_limits<int32>::max())) {
+ return errors::InvalidArgument(
+ "Shape with out_type=int32 does not support tensors > int32max",
+ " but dim ", i, " is ", dim_size);
+ }
+ vec(i) = static_cast<int32>(dim_size);
+ }
+ } else {
+ auto vec = shape_constant->vec<int64>();
+ for (int i = 0; i < dims; ++i) {
+ int64 dim_size = input_shape.dim_size(i);
+ vec(i) = dim_size;
+ }
+ }
+ return Status::OK();
+}
+
class ShapeOp : public XlaOpKernel {
public:
- explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
+ }
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
- const int rank = input_shape.dims();
- Tensor shape_constant(DT_INT32, TensorShape({rank}));
- auto vec = shape_constant.vec<int32>();
- // TODO(dga): support int64. b/28119922.
- for (int i = 0; i < rank; ++i) {
- int64 dim_size = input_shape.dim_size(i);
- OP_REQUIRES(
- ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
- errors::InvalidArgument("Shape does not support tensors > int32max",
- " but dim ", i, " is ", dim_size));
- vec(i) = static_cast<int32>(dim_size);
- }
-
+ Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()}));
+ OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant));
ctx->SetConstantOutput(0, shape_constant);
}
+
+ private:
+ DataType out_dtype_;
};
REGISTER_XLA_OP(Name("Shape"), ShapeOp);
class ShapeNOp : public XlaOpKernel {
public:
- explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
+ }
void Compile(XlaOpKernelContext* ctx) override {
for (int i = 0; i < ctx->num_inputs(); ++i) {
- const TensorShape shape = ctx->InputShape(i);
- const int dims = shape.dims();
- Tensor shape_constant(DT_INT32, TensorShape({dims}));
- auto vec = shape_constant.vec<int32>();
-
- // TODO(dga): support int64. b/28119922.
- for (int j = 0; j < dims; ++j) {
- int64 dim_size = shape.dim_size(j);
- OP_REQUIRES(
- ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
- errors::InvalidArgument("Shape does not support tensors > int32max",
- " but shape ", i, " dim ", j, " is ",
- dim_size));
- vec(j) = static_cast<int32>(dim_size);
- }
-
+ const TensorShape input_shape = ctx->InputShape(i);
+ Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()}));
+ OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant));
ctx->SetConstantOutput(i, shape_constant);
}
}
bool IsExpensive() override { return false; }
+
+ private:
+ DataType out_dtype_;
};
REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp);