aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite')
-rw-r--r--tensorflow/contrib/lite/build_def.bzl10
-rw-r--r--tensorflow/contrib/lite/examples/label_image/BUILD10
-rw-r--r--tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h16
-rw-r--r--tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h87
-rw-r--r--tensorflow/contrib/lite/examples/label_image/label_image.cc48
-rw-r--r--tensorflow/contrib/lite/examples/label_image/label_image.h7
-rw-r--r--tensorflow/contrib/lite/examples/label_image/label_image.md12
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD24
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils.cc1
-rw-r--r--tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc5
-rw-r--r--tensorflow/contrib/lite/toco/model.h1
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc5
16 files changed, 167 insertions, 66 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 0a097d5a69..19829e4991 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -5,25 +5,25 @@ def tflite_copts():
copts = [
"-DFARMHASH_NO_CXX_STRING",
] + select({
- "//tensorflow:android_arm64": [
+ str(Label("//tensorflow:android_arm64")): [
"-std=c++11",
"-O3",
],
- "//tensorflow:android_arm": [
+ str(Label("//tensorflow:android_arm")): [
"-mfpu=neon",
"-mfloat-abi=softfp",
"-std=c++11",
"-O3",
],
- "//tensorflow:android_x86": [
+ str(Label("//tensorflow:android_x86")): [
"-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK",
],
- "//tensorflow:ios_x86_64": [
+ str(Label("//tensorflow:ios_x86_64")): [
"-msse4.1",
],
"//conditions:default": [],
}) + select({
- "//tensorflow:with_default_optimizations": [],
+ str(Label("//tensorflow:with_default_optimizations")): [],
"//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"],
})
diff --git a/tensorflow/contrib/lite/examples/label_image/BUILD b/tensorflow/contrib/lite/examples/label_image/BUILD
index 476d85c031..959347b549 100644
--- a/tensorflow/contrib/lite/examples/label_image/BUILD
+++ b/tensorflow/contrib/lite/examples/label_image/BUILD
@@ -42,7 +42,15 @@ cc_library(
"bitmap_helpers_impl.h",
"label_image.h",
],
- deps = ["//tensorflow/contrib/lite:string"],
+ deps = [
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite:string",
+ "//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ ],
)
# TODO(ahentz): Test disabled as it has a memory leek from read_bmp
diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h
index 860e27e5ba..97343dde6b 100644
--- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h
+++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H
-#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H
+#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H_
+#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H_
#include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h"
#include "tensorflow/contrib/lite/examples/label_image/label_image.h"
@@ -26,15 +26,15 @@ uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height,
int* channels, Settings* s);
template <class T>
-void downsize(T* out, uint8_t* in, int image_height, int image_width,
- int image_channels, int wanted_height, int wanted_width,
- int wanted_channels, Settings* s);
+void resize(T* out, uint8_t* in, int image_height, int image_width,
+ int image_channels, int wanted_height, int wanted_width,
+ int wanted_channels, Settings* s);
// explicit instantiation
-template void downsize<uint8_t>(uint8_t*, unsigned char*, int, int, int, int,
- int, int, Settings*);
-template void downsize<float>(float*, unsigned char*, int, int, int, int, int,
+template void resize<uint8_t>(uint8_t*, unsigned char*, int, int, int, int, int,
int, Settings*);
+template void resize<float>(float*, unsigned char*, int, int, int, int, int,
+ int, Settings*);
} // namespace label_image
} // namespace tflite
diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h
index 64a931082b..d57f597875 100644
--- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h
+++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h
@@ -13,8 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H
-#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H
+#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_
+#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/string_util.h"
+#include "tensorflow/contrib/lite/version.h"
#include "tensorflow/contrib/lite/examples/label_image/label_image.h"
@@ -22,28 +28,67 @@ namespace tflite {
namespace label_image {
template <class T>
-void downsize(T* out, uint8_t* in, int image_height, int image_width,
- int image_channels, int wanted_height, int wanted_width,
- int wanted_channels, Settings* s) {
- for (int y = 0; y < wanted_height; ++y) {
- const int in_y = (y * image_height) / wanted_height;
- uint8_t* in_row = in + (in_y * image_width * image_channels);
- T* out_row = out + (y * wanted_width * wanted_channels);
- for (int x = 0; x < wanted_width; ++x) {
- const int in_x = (x * image_width) / wanted_width;
- uint8_t* in_pixel = in_row + (in_x * image_channels);
- T* out_pixel = out_row + (x * wanted_channels);
- for (int c = 0; c < wanted_channels; ++c) {
- if (s->input_floating)
- out_pixel[c] = (in_pixel[c] - s->input_mean) / s->input_std;
- else
- out_pixel[c] = in_pixel[c];
- }
- }
+void resize(T* out, uint8_t* in, int image_height, int image_width,
+ int image_channels, int wanted_height, int wanted_width,
+ int wanted_channels, Settings* s) {
+ int number_of_pixels = image_height * image_width * image_channels;
+ std::unique_ptr<Interpreter> interpreter(new Interpreter);
+
+ int base_index = 0;
+
+ // two inputs: input and new_sizes
+ interpreter->AddTensors(2, &base_index);
+ // one output
+ interpreter->AddTensors(1, &base_index);
+ // set input and output tensors
+ interpreter->SetInputs({0, 1});
+ interpreter->SetOutputs({2});
+
+ // set parameters of tensors
+ TfLiteQuantizationParams quant;
+ interpreter->SetTensorParametersReadWrite(
+ 0, kTfLiteFloat32, "input",
+ {1, image_height, image_width, image_channels}, quant);
+ interpreter->SetTensorParametersReadWrite(1, kTfLiteInt32, "new_size", {2},
+ quant);
+ interpreter->SetTensorParametersReadWrite(
+ 2, kTfLiteFloat32, "output",
+ {1, wanted_height, wanted_width, wanted_channels}, quant);
+
+ ops::builtin::BuiltinOpResolver resolver;
+ TfLiteRegistration* resize_op =
+ resolver.FindOp(BuiltinOperator_RESIZE_BILINEAR);
+ interpreter->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, nullptr,
+ resize_op, nullptr);
+
+ interpreter->AllocateTensors();
+
+ // fill input image
+ // in[] are integers, cannot do memcpy() directly
+ auto input = interpreter->typed_tensor<float>(0);
+ for (int i = 0; i < number_of_pixels; i++) {
+ input[i] = in[i];
+ }
+
+ // fill new_sizes
+ interpreter->typed_tensor<int>(1)[0] = wanted_height;
+ interpreter->typed_tensor<int>(1)[1] = wanted_width;
+
+ interpreter->Invoke();
+
+ auto output = interpreter->typed_tensor<float>(2);
+ auto output_number_of_pixels =
+ wanted_height * wanted_height * wanted_channels;
+
+ for (int i = 0; i < output_number_of_pixels; i++) {
+ if (s->input_floating)
+ out[i] = (output[i] - s->input_mean) / s->input_std;
+ else
+ out[i] = (uint8_t)output[i];
}
}
} // namespace label_image
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H
+#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc
index 4d2e1ce0bc..a91467d345 100644
--- a/tensorflow/contrib/lite/examples/label_image/label_image.cc
+++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc
@@ -148,14 +148,22 @@ void RunInference(Settings* s) {
int wanted_width = dims->data[2];
int wanted_channels = dims->data[3];
- if (s->input_floating) {
- downsize<float>(interpreter->typed_tensor<float>(input), in, image_height,
+ switch (interpreter->tensor(input)->type) {
+ case kTfLiteFloat32:
+ s->input_floating = true;
+ resize<float>(interpreter->typed_tensor<float>(input), in, image_height,
image_width, image_channels, wanted_height, wanted_width,
wanted_channels, s);
- } else {
- downsize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in,
+ break;
+ case kTfLiteUInt8:
+ resize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in,
image_height, image_width, image_channels, wanted_height,
wanted_width, wanted_channels, s);
+ break;
+ default:
+ LOG(FATAL) << "cannot handle input type "
+ << interpreter->tensor(input)->type << " yet";
+ exit(-1);
}
struct timeval start_time, stop_time;
@@ -177,13 +185,21 @@ void RunInference(Settings* s) {
std::vector<std::pair<float, int>> top_results;
- if (s->input_floating) {
- get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size,
- num_results, threshold, &top_results, s->input_floating);
- } else {
- get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0),
- output_size, num_results, threshold, &top_results,
- s->input_floating);
+ int output = interpreter->outputs()[0];
+ switch (interpreter->tensor(output)->type) {
+ case kTfLiteFloat32:
+ get_top_n<float>(interpreter->typed_output_tensor<float>(0), output_size,
+ num_results, threshold, &top_results, true);
+ break;
+ case kTfLiteUInt8:
+ get_top_n<uint8_t>(interpreter->typed_output_tensor<uint8_t>(0),
+ output_size, num_results, threshold, &top_results,
+ false);
+ break;
+ default:
+ LOG(FATAL) << "cannot handle output type "
+ << interpreter->tensor(input)->type << " yet";
+ exit(-1);
}
std::vector<string> labels;
@@ -203,13 +219,11 @@ void display_usage() {
LOG(INFO) << "label_image\n"
<< "--accelerated, -a: [0|1], use Android NNAPI or note\n"
<< "--count, -c: loop interpreter->Invoke() for certain times\n"
- << "--input_floating, -f: [0|1] type of input layer is floating "
- "point numbers\n"
<< "--input_mean, -b: input mean\n"
<< "--input_std, -s: input standard deviation\n"
<< "--image, -i: image_name.bmp\n"
<< "--labels, -l: labels for the model\n"
- << "--tflite_mode, -m: model_name.tflite\n"
+ << "--tflite_model, -m: model_name.tflite\n"
<< "--threads, -t: number of threads\n"
<< "--verbose, -v: [0|1] print more information\n"
<< "\n";
@@ -223,7 +237,6 @@ int Main(int argc, char** argv) {
static struct option long_options[] = {
{"accelerated", required_argument, 0, 'a'},
{"count", required_argument, 0, 'c'},
- {"input_floating", required_argument, 0, 'f'},
{"verbose", required_argument, 0, 'v'},
{"image", required_argument, 0, 'i'},
{"labels", required_argument, 0, 'l'},
@@ -254,11 +267,6 @@ int Main(int argc, char** argv) {
s.loop_count = strtol( // NOLINT(runtime/deprecated_fn)
optarg, (char**)NULL, 10);
break;
- case 'f':
- s.input_floating = strtol( // NOLINT(runtime/deprecated_fn)
- optarg, (char**)NULL, 10);
- s.input_layer_type = "float";
- break;
case 'i':
s.input_bmp_name = optarg;
break;
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.h b/tensorflow/contrib/lite/examples/label_image/label_image.h
index ce98e06fc1..4de32e33fb 100644
--- a/tensorflow/contrib/lite/examples/label_image/label_image.h
+++ b/tensorflow/contrib/lite/examples/label_image/label_image.h
@@ -16,9 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
-#include <string>
#include "tensorflow/contrib/lite/string.h"
+namespace tflite {
+namespace label_image {
+
struct Settings {
bool verbose = false;
bool accel = false;
@@ -33,4 +35,7 @@ struct Settings {
int number_of_threads = 4;
};
+} // namespace label_image
+} // namespace tflite
+
#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.md b/tensorflow/contrib/lite/examples/label_image/label_image.md
index d6019d673f..9ce32cf101 100644
--- a/tensorflow/contrib/lite/examples/label_image/label_image.md
+++ b/tensorflow/contrib/lite/examples/label_image/label_image.md
@@ -1,8 +1,12 @@
label_image for TensorFlow Lite inspired by TensorFlow's label_image.
+
+To build label_image for Android, run $TENSORFLOW_ROOT/configure
+and set Android NDK or configure NDK setting in
+$TENSORFLOW_ROOT/WORKSPACE first.
To build it for android ARMv8:
```
-> bazel build --cxxopt=-std=c++11 \
+> bazel build --config monolithic --cxxopt=-std=c++11 \
--crosstool_top=//external:android/crosstool \
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
--cpu=arm64-v8a \
@@ -10,13 +14,13 @@ To build it for android ARMv8:
```
or
```
-> bazel build --config android_arm64 --cxxopt=-std=c++11 \
+> bazel build --config android_arm64 --config monolithic --cxxopt=-std=c++11 \
//tensorflow/contrib/lite/examples/label_image:label_image
```
To build it for android arm-v7a:
```
-> bazel build --cxxopt=-std=c++11 \
+> bazel build --config monolithic --cxxopt=-std=c++11 \
--crosstool_top=//external:android/crosstool \
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
--cpu=armeabi-v7a \
@@ -24,7 +28,7 @@ To build it for android arm-v7a:
```
or
```
-> bazel build --config android_arm --cxxopt=-std=c++11 \
+> bazel build --config android_arm --config monolithic --cxxopt=-std=c++11 \
//tensorflow/contrib/lite/examples/label_image:label_image
```
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index 4691a543e9..a6ccc99a51 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -278,6 +278,8 @@ cc_library(
"optimized/neon_tensor_utils.cc",
],
hdrs = [
+ "common.h",
+ "optimized/cpu_check.h",
"optimized/neon_tensor_utils.h",
"optimized/tensor_utils_impl.h",
],
@@ -285,8 +287,11 @@ cc_library(
deps = [
":cpu_check",
":portable_tensor_utils",
+ ":types",
"//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite/kernels:activation_functor",
+ "@arm_neon_2_x86_sse",
+ "@gemmlowp",
],
)
@@ -306,14 +311,21 @@ cc_library(
"tensor_utils.cc",
],
hdrs = [
+ "common.h",
+ "compatibility.h",
+ "optimized/cpu_check.h",
+ "optimized/neon_tensor_utils.h",
"optimized/tensor_utils_impl.h",
"reference/portable_tensor_utils.h",
"tensor_utils.h",
+ "types.h",
],
copts = NEON_FLAGS_IF_APPLICABLE,
deps = [
"//tensorflow/contrib/lite/kernels:activation_functor",
"//tensorflow/contrib/lite:builtin_op_data",
+ "@arm_neon_2_x86_sse",
+ "@gemmlowp",
] + select({
":arm": [
":neon_tensor_utils",
@@ -333,6 +345,18 @@ cc_library(
":ios_arm64": [
":neon_tensor_utils",
],
+ ":x86_64": [
+ ":neon_tensor_utils",
+ ],
+ ":x86": [
+ ":neon_tensor_utils",
+ ],
+ ":k8": [
+ ":neon_tensor_utils",
+ ],
+ ":darwin": [
+ ":neon_tensor_utils",
+ ],
"//conditions:default": [
":portable_tensor_utils",
],
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h
index 6cb556bf45..3a53d3ab07 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h
@@ -34,7 +34,7 @@ inline bool TestCPUFeatureNeon() {
#endif // __aarch64__
}
-#elif __ARM_NEON
+#elif defined USE_NEON || defined __ARM_NEON
inline bool TestCPUFeatureNeon() { return true; }
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index bf0bdfb1fb..883c7f270d 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -16,11 +16,11 @@ limitations under the License.
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h"
#ifdef USE_NEON
-#include <arm_neon.h>
#define kFloatWeightsPerNeonLane 4
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc
index 904a97803a..f4181b18a8 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
#ifndef USE_NEON
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
index 7019c29959..76032771af 100644
--- a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
+++ b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
@@ -1571,7 +1571,7 @@ inline int ANeuralNetworksModel_addOperation(ANeuralNetworksModel* model,
}
/**
- * Specfifies which operands will be the model's inputs and outputs.
+ * Specifies which operands will be the model's inputs and outputs.
*
* An operand cannot be used for both input and output. Doing so will
* return an error.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc
index 2340f0e850..6961e23690 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc
@@ -132,6 +132,7 @@ bool GraphTransformationsPass(int increment, Model* model,
CHECK(increment == 1 || increment == -1);
bool changed = false;
if (model->operators.empty()) {
+ LOG(INFO) << "Model is empty!!!";
return false;
}
int op_index = increment == 1 ? 0 : model->operators.size() - 1;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
index 833c97c758..e79e2a32fc 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
@@ -189,7 +189,10 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
// Remove all the resolved arrays.
for (const string& input_name : concat_op->inputs) {
- model->EraseArray(input_name);
+ // Check to prevent removal of shared tensors
+ if (CountOpsWithInput(*model, input_name) == 1) {
+ model->EraseArray(input_name);
+ }
}
// Remove concatenate operator
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 8f12bc59fb..0bee694387 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
#define TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
+#include <functional>
#include <initializer_list>
#include <memory>
#include <string>
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 1add90fb82..ce0fde57f4 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -698,10 +698,11 @@ void CheckNonExistentIOArrays(const Model& model) {
void CheckNoMissingArray(const Model& model) {
for (const auto& op : model.operators) {
for (const auto& input : op->inputs) {
- CHECK(model.HasArray(input) || model.optional_arrays.count(input));
+ CHECK(model.HasArray(input) || model.optional_arrays.count(input))
+ << "Input: " << input << " missing for op: " << op->outputs[0] << ".";
}
for (const auto& output : op->outputs) {
- CHECK(model.HasArray(output));
+ CHECK(model.HasArray(output)) << "Output: " << output << " missing.";
}
}
CheckNonExistentIOArrays(model);