aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar David G. Andersen <dga@google.com>2016-05-23 15:51:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-23 17:02:47 -0700
commitd4ef9aa02c3c8297a053176918beaf34c13b73a6 (patch)
treefe900bd53361c3091339bc457d5af5c02ad47d7f
parent81af087fb71a8323b0af09694a4de1ab54646de8 (diff)
int64->32 fixes. Notable fixes: Have listdiff return an explicit error instead
of mishandling > 2^31 entry X tensors. More descriptive naming inside LRN. Change: 123053130
-rw-r--r--tensorflow/core/kernels/decode_png_op.cc4
-rw-r--r--tensorflow/core/kernels/edit_distance_op.cc2
-rw-r--r--tensorflow/core/kernels/example_parsing_ops.cc9
-rw-r--r--tensorflow/core/kernels/listdiff_op.cc20
-rw-r--r--tensorflow/core/kernels/lrn_op.cc19
-rw-r--r--tensorflow/examples/android/jni/tensorflow_jni.cc12
-rw-r--r--tensorflow/examples/android/jni/tensorflow_jni.h8
7 files changed, 39 insertions, 35 deletions
diff --git a/tensorflow/core/kernels/decode_png_op.cc b/tensorflow/core/kernels/decode_png_op.cc
index 827d8d23f7..6cd4d7e66f 100644
--- a/tensorflow/core/kernels/decode_png_op.cc
+++ b/tensorflow/core/kernels/decode_png_op.cc
@@ -70,8 +70,8 @@ class DecodePngOp : public OpKernel {
// verify single dimension is not too large.
// - verify when width and height are multiplied together, there are a few
// bits to spare as well.
- const int width = decode.width;
- const int height = decode.height;
+ const int width = static_cast<int>(decode.width);
+ const int height = static_cast<int>(decode.height);
const int64 total_size =
static_cast<int64>(width) * static_cast<int64>(height);
if (width != static_cast<int64>(decode.width) || width <= 0 ||
diff --git a/tensorflow/core/kernels/edit_distance_op.cc b/tensorflow/core/kernels/edit_distance_op.cc
index b4d14e8c62..7f0b73e6a2 100644
--- a/tensorflow/core/kernels/edit_distance_op.cc
+++ b/tensorflow/core/kernels/edit_distance_op.cc
@@ -144,7 +144,7 @@ class EditDistanceOp : public OpKernel {
std::iota(group_dims.begin(), group_dims.end(), 0);
TensorShape output_shape;
- for (size_t d = 0; d < group_dims.size(); ++d) {
+ for (int d = 0; d < static_cast<int>(group_dims.size()); ++d) {
output_shape.AddDim(std::max(hypothesis_st_shape.dim_size(d),
truth_st_shape.dim_size(d)));
}
diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc
index b7e0109b75..221706f831 100644
--- a/tensorflow/core/kernels/example_parsing_ops.cc
+++ b/tensorflow/core/kernels/example_parsing_ops.cc
@@ -48,6 +48,8 @@ class ExampleParserOp : public OpKernel {
errors::InvalidArgument("len(dense_keys) != len(dense_types"));
OP_REQUIRES(ctx, static_cast<size_t>(num_dense_) == dense_shapes_.size(),
errors::InvalidArgument("len(dense_keys) != len(dense_shapes"));
+ OP_REQUIRES(ctx, num_dense_ <= TensorShape::MaxDimensions(),
+ errors::InvalidArgument("num_dense_ too large"));
for (const DataType& type : dense_types_) {
OP_REQUIRES_OK(ctx, CheckValidType(type));
}
@@ -108,7 +110,7 @@ class ExampleParserOp : public OpKernel {
"Expected len(dense_defaults) == len(dense_keys) but got: ",
dense_defaults.size(), " vs. ", num_dense_));
- for (int d = 0; d < num_dense_; ++d) {
+ for (int d = 0; d < static_cast<int>(num_dense_); ++d) {
const Tensor& def_value = dense_defaults[d];
if (def_value.NumElements() > 0) {
OP_REQUIRES(ctx, def_value.shape() == dense_shapes_[d],
@@ -126,7 +128,7 @@ class ExampleParserOp : public OpKernel {
auto serialized_t = serialized->vec<string>();
- const int batch_size = serialized_t.size();
+ const int64 batch_size = serialized_t.size();
OpOutputList sparse_indices;
OpOutputList sparse_values;
@@ -146,7 +148,8 @@ class ExampleParserOp : public OpKernel {
// Preallocate dense_values, since we know their sizes
TensorShape out_shape;
out_shape.AddDim(batch_size);
- for (const int dim : dense_shapes_[d].dim_sizes()) out_shape.AddDim(dim);
+ for (const int64 dim : dense_shapes_[d].dim_sizes())
+ out_shape.AddDim(dim);
Tensor* out = nullptr;
dense_values.allocate(d, out_shape, &out);
diff --git a/tensorflow/core/kernels/listdiff_op.cc b/tensorflow/core/kernels/listdiff_op.cc
index 891f7888ab..9e221efac9 100644
--- a/tensorflow/core/kernels/listdiff_op.cc
+++ b/tensorflow/core/kernels/listdiff_op.cc
@@ -42,20 +42,24 @@ class ListDiffOp : public OpKernel {
OP_REQUIRES(context, TensorShapeUtils::IsVector(y.shape()),
errors::InvalidArgument("y should be a 1D vector."));
- std::unordered_set<T> y_set;
+ const auto Tx = x.vec<T>();
+ const size_t x_size = Tx.size();
const auto Ty = y.vec<T>();
- const int y_size = Ty.size();
+ const size_t y_size = Ty.size();
+
+ OP_REQUIRES(context, x_size < std::numeric_limits<int32>::max(),
+ errors::InvalidArgument("x too large for int32 indexing"));
+
+ std::unordered_set<T> y_set;
y_set.reserve(y_size);
- for (int i = 0; i < y_size; ++i) {
+ for (size_t i = 0; i < y_size; ++i) {
y_set.insert(Ty(i));
}
// Compute the size of the output.
- const auto Tx = x.vec<T>();
- const int x_size = Tx.size();
- int out_size = 0;
- for (int i = 0; i < x_size; ++i) {
+ int64 out_size = 0;
+ for (size_t i = 0; i < x_size; ++i) {
if (y_set.count(Tx(i)) == 0) {
++out_size;
}
@@ -70,7 +74,7 @@ class ListDiffOp : public OpKernel {
OP_REQUIRES_OK(context, context->allocate_output(1, {out_size}, &indices));
auto Tindices = indices->vec<int32>();
- for (int i = 0, p = 0; i < x_size; ++i) {
+ for (int i = 0, p = 0; i < static_cast<int32>(x_size); ++i) {
if (y_set.count(Tx(i)) == 0) {
OP_REQUIRES(context, p < out_size,
errors::InvalidArgument(
diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc
index 5a4930ec63..8a1b7b3de0 100644
--- a/tensorflow/core/kernels/lrn_op.cc
+++ b/tensorflow/core/kernels/lrn_op.cc
@@ -33,10 +33,10 @@ namespace tensorflow {
namespace {
-// When the depth is large and beta_ is 0.5 or 1.0, MognetLRN is faster than the
-// main band matrix approach used below. Benchmarks suggest switching to
-// MognetLRN when depth > 384.
-const int kMognetLRNDepthCutoff = 384;
+// When the depth is large and beta_ is 0.5 or 1.0, Single-threaded
+// LRN is faster than the main band matrix approach used
+// below. Benchmarks suggest switching to SingleThreadedLRN when depth > 384.
+const int kSingleThreadedLRNDepthCutoff = 384;
// Create a depth-by-depth band matrix with 1s along a swath of size (2 *
// depth_radius + 1) around the diagonal.
@@ -88,10 +88,11 @@ class LRNOp : public OpKernel {
0, TensorShape({batch, rows, cols, depth}), &output));
#if defined(__ANDROID__)
- MognetLRN(in, batch, rows, cols, depth, output);
+ SingleThreadedLRN(in, batch, rows, cols, depth, output);
#else
- if (depth > kMognetLRNDepthCutoff && (beta_ == 0.5f || beta_ == 1.0f)) {
- MognetLRN(in, batch, rows, cols, depth, output);
+ if (depth > kSingleThreadedLRNDepthCutoff &&
+ (beta_ == 0.5f || beta_ == 1.0f)) {
+ SingleThreadedLRN(in, batch, rows, cols, depth, output);
return;
}
@@ -124,8 +125,8 @@ class LRNOp : public OpKernel {
private:
typedef Eigen::Tensor<float, 1, Eigen::RowMajor>::DimensionPair DimPair;
- void MognetLRN(const Tensor& in, const int batch, const int rows,
- const int cols, const int depth, Tensor* out) {
+ void SingleThreadedLRN(const Tensor& in, const int batch, const int rows,
+ const int cols, const int depth, Tensor* out) {
Eigen::Map<const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>>
data_in(in.flat<float>().data(), depth, batch * rows * cols);
diff --git a/tensorflow/examples/android/jni/tensorflow_jni.cc b/tensorflow/examples/android/jni/tensorflow_jni.cc
index f61eb0655c..75d834b735 100644
--- a/tensorflow/examples/android/jni/tensorflow_jni.cc
+++ b/tensorflow/examples/android/jni/tensorflow_jni.cc
@@ -48,7 +48,7 @@ static std::vector<std::string> g_label_strings;
static bool g_compute_graph_initialized = false;
//static mutex g_compute_graph_mutex(base::LINKER_INITIALIZED);
-static int g_tensorflow_input_size; // The image size for the mognet input.
+static int g_tensorflow_input_size; // The image size for the model input.
static int g_image_mean; // The image mean.
static std::unique_ptr<StatSummarizer> g_stats;
@@ -82,11 +82,9 @@ inline static int64 CurrentThreadTimeUs() {
return tv.tv_sec * 1000000 + tv.tv_usec;
}
-JNIEXPORT jint JNICALL
-TENSORFLOW_METHOD(initializeTensorflow)(
- JNIEnv* env, jobject thiz, jobject java_asset_manager,
- jstring model, jstring labels,
- jint num_classes, jint mognet_input_size, jint image_mean) {
+JNIEXPORT jint JNICALL TENSORFLOW_METHOD(initializeTensorflow)(
+ JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring model,
+ jstring labels, jint num_classes, jint model_input_size, jint image_mean) {
g_num_runs = 0;
g_timing_total_us = 0;
g_frequency_start.Reset();
@@ -103,7 +101,7 @@ TENSORFLOW_METHOD(initializeTensorflow)(
const char* const model_cstr = env->GetStringUTFChars(model, NULL);
const char* const labels_cstr = env->GetStringUTFChars(labels, NULL);
- g_tensorflow_input_size = mognet_input_size;
+ g_tensorflow_input_size = model_input_size;
g_image_mean = image_mean;
LOG(INFO) << "Loading Tensorflow.";
diff --git a/tensorflow/examples/android/jni/tensorflow_jni.h b/tensorflow/examples/android/jni/tensorflow_jni.h
index 8c94e76a75..7c714b986a 100644
--- a/tensorflow/examples/android/jni/tensorflow_jni.h
+++ b/tensorflow/examples/android/jni/tensorflow_jni.h
@@ -30,11 +30,9 @@ extern "C" {
#define TENSORFLOW_METHOD(METHOD_NAME) \
Java_org_tensorflow_demo_TensorflowClassifier_##METHOD_NAME // NOLINT
-JNIEXPORT jint JNICALL
-TENSORFLOW_METHOD(initializeTensorflow)(
- JNIEnv* env, jobject thiz, jobject java_asset_manager,
- jstring model, jstring labels,
- jint num_classes, jint mognet_input_size, jint image_mean);
+JNIEXPORT jint JNICALL TENSORFLOW_METHOD(initializeTensorflow)(
+ JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring model,
+ jstring labels, jint num_classes, jint model_input_size, jint image_mean);
JNIEXPORT jstring JNICALL
TENSORFLOW_METHOD(classifyImageBmp)(