From d4ef9aa02c3c8297a053176918beaf34c13b73a6 Mon Sep 17 00:00:00 2001 From: "David G. Andersen" Date: Mon, 23 May 2016 15:51:03 -0800 Subject: int64->32 fixes. Notable fixes: Have listdiff return an explicit error instead of mishandling > 2^31 entry X tensors. More descriptive naming inside LRN. Change: 123053130 --- tensorflow/core/kernels/decode_png_op.cc | 4 ++-- tensorflow/core/kernels/edit_distance_op.cc | 2 +- tensorflow/core/kernels/example_parsing_ops.cc | 9 ++++++--- tensorflow/core/kernels/listdiff_op.cc | 20 ++++++++++++-------- tensorflow/core/kernels/lrn_op.cc | 19 ++++++++++--------- tensorflow/examples/android/jni/tensorflow_jni.cc | 12 +++++------- tensorflow/examples/android/jni/tensorflow_jni.h | 8 +++----- 7 files changed, 39 insertions(+), 35 deletions(-) diff --git a/tensorflow/core/kernels/decode_png_op.cc b/tensorflow/core/kernels/decode_png_op.cc index 827d8d23f7..6cd4d7e66f 100644 --- a/tensorflow/core/kernels/decode_png_op.cc +++ b/tensorflow/core/kernels/decode_png_op.cc @@ -70,8 +70,8 @@ class DecodePngOp : public OpKernel { // verify single dimension is not too large. // - verify when width and height are multiplied together, there are a few // bits to spare as well. - const int width = decode.width; - const int height = decode.height; + const int width = static_cast(decode.width); + const int height = static_cast(decode.height); const int64 total_size = static_cast(width) * static_cast(height); if (width != static_cast(decode.width) || width <= 0 || diff --git a/tensorflow/core/kernels/edit_distance_op.cc b/tensorflow/core/kernels/edit_distance_op.cc index b4d14e8c62..7f0b73e6a2 100644 --- a/tensorflow/core/kernels/edit_distance_op.cc +++ b/tensorflow/core/kernels/edit_distance_op.cc @@ -144,7 +144,7 @@ class EditDistanceOp : public OpKernel { std::iota(group_dims.begin(), group_dims.end(), 0); TensorShape output_shape; - for (size_t d = 0; d < group_dims.size(); ++d) { + for (int d = 0; d < static_cast(group_dims.size()); ++d) { output_shape.AddDim(std::max(hypothesis_st_shape.dim_size(d), truth_st_shape.dim_size(d))); } diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc index b7e0109b75..221706f831 100644 --- a/tensorflow/core/kernels/example_parsing_ops.cc +++ b/tensorflow/core/kernels/example_parsing_ops.cc @@ -48,6 +48,8 @@ class ExampleParserOp : public OpKernel { errors::InvalidArgument("len(dense_keys) != len(dense_types")); OP_REQUIRES(ctx, static_cast(num_dense_) == dense_shapes_.size(), errors::InvalidArgument("len(dense_keys) != len(dense_shapes")); + OP_REQUIRES(ctx, num_dense_ <= TensorShape::MaxDimensions(), + errors::InvalidArgument("num_dense_ too large")); for (const DataType& type : dense_types_) { OP_REQUIRES_OK(ctx, CheckValidType(type)); } @@ -108,7 +110,7 @@ class ExampleParserOp : public OpKernel { "Expected len(dense_defaults) == len(dense_keys) but got: ", dense_defaults.size(), " vs. ", num_dense_)); - for (int d = 0; d < num_dense_; ++d) { + for (int d = 0; d < static_cast(num_dense_); ++d) { const Tensor& def_value = dense_defaults[d]; if (def_value.NumElements() > 0) { OP_REQUIRES(ctx, def_value.shape() == dense_shapes_[d], @@ -126,7 +128,7 @@ class ExampleParserOp : public OpKernel { auto serialized_t = serialized->vec(); - const int batch_size = serialized_t.size(); + const int64 batch_size = serialized_t.size(); OpOutputList sparse_indices; OpOutputList sparse_values; @@ -146,7 +148,8 @@ class ExampleParserOp : public OpKernel { // Preallocate dense_values, since we know their sizes TensorShape out_shape; out_shape.AddDim(batch_size); - for (const int dim : dense_shapes_[d].dim_sizes()) out_shape.AddDim(dim); + for (const int64 dim : dense_shapes_[d].dim_sizes()) + out_shape.AddDim(dim); Tensor* out = nullptr; dense_values.allocate(d, out_shape, &out); diff --git a/tensorflow/core/kernels/listdiff_op.cc b/tensorflow/core/kernels/listdiff_op.cc index 891f7888ab..9e221efac9 100644 --- a/tensorflow/core/kernels/listdiff_op.cc +++ b/tensorflow/core/kernels/listdiff_op.cc @@ -42,20 +42,24 @@ class ListDiffOp : public OpKernel { OP_REQUIRES(context, TensorShapeUtils::IsVector(y.shape()), errors::InvalidArgument("y should be a 1D vector.")); - std::unordered_set y_set; + const auto Tx = x.vec(); + const size_t x_size = Tx.size(); const auto Ty = y.vec(); - const int y_size = Ty.size(); + const size_t y_size = Ty.size(); + + OP_REQUIRES(context, x_size < std::numeric_limits::max(), + errors::InvalidArgument("x too large for int32 indexing")); + + std::unordered_set y_set; y_set.reserve(y_size); - for (int i = 0; i < y_size; ++i) { + for (size_t i = 0; i < y_size; ++i) { y_set.insert(Ty(i)); } // Compute the size of the output. - const auto Tx = x.vec(); - const int x_size = Tx.size(); - int out_size = 0; - for (int i = 0; i < x_size; ++i) { + int64 out_size = 0; + for (size_t i = 0; i < x_size; ++i) { if (y_set.count(Tx(i)) == 0) { ++out_size; } @@ -70,7 +74,7 @@ class ListDiffOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(1, {out_size}, &indices)); auto Tindices = indices->vec(); - for (int i = 0, p = 0; i < x_size; ++i) { + for (int i = 0, p = 0; i < static_cast(x_size); ++i) { if (y_set.count(Tx(i)) == 0) { OP_REQUIRES(context, p < out_size, errors::InvalidArgument( diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc index 5a4930ec63..8a1b7b3de0 100644 --- a/tensorflow/core/kernels/lrn_op.cc +++ b/tensorflow/core/kernels/lrn_op.cc @@ -33,10 +33,10 @@ namespace tensorflow { namespace { -// When the depth is large and beta_ is 0.5 or 1.0, MognetLRN is faster than the -// main band matrix approach used below. Benchmarks suggest switching to -// MognetLRN when depth > 384. -const int kMognetLRNDepthCutoff = 384; +// When the depth is large and beta_ is 0.5 or 1.0, Single-threaded +// LRN is faster than the main band matrix approach used +// below. Benchmarks suggest switching to SingleThreadedLRN when depth > 384. +const int kSingleThreadedLRNDepthCutoff = 384; // Create a depth-by-depth band matrix with 1s along a swath of size (2 * // depth_radius + 1) around the diagonal. @@ -88,10 +88,11 @@ class LRNOp : public OpKernel { 0, TensorShape({batch, rows, cols, depth}), &output)); #if defined(__ANDROID__) - MognetLRN(in, batch, rows, cols, depth, output); + SingleThreadedLRN(in, batch, rows, cols, depth, output); #else - if (depth > kMognetLRNDepthCutoff && (beta_ == 0.5f || beta_ == 1.0f)) { - MognetLRN(in, batch, rows, cols, depth, output); + if (depth > kSingleThreadedLRNDepthCutoff && + (beta_ == 0.5f || beta_ == 1.0f)) { + SingleThreadedLRN(in, batch, rows, cols, depth, output); return; } @@ -124,8 +125,8 @@ class LRNOp : public OpKernel { private: typedef Eigen::Tensor::DimensionPair DimPair; - void MognetLRN(const Tensor& in, const int batch, const int rows, - const int cols, const int depth, Tensor* out) { + void SingleThreadedLRN(const Tensor& in, const int batch, const int rows, + const int cols, const int depth, Tensor* out) { Eigen::Map> data_in(in.flat().data(), depth, batch * rows * cols); diff --git a/tensorflow/examples/android/jni/tensorflow_jni.cc b/tensorflow/examples/android/jni/tensorflow_jni.cc index f61eb0655c..75d834b735 100644 --- a/tensorflow/examples/android/jni/tensorflow_jni.cc +++ b/tensorflow/examples/android/jni/tensorflow_jni.cc @@ -48,7 +48,7 @@ static std::vector g_label_strings; static bool g_compute_graph_initialized = false; //static mutex g_compute_graph_mutex(base::LINKER_INITIALIZED); -static int g_tensorflow_input_size; // The image size for the mognet input. +static int g_tensorflow_input_size; // The image size for the model input. static int g_image_mean; // The image mean. static std::unique_ptr g_stats; @@ -82,11 +82,9 @@ inline static int64 CurrentThreadTimeUs() { return tv.tv_sec * 1000000 + tv.tv_usec; } -JNIEXPORT jint JNICALL -TENSORFLOW_METHOD(initializeTensorflow)( - JNIEnv* env, jobject thiz, jobject java_asset_manager, - jstring model, jstring labels, - jint num_classes, jint mognet_input_size, jint image_mean) { +JNIEXPORT jint JNICALL TENSORFLOW_METHOD(initializeTensorflow)( + JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring model, + jstring labels, jint num_classes, jint model_input_size, jint image_mean) { g_num_runs = 0; g_timing_total_us = 0; g_frequency_start.Reset(); @@ -103,7 +101,7 @@ TENSORFLOW_METHOD(initializeTensorflow)( const char* const model_cstr = env->GetStringUTFChars(model, NULL); const char* const labels_cstr = env->GetStringUTFChars(labels, NULL); - g_tensorflow_input_size = mognet_input_size; + g_tensorflow_input_size = model_input_size; g_image_mean = image_mean; LOG(INFO) << "Loading Tensorflow."; diff --git a/tensorflow/examples/android/jni/tensorflow_jni.h b/tensorflow/examples/android/jni/tensorflow_jni.h index 8c94e76a75..7c714b986a 100644 --- a/tensorflow/examples/android/jni/tensorflow_jni.h +++ b/tensorflow/examples/android/jni/tensorflow_jni.h @@ -30,11 +30,9 @@ extern "C" { #define TENSORFLOW_METHOD(METHOD_NAME) \ Java_org_tensorflow_demo_TensorflowClassifier_##METHOD_NAME // NOLINT -JNIEXPORT jint JNICALL -TENSORFLOW_METHOD(initializeTensorflow)( - JNIEnv* env, jobject thiz, jobject java_asset_manager, - jstring model, jstring labels, - jint num_classes, jint mognet_input_size, jint image_mean); +JNIEXPORT jint JNICALL TENSORFLOW_METHOD(initializeTensorflow)( + JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring model, + jstring labels, jint num_classes, jint model_input_size, jint image_mean); JNIEXPORT jstring JNICALL TENSORFLOW_METHOD(classifyImageBmp)( -- cgit v1.2.3