diff options
-rw-r--r-- | tensorflow/core/common_runtime/function_test.cc | 66 | ||||
-rw-r--r-- | tensorflow/core/kernels/transpose_op.cc | 22 | ||||
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 212 | ||||
-rw-r--r-- | tensorflow/core/ops/array_ops_test.cc | 15 | ||||
-rw-r--r-- | tensorflow/core/ops/math_ops.cc | 92 | ||||
-rw-r--r-- | tensorflow/python/framework/op_def_library.py | 50 | ||||
-rw-r--r-- | tensorflow/python/framework/op_def_library_test.py | 40 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 54 | ||||
-rw-r--r-- | tensorflow/python/framework/ops_test.py | 25 |
9 files changed, 430 insertions, 146 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index e263e62bd8..523547b4eb 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -534,11 +534,11 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { (n2:float, n3:float) -> (n9:float) { n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [0] values: >]() n10 = Const[dtype=float, value=Tensor<type: float shape: [] values: 2>]() - n6 = Shape[T=float](n2) + n6 = Shape[T=float, out_type=int32](n2) n5 = Mul[T=float](n3, n10) - n7 = BroadcastGradientArgs(n6, n11) - n8 = Sum[T=float, keep_dims=false](n5, n7) - n9 = Reshape[T=float](n8, n6) + n7 = BroadcastGradientArgs[T=int32](n6, n11) + n8 = Sum[T=float, Tidx=int32, keep_dims=false](n5, n7) + n9 = Reshape[T=float, Tshape=int32](n8, n6) } )P"; EXPECT_EQ(e2, DebugString(g)); @@ -555,13 +555,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_Add) { (n7:float, n5:float, n2:float) -> (n14:float, n11:float) { n3 = Identity[T=float](n2) n4 = Identity[T=float](n2) - n6 = Shape[T=float](n5) - n8 = Shape[T=float](n7) - n9 = BroadcastGradientArgs(n8, n6) - n10 = Sum[T=float, keep_dims=false](n3, n9:1) - n13 = Sum[T=float, keep_dims=false](n4, n9) - n11 = Reshape[T=float](n10, n6) - n14 = Reshape[T=float](n13, n8) + n6 = Shape[T=float, out_type=int32](n5) + n8 = Shape[T=float, out_type=int32](n7) + n9 = BroadcastGradientArgs[T=int32](n8, n6) + n10 = Sum[T=float, Tidx=int32, keep_dims=false](n3, n9:1) + n13 = Sum[T=float, Tidx=int32, keep_dims=false](n4, n9) + n11 = Reshape[T=float, Tshape=int32](n10, n6) + n14 = Reshape[T=float, Tshape=int32](n13, n8) } )P"; EXPECT_EQ(e0, DebugString(g)); @@ -576,14 +576,14 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_Mul) { const char* e0 = R"P( (n6:float, n3:float, n2:float) -> (n14:float, n11:float) { n4 = Mul[T=float](n2, n3) - n5 = Shape[T=float](n3) + n5 = Shape[T=float, out_type=int32](n3) n7 = Mul[T=float](n6, n2) - n8 = Shape[T=float](n6) - n9 = BroadcastGradientArgs(n8, n5) - n10 = Sum[T=float, keep_dims=false](n7, n9:1) - n13 = Sum[T=float, keep_dims=false](n4, n9) - n11 = Reshape[T=float](n10, n5) - n14 = Reshape[T=float](n13, n8) + n8 = Shape[T=float, out_type=int32](n6) + n9 = BroadcastGradientArgs[T=int32](n8, n5) + n10 = Sum[T=float, Tidx=int32, keep_dims=false](n7, n9:1) + n13 = Sum[T=float, Tidx=int32, keep_dims=false](n4, n9) + n11 = Reshape[T=float, Tshape=int32](n10, n5) + n14 = Reshape[T=float, Tshape=int32](n13, n8) } )P"; EXPECT_EQ(e0, DebugString(g)); @@ -643,10 +643,10 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { n24 = Identity[T=float](n4) n14 = Add[T=float](n24, n25) n15 = Rank[T=float](n14) - n16 = Range(n11, n15, n10) + n16 = Range[Tidx=int32](n11, n15, n10) n20 = ZerosLike[T=int32](n15) - n17 = Sum[T=float, keep_dims=false](n14, n16) - n19 = SymbolicGradient[Tin={float, int32, float}, Tout={float, int32}, f=Sum[T=float, keep_dims=false]](n14, n16, n26) + n17 = Sum[T=float, Tidx=int32, keep_dims=false](n14, n16) + n19 = SymbolicGradient[Tin={float, int32, float}, Tout={float, int32}, f=Sum[T=float, Tidx=int32, keep_dims=false]](n14, n16, n26) n21 = SymbolicGradient[Tin={float, float, float}, Tout={float, float}, f=Add[T=float]](n24, n25, n19) n28 = Identity[T=float](n21:1) n27 = Identity[T=float](n21) @@ -662,23 +662,23 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]() n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]() n7 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]() - n19 = Shape[T=float](n3) + n19 = Shape[T=float, out_type=int32](n3) n8 = Add[T=float](n4, n3) - n20 = Shape[T=float](n4) + n20 = Shape[T=float, out_type=int32](n4) n9 = Rank[T=float](n8) - n14 = Shape[T=float](n8) - n21 = BroadcastGradientArgs(n20, n19) - n10 = Range(n7, n9, n11) - n12 = Shape[T=int32](n10) + n14 = Shape[T=float, out_type=int32](n8) + n21 = BroadcastGradientArgs[T=int32](n20, n19) + n10 = Range[Tidx=int32](n7, n9, n11) + n12 = Shape[T=int32, out_type=int32](n10) n13 = Fill[T=int32](n12, n11) n15 = DynamicStitch[N=2, T=int32](n10, n10, n14, n13) - n16 = Reshape[T=float](n2, n15) + n16 = Reshape[T=float, Tshape=int32](n2, n15) n17 = Div[T=int32](n14, n15) - n18 = Tile[T=float](n16, n17) - n24 = Sum[T=float, keep_dims=false](n18, n21) - n22 = Sum[T=float, keep_dims=false](n18, n21:1) - n25 = Reshape[T=float](n24, n20) - n23 = Reshape[T=float](n22, n19) + n18 = Tile[T=float, Tmultiples=int32](n16, n17) + n24 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21) + n22 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21:1) + n25 = Reshape[T=float, Tshape=int32](n24, n20) + n23 = Reshape[T=float, Tshape=int32](n22, n19) } )P"; EXPECT_EQ(e2, DebugString(g)); diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index fbde0c9626..35429a382e 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -169,11 +169,12 @@ Status TransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, out); } -#define REGISTER(T) \ - REGISTER_KERNEL_BUILDER(Name("Transpose") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<T>("T") \ - .HostMemory("perm"), \ +#define REGISTER(T) \ + REGISTER_KERNEL_BUILDER(Name("Transpose") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<int32>("Tperm") \ + .HostMemory("perm"), \ TransposeCpuOp); TF_CALL_ALL_TYPES(REGISTER) REGISTER(bfloat16); @@ -187,11 +188,12 @@ Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, out); } -#define REGISTER(T) \ - REGISTER_KERNEL_BUILDER(Name("Transpose") \ - .Device(DEVICE_GPU) \ - .TypeConstraint<T>("T") \ - .HostMemory("perm"), \ +#define REGISTER(T) \ + REGISTER_KERNEL_BUILDER(Name("Transpose") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<int32>("Tperm") \ + .HostMemory("perm"), \ TransposeGpuOp); TF_CALL_POD_TYPES(REGISTER); #undef REGISTER diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index fc1d031d4b..4ab47754df 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -39,6 +39,34 @@ Status GetAxisForPackAndUnpack(InferenceContext* c, int32 rank_after_pack, return Status::OK(); } +template <typename T> +std::vector<int64> AsInt64(const Tensor* tensor, int num_elements) { + std::vector<int64> ret(num_elements); + auto data = tensor->vec<T>(); + for (int i = 0; i < num_elements; ++i) { + ret[i] = data(i); + } + return ret; +} + +template <typename T> +Status PadKnown(InferenceContext* c, ShapeHandle input, + const Tensor* paddings_t, int32 num_dims) { + // paddings_t is known. + std::vector<DimensionHandle> dims(num_dims); + auto paddings_data = paddings_t->matrix<T>(); + for (int i = 0; i < num_dims; ++i) { + const T pad0 = paddings_data(i, 0); + const T pad1 = paddings_data(i, 1); + if (pad0 < 0 || pad1 < 0) { + return errors::InvalidArgument("Paddings must be non-negative"); + } + TF_RETURN_IF_ERROR(c->Add(c->Dim(input, i), pad0 + pad1, &dims[i])); + } + c->set_output(0, c->MakeShape(dims)); + return Status::OK(); +} + Status PadShapeFn(InferenceContext* c) { // Paddings is a matrix of [input_rank, 2]. ShapeHandle paddings; @@ -73,19 +101,11 @@ Status PadShapeFn(InferenceContext* c) { const auto num_dims = c->Value(n_dim); DCHECK_EQ(num_dims, paddings_t->shape().dim_size(0)); - // paddings_t is known. - auto paddings_data = paddings_t->matrix<int32>(); - std::vector<DimensionHandle> dims(num_dims); - for (int i = 0; i < num_dims; ++i) { - const int32 pad0 = paddings_data(i, 0); - const int32 pad1 = paddings_data(i, 1); - if (pad0 < 0 || pad1 < 0) { - return errors::InvalidArgument("Paddings must be non-negative"); - } - TF_RETURN_IF_ERROR(c->Add(c->Dim(input, i), pad0 + pad1, &dims[i])); + if (paddings_t->dtype() == DT_INT32) { + return PadKnown<int32>(c, input, paddings_t, num_dims); + } else { + return PadKnown<int64>(c, input, paddings_t, num_dims); } - c->set_output(0, c->MakeShape(dims)); - return Status::OK(); } } // namespace @@ -1127,9 +1147,10 @@ message: Prefix of the error message. // -------------------------------------------------------------------------- REGISTER_OP("Reshape") .Input("tensor: T") - .Input("shape: int32") + .Input("shape: Tshape") .Output("output: T") .Attr("T: type") + .Attr("Tshape: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { ShapeHandle in = c->input(0); ShapeHandle out; @@ -1247,8 +1268,9 @@ shape: Defines the shape of the output tensor. // -------------------------------------------------------------------------- REGISTER_OP("InvertPermutation") - .Input("x: int32") - .Output("y: int32") + .Input("x: T") + .Output("y: T") + .Attr("T: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { ShapeHandle x; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &x)); @@ -1281,9 +1303,10 @@ y: 1-D. // -------------------------------------------------------------------------- REGISTER_OP("Transpose") .Input("x: T") - .Input("perm: int32") + .Input("perm: Tperm") .Output("y: T") .Attr("T: type") + .Attr("Tperm: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { ShapeHandle input = c->input(0); ShapeHandle perm_shape = c->input(1); @@ -1318,9 +1341,15 @@ REGISTER_OP("Transpose") // all shape informantion, otherwise we can only return rank information, // but no information for the dimensions. if (perm != nullptr) { - auto flat_perm = perm->flat<int32>(); + std::vector<int64> data; + if (perm->dtype() == DT_INT32) { + data = AsInt64<int32>(perm, rank); + } else { + data = AsInt64<int64>(perm, rank); + } + for (int32 i = 0; i < rank; ++i) { - int32 in_idx = flat_perm(i); + int64 in_idx = data[i]; if (in_idx >= rank) { return errors::InvalidArgument( "perm dim ", in_idx, " is out of range of input rank ", rank); @@ -1347,8 +1376,9 @@ The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: REGISTER_OP("Unique") .Input("x: T") .Output("y: T") - .Output("idx: int32") + .Output("idx: out_idx") .Attr("T: type") + .Attr("out_idx: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); c->set_output(1, c->input(0)); @@ -1382,9 +1412,10 @@ idx: 1-D. REGISTER_OP("UniqueWithCounts") .Input("x: T") .Output("y: T") - .Output("idx: int32") - .Output("count: int32") + .Output("idx: out_idx") + .Output("count: out_idx") .Attr("T: type") + .Attr("out_idx: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { auto uniq = c->Vector(InferenceContext::kUnknownDim); c->set_output(0, uniq); @@ -1439,8 +1470,9 @@ Status ShapeShapeFn(InferenceContext* c) { // -------------------------------------------------------------------------- REGISTER_OP("Shape") .Input("input: T") - .Output("output: int32") + .Output("output: out_type") .Attr("T: type") + .Attr("out_type: {int32, int64} = DT_INT32") .SetShapeFn(ShapeShapeFn) .Doc(R"doc( Returns the shape of a tensor. @@ -1458,9 +1490,10 @@ shape(t) ==> [2, 2, 3] REGISTER_OP("ShapeN") .Input("input: N * T") - .Output("output: N * int32") + .Output("output: N * out_type") .Attr("N: int") .Attr("T: type") + .Attr("out_type: {int32, int64} = DT_INT32") .SetShapeFn(ShapeShapeFn) .Doc(R"doc( Returns shape of tensors. @@ -1606,8 +1639,9 @@ of the tensor. Rank is also known as "order", "degree", or "ndims." // -------------------------------------------------------------------------- REGISTER_OP("Size") .Input("input: T") - .Output("output: int32") + .Output("output: out_type") .Attr("T: type") + .Attr("out_type: {int32, int64} = DT_INT32") .SetShapeFn(shape_inference::ScalarShape) .Doc(R"doc( Returns the size of a tensor. @@ -1854,11 +1888,13 @@ shape must be exactly the shape produced by the slice of `ref`. // TODO(aselle): Fix this documentation once StridedSliceAssign Supports // broadcasting. // -------------------------------------------------------------------------- + REGISTER_OP("Tile") .Input("input: T") - .Input("multiples: int32") + .Input("multiples: Tmultiples") .Output("output: T") .Attr("T: type") + .Attr("Tmultiples: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { ShapeHandle input; ShapeHandle multiples; @@ -1882,12 +1918,15 @@ REGISTER_OP("Tile") return Status::OK(); } - // Multiply each input dimension by its corresponding value - // from the multiples tensor. - auto multiples_data = multiples_t->vec<int32>(); + std::vector<int64> data; + if (multiples_t->dtype() == DT_INT32) { + data = AsInt64<int32>(multiples_t, rank); + } else { + data = AsInt64<int64>(multiples_t, rank); + } std::vector<DimensionHandle> dims(rank); for (int i = 0; i < rank; ++i) { - const int32 multiple = multiples_data(i); + const int64 multiple = data[i]; TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input, i), multiple, &dims[i])); } c->set_output(0, c->MakeShape(dims)); @@ -1968,10 +2007,11 @@ where(input) ==> [[0, 0, 0], // -------------------------------------------------------------------------- REGISTER_OP("BroadcastGradientArgs") - .Input("s0: int32") - .Input("s1: int32") - .Output("r0: int32") - .Output("r1: int32") + .Input("s0: T") + .Input("s1: T") + .Output("r0: T") + .Output("r1: T") + .Attr("T: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { // TODO(mrry): Implement constant_value for BroadcastGradientArgs? ShapeHandle unused; @@ -1990,9 +2030,10 @@ This is typically used by gradient computations for a broadcasting operation. // -------------------------------------------------------------------------- REGISTER_OP("Pad") .Input("input: T") - .Input("paddings: int32") + .Input("paddings: Tpaddings") .Output("output: T") .Attr("T: type") + .Attr("Tpaddings: {int32, int64} = DT_INT32") .SetShapeFn(PadShapeFn) .Doc(R"doc( Pads a tensor with zeros. @@ -2025,9 +2066,10 @@ pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] // -------------------------------------------------------------------------- REGISTER_OP("MirrorPad") .Input("input: T") - .Input("paddings: int32") + .Input("paddings: Tpaddings") .Output("output: T") .Attr("T: type") + .Attr("Tpaddings: {int32, int64} = DT_INT32") .Attr(GetMirrorPadModeAttrString()) .SetShapeFn(PadShapeFn) .Doc(R"doc( @@ -2071,11 +2113,33 @@ output: The padded tensor. )doc"); // -------------------------------------------------------------------------- +namespace { +template <typename T> +Status MirrorPadKnown(InferenceContext* c, ShapeHandle input, + const Tensor* paddings_t, int32 input_rank) { + auto paddings_data = paddings_t->matrix<T>(); + std::vector<DimensionHandle> dims(input_rank); + for (int i = 0; i < input_rank; ++i) { + const int64 pad0 = static_cast<int64>(paddings_data(i, 0)); + const int64 pad1 = static_cast<int64>(paddings_data(i, 1)); + if (pad0 < 0 || pad1 < 0) { + return errors::InvalidArgument("Paddings must be non-negative"); + } + + TF_RETURN_IF_ERROR(c->Subtract(c->Dim(input, i), pad0 + pad1, &dims[i])); + } + c->set_output(0, c->MakeShape(dims)); + return Status::OK(); +} + +} // namespace + REGISTER_OP("MirrorPadGrad") .Input("input: T") - .Input("paddings: int32") + .Input("paddings: Tpaddings") .Output("output: T") .Attr("T: type") + .Attr("Tpaddings: {int32, int64} = DT_INT32") .Attr(GetMirrorPadModeAttrString()) .SetShapeFn([](InferenceContext* c) { ShapeHandle paddings; @@ -2103,20 +2167,11 @@ REGISTER_OP("MirrorPadGrad") return Status::OK(); } - auto paddings_data = paddings_t->matrix<int32>(); - std::vector<DimensionHandle> dims(input_rank); - for (int i = 0; i < input_rank; ++i) { - const int64 pad0 = static_cast<int64>(paddings_data(i, 0)); - const int64 pad1 = static_cast<int64>(paddings_data(i, 1)); - if (pad0 < 0 || pad1 < 0) { - return errors::InvalidArgument("Paddings must be non-negative"); - } - - TF_RETURN_IF_ERROR( - c->Subtract(c->Dim(input, i), pad0 + pad1, &dims[i])); + if (paddings_t->dtype() == DT_INT32) { + return MirrorPadKnown<int32>(c, input, paddings_t, input_rank); + } else { + return MirrorPadKnown<int64>(c, input, paddings_t, input_rank); } - c->set_output(0, c->MakeShape(dims)); - return Status::OK(); }) .Doc(R"doc( Gradient op for `MirrorPad` op. This op folds a mirror-padded tensor. @@ -2218,9 +2273,10 @@ shape: The (possibly partial) shape of the tensor. // -------------------------------------------------------------------------- REGISTER_OP("ExpandDims") .Input("input: T") - .Input("dim: int32") + .Input("dim: Tdim") .Output("output: T") .Attr("T: type") + .Attr("Tdim: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { ShapeHandle input = c->input(0); ShapeHandle expand_dim; @@ -2231,7 +2287,14 @@ REGISTER_OP("ExpandDims") c->set_output(0, c->UnknownShape()); return Status::OK(); } - int32 which_dim = dim_t->flat<int32>()(0); + + int64 which_dim; + if (dim_t->dtype() == DT_INT32) { + which_dim = static_cast<int64>(dim_t->flat<int32>()(0)); + } else { + which_dim = dim_t->flat<int64>()(0); + } + if (which_dim < 0) { which_dim += c->Rank(input) + 1; } @@ -2388,8 +2451,9 @@ REGISTER_OP("ListDiff") .Input("x: T") .Input("y: T") .Output("out: T") - .Output("idx: int32") + .Output("idx: out_idx") .Attr("T: type") + .Attr("out_idx: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); @@ -2434,9 +2498,10 @@ idx: 1-D. Positions of `x` values preserved in `out`. // -------------------------------------------------------------------------- REGISTER_OP("SpaceToBatch") .Input("input: T") - .Input("paddings: int32") + .Input("paddings: Tpaddings") .Output("output: T") .Attr("T: type") + .Attr("Tpaddings: {int32, int64} = DT_INT32") .Attr("block_size: int >= 2") .SetShapeFn([](InferenceContext* c) { ShapeHandle input; @@ -2470,11 +2535,20 @@ REGISTER_OP("SpaceToBatch") output_height = c->UnknownDim(); output_width = c->UnknownDim(); } else { - auto pad_matrix = paddings_t->matrix<int32>(); - const int32 pad_top = pad_matrix(0, 0); - const int32 pad_bottom = pad_matrix(0, 1); - const int32 pad_left = pad_matrix(1, 0); - const int32 pad_right = pad_matrix(1, 1); + int64 pad_top, pad_bottom, pad_left, pad_right; + if (paddings_t->dtype() == DT_INT32) { + auto pad_matrix = paddings_t->matrix<int32>(); + pad_top = pad_matrix(0, 0); + pad_bottom = pad_matrix(0, 1); + pad_left = pad_matrix(1, 0); + pad_right = pad_matrix(1, 1); + } else { + auto pad_matrix = paddings_t->matrix<int64>(); + pad_top = pad_matrix(0, 0); + pad_bottom = pad_matrix(0, 1); + pad_left = pad_matrix(1, 0); + pad_right = pad_matrix(1, 1); + } if (pad_top < 0 || pad_bottom < 0 || pad_left < 0 || pad_right < 0) { return errors::InvalidArgument("Paddings cannot be negative."); @@ -2599,10 +2673,11 @@ regular convolution. // -------------------------------------------------------------------------- REGISTER_OP("BatchToSpace") .Input("input: T") - .Input("crops: int32") + .Input("crops: Tidx") .Output("output: T") .Attr("T: type") .Attr("block_size: int >= 2") + .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { ShapeHandle input; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); @@ -2640,11 +2715,20 @@ REGISTER_OP("BatchToSpace") output_height = c->UnknownDim(); output_width = c->UnknownDim(); } else { - auto crops_matrix = crops_t->matrix<int32>(); - const int32 crops_top = crops_matrix(0, 0); - const int32 crops_bottom = crops_matrix(0, 1); - const int32 crops_left = crops_matrix(1, 0); - const int32 crops_right = crops_matrix(1, 1); + int64 crops_top, crops_bottom, crops_left, crops_right; + if (crops_t->dtype() == DT_INT32) { + auto crops_matrix = crops_t->matrix<int32>(); + crops_top = crops_matrix(0, 0); + crops_bottom = crops_matrix(0, 1); + crops_left = crops_matrix(1, 0); + crops_right = crops_matrix(1, 1); + } else { + auto crops_matrix = crops_t->matrix<int64>(); + crops_top = crops_matrix(0, 0); + crops_bottom = crops_matrix(0, 1); + crops_left = crops_matrix(1, 0); + crops_right = crops_matrix(1, 1); + } if (crops_top < 0 || crops_bottom < 0 || crops_left < 0 || crops_right < 0) { diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index 42c69f0950..b3c07fedae 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -293,8 +293,8 @@ TEST(ArrayOpsTest, PadD_ShapeFn) { // Make the paddings tensor known and verify padding values get added. // E.g., if padding is ((1,10),(2,20),(3,30)) then values 11,22,23 are added // to input dims to get output. - Tensor paddings_t(DT_INT32, TensorShape{3, 2}); - test::FillValues<int32>(&paddings_t, {1, 10, 2, 20, 3, 30}); + Tensor paddings_t(DT_INT64, TensorShape{3, 2}); + test::FillValues<int64>(&paddings_t, {1, 10, 2, 20, 3, 30}); op.input_tensors[1] = &paddings_t; INFER_OK(op, "[100,200,300];[3,2]", "[111,222,333]"); INFER_OK(op, "[100,?,300];[3,2]", "[111,?,333]"); @@ -326,8 +326,8 @@ TEST(ArrayOpsTest, MirrorPadGrad_ShapeFn) { // Make the paddings tensor known and verify padding values get // subtracted. E.g., if padding is ((1,10),(2,20),(3,30)) then // values 11,22,23 are subtracted to input dims to get output. - Tensor paddings_t(DT_INT32, TensorShape{3, 2}); - test::FillValues<int32>(&paddings_t, {1, 10, 2, 20, 3, 30}); + Tensor paddings_t(DT_INT64, TensorShape{3, 2}); + test::FillValues<int64>(&paddings_t, {1, 10, 2, 20, 3, 30}); op.input_tensors[1] = &paddings_t; INFER_OK(op, "[111,222,333];[3,2]", "[100,200,300]"); @@ -823,6 +823,9 @@ TEST(ArrayOpsTest, Tile_ShapeFn) { Tensor multiples = test::AsTensor<int32>({2, 3, 4, 5}); op.input_tensors[1] = &multiples; INFER_OK(op, "[2,3,1,4];[4]", "[4,9,4,20]"); + // Test 64-bit tensor type + multiples = test::AsTensor<int64>({2, 3, 4, 5}); + INFER_OK(op, "[2,3,1,4];[4]", "[4,9,4,20]"); } TEST(ArrayOpsTest, EditDistance_ShapeFn) { @@ -933,6 +936,8 @@ TEST(ArrayOpsTest, SpaceToBatch_ShapeFn) { Tensor paddings = test::AsTensor<int32>({4, 2, 2, 4}, {{2, 2}}); op.input_tensors[1] = &paddings; INFER_OK(op, "[1,10,10,3];[2,2]", "[4,8,8,d0_3]"); + paddings = test::AsTensor<int64>({4, 2, 2, 4}, {{2, 2}}); + INFER_OK(op, "[1,10,10,3];[2,2]", "[4,8,8,d0_3]"); // Bad paddings values paddings = test::AsTensor<int32>({1, 2, 3, 4}, {{2, 2}}); @@ -970,7 +975,7 @@ TEST(ArrayOpsTest, BatchToSpace_ShapeFn) { INFER_ERROR("BatchToSpace requires crops with shape [2,2]", op, "[4,8,8,3];[2,3]"); - Tensor croppings = test::AsTensor<int32>({4, 2, 2, 4}, {{2, 2}}); + Tensor croppings = test::AsTensor<int64>({4, 2, 2, 4}, {{2, 2}}); op.input_tensors[1] = &croppings; INFER_OK(op, "[4,8,8,3];[2,2]", "[1,10,10,d0_3]"); diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 7c5a6cf11b..f034b58a7b 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1015,10 +1015,11 @@ matrix multiply on one platform was 30% zero values in the sparse matrix. // dimensions of the input. REGISTER_OP("Sum") .Input("input: T") - .Input("reduction_indices: int32") + .Input("reduction_indices: Tidx") .Output("output: T") .Attr("keep_dims: bool = false") .Attr("T: numbertype") + .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn(shape_inference::ReductionShape) .Doc(R"doc( Computes the sum of elements across dimensions of a tensor. @@ -1036,10 +1037,11 @@ output: The reduced tensor. REGISTER_OP("Mean") .Input("input: T") - .Input("reduction_indices: int32") + .Input("reduction_indices: Tidx") .Output("output: T") .Attr("keep_dims: bool = false") .Attr("T: numbertype") + .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn(shape_inference::ReductionShape) .Doc(R"doc( Computes the mean of elements across dimensions of a tensor. @@ -1057,10 +1059,11 @@ output: The reduced tensor. REGISTER_OP("Prod") .Input("input: T") - .Input("reduction_indices: int32") + .Input("reduction_indices: Tidx") .Output("output: T") .Attr("keep_dims: bool = false") .Attr("T: numbertype") + .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn(shape_inference::ReductionShape) .Doc(R"doc( Computes the product of elements across dimensions of a tensor. @@ -1078,10 +1081,11 @@ output: The reduced tensor. REGISTER_OP("Min") .Input("input: T") - .Input("reduction_indices: int32") + .Input("reduction_indices: Tidx") .Output("output: T") .Attr("keep_dims: bool = false") .Attr("T: numbertype") + .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn(shape_inference::ReductionShape) .Doc(R"doc( Computes the minimum of elements across dimensions of a tensor. @@ -1099,10 +1103,11 @@ output: The reduced tensor. REGISTER_OP("Max") .Input("input: T") - .Input("reduction_indices: int32") + .Input("reduction_indices: Tidx") .Output("output: T") .Attr("keep_dims: bool = false") .Attr("T: numbertype") + .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn(shape_inference::ReductionShape) .Doc(R"doc( Computes the maximum of elements across dimensions of a tensor. @@ -1149,7 +1154,13 @@ Status ArgOpShape(shape_inference::InferenceContext* c) { return Status::OK(); } - const int32 dimension_val = dim_t->scalar<int32>()(); + int64 dimension_val; + if (dim_t->dtype() == DT_INT32) { + dimension_val = dim_t->scalar<int32>()(); + } else { + dimension_val = dim_t->scalar<int64>()(); + } + if (dimension_val < 0 || dimension_val >= input_rank) { return errors::InvalidArgument("Dimension (", dimension_val, ") must be in the range [0, ", input_rank, @@ -1172,9 +1183,10 @@ Status ArgOpShape(shape_inference::InferenceContext* c) { REGISTER_OP("ArgMax") .Input("input: T") - .Input("dimension: int32") + .Input("dimension: Tidx") .Output("output: int64") .Attr("T: numbertype") + .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn(ArgOpShape) .Doc(R"doc( Returns the index with the largest value across dimensions of a tensor. @@ -1185,9 +1197,10 @@ dimension: int32, 0 <= dimension < rank(input). Describes which dimension REGISTER_OP("ArgMin") .Input("input: T") - .Input("dimension: int32") + .Input("dimension: Tidx") .Output("output: int64") .Attr("T: numbertype") + .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn(ArgOpShape) .Doc(R"doc( Returns the index with the smallest value across dimensions of a tensor. @@ -1490,10 +1503,11 @@ output: Has same shape as data, except for the first `segment_ids.rank` REGISTER_OP("SparseSegmentSum") .Input("data: T") - .Input("indices: int32") + .Input("indices: Tidx") .Input("segment_ids: int32") .Output("output: T") .Attr("T: realnumbertype") + .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn(SparseSegmentReductionShapeFn) .Doc(R"doc( Computes the sum along sparse segments of a tensor. @@ -1538,10 +1552,11 @@ output: Has same shape as data, except for dimension 0 which REGISTER_OP("SparseSegmentMean") .Input("data: T") - .Input("indices: int32") + .Input("indices: Tidx") .Input("segment_ids: int32") .Output("output: T") .Attr("T: {float, double}") + .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn(SparseSegmentReductionShapeFn) .Doc(R"doc( Computes the mean along sparse segments of a tensor. @@ -1564,11 +1579,12 @@ output: Has same shape as data, except for dimension 0 which REGISTER_OP("SparseSegmentMeanGrad") .Input("grad: T") - .Input("indices: int32") + .Input("indices: Tidx") .Input("segment_ids: int32") .Input("output_dim0: int32") .Output("output: T") .Attr("T: {float, double}") + .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn(SparseSegmentReductionGradShapeFn) .Doc(R"doc( Computes gradients for SparseSegmentMean. @@ -1584,10 +1600,11 @@ output_dim0: dimension 0 of "data" passed to SparseSegmentMean op. REGISTER_OP("SparseSegmentSqrtN") .Input("data: T") - .Input("indices: int32") + .Input("indices: Tidx") .Input("segment_ids: int32") .Output("output: T") .Attr("T: {float, double}") + .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn(SparseSegmentReductionShapeFn) .Doc(R"doc( Computes the sum along sparse segments of a tensor divided by the sqrt of N. @@ -1609,11 +1626,12 @@ output: Has same shape as data, except for dimension 0 which REGISTER_OP("SparseSegmentSqrtNGrad") .Input("grad: T") - .Input("indices: int32") + .Input("indices: Tidx") .Input("segment_ids: int32") .Input("output_dim0: int32") .Output("output: T") .Attr("T: {float, double}") + .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn(SparseSegmentReductionGradShapeFn) .Doc(R"doc( Computes gradients for SparseSegmentSqrtN. @@ -1629,9 +1647,10 @@ output_dim0: dimension 0 of "data" passed to SparseSegmentSqrtN op. REGISTER_OP("All") .Input("input: bool") - .Input("reduction_indices: int32") + .Input("reduction_indices: Tidx") .Output("output: bool") .Attr("keep_dims: bool = false") + .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn(shape_inference::ReductionShape) .Doc(R"doc( Computes the "logical and" of elements across dimensions of a tensor. @@ -1649,9 +1668,10 @@ output: The reduced tensor. REGISTER_OP("Any") .Input("input: bool") - .Input("reduction_indices: int32") + .Input("reduction_indices: Tidx") .Attr("keep_dims: bool = false") .Output("output: bool") + .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn(shape_inference::ReductionShape) .Doc(R"doc( Computes the "logical or" of elements across dimensions of a tensor. @@ -1670,10 +1690,11 @@ output: The reduced tensor. // -------------------------------------------------------------------------- REGISTER_OP("Range") - .Input("start: int32") - .Input("limit: int32") - .Input("delta: int32") - .Output("output: int32") + .Input("start: Tidx") + .Input("limit: Tidx") + .Input("delta: Tidx") + .Output("output: Tidx") + .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { ShapeHandle unused; TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused), @@ -1689,9 +1710,17 @@ REGISTER_OP("Range") c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); return Status::OK(); } - const int32 start = start_t->scalar<int32>()(); - const int32 limit = limit_t->scalar<int32>()(); - const int32 delta = delta_t->scalar<int32>()(); + // TODO + int64 start, limit, delta; + if (start_t->dtype() == DT_INT32) { + start = start_t->scalar<int32>()(); + limit = limit_t->scalar<int32>()(); + delta = delta_t->scalar<int32>()(); + } else { + start = start_t->scalar<int64>()(); + limit = limit_t->scalar<int64>()(); + delta = delta_t->scalar<int64>()(); + } if (start > limit) { return errors::InvalidArgument("Requires start <= limit: ", start, "/", limit); @@ -1699,7 +1728,7 @@ REGISTER_OP("Range") if (delta <= 0) { return errors::InvalidArgument("Requires delta > 0: ", delta); } - const int32 size = (limit - start + delta - 1) / delta; + const int64 size = (limit - start + delta - 1) / delta; c->set_output(0, c->Vector(size)); return Status::OK(); }) @@ -1727,9 +1756,10 @@ output: 1-D. REGISTER_OP("LinSpace") .Input("start: T") .Input("stop: T") - .Input("num: int32") + .Input("num: Tidx") .Output("output: T") .Attr("T: {float, double}") + .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { ShapeHandle unused; TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused), @@ -1743,7 +1773,13 @@ REGISTER_OP("LinSpace") c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); return Status::OK(); } - const int64 num = num_t->scalar<int32>()(); + + int64 num; + if (num_t->dtype() == DT_INT32) { + num = num_t->scalar<int32>()(); + } else { + num = num_t->scalar<int64>()(); + } if (num <= 0) return errors::InvalidArgument("Requires num > 0: ", num); c->set_output(0, c->Vector(num)); return Status::OK(); @@ -2057,11 +2093,12 @@ product: Pairwise cross product of the vectors in `a` and `b`. REGISTER_OP("Cumsum") .Input("x: T") - .Input("axis: int32") + .Input("axis: Tidx") .Attr("exclusive: bool = false") .Attr("reverse: bool = false") .Output("out: T") .Attr("T: numbertype") + .Attr("Tidx: {int32, int64} = DT_INT32") .Doc(R"doc( Compute the cumulative sum of the tensor `x` along `axis`. @@ -2092,11 +2129,12 @@ tf.cumsum([a, b, c], exclusive=True, reverse=True) ==> [b + c, c, 0] REGISTER_OP("Cumprod") .Input("x: T") - .Input("axis: int32") + .Input("axis: Tidx") .Attr("exclusive: bool = false") .Attr("reverse: bool = false") .Output("out: T") .Attr("T: numbertype") + .Attr("Tidx: {int32, int64} = DT_INT32") .Doc(R"doc( Compute the cumulative product of the tensor `x` along `axis`. diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py index b2e1dfce27..b34bb068d7 100644 --- a/tensorflow/python/framework/op_def_library.py +++ b/tensorflow/python/framework/op_def_library.py @@ -349,6 +349,23 @@ class OpDefLibrary(object): (op_type_name, producer, deprecation_version, op_def.deprecation.explanation)) + # Fill in the list of default types for all "type" attrs. This + # will be used to choose a preferred dtype to convert to in the + # absence of input type information. + # + # TODO(b/31302892): Currently the defaults don't work in the right + # way if you have two inputs, one of whose type resolution depends + # on the other. Handling this will require restructuring this code + # significantly. + default_type_attr_map = {} + for attr_def in op_def.attr: + if attr_def.type != "type": + continue + key = attr_def.name + if attr_def.HasField("default_value"): + default_type_attr_map[key] = dtypes.as_dtype( + attr_def.default_value.type) + # Requires that op_def has passed validation (using the C++ # ValidateOpDef() from ../framework/op_def_util.h). attrs = {} @@ -390,6 +407,7 @@ class OpDefLibrary(object): # In cases where we expect all elements of the list to have the # same dtype, try to cast non-Tensor elements to that type. dtype = None + default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.number_attr: @@ -401,11 +419,19 @@ class OpDefLibrary(object): dtype = t.dtype break + # dtype still not found, prefer using the default dtype + # from the attr. + if dtype is None and input_arg.type_attr in default_type_attr_map: + default_dtype = default_type_attr_map[input_arg.type_attr] + try: if not input_arg.is_ref and dtype: dtype = dtypes.as_dtype(dtype).base_dtype values = ops.convert_n_to_tensor( - values, name=input_arg.name, dtype=dtype if dtype else None, + values, + name=input_arg.name, + dtype=dtype if dtype else None, + preferred_dtype=default_dtype, as_ref=input_arg.is_ref) if input_arg.number_attr and len( set(v.dtype.base_dtype for v in values)) > 1: @@ -444,14 +470,24 @@ class OpDefLibrary(object): # In cases where we have an expected type, try to convert non-Tensor # arguments to that type. dtype = None + default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] + elif input_arg.type_attr in default_type_attr_map: + # The dtype could not be inferred solely from the inputs, + # so we prefer the attr's default, so code that adds a new attr + # with a default is backwards compatible. + default_dtype = default_type_attr_map[input_arg.type_attr] + try: values = ops.convert_to_tensor( - values, name=input_arg.name, dtype=dtype, - as_ref=input_arg.is_ref) + values, + name=input_arg.name, + dtype=dtype, + as_ref=input_arg.is_ref, + preferred_dtype=default_dtype) except ValueError: # What type does convert_to_tensor think it has? observed = ops.convert_to_tensor(values, @@ -462,6 +498,14 @@ class OpDefLibrary(object): raise TypeError("%s expected type of %s." % (prefix, dtypes.as_dtype(input_arg.type).name)) else: + # Update the maps with the default, if needed. + k = input_arg.type_attr + if k in default_type_attr_map: + if k not in attrs: + attrs[k] = default_type_attr_map[k] + if k not in inferred_from: + inferred_from[k] = "Default in OpDef" + raise TypeError( "%s type %s of argument '%s'." % (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name, diff --git a/tensorflow/python/framework/op_def_library_test.py b/tensorflow/python/framework/op_def_library_test.py index 27933cc193..d95bb45cb8 100644 --- a/tensorflow/python/framework/op_def_library_test.py +++ b/tensorflow/python/framework/op_def_library_test.py @@ -47,6 +47,8 @@ ops.RegisterShape("AttrShape")(None) ops.RegisterShape("AttrShapeList")(None) ops.RegisterShape("AttrPartialShape")(None) ops.RegisterShape("AttrPartialShapeList")(None) +ops.RegisterShape("AttrTypeDefault")(None) +ops.RegisterShape("AttrListTypeDefault")(None) ops.RegisterShape("Binary")(None) ops.RegisterShape("ComplexStruct")(None) ops.RegisterShape("InPolymorphicTwice")(None) @@ -868,6 +870,44 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): name: 'x' op: 'ReservedAttr' attr { key: 'range' value { i: 7 } } """, op.node_def) + def testDefaultAttrType(self): + self._add_op("name: 'AttrTypeDefault' " + "input_arg { name: 'a' type_attr: 'T' } " + "attr { name: 'T' type: 'type' " + " default_value { type: DT_INT32 } }") + + # Give an input whose type has no obvious output type. + op = self._lib.apply_op("AttrTypeDefault", a=[], name="n") + self.assertProtoEquals(""" + name: 'n' op: 'AttrTypeDefault' input: 'n/a' + attr { key: 'T' value { type: DT_INT32 } } + """, op.node_def) + + # Give an input whose type can be inferred as different + # than the default. + op = self._lib.apply_op("AttrTypeDefault", a=[1.0], name="f") + self.assertProtoEquals(""" + name: 'f' op: 'AttrTypeDefault' input: 'f/a' + attr { key: 'T' value { type: DT_FLOAT } } + """, op.node_def) + + def testDefaultListAttrType(self): + self._add_op("name: 'AttrListTypeDefault' " + "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } " + "input_arg { name: 'b' type_attr: 'T' number_attr: 'N' } " + "attr { name: 'T' type: 'type' " + " default_value { type: DT_INT32 } }" + "attr { name: 'N' type: 'int' }") + + # Give an input whose type can be inferred as different + # than the default. + op = self._lib.apply_op("AttrListTypeDefault", a=[1.0], b=[2.0], name="n") + self.assertProtoEquals(""" + name: 'n' op: 'AttrListTypeDefault' input: 'n/a_0' input: 'n/b_0' + attr { key: 'T' value { type: DT_FLOAT } } + attr { key: 'N' value { i: 1 } } + """, op.node_def) + def testNIntsIn(self): self._add_op("name: 'NIntsIn' " "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } " diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index f47ecd1b36..a5e253eabf 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -577,7 +577,11 @@ _tensor_conversion_func_registry = { register_dense_tensor_like_type(Tensor) -def convert_to_tensor(value, dtype=None, name=None, as_ref=False): +def convert_to_tensor(value, + dtype=None, + name=None, + as_ref=False, + preferred_dtype=None): """Converts the given `value` to a `Tensor`. This function converts Python objects of various types to `Tensor` @@ -610,6 +614,11 @@ def convert_to_tensor(value, dtype=None, name=None, as_ref=False): name: Optional name to use if a new `Tensor` is created. as_ref: True if we want the result as a ref tensor. Only used if a new `Tensor` is created. + preferred_dtype: Optional element type for the returned tensor, + used when dtype is None. In some cases, a caller may not have a + dtype in mind when converting to a tensor, so preferred_dtype + can be used as a soft preference. If the conversion to + `preferred_dtype` is not possible, this argument has no effect. Returns: A `Tensor` based on `value`. @@ -625,9 +634,31 @@ def convert_to_tensor(value, dtype=None, name=None, as_ref=False): for _, funcs_at_priority in sorted(_tensor_conversion_func_registry.items()): for base_type, conversion_func in funcs_at_priority: if isinstance(value, base_type): - ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref) + # If dtype is None but preferred_dtype is not None, we try to + # cast to preferred_dtype first. + ret = None + if dtype is None and preferred_dtype is not None: + try: + ret = conversion_func( + value, dtype=preferred_dtype, name=name, as_ref=as_ref) + except (TypeError, ValueError): + # Could not coerce the conversion to use the preferred dtype. + ret = None + + if ret is not None and ret is not NotImplemented: + if (ret.dtype.base_dtype != + dtypes.as_dtype(preferred_dtype).base_dtype): + raise TypeError("convert_to_tensor did not convert to " + "the preferred dtype: %s vs %s " % + (ret.dtype.base_dtype, + dtypes.as_dtype(preferred_dtype).base_dtype)) + + if ret is None: + ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref) + if ret is NotImplemented: continue + if not isinstance(ret, Tensor): raise RuntimeError( "%sConversion function %r for type %s returned non-Tensor: %r" @@ -644,7 +675,11 @@ def convert_to_tensor(value, dtype=None, name=None, as_ref=False): % (error_prefix, value, type(value))) -def convert_n_to_tensor(values, dtype=None, name=None, as_ref=False): +def convert_n_to_tensor(values, + dtype=None, + name=None, + as_ref=False, + preferred_dtype=None): """Converts `values` to a list of `Tensor` objects. Args: @@ -654,6 +689,11 @@ def convert_n_to_tensor(values, dtype=None, name=None, as_ref=False): created, in which case element `i` will be given the name `name + '_' + i`. as_ref: True if the caller wants the results as ref tensors. + preferred_dtype: Optional element type for the returned tensors, + used when dtype is None. In some cases, a caller may not have a + dtype in mind when converting to a tensor, so preferred_dtype + can be used as a soft preference. If the conversion to + `preferred_dtype` is not possible, this argument has no effect. Returns: A list of `Tensor` and/or `IndexedSlices` objects. @@ -669,7 +709,13 @@ def convert_n_to_tensor(values, dtype=None, name=None, as_ref=False): ret = [] for i, value in enumerate(values): n = None if name is None else "%s_%d" % (name, i) - ret.append(convert_to_tensor(value, dtype=dtype, name=n, as_ref=as_ref)) + ret.append( + convert_to_tensor( + value, + dtype=dtype, + name=n, + as_ref=as_ref, + preferred_dtype=preferred_dtype)) return ret diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 98ad4a2d60..eac85ac844 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -280,6 +280,31 @@ class OperationTest(test_util.TensorFlowTestCase): self.assertEqual(tensor_shape.unknown_shape(), _apply_op(g, "an_op", [], [dtypes.float32]).get_shape()) + def testConvertToTensorPreferred(self): + with self.test_session(): + values = [2, 3, 5, 7] + tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32) + self.assertEqual(dtypes.float32, tensor.dtype) + + with self.test_session(): + # Convert empty tensor to anything + values = [] + tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) + self.assertEqual(dtypes.int64, tensor.dtype) + + with self.test_session(): + # The preferred dtype is a type error and will convert to + # float32 instead. + values = [1.23] + tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) + self.assertEqual(dtypes.float32, tensor.dtype) + + def testConvertToInvalidTensorType(self): + with self.assertRaises(TypeError): + # Forcing an invalid dtype should fail with a type error. + values = [1.23] + _ = ops.convert_to_tensor(values, dtype=dtypes.int64) + def testNoConvert(self): # Operation cannot be converted to Tensor op = control_flow_ops.no_op() |