aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/common_runtime/function_test.cc66
-rw-r--r--tensorflow/core/kernels/transpose_op.cc22
-rw-r--r--tensorflow/core/ops/array_ops.cc212
-rw-r--r--tensorflow/core/ops/array_ops_test.cc15
-rw-r--r--tensorflow/core/ops/math_ops.cc92
-rw-r--r--tensorflow/python/framework/op_def_library.py50
-rw-r--r--tensorflow/python/framework/op_def_library_test.py40
-rw-r--r--tensorflow/python/framework/ops.py54
-rw-r--r--tensorflow/python/framework/ops_test.py25
9 files changed, 430 insertions, 146 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index e263e62bd8..523547b4eb 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -534,11 +534,11 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
(n2:float, n3:float) -> (n9:float) {
n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [0] values: >]()
n10 = Const[dtype=float, value=Tensor<type: float shape: [] values: 2>]()
- n6 = Shape[T=float](n2)
+ n6 = Shape[T=float, out_type=int32](n2)
n5 = Mul[T=float](n3, n10)
- n7 = BroadcastGradientArgs(n6, n11)
- n8 = Sum[T=float, keep_dims=false](n5, n7)
- n9 = Reshape[T=float](n8, n6)
+ n7 = BroadcastGradientArgs[T=int32](n6, n11)
+ n8 = Sum[T=float, Tidx=int32, keep_dims=false](n5, n7)
+ n9 = Reshape[T=float, Tshape=int32](n8, n6)
}
)P";
EXPECT_EQ(e2, DebugString(g));
@@ -555,13 +555,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_Add) {
(n7:float, n5:float, n2:float) -> (n14:float, n11:float) {
n3 = Identity[T=float](n2)
n4 = Identity[T=float](n2)
- n6 = Shape[T=float](n5)
- n8 = Shape[T=float](n7)
- n9 = BroadcastGradientArgs(n8, n6)
- n10 = Sum[T=float, keep_dims=false](n3, n9:1)
- n13 = Sum[T=float, keep_dims=false](n4, n9)
- n11 = Reshape[T=float](n10, n6)
- n14 = Reshape[T=float](n13, n8)
+ n6 = Shape[T=float, out_type=int32](n5)
+ n8 = Shape[T=float, out_type=int32](n7)
+ n9 = BroadcastGradientArgs[T=int32](n8, n6)
+ n10 = Sum[T=float, Tidx=int32, keep_dims=false](n3, n9:1)
+ n13 = Sum[T=float, Tidx=int32, keep_dims=false](n4, n9)
+ n11 = Reshape[T=float, Tshape=int32](n10, n6)
+ n14 = Reshape[T=float, Tshape=int32](n13, n8)
}
)P";
EXPECT_EQ(e0, DebugString(g));
@@ -576,14 +576,14 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_Mul) {
const char* e0 = R"P(
(n6:float, n3:float, n2:float) -> (n14:float, n11:float) {
n4 = Mul[T=float](n2, n3)
- n5 = Shape[T=float](n3)
+ n5 = Shape[T=float, out_type=int32](n3)
n7 = Mul[T=float](n6, n2)
- n8 = Shape[T=float](n6)
- n9 = BroadcastGradientArgs(n8, n5)
- n10 = Sum[T=float, keep_dims=false](n7, n9:1)
- n13 = Sum[T=float, keep_dims=false](n4, n9)
- n11 = Reshape[T=float](n10, n5)
- n14 = Reshape[T=float](n13, n8)
+ n8 = Shape[T=float, out_type=int32](n6)
+ n9 = BroadcastGradientArgs[T=int32](n8, n5)
+ n10 = Sum[T=float, Tidx=int32, keep_dims=false](n7, n9:1)
+ n13 = Sum[T=float, Tidx=int32, keep_dims=false](n4, n9)
+ n11 = Reshape[T=float, Tshape=int32](n10, n5)
+ n14 = Reshape[T=float, Tshape=int32](n13, n8)
}
)P";
EXPECT_EQ(e0, DebugString(g));
@@ -643,10 +643,10 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
n24 = Identity[T=float](n4)
n14 = Add[T=float](n24, n25)
n15 = Rank[T=float](n14)
- n16 = Range(n11, n15, n10)
+ n16 = Range[Tidx=int32](n11, n15, n10)
n20 = ZerosLike[T=int32](n15)
- n17 = Sum[T=float, keep_dims=false](n14, n16)
- n19 = SymbolicGradient[Tin={float, int32, float}, Tout={float, int32}, f=Sum[T=float, keep_dims=false]](n14, n16, n26)
+ n17 = Sum[T=float, Tidx=int32, keep_dims=false](n14, n16)
+ n19 = SymbolicGradient[Tin={float, int32, float}, Tout={float, int32}, f=Sum[T=float, Tidx=int32, keep_dims=false]](n14, n16, n26)
n21 = SymbolicGradient[Tin={float, float, float}, Tout={float, float}, f=Add[T=float]](n24, n25, n19)
n28 = Identity[T=float](n21:1)
n27 = Identity[T=float](n21)
@@ -662,23 +662,23 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]()
n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]()
n7 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
- n19 = Shape[T=float](n3)
+ n19 = Shape[T=float, out_type=int32](n3)
n8 = Add[T=float](n4, n3)
- n20 = Shape[T=float](n4)
+ n20 = Shape[T=float, out_type=int32](n4)
n9 = Rank[T=float](n8)
- n14 = Shape[T=float](n8)
- n21 = BroadcastGradientArgs(n20, n19)
- n10 = Range(n7, n9, n11)
- n12 = Shape[T=int32](n10)
+ n14 = Shape[T=float, out_type=int32](n8)
+ n21 = BroadcastGradientArgs[T=int32](n20, n19)
+ n10 = Range[Tidx=int32](n7, n9, n11)
+ n12 = Shape[T=int32, out_type=int32](n10)
n13 = Fill[T=int32](n12, n11)
n15 = DynamicStitch[N=2, T=int32](n10, n10, n14, n13)
- n16 = Reshape[T=float](n2, n15)
+ n16 = Reshape[T=float, Tshape=int32](n2, n15)
n17 = Div[T=int32](n14, n15)
- n18 = Tile[T=float](n16, n17)
- n24 = Sum[T=float, keep_dims=false](n18, n21)
- n22 = Sum[T=float, keep_dims=false](n18, n21:1)
- n25 = Reshape[T=float](n24, n20)
- n23 = Reshape[T=float](n22, n19)
+ n18 = Tile[T=float, Tmultiples=int32](n16, n17)
+ n24 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21)
+ n22 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21:1)
+ n25 = Reshape[T=float, Tshape=int32](n24, n20)
+ n23 = Reshape[T=float, Tshape=int32](n22, n19)
}
)P";
EXPECT_EQ(e2, DebugString(g));
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index fbde0c9626..35429a382e 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -169,11 +169,12 @@ Status TransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
out);
}
-#define REGISTER(T) \
- REGISTER_KERNEL_BUILDER(Name("Transpose") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .HostMemory("perm"), \
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER(Name("Transpose") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int32>("Tperm") \
+ .HostMemory("perm"), \
TransposeCpuOp);
TF_CALL_ALL_TYPES(REGISTER)
REGISTER(bfloat16);
@@ -187,11 +188,12 @@ Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
out);
}
-#define REGISTER(T) \
- REGISTER_KERNEL_BUILDER(Name("Transpose") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<T>("T") \
- .HostMemory("perm"), \
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER(Name("Transpose") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int32>("Tperm") \
+ .HostMemory("perm"), \
TransposeGpuOp);
TF_CALL_POD_TYPES(REGISTER);
#undef REGISTER
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index fc1d031d4b..4ab47754df 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -39,6 +39,34 @@ Status GetAxisForPackAndUnpack(InferenceContext* c, int32 rank_after_pack,
return Status::OK();
}
+template <typename T>
+std::vector<int64> AsInt64(const Tensor* tensor, int num_elements) {
+ std::vector<int64> ret(num_elements);
+ auto data = tensor->vec<T>();
+ for (int i = 0; i < num_elements; ++i) {
+ ret[i] = data(i);
+ }
+ return ret;
+}
+
+template <typename T>
+Status PadKnown(InferenceContext* c, ShapeHandle input,
+ const Tensor* paddings_t, int32 num_dims) {
+ // paddings_t is known.
+ std::vector<DimensionHandle> dims(num_dims);
+ auto paddings_data = paddings_t->matrix<T>();
+ for (int i = 0; i < num_dims; ++i) {
+ const T pad0 = paddings_data(i, 0);
+ const T pad1 = paddings_data(i, 1);
+ if (pad0 < 0 || pad1 < 0) {
+ return errors::InvalidArgument("Paddings must be non-negative");
+ }
+ TF_RETURN_IF_ERROR(c->Add(c->Dim(input, i), pad0 + pad1, &dims[i]));
+ }
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+}
+
Status PadShapeFn(InferenceContext* c) {
// Paddings is a matrix of [input_rank, 2].
ShapeHandle paddings;
@@ -73,19 +101,11 @@ Status PadShapeFn(InferenceContext* c) {
const auto num_dims = c->Value(n_dim);
DCHECK_EQ(num_dims, paddings_t->shape().dim_size(0));
- // paddings_t is known.
- auto paddings_data = paddings_t->matrix<int32>();
- std::vector<DimensionHandle> dims(num_dims);
- for (int i = 0; i < num_dims; ++i) {
- const int32 pad0 = paddings_data(i, 0);
- const int32 pad1 = paddings_data(i, 1);
- if (pad0 < 0 || pad1 < 0) {
- return errors::InvalidArgument("Paddings must be non-negative");
- }
- TF_RETURN_IF_ERROR(c->Add(c->Dim(input, i), pad0 + pad1, &dims[i]));
+ if (paddings_t->dtype() == DT_INT32) {
+ return PadKnown<int32>(c, input, paddings_t, num_dims);
+ } else {
+ return PadKnown<int64>(c, input, paddings_t, num_dims);
}
- c->set_output(0, c->MakeShape(dims));
- return Status::OK();
}
} // namespace
@@ -1127,9 +1147,10 @@ message: Prefix of the error message.
// --------------------------------------------------------------------------
REGISTER_OP("Reshape")
.Input("tensor: T")
- .Input("shape: int32")
+ .Input("shape: Tshape")
.Output("output: T")
.Attr("T: type")
+ .Attr("Tshape: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle in = c->input(0);
ShapeHandle out;
@@ -1247,8 +1268,9 @@ shape: Defines the shape of the output tensor.
// --------------------------------------------------------------------------
REGISTER_OP("InvertPermutation")
- .Input("x: int32")
- .Output("y: int32")
+ .Input("x: T")
+ .Output("y: T")
+ .Attr("T: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle x;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &x));
@@ -1281,9 +1303,10 @@ y: 1-D.
// --------------------------------------------------------------------------
REGISTER_OP("Transpose")
.Input("x: T")
- .Input("perm: int32")
+ .Input("perm: Tperm")
.Output("y: T")
.Attr("T: type")
+ .Attr("Tperm: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
ShapeHandle perm_shape = c->input(1);
@@ -1318,9 +1341,15 @@ REGISTER_OP("Transpose")
// all shape informantion, otherwise we can only return rank information,
// but no information for the dimensions.
if (perm != nullptr) {
- auto flat_perm = perm->flat<int32>();
+ std::vector<int64> data;
+ if (perm->dtype() == DT_INT32) {
+ data = AsInt64<int32>(perm, rank);
+ } else {
+ data = AsInt64<int64>(perm, rank);
+ }
+
for (int32 i = 0; i < rank; ++i) {
- int32 in_idx = flat_perm(i);
+ int64 in_idx = data[i];
if (in_idx >= rank) {
return errors::InvalidArgument(
"perm dim ", in_idx, " is out of range of input rank ", rank);
@@ -1347,8 +1376,9 @@ The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
REGISTER_OP("Unique")
.Input("x: T")
.Output("y: T")
- .Output("idx: int32")
+ .Output("idx: out_idx")
.Attr("T: type")
+ .Attr("out_idx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
c->set_output(1, c->input(0));
@@ -1382,9 +1412,10 @@ idx: 1-D.
REGISTER_OP("UniqueWithCounts")
.Input("x: T")
.Output("y: T")
- .Output("idx: int32")
- .Output("count: int32")
+ .Output("idx: out_idx")
+ .Output("count: out_idx")
.Attr("T: type")
+ .Attr("out_idx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
auto uniq = c->Vector(InferenceContext::kUnknownDim);
c->set_output(0, uniq);
@@ -1439,8 +1470,9 @@ Status ShapeShapeFn(InferenceContext* c) {
// --------------------------------------------------------------------------
REGISTER_OP("Shape")
.Input("input: T")
- .Output("output: int32")
+ .Output("output: out_type")
.Attr("T: type")
+ .Attr("out_type: {int32, int64} = DT_INT32")
.SetShapeFn(ShapeShapeFn)
.Doc(R"doc(
Returns the shape of a tensor.
@@ -1458,9 +1490,10 @@ shape(t) ==> [2, 2, 3]
REGISTER_OP("ShapeN")
.Input("input: N * T")
- .Output("output: N * int32")
+ .Output("output: N * out_type")
.Attr("N: int")
.Attr("T: type")
+ .Attr("out_type: {int32, int64} = DT_INT32")
.SetShapeFn(ShapeShapeFn)
.Doc(R"doc(
Returns shape of tensors.
@@ -1606,8 +1639,9 @@ of the tensor. Rank is also known as "order", "degree", or "ndims."
// --------------------------------------------------------------------------
REGISTER_OP("Size")
.Input("input: T")
- .Output("output: int32")
+ .Output("output: out_type")
.Attr("T: type")
+ .Attr("out_type: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
Returns the size of a tensor.
@@ -1854,11 +1888,13 @@ shape must be exactly the shape produced by the slice of `ref`.
// TODO(aselle): Fix this documentation once StridedSliceAssign Supports
// broadcasting.
// --------------------------------------------------------------------------
+
REGISTER_OP("Tile")
.Input("input: T")
- .Input("multiples: int32")
+ .Input("multiples: Tmultiples")
.Output("output: T")
.Attr("T: type")
+ .Attr("Tmultiples: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input;
ShapeHandle multiples;
@@ -1882,12 +1918,15 @@ REGISTER_OP("Tile")
return Status::OK();
}
- // Multiply each input dimension by its corresponding value
- // from the multiples tensor.
- auto multiples_data = multiples_t->vec<int32>();
+ std::vector<int64> data;
+ if (multiples_t->dtype() == DT_INT32) {
+ data = AsInt64<int32>(multiples_t, rank);
+ } else {
+ data = AsInt64<int64>(multiples_t, rank);
+ }
std::vector<DimensionHandle> dims(rank);
for (int i = 0; i < rank; ++i) {
- const int32 multiple = multiples_data(i);
+ const int64 multiple = data[i];
TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input, i), multiple, &dims[i]));
}
c->set_output(0, c->MakeShape(dims));
@@ -1968,10 +2007,11 @@ where(input) ==> [[0, 0, 0],
// --------------------------------------------------------------------------
REGISTER_OP("BroadcastGradientArgs")
- .Input("s0: int32")
- .Input("s1: int32")
- .Output("r0: int32")
- .Output("r1: int32")
+ .Input("s0: T")
+ .Input("s1: T")
+ .Output("r0: T")
+ .Output("r1: T")
+ .Attr("T: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
// TODO(mrry): Implement constant_value for BroadcastGradientArgs?
ShapeHandle unused;
@@ -1990,9 +2030,10 @@ This is typically used by gradient computations for a broadcasting operation.
// --------------------------------------------------------------------------
REGISTER_OP("Pad")
.Input("input: T")
- .Input("paddings: int32")
+ .Input("paddings: Tpaddings")
.Output("output: T")
.Attr("T: type")
+ .Attr("Tpaddings: {int32, int64} = DT_INT32")
.SetShapeFn(PadShapeFn)
.Doc(R"doc(
Pads a tensor with zeros.
@@ -2025,9 +2066,10 @@ pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
// --------------------------------------------------------------------------
REGISTER_OP("MirrorPad")
.Input("input: T")
- .Input("paddings: int32")
+ .Input("paddings: Tpaddings")
.Output("output: T")
.Attr("T: type")
+ .Attr("Tpaddings: {int32, int64} = DT_INT32")
.Attr(GetMirrorPadModeAttrString())
.SetShapeFn(PadShapeFn)
.Doc(R"doc(
@@ -2071,11 +2113,33 @@ output: The padded tensor.
)doc");
// --------------------------------------------------------------------------
+namespace {
+template <typename T>
+Status MirrorPadKnown(InferenceContext* c, ShapeHandle input,
+ const Tensor* paddings_t, int32 input_rank) {
+ auto paddings_data = paddings_t->matrix<T>();
+ std::vector<DimensionHandle> dims(input_rank);
+ for (int i = 0; i < input_rank; ++i) {
+ const int64 pad0 = static_cast<int64>(paddings_data(i, 0));
+ const int64 pad1 = static_cast<int64>(paddings_data(i, 1));
+ if (pad0 < 0 || pad1 < 0) {
+ return errors::InvalidArgument("Paddings must be non-negative");
+ }
+
+ TF_RETURN_IF_ERROR(c->Subtract(c->Dim(input, i), pad0 + pad1, &dims[i]));
+ }
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+}
+
+} // namespace
+
REGISTER_OP("MirrorPadGrad")
.Input("input: T")
- .Input("paddings: int32")
+ .Input("paddings: Tpaddings")
.Output("output: T")
.Attr("T: type")
+ .Attr("Tpaddings: {int32, int64} = DT_INT32")
.Attr(GetMirrorPadModeAttrString())
.SetShapeFn([](InferenceContext* c) {
ShapeHandle paddings;
@@ -2103,20 +2167,11 @@ REGISTER_OP("MirrorPadGrad")
return Status::OK();
}
- auto paddings_data = paddings_t->matrix<int32>();
- std::vector<DimensionHandle> dims(input_rank);
- for (int i = 0; i < input_rank; ++i) {
- const int64 pad0 = static_cast<int64>(paddings_data(i, 0));
- const int64 pad1 = static_cast<int64>(paddings_data(i, 1));
- if (pad0 < 0 || pad1 < 0) {
- return errors::InvalidArgument("Paddings must be non-negative");
- }
-
- TF_RETURN_IF_ERROR(
- c->Subtract(c->Dim(input, i), pad0 + pad1, &dims[i]));
+ if (paddings_t->dtype() == DT_INT32) {
+ return MirrorPadKnown<int32>(c, input, paddings_t, input_rank);
+ } else {
+ return MirrorPadKnown<int64>(c, input, paddings_t, input_rank);
}
- c->set_output(0, c->MakeShape(dims));
- return Status::OK();
})
.Doc(R"doc(
Gradient op for `MirrorPad` op. This op folds a mirror-padded tensor.
@@ -2218,9 +2273,10 @@ shape: The (possibly partial) shape of the tensor.
// --------------------------------------------------------------------------
REGISTER_OP("ExpandDims")
.Input("input: T")
- .Input("dim: int32")
+ .Input("dim: Tdim")
.Output("output: T")
.Attr("T: type")
+ .Attr("Tdim: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
ShapeHandle expand_dim;
@@ -2231,7 +2287,14 @@ REGISTER_OP("ExpandDims")
c->set_output(0, c->UnknownShape());
return Status::OK();
}
- int32 which_dim = dim_t->flat<int32>()(0);
+
+ int64 which_dim;
+ if (dim_t->dtype() == DT_INT32) {
+ which_dim = static_cast<int64>(dim_t->flat<int32>()(0));
+ } else {
+ which_dim = dim_t->flat<int64>()(0);
+ }
+
if (which_dim < 0) {
which_dim += c->Rank(input) + 1;
}
@@ -2388,8 +2451,9 @@ REGISTER_OP("ListDiff")
.Input("x: T")
.Input("y: T")
.Output("out: T")
- .Output("idx: int32")
+ .Output("idx: out_idx")
.Attr("T: type")
+ .Attr("out_idx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
@@ -2434,9 +2498,10 @@ idx: 1-D. Positions of `x` values preserved in `out`.
// --------------------------------------------------------------------------
REGISTER_OP("SpaceToBatch")
.Input("input: T")
- .Input("paddings: int32")
+ .Input("paddings: Tpaddings")
.Output("output: T")
.Attr("T: type")
+ .Attr("Tpaddings: {int32, int64} = DT_INT32")
.Attr("block_size: int >= 2")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input;
@@ -2470,11 +2535,20 @@ REGISTER_OP("SpaceToBatch")
output_height = c->UnknownDim();
output_width = c->UnknownDim();
} else {
- auto pad_matrix = paddings_t->matrix<int32>();
- const int32 pad_top = pad_matrix(0, 0);
- const int32 pad_bottom = pad_matrix(0, 1);
- const int32 pad_left = pad_matrix(1, 0);
- const int32 pad_right = pad_matrix(1, 1);
+ int64 pad_top, pad_bottom, pad_left, pad_right;
+ if (paddings_t->dtype() == DT_INT32) {
+ auto pad_matrix = paddings_t->matrix<int32>();
+ pad_top = pad_matrix(0, 0);
+ pad_bottom = pad_matrix(0, 1);
+ pad_left = pad_matrix(1, 0);
+ pad_right = pad_matrix(1, 1);
+ } else {
+ auto pad_matrix = paddings_t->matrix<int64>();
+ pad_top = pad_matrix(0, 0);
+ pad_bottom = pad_matrix(0, 1);
+ pad_left = pad_matrix(1, 0);
+ pad_right = pad_matrix(1, 1);
+ }
if (pad_top < 0 || pad_bottom < 0 || pad_left < 0 || pad_right < 0) {
return errors::InvalidArgument("Paddings cannot be negative.");
@@ -2599,10 +2673,11 @@ regular convolution.
// --------------------------------------------------------------------------
REGISTER_OP("BatchToSpace")
.Input("input: T")
- .Input("crops: int32")
+ .Input("crops: Tidx")
.Output("output: T")
.Attr("T: type")
.Attr("block_size: int >= 2")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
@@ -2640,11 +2715,20 @@ REGISTER_OP("BatchToSpace")
output_height = c->UnknownDim();
output_width = c->UnknownDim();
} else {
- auto crops_matrix = crops_t->matrix<int32>();
- const int32 crops_top = crops_matrix(0, 0);
- const int32 crops_bottom = crops_matrix(0, 1);
- const int32 crops_left = crops_matrix(1, 0);
- const int32 crops_right = crops_matrix(1, 1);
+ int64 crops_top, crops_bottom, crops_left, crops_right;
+ if (crops_t->dtype() == DT_INT32) {
+ auto crops_matrix = crops_t->matrix<int32>();
+ crops_top = crops_matrix(0, 0);
+ crops_bottom = crops_matrix(0, 1);
+ crops_left = crops_matrix(1, 0);
+ crops_right = crops_matrix(1, 1);
+ } else {
+ auto crops_matrix = crops_t->matrix<int64>();
+ crops_top = crops_matrix(0, 0);
+ crops_bottom = crops_matrix(0, 1);
+ crops_left = crops_matrix(1, 0);
+ crops_right = crops_matrix(1, 1);
+ }
if (crops_top < 0 || crops_bottom < 0 || crops_left < 0 ||
crops_right < 0) {
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index 42c69f0950..b3c07fedae 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -293,8 +293,8 @@ TEST(ArrayOpsTest, PadD_ShapeFn) {
// Make the paddings tensor known and verify padding values get added.
// E.g., if padding is ((1,10),(2,20),(3,30)) then values 11,22,23 are added
// to input dims to get output.
- Tensor paddings_t(DT_INT32, TensorShape{3, 2});
- test::FillValues<int32>(&paddings_t, {1, 10, 2, 20, 3, 30});
+ Tensor paddings_t(DT_INT64, TensorShape{3, 2});
+ test::FillValues<int64>(&paddings_t, {1, 10, 2, 20, 3, 30});
op.input_tensors[1] = &paddings_t;
INFER_OK(op, "[100,200,300];[3,2]", "[111,222,333]");
INFER_OK(op, "[100,?,300];[3,2]", "[111,?,333]");
@@ -326,8 +326,8 @@ TEST(ArrayOpsTest, MirrorPadGrad_ShapeFn) {
// Make the paddings tensor known and verify padding values get
// subtracted. E.g., if padding is ((1,10),(2,20),(3,30)) then
// values 11,22,23 are subtracted to input dims to get output.
- Tensor paddings_t(DT_INT32, TensorShape{3, 2});
- test::FillValues<int32>(&paddings_t, {1, 10, 2, 20, 3, 30});
+ Tensor paddings_t(DT_INT64, TensorShape{3, 2});
+ test::FillValues<int64>(&paddings_t, {1, 10, 2, 20, 3, 30});
op.input_tensors[1] = &paddings_t;
INFER_OK(op, "[111,222,333];[3,2]", "[100,200,300]");
@@ -823,6 +823,9 @@ TEST(ArrayOpsTest, Tile_ShapeFn) {
Tensor multiples = test::AsTensor<int32>({2, 3, 4, 5});
op.input_tensors[1] = &multiples;
INFER_OK(op, "[2,3,1,4];[4]", "[4,9,4,20]");
+ // Test 64-bit tensor type
+ multiples = test::AsTensor<int64>({2, 3, 4, 5});
+ INFER_OK(op, "[2,3,1,4];[4]", "[4,9,4,20]");
}
TEST(ArrayOpsTest, EditDistance_ShapeFn) {
@@ -933,6 +936,8 @@ TEST(ArrayOpsTest, SpaceToBatch_ShapeFn) {
Tensor paddings = test::AsTensor<int32>({4, 2, 2, 4}, {{2, 2}});
op.input_tensors[1] = &paddings;
INFER_OK(op, "[1,10,10,3];[2,2]", "[4,8,8,d0_3]");
+ paddings = test::AsTensor<int64>({4, 2, 2, 4}, {{2, 2}});
+ INFER_OK(op, "[1,10,10,3];[2,2]", "[4,8,8,d0_3]");
// Bad paddings values
paddings = test::AsTensor<int32>({1, 2, 3, 4}, {{2, 2}});
@@ -970,7 +975,7 @@ TEST(ArrayOpsTest, BatchToSpace_ShapeFn) {
INFER_ERROR("BatchToSpace requires crops with shape [2,2]", op,
"[4,8,8,3];[2,3]");
- Tensor croppings = test::AsTensor<int32>({4, 2, 2, 4}, {{2, 2}});
+ Tensor croppings = test::AsTensor<int64>({4, 2, 2, 4}, {{2, 2}});
op.input_tensors[1] = &croppings;
INFER_OK(op, "[4,8,8,3];[2,2]", "[1,10,10,d0_3]");
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 7c5a6cf11b..f034b58a7b 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -1015,10 +1015,11 @@ matrix multiply on one platform was 30% zero values in the sparse matrix.
// dimensions of the input.
REGISTER_OP("Sum")
.Input("input: T")
- .Input("reduction_indices: int32")
+ .Input("reduction_indices: Tidx")
.Output("output: T")
.Attr("keep_dims: bool = false")
.Attr("T: numbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ReductionShape)
.Doc(R"doc(
Computes the sum of elements across dimensions of a tensor.
@@ -1036,10 +1037,11 @@ output: The reduced tensor.
REGISTER_OP("Mean")
.Input("input: T")
- .Input("reduction_indices: int32")
+ .Input("reduction_indices: Tidx")
.Output("output: T")
.Attr("keep_dims: bool = false")
.Attr("T: numbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ReductionShape)
.Doc(R"doc(
Computes the mean of elements across dimensions of a tensor.
@@ -1057,10 +1059,11 @@ output: The reduced tensor.
REGISTER_OP("Prod")
.Input("input: T")
- .Input("reduction_indices: int32")
+ .Input("reduction_indices: Tidx")
.Output("output: T")
.Attr("keep_dims: bool = false")
.Attr("T: numbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ReductionShape)
.Doc(R"doc(
Computes the product of elements across dimensions of a tensor.
@@ -1078,10 +1081,11 @@ output: The reduced tensor.
REGISTER_OP("Min")
.Input("input: T")
- .Input("reduction_indices: int32")
+ .Input("reduction_indices: Tidx")
.Output("output: T")
.Attr("keep_dims: bool = false")
.Attr("T: numbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ReductionShape)
.Doc(R"doc(
Computes the minimum of elements across dimensions of a tensor.
@@ -1099,10 +1103,11 @@ output: The reduced tensor.
REGISTER_OP("Max")
.Input("input: T")
- .Input("reduction_indices: int32")
+ .Input("reduction_indices: Tidx")
.Output("output: T")
.Attr("keep_dims: bool = false")
.Attr("T: numbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ReductionShape)
.Doc(R"doc(
Computes the maximum of elements across dimensions of a tensor.
@@ -1149,7 +1154,13 @@ Status ArgOpShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
- const int32 dimension_val = dim_t->scalar<int32>()();
+ int64 dimension_val;
+ if (dim_t->dtype() == DT_INT32) {
+ dimension_val = dim_t->scalar<int32>()();
+ } else {
+ dimension_val = dim_t->scalar<int64>()();
+ }
+
if (dimension_val < 0 || dimension_val >= input_rank) {
return errors::InvalidArgument("Dimension (", dimension_val,
") must be in the range [0, ", input_rank,
@@ -1172,9 +1183,10 @@ Status ArgOpShape(shape_inference::InferenceContext* c) {
REGISTER_OP("ArgMax")
.Input("input: T")
- .Input("dimension: int32")
+ .Input("dimension: Tidx")
.Output("output: int64")
.Attr("T: numbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(ArgOpShape)
.Doc(R"doc(
Returns the index with the largest value across dimensions of a tensor.
@@ -1185,9 +1197,10 @@ dimension: int32, 0 <= dimension < rank(input). Describes which dimension
REGISTER_OP("ArgMin")
.Input("input: T")
- .Input("dimension: int32")
+ .Input("dimension: Tidx")
.Output("output: int64")
.Attr("T: numbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(ArgOpShape)
.Doc(R"doc(
Returns the index with the smallest value across dimensions of a tensor.
@@ -1490,10 +1503,11 @@ output: Has same shape as data, except for the first `segment_ids.rank`
REGISTER_OP("SparseSegmentSum")
.Input("data: T")
- .Input("indices: int32")
+ .Input("indices: Tidx")
.Input("segment_ids: int32")
.Output("output: T")
.Attr("T: realnumbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionShapeFn)
.Doc(R"doc(
Computes the sum along sparse segments of a tensor.
@@ -1538,10 +1552,11 @@ output: Has same shape as data, except for dimension 0 which
REGISTER_OP("SparseSegmentMean")
.Input("data: T")
- .Input("indices: int32")
+ .Input("indices: Tidx")
.Input("segment_ids: int32")
.Output("output: T")
.Attr("T: {float, double}")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionShapeFn)
.Doc(R"doc(
Computes the mean along sparse segments of a tensor.
@@ -1564,11 +1579,12 @@ output: Has same shape as data, except for dimension 0 which
REGISTER_OP("SparseSegmentMeanGrad")
.Input("grad: T")
- .Input("indices: int32")
+ .Input("indices: Tidx")
.Input("segment_ids: int32")
.Input("output_dim0: int32")
.Output("output: T")
.Attr("T: {float, double}")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionGradShapeFn)
.Doc(R"doc(
Computes gradients for SparseSegmentMean.
@@ -1584,10 +1600,11 @@ output_dim0: dimension 0 of "data" passed to SparseSegmentMean op.
REGISTER_OP("SparseSegmentSqrtN")
.Input("data: T")
- .Input("indices: int32")
+ .Input("indices: Tidx")
.Input("segment_ids: int32")
.Output("output: T")
.Attr("T: {float, double}")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionShapeFn)
.Doc(R"doc(
Computes the sum along sparse segments of a tensor divided by the sqrt of N.
@@ -1609,11 +1626,12 @@ output: Has same shape as data, except for dimension 0 which
REGISTER_OP("SparseSegmentSqrtNGrad")
.Input("grad: T")
- .Input("indices: int32")
+ .Input("indices: Tidx")
.Input("segment_ids: int32")
.Input("output_dim0: int32")
.Output("output: T")
.Attr("T: {float, double}")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionGradShapeFn)
.Doc(R"doc(
Computes gradients for SparseSegmentSqrtN.
@@ -1629,9 +1647,10 @@ output_dim0: dimension 0 of "data" passed to SparseSegmentSqrtN op.
REGISTER_OP("All")
.Input("input: bool")
- .Input("reduction_indices: int32")
+ .Input("reduction_indices: Tidx")
.Output("output: bool")
.Attr("keep_dims: bool = false")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ReductionShape)
.Doc(R"doc(
Computes the "logical and" of elements across dimensions of a tensor.
@@ -1649,9 +1668,10 @@ output: The reduced tensor.
REGISTER_OP("Any")
.Input("input: bool")
- .Input("reduction_indices: int32")
+ .Input("reduction_indices: Tidx")
.Attr("keep_dims: bool = false")
.Output("output: bool")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ReductionShape)
.Doc(R"doc(
Computes the "logical or" of elements across dimensions of a tensor.
@@ -1670,10 +1690,11 @@ output: The reduced tensor.
// --------------------------------------------------------------------------
REGISTER_OP("Range")
- .Input("start: int32")
- .Input("limit: int32")
- .Input("delta: int32")
- .Output("output: int32")
+ .Input("start: Tidx")
+ .Input("limit: Tidx")
+ .Input("delta: Tidx")
+ .Output("output: Tidx")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
@@ -1689,9 +1710,17 @@ REGISTER_OP("Range")
c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
return Status::OK();
}
- const int32 start = start_t->scalar<int32>()();
- const int32 limit = limit_t->scalar<int32>()();
- const int32 delta = delta_t->scalar<int32>()();
+ // TODO
+ int64 start, limit, delta;
+ if (start_t->dtype() == DT_INT32) {
+ start = start_t->scalar<int32>()();
+ limit = limit_t->scalar<int32>()();
+ delta = delta_t->scalar<int32>()();
+ } else {
+ start = start_t->scalar<int64>()();
+ limit = limit_t->scalar<int64>()();
+ delta = delta_t->scalar<int64>()();
+ }
if (start > limit) {
return errors::InvalidArgument("Requires start <= limit: ", start, "/",
limit);
@@ -1699,7 +1728,7 @@ REGISTER_OP("Range")
if (delta <= 0) {
return errors::InvalidArgument("Requires delta > 0: ", delta);
}
- const int32 size = (limit - start + delta - 1) / delta;
+ const int64 size = (limit - start + delta - 1) / delta;
c->set_output(0, c->Vector(size));
return Status::OK();
})
@@ -1727,9 +1756,10 @@ output: 1-D.
REGISTER_OP("LinSpace")
.Input("start: T")
.Input("stop: T")
- .Input("num: int32")
+ .Input("num: Tidx")
.Output("output: T")
.Attr("T: {float, double}")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
@@ -1743,7 +1773,13 @@ REGISTER_OP("LinSpace")
c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
return Status::OK();
}
- const int64 num = num_t->scalar<int32>()();
+
+ int64 num;
+ if (num_t->dtype() == DT_INT32) {
+ num = num_t->scalar<int32>()();
+ } else {
+ num = num_t->scalar<int64>()();
+ }
if (num <= 0) return errors::InvalidArgument("Requires num > 0: ", num);
c->set_output(0, c->Vector(num));
return Status::OK();
@@ -2057,11 +2093,12 @@ product: Pairwise cross product of the vectors in `a` and `b`.
REGISTER_OP("Cumsum")
.Input("x: T")
- .Input("axis: int32")
+ .Input("axis: Tidx")
.Attr("exclusive: bool = false")
.Attr("reverse: bool = false")
.Output("out: T")
.Attr("T: numbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.Doc(R"doc(
Compute the cumulative sum of the tensor `x` along `axis`.
@@ -2092,11 +2129,12 @@ tf.cumsum([a, b, c], exclusive=True, reverse=True) ==> [b + c, c, 0]
REGISTER_OP("Cumprod")
.Input("x: T")
- .Input("axis: int32")
+ .Input("axis: Tidx")
.Attr("exclusive: bool = false")
.Attr("reverse: bool = false")
.Output("out: T")
.Attr("T: numbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.Doc(R"doc(
Compute the cumulative product of the tensor `x` along `axis`.
diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py
index b2e1dfce27..b34bb068d7 100644
--- a/tensorflow/python/framework/op_def_library.py
+++ b/tensorflow/python/framework/op_def_library.py
@@ -349,6 +349,23 @@ class OpDefLibrary(object):
(op_type_name, producer, deprecation_version,
op_def.deprecation.explanation))
+ # Fill in the list of default types for all "type" attrs. This
+ # will be used to choose a preferred dtype to convert to in the
+ # absence of input type information.
+ #
+ # TODO(b/31302892): Currently the defaults don't work in the right
+ # way if you have two inputs, one of whose type resolution depends
+ # on the other. Handling this will require restructuring this code
+ # significantly.
+ default_type_attr_map = {}
+ for attr_def in op_def.attr:
+ if attr_def.type != "type":
+ continue
+ key = attr_def.name
+ if attr_def.HasField("default_value"):
+ default_type_attr_map[key] = dtypes.as_dtype(
+ attr_def.default_value.type)
+
# Requires that op_def has passed validation (using the C++
# ValidateOpDef() from ../framework/op_def_util.h).
attrs = {}
@@ -390,6 +407,7 @@ class OpDefLibrary(object):
# In cases where we expect all elements of the list to have the
# same dtype, try to cast non-Tensor elements to that type.
dtype = None
+ default_dtype = None
if input_arg.type != types_pb2.DT_INVALID:
dtype = input_arg.type
elif input_arg.number_attr:
@@ -401,11 +419,19 @@ class OpDefLibrary(object):
dtype = t.dtype
break
+ # dtype still not found, prefer using the default dtype
+ # from the attr.
+ if dtype is None and input_arg.type_attr in default_type_attr_map:
+ default_dtype = default_type_attr_map[input_arg.type_attr]
+
try:
if not input_arg.is_ref and dtype:
dtype = dtypes.as_dtype(dtype).base_dtype
values = ops.convert_n_to_tensor(
- values, name=input_arg.name, dtype=dtype if dtype else None,
+ values,
+ name=input_arg.name,
+ dtype=dtype if dtype else None,
+ preferred_dtype=default_dtype,
as_ref=input_arg.is_ref)
if input_arg.number_attr and len(
set(v.dtype.base_dtype for v in values)) > 1:
@@ -444,14 +470,24 @@ class OpDefLibrary(object):
# In cases where we have an expected type, try to convert non-Tensor
# arguments to that type.
dtype = None
+ default_dtype = None
if input_arg.type != types_pb2.DT_INVALID:
dtype = input_arg.type
elif input_arg.type_attr in attrs:
dtype = attrs[input_arg.type_attr]
+ elif input_arg.type_attr in default_type_attr_map:
+ # The dtype could not be inferred solely from the inputs,
+ # so we prefer the attr's default, so code that adds a new attr
+ # with a default is backwards compatible.
+ default_dtype = default_type_attr_map[input_arg.type_attr]
+
try:
values = ops.convert_to_tensor(
- values, name=input_arg.name, dtype=dtype,
- as_ref=input_arg.is_ref)
+ values,
+ name=input_arg.name,
+ dtype=dtype,
+ as_ref=input_arg.is_ref,
+ preferred_dtype=default_dtype)
except ValueError:
# What type does convert_to_tensor think it has?
observed = ops.convert_to_tensor(values,
@@ -462,6 +498,14 @@ class OpDefLibrary(object):
raise TypeError("%s expected type of %s." %
(prefix, dtypes.as_dtype(input_arg.type).name))
else:
+ # Update the maps with the default, if needed.
+ k = input_arg.type_attr
+ if k in default_type_attr_map:
+ if k not in attrs:
+ attrs[k] = default_type_attr_map[k]
+ if k not in inferred_from:
+ inferred_from[k] = "Default in OpDef"
+
raise TypeError(
"%s type %s of argument '%s'." %
(prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name,
diff --git a/tensorflow/python/framework/op_def_library_test.py b/tensorflow/python/framework/op_def_library_test.py
index 27933cc193..d95bb45cb8 100644
--- a/tensorflow/python/framework/op_def_library_test.py
+++ b/tensorflow/python/framework/op_def_library_test.py
@@ -47,6 +47,8 @@ ops.RegisterShape("AttrShape")(None)
ops.RegisterShape("AttrShapeList")(None)
ops.RegisterShape("AttrPartialShape")(None)
ops.RegisterShape("AttrPartialShapeList")(None)
+ops.RegisterShape("AttrTypeDefault")(None)
+ops.RegisterShape("AttrListTypeDefault")(None)
ops.RegisterShape("Binary")(None)
ops.RegisterShape("ComplexStruct")(None)
ops.RegisterShape("InPolymorphicTwice")(None)
@@ -868,6 +870,44 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
name: 'x' op: 'ReservedAttr' attr { key: 'range' value { i: 7 } }
""", op.node_def)
+ def testDefaultAttrType(self):
+ self._add_op("name: 'AttrTypeDefault' "
+ "input_arg { name: 'a' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' "
+ " default_value { type: DT_INT32 } }")
+
+ # Give an input whose type has no obvious output type.
+ op = self._lib.apply_op("AttrTypeDefault", a=[], name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'AttrTypeDefault' input: 'n/a'
+ attr { key: 'T' value { type: DT_INT32 } }
+ """, op.node_def)
+
+ # Give an input whose type can be inferred as different
+ # than the default.
+ op = self._lib.apply_op("AttrTypeDefault", a=[1.0], name="f")
+ self.assertProtoEquals("""
+ name: 'f' op: 'AttrTypeDefault' input: 'f/a'
+ attr { key: 'T' value { type: DT_FLOAT } }
+ """, op.node_def)
+
+ def testDefaultListAttrType(self):
+ self._add_op("name: 'AttrListTypeDefault' "
+ "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "input_arg { name: 'b' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type' "
+ " default_value { type: DT_INT32 } }"
+ "attr { name: 'N' type: 'int' }")
+
+ # Give an input whose type can be inferred as different
+ # than the default.
+ op = self._lib.apply_op("AttrListTypeDefault", a=[1.0], b=[2.0], name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'AttrListTypeDefault' input: 'n/a_0' input: 'n/b_0'
+ attr { key: 'T' value { type: DT_FLOAT } }
+ attr { key: 'N' value { i: 1 } }
+ """, op.node_def)
+
def testNIntsIn(self):
self._add_op("name: 'NIntsIn' "
"input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index f47ecd1b36..a5e253eabf 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -577,7 +577,11 @@ _tensor_conversion_func_registry = {
register_dense_tensor_like_type(Tensor)
-def convert_to_tensor(value, dtype=None, name=None, as_ref=False):
+def convert_to_tensor(value,
+ dtype=None,
+ name=None,
+ as_ref=False,
+ preferred_dtype=None):
"""Converts the given `value` to a `Tensor`.
This function converts Python objects of various types to `Tensor`
@@ -610,6 +614,11 @@ def convert_to_tensor(value, dtype=None, name=None, as_ref=False):
name: Optional name to use if a new `Tensor` is created.
as_ref: True if we want the result as a ref tensor. Only used if a new
`Tensor` is created.
+ preferred_dtype: Optional element type for the returned tensor,
+ used when dtype is None. In some cases, a caller may not have a
+ dtype in mind when converting to a tensor, so preferred_dtype
+ can be used as a soft preference. If the conversion to
+ `preferred_dtype` is not possible, this argument has no effect.
Returns:
A `Tensor` based on `value`.
@@ -625,9 +634,31 @@ def convert_to_tensor(value, dtype=None, name=None, as_ref=False):
for _, funcs_at_priority in sorted(_tensor_conversion_func_registry.items()):
for base_type, conversion_func in funcs_at_priority:
if isinstance(value, base_type):
- ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
+ # If dtype is None but preferred_dtype is not None, we try to
+ # cast to preferred_dtype first.
+ ret = None
+ if dtype is None and preferred_dtype is not None:
+ try:
+ ret = conversion_func(
+ value, dtype=preferred_dtype, name=name, as_ref=as_ref)
+ except (TypeError, ValueError):
+ # Could not coerce the conversion to use the preferred dtype.
+ ret = None
+
+ if ret is not None and ret is not NotImplemented:
+ if (ret.dtype.base_dtype !=
+ dtypes.as_dtype(preferred_dtype).base_dtype):
+ raise TypeError("convert_to_tensor did not convert to "
+ "the preferred dtype: %s vs %s " %
+ (ret.dtype.base_dtype,
+ dtypes.as_dtype(preferred_dtype).base_dtype))
+
+ if ret is None:
+ ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
+
if ret is NotImplemented:
continue
+
if not isinstance(ret, Tensor):
raise RuntimeError(
"%sConversion function %r for type %s returned non-Tensor: %r"
@@ -644,7 +675,11 @@ def convert_to_tensor(value, dtype=None, name=None, as_ref=False):
% (error_prefix, value, type(value)))
-def convert_n_to_tensor(values, dtype=None, name=None, as_ref=False):
+def convert_n_to_tensor(values,
+ dtype=None,
+ name=None,
+ as_ref=False,
+ preferred_dtype=None):
"""Converts `values` to a list of `Tensor` objects.
Args:
@@ -654,6 +689,11 @@ def convert_n_to_tensor(values, dtype=None, name=None, as_ref=False):
created, in which case element `i` will be given the name `name
+ '_' + i`.
as_ref: True if the caller wants the results as ref tensors.
+ preferred_dtype: Optional element type for the returned tensors,
+ used when dtype is None. In some cases, a caller may not have a
+ dtype in mind when converting to a tensor, so preferred_dtype
+ can be used as a soft preference. If the conversion to
+ `preferred_dtype` is not possible, this argument has no effect.
Returns:
A list of `Tensor` and/or `IndexedSlices` objects.
@@ -669,7 +709,13 @@ def convert_n_to_tensor(values, dtype=None, name=None, as_ref=False):
ret = []
for i, value in enumerate(values):
n = None if name is None else "%s_%d" % (name, i)
- ret.append(convert_to_tensor(value, dtype=dtype, name=n, as_ref=as_ref))
+ ret.append(
+ convert_to_tensor(
+ value,
+ dtype=dtype,
+ name=n,
+ as_ref=as_ref,
+ preferred_dtype=preferred_dtype))
return ret
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 98ad4a2d60..eac85ac844 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -280,6 +280,31 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertEqual(tensor_shape.unknown_shape(),
_apply_op(g, "an_op", [], [dtypes.float32]).get_shape())
+ def testConvertToTensorPreferred(self):
+ with self.test_session():
+ values = [2, 3, 5, 7]
+ tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32)
+ self.assertEqual(dtypes.float32, tensor.dtype)
+
+ with self.test_session():
+ # Convert empty tensor to anything
+ values = []
+ tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
+ self.assertEqual(dtypes.int64, tensor.dtype)
+
+ with self.test_session():
+ # The preferred dtype is a type error and will convert to
+ # float32 instead.
+ values = [1.23]
+ tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
+ self.assertEqual(dtypes.float32, tensor.dtype)
+
+ def testConvertToInvalidTensorType(self):
+ with self.assertRaises(TypeError):
+ # Forcing an invalid dtype should fail with a type error.
+ values = [1.23]
+ _ = ops.convert_to_tensor(values, dtype=dtypes.int64)
+
def testNoConvert(self):
# Operation cannot be converted to Tensor
op = control_flow_ops.no_op()