aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-09-06 17:30:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-06 18:47:10 -0700
commit91ce95d497ec2957535b2ce6a965cd8269d723e5 (patch)
treea97799f5b2b92cc676e7890aa0f206c25f705490
parent970c408a0c34726d5361b8b7e90fe02376f7e022 (diff)
Change shape ops that produce int32s to also output int64s in the type
system, with the default remaining int32 for backwards compatibility, and change some ops that consume only int32s to also take in int64. Adds changes to the python op_def_library to properly handle the addition of new types with unchanged defaults. This changes the specification, not the implementations. The implementations will come later with additional registrations. Note that in some cases we are not yet ready for registering the int64 types, because there is already an int64 type registered that assumes data is in device memory. We'll need to fix the problem of choosing kernels based on memorytype in a follow up CL before being able to register multiple kernels for the same type but with different memorytypes. Change: 132386744
-rw-r--r--tensorflow/core/common_runtime/function_test.cc66
-rw-r--r--tensorflow/core/kernels/transpose_op.cc22
-rw-r--r--tensorflow/core/ops/array_ops.cc212
-rw-r--r--tensorflow/core/ops/array_ops_test.cc15
-rw-r--r--tensorflow/core/ops/math_ops.cc92
-rw-r--r--tensorflow/python/framework/op_def_library.py50
-rw-r--r--tensorflow/python/framework/op_def_library_test.py40
-rw-r--r--tensorflow/python/framework/ops.py54
-rw-r--r--tensorflow/python/framework/ops_test.py25
9 files changed, 430 insertions, 146 deletions
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index e263e62bd8..523547b4eb 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -534,11 +534,11 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
(n2:float, n3:float) -> (n9:float) {
n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [0] values: >]()
n10 = Const[dtype=float, value=Tensor<type: float shape: [] values: 2>]()
- n6 = Shape[T=float](n2)
+ n6 = Shape[T=float, out_type=int32](n2)
n5 = Mul[T=float](n3, n10)
- n7 = BroadcastGradientArgs(n6, n11)
- n8 = Sum[T=float, keep_dims=false](n5, n7)
- n9 = Reshape[T=float](n8, n6)
+ n7 = BroadcastGradientArgs[T=int32](n6, n11)
+ n8 = Sum[T=float, Tidx=int32, keep_dims=false](n5, n7)
+ n9 = Reshape[T=float, Tshape=int32](n8, n6)
}
)P";
EXPECT_EQ(e2, DebugString(g));
@@ -555,13 +555,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_Add) {
(n7:float, n5:float, n2:float) -> (n14:float, n11:float) {
n3 = Identity[T=float](n2)
n4 = Identity[T=float](n2)
- n6 = Shape[T=float](n5)
- n8 = Shape[T=float](n7)
- n9 = BroadcastGradientArgs(n8, n6)
- n10 = Sum[T=float, keep_dims=false](n3, n9:1)
- n13 = Sum[T=float, keep_dims=false](n4, n9)
- n11 = Reshape[T=float](n10, n6)
- n14 = Reshape[T=float](n13, n8)
+ n6 = Shape[T=float, out_type=int32](n5)
+ n8 = Shape[T=float, out_type=int32](n7)
+ n9 = BroadcastGradientArgs[T=int32](n8, n6)
+ n10 = Sum[T=float, Tidx=int32, keep_dims=false](n3, n9:1)
+ n13 = Sum[T=float, Tidx=int32, keep_dims=false](n4, n9)
+ n11 = Reshape[T=float, Tshape=int32](n10, n6)
+ n14 = Reshape[T=float, Tshape=int32](n13, n8)
}
)P";
EXPECT_EQ(e0, DebugString(g));
@@ -576,14 +576,14 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_Mul) {
const char* e0 = R"P(
(n6:float, n3:float, n2:float) -> (n14:float, n11:float) {
n4 = Mul[T=float](n2, n3)
- n5 = Shape[T=float](n3)
+ n5 = Shape[T=float, out_type=int32](n3)
n7 = Mul[T=float](n6, n2)
- n8 = Shape[T=float](n6)
- n9 = BroadcastGradientArgs(n8, n5)
- n10 = Sum[T=float, keep_dims=false](n7, n9:1)
- n13 = Sum[T=float, keep_dims=false](n4, n9)
- n11 = Reshape[T=float](n10, n5)
- n14 = Reshape[T=float](n13, n8)
+ n8 = Shape[T=float, out_type=int32](n6)
+ n9 = BroadcastGradientArgs[T=int32](n8, n5)
+ n10 = Sum[T=float, Tidx=int32, keep_dims=false](n7, n9:1)
+ n13 = Sum[T=float, Tidx=int32, keep_dims=false](n4, n9)
+ n11 = Reshape[T=float, Tshape=int32](n10, n5)
+ n14 = Reshape[T=float, Tshape=int32](n13, n8)
}
)P";
EXPECT_EQ(e0, DebugString(g));
@@ -643,10 +643,10 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
n24 = Identity[T=float](n4)
n14 = Add[T=float](n24, n25)
n15 = Rank[T=float](n14)
- n16 = Range(n11, n15, n10)
+ n16 = Range[Tidx=int32](n11, n15, n10)
n20 = ZerosLike[T=int32](n15)
- n17 = Sum[T=float, keep_dims=false](n14, n16)
- n19 = SymbolicGradient[Tin={float, int32, float}, Tout={float, int32}, f=Sum[T=float, keep_dims=false]](n14, n16, n26)
+ n17 = Sum[T=float, Tidx=int32, keep_dims=false](n14, n16)
+ n19 = SymbolicGradient[Tin={float, int32, float}, Tout={float, int32}, f=Sum[T=float, Tidx=int32, keep_dims=false]](n14, n16, n26)
n21 = SymbolicGradient[Tin={float, float, float}, Tout={float, float}, f=Add[T=float]](n24, n25, n19)
n28 = Identity[T=float](n21:1)
n27 = Identity[T=float](n21)
@@ -662,23 +662,23 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]()
n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]()
n7 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
- n19 = Shape[T=float](n3)
+ n19 = Shape[T=float, out_type=int32](n3)
n8 = Add[T=float](n4, n3)
- n20 = Shape[T=float](n4)
+ n20 = Shape[T=float, out_type=int32](n4)
n9 = Rank[T=float](n8)
- n14 = Shape[T=float](n8)
- n21 = BroadcastGradientArgs(n20, n19)
- n10 = Range(n7, n9, n11)
- n12 = Shape[T=int32](n10)
+ n14 = Shape[T=float, out_type=int32](n8)
+ n21 = BroadcastGradientArgs[T=int32](n20, n19)
+ n10 = Range[Tidx=int32](n7, n9, n11)
+ n12 = Shape[T=int32, out_type=int32](n10)
n13 = Fill[T=int32](n12, n11)
n15 = DynamicStitch[N=2, T=int32](n10, n10, n14, n13)
- n16 = Reshape[T=float](n2, n15)
+ n16 = Reshape[T=float, Tshape=int32](n2, n15)
n17 = Div[T=int32](n14, n15)
- n18 = Tile[T=float](n16, n17)
- n24 = Sum[T=float, keep_dims=false](n18, n21)
- n22 = Sum[T=float, keep_dims=false](n18, n21:1)
- n25 = Reshape[T=float](n24, n20)
- n23 = Reshape[T=float](n22, n19)
+ n18 = Tile[T=float, Tmultiples=int32](n16, n17)
+ n24 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21)
+ n22 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21:1)
+ n25 = Reshape[T=float, Tshape=int32](n24, n20)
+ n23 = Reshape[T=float, Tshape=int32](n22, n19)
}
)P";
EXPECT_EQ(e2, DebugString(g));
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index fbde0c9626..35429a382e 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -169,11 +169,12 @@ Status TransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
out);
}
-#define REGISTER(T) \
- REGISTER_KERNEL_BUILDER(Name("Transpose") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .HostMemory("perm"), \
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER(Name("Transpose") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int32>("Tperm") \
+ .HostMemory("perm"), \
TransposeCpuOp);
TF_CALL_ALL_TYPES(REGISTER)
REGISTER(bfloat16);
@@ -187,11 +188,12 @@ Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
out);
}
-#define REGISTER(T) \
- REGISTER_KERNEL_BUILDER(Name("Transpose") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<T>("T") \
- .HostMemory("perm"), \
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER(Name("Transpose") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<int32>("Tperm") \
+ .HostMemory("perm"), \
TransposeGpuOp);
TF_CALL_POD_TYPES(REGISTER);
#undef REGISTER
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index fc1d031d4b..4ab47754df 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -39,6 +39,34 @@ Status GetAxisForPackAndUnpack(InferenceContext* c, int32 rank_after_pack,
return Status::OK();
}
+template <typename T>
+std::vector<int64> AsInt64(const Tensor* tensor, int num_elements) {
+ std::vector<int64> ret(num_elements);
+ auto data = tensor->vec<T>();
+ for (int i = 0; i < num_elements; ++i) {
+ ret[i] = data(i);
+ }
+ return ret;
+}
+
+template <typename T>
+Status PadKnown(InferenceContext* c, ShapeHandle input,
+ const Tensor* paddings_t, int32 num_dims) {
+ // paddings_t is known.
+ std::vector<DimensionHandle> dims(num_dims);
+ auto paddings_data = paddings_t->matrix<T>();
+ for (int i = 0; i < num_dims; ++i) {
+ const T pad0 = paddings_data(i, 0);
+ const T pad1 = paddings_data(i, 1);
+ if (pad0 < 0 || pad1 < 0) {
+ return errors::InvalidArgument("Paddings must be non-negative");
+ }
+ TF_RETURN_IF_ERROR(c->Add(c->Dim(input, i), pad0 + pad1, &dims[i]));
+ }
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+}
+
Status PadShapeFn(InferenceContext* c) {
// Paddings is a matrix of [input_rank, 2].
ShapeHandle paddings;
@@ -73,19 +101,11 @@ Status PadShapeFn(InferenceContext* c) {
const auto num_dims = c->Value(n_dim);
DCHECK_EQ(num_dims, paddings_t->shape().dim_size(0));
- // paddings_t is known.
- auto paddings_data = paddings_t->matrix<int32>();
- std::vector<DimensionHandle> dims(num_dims);
- for (int i = 0; i < num_dims; ++i) {
- const int32 pad0 = paddings_data(i, 0);
- const int32 pad1 = paddings_data(i, 1);
- if (pad0 < 0 || pad1 < 0) {
- return errors::InvalidArgument("Paddings must be non-negative");
- }
- TF_RETURN_IF_ERROR(c->Add(c->Dim(input, i), pad0 + pad1, &dims[i]));
+ if (paddings_t->dtype() == DT_INT32) {
+ return PadKnown<int32>(c, input, paddings_t, num_dims);
+ } else {
+ return PadKnown<int64>(c, input, paddings_t, num_dims);
}
- c->set_output(0, c->MakeShape(dims));
- return Status::OK();
}
} // namespace
@@ -1127,9 +1147,10 @@ message: Prefix of the error message.
// --------------------------------------------------------------------------
REGISTER_OP("Reshape")
.Input("tensor: T")
- .Input("shape: int32")
+ .Input("shape: Tshape")
.Output("output: T")
.Attr("T: type")
+ .Attr("Tshape: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle in = c->input(0);
ShapeHandle out;
@@ -1247,8 +1268,9 @@ shape: Defines the shape of the output tensor.
// --------------------------------------------------------------------------
REGISTER_OP("InvertPermutation")
- .Input("x: int32")
- .Output("y: int32")
+ .Input("x: T")
+ .Output("y: T")
+ .Attr("T: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle x;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &x));
@@ -1281,9 +1303,10 @@ y: 1-D.
// --------------------------------------------------------------------------
REGISTER_OP("Transpose")
.Input("x: T")
- .Input("perm: int32")
+ .Input("perm: Tperm")
.Output("y: T")
.Attr("T: type")
+ .Attr("Tperm: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
ShapeHandle perm_shape = c->input(1);
@@ -1318,9 +1341,15 @@ REGISTER_OP("Transpose")
// all shape informantion, otherwise we can only return rank information,
// but no information for the dimensions.
if (perm != nullptr) {
- auto flat_perm = perm->flat<int32>();
+ std::vector<int64> data;
+ if (perm->dtype() == DT_INT32) {
+ data = AsInt64<int32>(perm, rank);
+ } else {
+ data = AsInt64<int64>(perm, rank);
+ }
+
for (int32 i = 0; i < rank; ++i) {
- int32 in_idx = flat_perm(i);
+ int64 in_idx = data[i];
if (in_idx >= rank) {
return errors::InvalidArgument(
"perm dim ", in_idx, " is out of range of input rank ", rank);
@@ -1347,8 +1376,9 @@ The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
REGISTER_OP("Unique")
.Input("x: T")
.Output("y: T")
- .Output("idx: int32")
+ .Output("idx: out_idx")
.Attr("T: type")
+ .Attr("out_idx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
c->set_output(1, c->input(0));
@@ -1382,9 +1412,10 @@ idx: 1-D.
REGISTER_OP("UniqueWithCounts")
.Input("x: T")
.Output("y: T")
- .Output("idx: int32")
- .Output("count: int32")
+ .Output("idx: out_idx")
+ .Output("count: out_idx")
.Attr("T: type")
+ .Attr("out_idx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
auto uniq = c->Vector(InferenceContext::kUnknownDim);
c->set_output(0, uniq);
@@ -1439,8 +1470,9 @@ Status ShapeShapeFn(InferenceContext* c) {
// --------------------------------------------------------------------------
REGISTER_OP("Shape")
.Input("input: T")
- .Output("output: int32")
+ .Output("output: out_type")
.Attr("T: type")
+ .Attr("out_type: {int32, int64} = DT_INT32")
.SetShapeFn(ShapeShapeFn)
.Doc(R"doc(
Returns the shape of a tensor.
@@ -1458,9 +1490,10 @@ shape(t) ==> [2, 2, 3]
REGISTER_OP("ShapeN")
.Input("input: N * T")
- .Output("output: N * int32")
+ .Output("output: N * out_type")
.Attr("N: int")
.Attr("T: type")
+ .Attr("out_type: {int32, int64} = DT_INT32")
.SetShapeFn(ShapeShapeFn)
.Doc(R"doc(
Returns shape of tensors.
@@ -1606,8 +1639,9 @@ of the tensor. Rank is also known as "order", "degree", or "ndims."
// --------------------------------------------------------------------------
REGISTER_OP("Size")
.Input("input: T")
- .Output("output: int32")
+ .Output("output: out_type")
.Attr("T: type")
+ .Attr("out_type: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
Returns the size of a tensor.
@@ -1854,11 +1888,13 @@ shape must be exactly the shape produced by the slice of `ref`.
// TODO(aselle): Fix this documentation once StridedSliceAssign Supports
// broadcasting.
// --------------------------------------------------------------------------
+
REGISTER_OP("Tile")
.Input("input: T")
- .Input("multiples: int32")
+ .Input("multiples: Tmultiples")
.Output("output: T")
.Attr("T: type")
+ .Attr("Tmultiples: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input;
ShapeHandle multiples;
@@ -1882,12 +1918,15 @@ REGISTER_OP("Tile")
return Status::OK();
}
- // Multiply each input dimension by its corresponding value
- // from the multiples tensor.
- auto multiples_data = multiples_t->vec<int32>();
+ std::vector<int64> data;
+ if (multiples_t->dtype() == DT_INT32) {
+ data = AsInt64<int32>(multiples_t, rank);
+ } else {
+ data = AsInt64<int64>(multiples_t, rank);
+ }
std::vector<DimensionHandle> dims(rank);
for (int i = 0; i < rank; ++i) {
- const int32 multiple = multiples_data(i);
+ const int64 multiple = data[i];
TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input, i), multiple, &dims[i]));
}
c->set_output(0, c->MakeShape(dims));
@@ -1968,10 +2007,11 @@ where(input) ==> [[0, 0, 0],
// --------------------------------------------------------------------------
REGISTER_OP("BroadcastGradientArgs")
- .Input("s0: int32")
- .Input("s1: int32")
- .Output("r0: int32")
- .Output("r1: int32")
+ .Input("s0: T")
+ .Input("s1: T")
+ .Output("r0: T")
+ .Output("r1: T")
+ .Attr("T: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
// TODO(mrry): Implement constant_value for BroadcastGradientArgs?
ShapeHandle unused;
@@ -1990,9 +2030,10 @@ This is typically used by gradient computations for a broadcasting operation.
// --------------------------------------------------------------------------
REGISTER_OP("Pad")
.Input("input: T")
- .Input("paddings: int32")
+ .Input("paddings: Tpaddings")
.Output("output: T")
.Attr("T: type")
+ .Attr("Tpaddings: {int32, int64} = DT_INT32")
.SetShapeFn(PadShapeFn)
.Doc(R"doc(
Pads a tensor with zeros.
@@ -2025,9 +2066,10 @@ pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
// --------------------------------------------------------------------------
REGISTER_OP("MirrorPad")
.Input("input: T")
- .Input("paddings: int32")
+ .Input("paddings: Tpaddings")
.Output("output: T")
.Attr("T: type")
+ .Attr("Tpaddings: {int32, int64} = DT_INT32")
.Attr(GetMirrorPadModeAttrString())
.SetShapeFn(PadShapeFn)
.Doc(R"doc(
@@ -2071,11 +2113,33 @@ output: The padded tensor.
)doc");
// --------------------------------------------------------------------------
+namespace {
+template <typename T>
+Status MirrorPadKnown(InferenceContext* c, ShapeHandle input,
+ const Tensor* paddings_t, int32 input_rank) {
+ auto paddings_data = paddings_t->matrix<T>();
+ std::vector<DimensionHandle> dims(input_rank);
+ for (int i = 0; i < input_rank; ++i) {
+ const int64 pad0 = static_cast<int64>(paddings_data(i, 0));
+ const int64 pad1 = static_cast<int64>(paddings_data(i, 1));
+ if (pad0 < 0 || pad1 < 0) {
+ return errors::InvalidArgument("Paddings must be non-negative");
+ }
+
+ TF_RETURN_IF_ERROR(c->Subtract(c->Dim(input, i), pad0 + pad1, &dims[i]));
+ }
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+}
+
+} // namespace
+
REGISTER_OP("MirrorPadGrad")
.Input("input: T")
- .Input("paddings: int32")
+ .Input("paddings: Tpaddings")
.Output("output: T")
.Attr("T: type")
+ .Attr("Tpaddings: {int32, int64} = DT_INT32")
.Attr(GetMirrorPadModeAttrString())
.SetShapeFn([](InferenceContext* c) {
ShapeHandle paddings;
@@ -2103,20 +2167,11 @@ REGISTER_OP("MirrorPadGrad")
return Status::OK();
}
- auto paddings_data = paddings_t->matrix<int32>();
- std::vector<DimensionHandle> dims(input_rank);
- for (int i = 0; i < input_rank; ++i) {
- const int64 pad0 = static_cast<int64>(paddings_data(i, 0));
- const int64 pad1 = static_cast<int64>(paddings_data(i, 1));
- if (pad0 < 0 || pad1 < 0) {
- return errors::InvalidArgument("Paddings must be non-negative");
- }
-
- TF_RETURN_IF_ERROR(
- c->Subtract(c->Dim(input, i), pad0 + pad1, &dims[i]));
+ if (paddings_t->dtype() == DT_INT32) {
+ return MirrorPadKnown<int32>(c, input, paddings_t, input_rank);
+ } else {
+ return MirrorPadKnown<int64>(c, input, paddings_t, input_rank);
}
- c->set_output(0, c->MakeShape(dims));
- return Status::OK();
})
.Doc(R"doc(
Gradient op for `MirrorPad` op. This op folds a mirror-padded tensor.
@@ -2218,9 +2273,10 @@ shape: The (possibly partial) shape of the tensor.
// --------------------------------------------------------------------------
REGISTER_OP("ExpandDims")
.Input("input: T")
- .Input("dim: int32")
+ .Input("dim: Tdim")
.Output("output: T")
.Attr("T: type")
+ .Attr("Tdim: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
ShapeHandle expand_dim;
@@ -2231,7 +2287,14 @@ REGISTER_OP("ExpandDims")
c->set_output(0, c->UnknownShape());
return Status::OK();
}
- int32 which_dim = dim_t->flat<int32>()(0);
+
+ int64 which_dim;
+ if (dim_t->dtype() == DT_INT32) {
+ which_dim = static_cast<int64>(dim_t->flat<int32>()(0));
+ } else {
+ which_dim = dim_t->flat<int64>()(0);
+ }
+
if (which_dim < 0) {
which_dim += c->Rank(input) + 1;
}
@@ -2388,8 +2451,9 @@ REGISTER_OP("ListDiff")
.Input("x: T")
.Input("y: T")
.Output("out: T")
- .Output("idx: int32")
+ .Output("idx: out_idx")
.Attr("T: type")
+ .Attr("out_idx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
@@ -2434,9 +2498,10 @@ idx: 1-D. Positions of `x` values preserved in `out`.
// --------------------------------------------------------------------------
REGISTER_OP("SpaceToBatch")
.Input("input: T")
- .Input("paddings: int32")
+ .Input("paddings: Tpaddings")
.Output("output: T")
.Attr("T: type")
+ .Attr("Tpaddings: {int32, int64} = DT_INT32")
.Attr("block_size: int >= 2")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input;
@@ -2470,11 +2535,20 @@ REGISTER_OP("SpaceToBatch")
output_height = c->UnknownDim();
output_width = c->UnknownDim();
} else {
- auto pad_matrix = paddings_t->matrix<int32>();
- const int32 pad_top = pad_matrix(0, 0);
- const int32 pad_bottom = pad_matrix(0, 1);
- const int32 pad_left = pad_matrix(1, 0);
- const int32 pad_right = pad_matrix(1, 1);
+ int64 pad_top, pad_bottom, pad_left, pad_right;
+ if (paddings_t->dtype() == DT_INT32) {
+ auto pad_matrix = paddings_t->matrix<int32>();
+ pad_top = pad_matrix(0, 0);
+ pad_bottom = pad_matrix(0, 1);
+ pad_left = pad_matrix(1, 0);
+ pad_right = pad_matrix(1, 1);
+ } else {
+ auto pad_matrix = paddings_t->matrix<int64>();
+ pad_top = pad_matrix(0, 0);
+ pad_bottom = pad_matrix(0, 1);
+ pad_left = pad_matrix(1, 0);
+ pad_right = pad_matrix(1, 1);
+ }
if (pad_top < 0 || pad_bottom < 0 || pad_left < 0 || pad_right < 0) {
return errors::InvalidArgument("Paddings cannot be negative.");
@@ -2599,10 +2673,11 @@ regular convolution.
// --------------------------------------------------------------------------
REGISTER_OP("BatchToSpace")
.Input("input: T")
- .Input("crops: int32")
+ .Input("crops: Tidx")
.Output("output: T")
.Attr("T: type")
.Attr("block_size: int >= 2")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
@@ -2640,11 +2715,20 @@ REGISTER_OP("BatchToSpace")
output_height = c->UnknownDim();
output_width = c->UnknownDim();
} else {
- auto crops_matrix = crops_t->matrix<int32>();
- const int32 crops_top = crops_matrix(0, 0);
- const int32 crops_bottom = crops_matrix(0, 1);
- const int32 crops_left = crops_matrix(1, 0);
- const int32 crops_right = crops_matrix(1, 1);
+ int64 crops_top, crops_bottom, crops_left, crops_right;
+ if (crops_t->dtype() == DT_INT32) {
+ auto crops_matrix = crops_t->matrix<int32>();
+ crops_top = crops_matrix(0, 0);
+ crops_bottom = crops_matrix(0, 1);
+ crops_left = crops_matrix(1, 0);
+ crops_right = crops_matrix(1, 1);
+ } else {
+ auto crops_matrix = crops_t->matrix<int64>();
+ crops_top = crops_matrix(0, 0);
+ crops_bottom = crops_matrix(0, 1);
+ crops_left = crops_matrix(1, 0);
+ crops_right = crops_matrix(1, 1);
+ }
if (crops_top < 0 || crops_bottom < 0 || crops_left < 0 ||
crops_right < 0) {
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index 42c69f0950..b3c07fedae 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -293,8 +293,8 @@ TEST(ArrayOpsTest, PadD_ShapeFn) {
// Make the paddings tensor known and verify padding values get added.
// E.g., if padding is ((1,10),(2,20),(3,30)) then values 11,22,23 are added
// to input dims to get output.
- Tensor paddings_t(DT_INT32, TensorShape{3, 2});
- test::FillValues<int32>(&paddings_t, {1, 10, 2, 20, 3, 30});
+ Tensor paddings_t(DT_INT64, TensorShape{3, 2});
+ test::FillValues<int64>(&paddings_t, {1, 10, 2, 20, 3, 30});
op.input_tensors[1] = &paddings_t;
INFER_OK(op, "[100,200,300];[3,2]", "[111,222,333]");
INFER_OK(op, "[100,?,300];[3,2]", "[111,?,333]");
@@ -326,8 +326,8 @@ TEST(ArrayOpsTest, MirrorPadGrad_ShapeFn) {
// Make the paddings tensor known and verify padding values get
// subtracted. E.g., if padding is ((1,10),(2,20),(3,30)) then
// values 11,22,23 are subtracted to input dims to get output.
- Tensor paddings_t(DT_INT32, TensorShape{3, 2});
- test::FillValues<int32>(&paddings_t, {1, 10, 2, 20, 3, 30});
+ Tensor paddings_t(DT_INT64, TensorShape{3, 2});
+ test::FillValues<int64>(&paddings_t, {1, 10, 2, 20, 3, 30});
op.input_tensors[1] = &paddings_t;
INFER_OK(op, "[111,222,333];[3,2]", "[100,200,300]");
@@ -823,6 +823,9 @@ TEST(ArrayOpsTest, Tile_ShapeFn) {
Tensor multiples = test::AsTensor<int32>({2, 3, 4, 5});
op.input_tensors[1] = &multiples;
INFER_OK(op, "[2,3,1,4];[4]", "[4,9,4,20]");
+ // Test 64-bit tensor type
+ multiples = test::AsTensor<int64>({2, 3, 4, 5});
+ INFER_OK(op, "[2,3,1,4];[4]", "[4,9,4,20]");
}
TEST(ArrayOpsTest, EditDistance_ShapeFn) {
@@ -933,6 +936,8 @@ TEST(ArrayOpsTest, SpaceToBatch_ShapeFn) {
Tensor paddings = test::AsTensor<int32>({4, 2, 2, 4}, {{2, 2}});
op.input_tensors[1] = &paddings;
INFER_OK(op, "[1,10,10,3];[2,2]", "[4,8,8,d0_3]");
+ paddings = test::AsTensor<int64>({4, 2, 2, 4}, {{2, 2}});
+ INFER_OK(op, "[1,10,10,3];[2,2]", "[4,8,8,d0_3]");
// Bad paddings values
paddings = test::AsTensor<int32>({1, 2, 3, 4}, {{2, 2}});
@@ -970,7 +975,7 @@ TEST(ArrayOpsTest, BatchToSpace_ShapeFn) {
INFER_ERROR("BatchToSpace requires crops with shape [2,2]", op,
"[4,8,8,3];[2,3]");
- Tensor croppings = test::AsTensor<int32>({4, 2, 2, 4}, {{2, 2}});
+ Tensor croppings = test::AsTensor<int64>({4, 2, 2, 4}, {{2, 2}});
op.input_tensors[1] = &croppings;
INFER_OK(op, "[4,8,8,3];[2,2]", "[1,10,10,d0_3]");
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 7c5a6cf11b..f034b58a7b 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -1015,10 +1015,11 @@ matrix multiply on one platform was 30% zero values in the sparse matrix.
// dimensions of the input.
REGISTER_OP("Sum")
.Input("input: T")
- .Input("reduction_indices: int32")
+ .Input("reduction_indices: Tidx")
.Output("output: T")
.Attr("keep_dims: bool = false")
.Attr("T: numbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ReductionShape)
.Doc(R"doc(
Computes the sum of elements across dimensions of a tensor.
@@ -1036,10 +1037,11 @@ output: The reduced tensor.
REGISTER_OP("Mean")
.Input("input: T")
- .Input("reduction_indices: int32")
+ .Input("reduction_indices: Tidx")
.Output("output: T")
.Attr("keep_dims: bool = false")
.Attr("T: numbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ReductionShape)
.Doc(R"doc(
Computes the mean of elements across dimensions of a tensor.
@@ -1057,10 +1059,11 @@ output: The reduced tensor.
REGISTER_OP("Prod")
.Input("input: T")
- .Input("reduction_indices: int32")
+ .Input("reduction_indices: Tidx")
.Output("output: T")
.Attr("keep_dims: bool = false")
.Attr("T: numbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ReductionShape)
.Doc(R"doc(
Computes the product of elements across dimensions of a tensor.
@@ -1078,10 +1081,11 @@ output: The reduced tensor.
REGISTER_OP("Min")
.Input("input: T")
- .Input("reduction_indices: int32")
+ .Input("reduction_indices: Tidx")
.Output("output: T")
.Attr("keep_dims: bool = false")
.Attr("T: numbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ReductionShape)
.Doc(R"doc(
Computes the minimum of elements across dimensions of a tensor.
@@ -1099,10 +1103,11 @@ output: The reduced tensor.
REGISTER_OP("Max")
.Input("input: T")
- .Input("reduction_indices: int32")
+ .Input("reduction_indices: Tidx")
.Output("output: T")
.Attr("keep_dims: bool = false")
.Attr("T: numbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ReductionShape)
.Doc(R"doc(
Computes the maximum of elements across dimensions of a tensor.
@@ -1149,7 +1154,13 @@ Status ArgOpShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
- const int32 dimension_val = dim_t->scalar<int32>()();
+ int64 dimension_val;
+ if (dim_t->dtype() == DT_INT32) {
+ dimension_val = dim_t->scalar<int32>()();
+ } else {
+ dimension_val = dim_t->scalar<int64>()();
+ }
+
if (dimension_val < 0 || dimension_val >= input_rank) {
return errors::InvalidArgument("Dimension (", dimension_val,
") must be in the range [0, ", input_rank,
@@ -1172,9 +1183,10 @@ Status ArgOpShape(shape_inference::InferenceContext* c) {
REGISTER_OP("ArgMax")
.Input("input: T")
- .Input("dimension: int32")
+ .Input("dimension: Tidx")
.Output("output: int64")
.Attr("T: numbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(ArgOpShape)
.Doc(R"doc(
Returns the index with the largest value across dimensions of a tensor.
@@ -1185,9 +1197,10 @@ dimension: int32, 0 <= dimension < rank(input). Describes which dimension
REGISTER_OP("ArgMin")
.Input("input: T")
- .Input("dimension: int32")
+ .Input("dimension: Tidx")
.Output("output: int64")
.Attr("T: numbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(ArgOpShape)
.Doc(R"doc(
Returns the index with the smallest value across dimensions of a tensor.
@@ -1490,10 +1503,11 @@ output: Has same shape as data, except for the first `segment_ids.rank`
REGISTER_OP("SparseSegmentSum")
.Input("data: T")
- .Input("indices: int32")
+ .Input("indices: Tidx")
.Input("segment_ids: int32")
.Output("output: T")
.Attr("T: realnumbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionShapeFn)
.Doc(R"doc(
Computes the sum along sparse segments of a tensor.
@@ -1538,10 +1552,11 @@ output: Has same shape as data, except for dimension 0 which
REGISTER_OP("SparseSegmentMean")
.Input("data: T")
- .Input("indices: int32")
+ .Input("indices: Tidx")
.Input("segment_ids: int32")
.Output("output: T")
.Attr("T: {float, double}")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionShapeFn)
.Doc(R"doc(
Computes the mean along sparse segments of a tensor.
@@ -1564,11 +1579,12 @@ output: Has same shape as data, except for dimension 0 which
REGISTER_OP("SparseSegmentMeanGrad")
.Input("grad: T")
- .Input("indices: int32")
+ .Input("indices: Tidx")
.Input("segment_ids: int32")
.Input("output_dim0: int32")
.Output("output: T")
.Attr("T: {float, double}")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionGradShapeFn)
.Doc(R"doc(
Computes gradients for SparseSegmentMean.
@@ -1584,10 +1600,11 @@ output_dim0: dimension 0 of "data" passed to SparseSegmentMean op.
REGISTER_OP("SparseSegmentSqrtN")
.Input("data: T")
- .Input("indices: int32")
+ .Input("indices: Tidx")
.Input("segment_ids: int32")
.Output("output: T")
.Attr("T: {float, double}")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionShapeFn)
.Doc(R"doc(
Computes the sum along sparse segments of a tensor divided by the sqrt of N.
@@ -1609,11 +1626,12 @@ output: Has same shape as data, except for dimension 0 which
REGISTER_OP("SparseSegmentSqrtNGrad")
.Input("grad: T")
- .Input("indices: int32")
+ .Input("indices: Tidx")
.Input("segment_ids: int32")
.Input("output_dim0: int32")
.Output("output: T")
.Attr("T: {float, double}")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(SparseSegmentReductionGradShapeFn)
.Doc(R"doc(
Computes gradients for SparseSegmentSqrtN.
@@ -1629,9 +1647,10 @@ output_dim0: dimension 0 of "data" passed to SparseSegmentSqrtN op.
REGISTER_OP("All")
.Input("input: bool")
- .Input("reduction_indices: int32")
+ .Input("reduction_indices: Tidx")
.Output("output: bool")
.Attr("keep_dims: bool = false")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ReductionShape)
.Doc(R"doc(
Computes the "logical and" of elements across dimensions of a tensor.
@@ -1649,9 +1668,10 @@ output: The reduced tensor.
REGISTER_OP("Any")
.Input("input: bool")
- .Input("reduction_indices: int32")
+ .Input("reduction_indices: Tidx")
.Attr("keep_dims: bool = false")
.Output("output: bool")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ReductionShape)
.Doc(R"doc(
Computes the "logical or" of elements across dimensions of a tensor.
@@ -1670,10 +1690,11 @@ output: The reduced tensor.
// --------------------------------------------------------------------------
REGISTER_OP("Range")
- .Input("start: int32")
- .Input("limit: int32")
- .Input("delta: int32")
- .Output("output: int32")
+ .Input("start: Tidx")
+ .Input("limit: Tidx")
+ .Input("delta: Tidx")
+ .Output("output: Tidx")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
@@ -1689,9 +1710,17 @@ REGISTER_OP("Range")
c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
return Status::OK();
}
- const int32 start = start_t->scalar<int32>()();
- const int32 limit = limit_t->scalar<int32>()();
- const int32 delta = delta_t->scalar<int32>()();
+ // TODO
+ int64 start, limit, delta;
+ if (start_t->dtype() == DT_INT32) {
+ start = start_t->scalar<int32>()();
+ limit = limit_t->scalar<int32>()();
+ delta = delta_t->scalar<int32>()();
+ } else {
+ start = start_t->scalar<int64>()();
+ limit = limit_t->scalar<int64>()();
+ delta = delta_t->scalar<int64>()();
+ }
if (start > limit) {
return errors::InvalidArgument("Requires start <= limit: ", start, "/",
limit);
@@ -1699,7 +1728,7 @@ REGISTER_OP("Range")
if (delta <= 0) {
return errors::InvalidArgument("Requires delta > 0: ", delta);
}
- const int32 size = (limit - start + delta - 1) / delta;
+ const int64 size = (limit - start + delta - 1) / delta;
c->set_output(0, c->Vector(size));
return Status::OK();
})
@@ -1727,9 +1756,10 @@ output: 1-D.
REGISTER_OP("LinSpace")
.Input("start: T")
.Input("stop: T")
- .Input("num: int32")
+ .Input("num: Tidx")
.Output("output: T")
.Attr("T: {float, double}")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
@@ -1743,7 +1773,13 @@ REGISTER_OP("LinSpace")
c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
return Status::OK();
}
- const int64 num = num_t->scalar<int32>()();
+
+ int64 num;
+ if (num_t->dtype() == DT_INT32) {
+ num = num_t->scalar<int32>()();
+ } else {
+ num = num_t->scalar<int64>()();
+ }
if (num <= 0) return errors::InvalidArgument("Requires num > 0: ", num);
c->set_output(0, c->Vector(num));
return Status::OK();
@@ -2057,11 +2093,12 @@ product: Pairwise cross product of the vectors in `a` and `b`.
REGISTER_OP("Cumsum")
.Input("x: T")
- .Input("axis: int32")
+ .Input("axis: Tidx")
.Attr("exclusive: bool = false")
.Attr("reverse: bool = false")
.Output("out: T")
.Attr("T: numbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.Doc(R"doc(
Compute the cumulative sum of the tensor `x` along `axis`.
@@ -2092,11 +2129,12 @@ tf.cumsum([a, b, c], exclusive=True, reverse=True) ==> [b + c, c, 0]
REGISTER_OP("Cumprod")
.Input("x: T")
- .Input("axis: int32")
+ .Input("axis: Tidx")
.Attr("exclusive: bool = false")
.Attr("reverse: bool = false")
.Output("out: T")
.Attr("T: numbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
.Doc(R"doc(
Compute the cumulative product of the tensor `x` along `axis`.
diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py
index b2e1dfce27..b34bb068d7 100644
--- a/tensorflow/python/framework/op_def_library.py
+++ b/tensorflow/python/framework/op_def_library.py
@@ -349,6 +349,23 @@ class OpDefLibrary(object):
(op_type_name, producer, deprecation_version,
op_def.deprecation.explanation))
+ # Fill in the list of default types for all "type" attrs. This
+ # will be used to choose a preferred dtype to convert to in the
+ # absence of input type information.
+ #
+ # TODO(b/31302892): Currently the defaults don't work in the right
+ # way if you have two inputs, one of whose type resolution depends
+ # on the other. Handling this will require restructuring this code
+ # significantly.
+ default_type_attr_map = {}
+ for attr_def in op_def.attr:
+ if attr_def.type != "type":
+ continue
+ key = attr_def.name
+ if attr_def.HasField("default_value"):
+ default_type_attr_map[key] = dtypes.as_dtype(
+ attr_def.default_value.type)
+
# Requires that op_def has passed validation (using the C++
# ValidateOpDef() from ../framework/op_def_util.h).
attrs = {}
@@ -390,6 +407,7 @@ class OpDefLibrary(object):
# In cases where we expect all elements of the list to have the
# same dtype, try to cast non-Tensor elements to that type.
dtype = None
+ default_dtype = None
if input_arg.type != types_pb2.DT_INVALID:
dtype = input_arg.type
elif input_arg.number_attr:
@@ -401,11 +419,19 @@ class OpDefLibrary(object):
dtype = t.dtype
break
+ # dtype still not found, prefer using the default dtype
+ # from the attr.
+ if dtype is None and input_arg.type_attr in default_type_attr_map:
+ default_dtype = default_type_attr_map[input_arg.type_attr]
+
try:
if not input_arg.is_ref and dtype:
dtype = dtypes.as_dtype(dtype).base_dtype
values = ops.convert_n_to_tensor(
- values, name=input_arg.name, dtype=dtype if dtype else None,
+ values,
+ name=input_arg.name,
+ dtype=dtype if dtype else None,
+ preferred_dtype=default_dtype,
as_ref=input_arg.is_ref)
if input_arg.number_attr and len(
set(v.dtype.base_dtype for v in values)) > 1:
@@ -444,14 +470,24 @@ class OpDefLibrary(object):
# In cases where we have an expected type, try to convert non-Tensor
# arguments to that type.
dtype = None
+ default_dtype = None
if input_arg.type != types_pb2.DT_INVALID:
dtype = input_arg.type
elif input_arg.type_attr in attrs:
dtype = attrs[input_arg.type_attr]
+ elif input_arg.type_attr in default_type_attr_map:
+ # The dtype could not be inferred solely from the inputs,
+ # so we prefer the attr's default, so code that adds a new attr
+ # with a default is backwards compatible.
+ default_dtype = default_type_attr_map[input_arg.type_attr]
+
try:
values = ops.convert_to_tensor(
- values, name=input_arg.name, dtype=dtype,
- as_ref=input_arg.is_ref)
+ values,
+ name=input_arg.name,
+ dtype=dtype,
+ as_ref=input_arg.is_ref,
+ preferred_dtype=default_dtype)
except ValueError:
# What type does convert_to_tensor think it has?
observed = ops.convert_to_tensor(values,
@@ -462,6 +498,14 @@ class OpDefLibrary(object):
raise TypeError("%s expected type of %s." %
(prefix, dtypes.as_dtype(input_arg.type).name))
else:
+ # Update the maps with the default, if needed.
+ k = input_arg.type_attr
+ if k in default_type_attr_map:
+ if k not in attrs:
+ attrs[k] = default_type_attr_map[k]
+ if k not in inferred_from:
+ inferred_from[k] = "Default in OpDef"
+
raise TypeError(
"%s type %s of argument '%s'." %
(prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name,
diff --git a/tensorflow/python/framework/op_def_library_test.py b/tensorflow/python/framework/op_def_library_test.py
index 27933cc193..d95bb45cb8 100644
--- a/tensorflow/python/framework/op_def_library_test.py
+++ b/tensorflow/python/framework/op_def_library_test.py
@@ -47,6 +47,8 @@ ops.RegisterShape("AttrShape")(None)
ops.RegisterShape("AttrShapeList")(None)
ops.RegisterShape("AttrPartialShape")(None)
ops.RegisterShape("AttrPartialShapeList")(None)
+ops.RegisterShape("AttrTypeDefault")(None)
+ops.RegisterShape("AttrListTypeDefault")(None)
ops.RegisterShape("Binary")(None)
ops.RegisterShape("ComplexStruct")(None)
ops.RegisterShape("InPolymorphicTwice")(None)
@@ -868,6 +870,44 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
name: 'x' op: 'ReservedAttr' attr { key: 'range' value { i: 7 } }
""", op.node_def)
+ def testDefaultAttrType(self):
+ self._add_op("name: 'AttrTypeDefault' "
+ "input_arg { name: 'a' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' "
+ " default_value { type: DT_INT32 } }")
+
+ # Give an input whose type has no obvious output type.
+ op = self._lib.apply_op("AttrTypeDefault", a=[], name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'AttrTypeDefault' input: 'n/a'
+ attr { key: 'T' value { type: DT_INT32 } }
+ """, op.node_def)
+
+ # Give an input whose type can be inferred as different
+ # than the default.
+ op = self._lib.apply_op("AttrTypeDefault", a=[1.0], name="f")
+ self.assertProtoEquals("""
+ name: 'f' op: 'AttrTypeDefault' input: 'f/a'
+ attr { key: 'T' value { type: DT_FLOAT } }
+ """, op.node_def)
+
+ def testDefaultListAttrType(self):
+ self._add_op("name: 'AttrListTypeDefault' "
+ "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "input_arg { name: 'b' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type' "
+ " default_value { type: DT_INT32 } }"
+ "attr { name: 'N' type: 'int' }")
+
+ # Give an input whose type can be inferred as different
+ # than the default.
+ op = self._lib.apply_op("AttrListTypeDefault", a=[1.0], b=[2.0], name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'AttrListTypeDefault' input: 'n/a_0' input: 'n/b_0'
+ attr { key: 'T' value { type: DT_FLOAT } }
+ attr { key: 'N' value { i: 1 } }
+ """, op.node_def)
+
def testNIntsIn(self):
self._add_op("name: 'NIntsIn' "
"input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index f47ecd1b36..a5e253eabf 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -577,7 +577,11 @@ _tensor_conversion_func_registry = {
register_dense_tensor_like_type(Tensor)
-def convert_to_tensor(value, dtype=None, name=None, as_ref=False):
+def convert_to_tensor(value,
+ dtype=None,
+ name=None,
+ as_ref=False,
+ preferred_dtype=None):
"""Converts the given `value` to a `Tensor`.
This function converts Python objects of various types to `Tensor`
@@ -610,6 +614,11 @@ def convert_to_tensor(value, dtype=None, name=None, as_ref=False):
name: Optional name to use if a new `Tensor` is created.
as_ref: True if we want the result as a ref tensor. Only used if a new
`Tensor` is created.
+ preferred_dtype: Optional element type for the returned tensor,
+ used when dtype is None. In some cases, a caller may not have a
+ dtype in mind when converting to a tensor, so preferred_dtype
+ can be used as a soft preference. If the conversion to
+ `preferred_dtype` is not possible, this argument has no effect.
Returns:
A `Tensor` based on `value`.
@@ -625,9 +634,31 @@ def convert_to_tensor(value, dtype=None, name=None, as_ref=False):
for _, funcs_at_priority in sorted(_tensor_conversion_func_registry.items()):
for base_type, conversion_func in funcs_at_priority:
if isinstance(value, base_type):
- ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
+ # If dtype is None but preferred_dtype is not None, we try to
+ # cast to preferred_dtype first.
+ ret = None
+ if dtype is None and preferred_dtype is not None:
+ try:
+ ret = conversion_func(
+ value, dtype=preferred_dtype, name=name, as_ref=as_ref)
+ except (TypeError, ValueError):
+ # Could not coerce the conversion to use the preferred dtype.
+ ret = None
+
+ if ret is not None and ret is not NotImplemented:
+ if (ret.dtype.base_dtype !=
+ dtypes.as_dtype(preferred_dtype).base_dtype):
+ raise TypeError("convert_to_tensor did not convert to "
+ "the preferred dtype: %s vs %s " %
+ (ret.dtype.base_dtype,
+ dtypes.as_dtype(preferred_dtype).base_dtype))
+
+ if ret is None:
+ ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
+
if ret is NotImplemented:
continue
+
if not isinstance(ret, Tensor):
raise RuntimeError(
"%sConversion function %r for type %s returned non-Tensor: %r"
@@ -644,7 +675,11 @@ def convert_to_tensor(value, dtype=None, name=None, as_ref=False):
% (error_prefix, value, type(value)))
-def convert_n_to_tensor(values, dtype=None, name=None, as_ref=False):
+def convert_n_to_tensor(values,
+ dtype=None,
+ name=None,
+ as_ref=False,
+ preferred_dtype=None):
"""Converts `values` to a list of `Tensor` objects.
Args:
@@ -654,6 +689,11 @@ def convert_n_to_tensor(values, dtype=None, name=None, as_ref=False):
created, in which case element `i` will be given the name `name
+ '_' + i`.
as_ref: True if the caller wants the results as ref tensors.
+ preferred_dtype: Optional element type for the returned tensors,
+ used when dtype is None. In some cases, a caller may not have a
+ dtype in mind when converting to a tensor, so preferred_dtype
+ can be used as a soft preference. If the conversion to
+ `preferred_dtype` is not possible, this argument has no effect.
Returns:
A list of `Tensor` and/or `IndexedSlices` objects.
@@ -669,7 +709,13 @@ def convert_n_to_tensor(values, dtype=None, name=None, as_ref=False):
ret = []
for i, value in enumerate(values):
n = None if name is None else "%s_%d" % (name, i)
- ret.append(convert_to_tensor(value, dtype=dtype, name=n, as_ref=as_ref))
+ ret.append(
+ convert_to_tensor(
+ value,
+ dtype=dtype,
+ name=n,
+ as_ref=as_ref,
+ preferred_dtype=preferred_dtype))
return ret
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 98ad4a2d60..eac85ac844 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -280,6 +280,31 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertEqual(tensor_shape.unknown_shape(),
_apply_op(g, "an_op", [], [dtypes.float32]).get_shape())
+ def testConvertToTensorPreferred(self):
+ with self.test_session():
+ values = [2, 3, 5, 7]
+ tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32)
+ self.assertEqual(dtypes.float32, tensor.dtype)
+
+ with self.test_session():
+ # Convert empty tensor to anything
+ values = []
+ tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
+ self.assertEqual(dtypes.int64, tensor.dtype)
+
+ with self.test_session():
+ # The preferred dtype is a type error and will convert to
+ # float32 instead.
+ values = [1.23]
+ tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
+ self.assertEqual(dtypes.float32, tensor.dtype)
+
+ def testConvertToInvalidTensorType(self):
+ with self.assertRaises(TypeError):
+ # Forcing an invalid dtype should fail with a type error.
+ values = [1.23]
+ _ = ops.convert_to_tensor(values, dtype=dtypes.int64)
+
def testNoConvert(self):
# Operation cannot be converted to Tensor
op = control_flow_ops.no_op()