aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite')
-rw-r--r--tensorflow/contrib/lite/build_def.bzl4
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h10
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h2
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc16
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc94
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md25
-rw-r--r--tensorflow/contrib/lite/interpreter.cc11
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc16
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java8
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java9
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java15
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java239
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java179
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/BUILD1
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc38
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc307
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h79
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc123
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.h40
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java8
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java236
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java131
-rw-r--r--tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java4
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD14
-rw-r--r--tensorflow/contrib/lite/kernels/arg_min_max.cc10
-rw-r--r--tensorflow/contrib/lite/kernels/arg_min_max_test.cc89
-rw-r--r--tensorflow/contrib/lite/kernels/fake_quant.cc81
-rw-r--r--tensorflow/contrib/lite/kernels/fake_quant_test.cc112
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc59
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/select.cc3
-rw-r--r--tensorflow/contrib/lite/kernels/select_test.cc13
-rw-r--r--tensorflow/contrib/lite/model.cc24
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc15
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs14
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h300
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py57
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc7
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc19
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc25
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc7
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc27
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc10
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc13
-rw-r--r--tensorflow/contrib/lite/toco/model.h12
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc18
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc50
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc7
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc1
-rw-r--r--tensorflow/contrib/lite/tools/BUILD1
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc3
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_model.h1
-rw-r--r--tensorflow/contrib/lite/tools/visualize.py17
53 files changed, 1712 insertions, 896 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 6e1dafefa9..b735d08b4b 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -195,7 +195,7 @@ def json_to_tflite(name, src, out):
def generated_test_models():
return [
"add",
- "arg_max",
+ "arg_min_max",
"avg_pool",
"batch_to_space_nd",
"concat",
@@ -232,7 +232,7 @@ def generated_test_models():
"not_equal",
"pad",
"padv2",
- # "prelu",
+ "prelu",
"pow",
"relu",
"relu1",
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index cda889bf50..a58dde9a7b 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -250,6 +250,10 @@ typedef struct {
} TfLiteArgMaxParams;
typedef struct {
+ TfLiteType output_type;
+} TfLiteArgMinParams;
+
+typedef struct {
TfLitePadding padding;
int stride_width;
int stride_height;
@@ -263,6 +267,12 @@ typedef struct {
TfLiteType out_type;
} TfLiteShapeParams;
+typedef struct {
+ float min;
+ float max;
+ int num_bits;
+} TfLiteFakeQuantParams;
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index a44e918230..6bde5d2e6d 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -104,6 +104,8 @@ typedef enum {
kTfLiteBuiltinRsqrt = 76,
kTfLiteBuiltinShape = 77,
kTfLiteBuiltinPow = 78,
+ kTfLiteBuiltinArgMin = 79,
+ kTfLiteBuiltinFakeQuant = 80,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
index fd798c209e..f0d16575ec 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
@@ -452,6 +452,22 @@ class NNAPIDelegateKernel {
} else {
return nullptr;
}
+ case kTfLiteBuiltinTranspose:
+ // Transpose requires NNAPI1.1. Also note that the permutation input
+ // tensor value dictates the output dimensions.
+ // TODO(b/110888333): Support dynamically-sized tensors in delegates.
+ if ((version == 1) &&
+ (kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) &&
+ (node->inputs->size > 1) &&
+ (context->tensors[node->inputs->data[1]].allocation_type ==
+ kTfLiteMmapRo)) {
+ return [](TfLiteContext* context, NNAPIOpBuilder* builder,
+ TfLiteNode* node) -> ANeuralNetworksOperationType {
+ return ANEURALNETWORKS_TRANSPOSE;
+ };
+ } else {
+ return nullptr;
+ }
break;
default:
return nullptr;
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
index aad10c9ce7..ab2181e8ff 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
@@ -27,14 +27,20 @@ using ::testing::ElementsAreArray;
// TODO(b/110368244): figure out how to share the existing tests in kernels/ but
// with the delegation on. Also, add more unit tests to improve code coverage.
-class FloatAddOpModel : public SingleOpModel {
+class SingleOpModelWithNNAPI : public SingleOpModel {
+ public:
+ SingleOpModelWithNNAPI() {
+ this->SetApplyDelegate([](Interpreter* interpreter) {
+ interpreter->ModifyGraphWithDelegate(NnApiDelegate(), false);
+ });
+ }
+};
+
+class FloatAddOpModel : public SingleOpModelWithNNAPI {
public:
FloatAddOpModel(const TensorData& input1, const TensorData& input2,
const TensorData& output,
ActivationFunctionType activation_type) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
input1_ = AddInput(input1);
input2_ = AddInput(input2);
output_ = AddOutput(output);
@@ -81,9 +87,6 @@ class FloatMulOpModel : public SingleOpModel {
FloatMulOpModel(const TensorData& input1, const TensorData& input2,
const TensorData& output,
ActivationFunctionType activation_type) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
input1_ = AddInput(input1);
input2_ = AddInput(input2);
output_ = AddOutput(output);
@@ -114,15 +117,11 @@ TEST(NNAPIDelegate, MulWithNoActivation) {
ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4})));
}
-class FloatPoolingOpModel : public SingleOpModel {
+class FloatPoolingOpModel : public SingleOpModelWithNNAPI {
public:
FloatPoolingOpModel(BuiltinOperator type, const TensorData& input,
int filter_width, int filter_height,
const TensorData& output) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
input_ = AddInput(input);
output_ = AddOutput(output);
@@ -193,10 +192,6 @@ class BaseConvolutionOpModel : public SingleOpModel {
enum Padding padding = Padding_VALID,
enum ActivationFunctionType activation = ActivationFunctionType_NONE,
int dilation_width_factor = 1, int dilation_height_factor = 1) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
input_ = AddInput(input);
filter_ = AddInput(filter);
@@ -344,14 +339,10 @@ TEST(NNAPIDelegate, Conv2DWithNoActivation) {
}));
}
-class DepthwiseConvolutionOpModel : public SingleOpModel {
+class DepthwiseConvolutionOpModel : public SingleOpModelWithNNAPI {
public:
DepthwiseConvolutionOpModel(const TensorData& input, const TensorData& filter,
const TensorData& output) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
input_ = AddInput(input);
filter_ = AddInput(filter);
@@ -426,15 +417,11 @@ TEST(NNAPIDelegate, DepthwiseConv2DWithNoActivation) {
}));
}
-class FloatFullyConnectedOpModel : public SingleOpModel {
+class FloatFullyConnectedOpModel : public SingleOpModelWithNNAPI {
public:
FloatFullyConnectedOpModel(int units, int batches, const TensorData& input,
const TensorData& output = {TensorType_FLOAT32})
: batches_(batches), units_(units) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
int total_input_size = 1;
for (int i = 0; i < input.shape.size(); ++i) {
total_input_size *= input.shape[i];
@@ -515,14 +502,10 @@ TEST(NNAPIDelegate, FullyConnectedSimpleTest) {
EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
}
-class SoftmaxOpModel : public SingleOpModel {
+class SoftmaxOpModel : public SingleOpModelWithNNAPI {
public:
SoftmaxOpModel(int batches, int size, float beta)
: batches_(batches), input_size_(size), beta_(beta) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
input_ = AddInput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions,
@@ -566,14 +549,10 @@ TEST(NNAPIDelegate, SoftmaxSimpleTest) {
1e-6)));
}
-class ReshapeOpModel : public SingleOpModel {
+class ReshapeOpModel : public SingleOpModelWithNNAPI {
public:
ReshapeOpModel(std::initializer_list<int> input_shape,
std::initializer_list<int> new_shape) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
input_ = AddInput(TensorType_FLOAT32);
new_shape_ = AddInput(TensorType_INT32);
output_ = AddOutput(TensorType_FLOAT32);
@@ -605,14 +584,10 @@ TEST(NNAPIDelegate, ReshapeSimpleTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
}
-class SqueezeOpModel : public SingleOpModel {
+class SqueezeOpModel : public SingleOpModelWithNNAPI {
public:
SqueezeOpModel(const TensorData& input, const TensorData& output,
std::initializer_list<int> axis) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
input_ = AddInput(input);
output_ = AddOutput(output);
SetBuiltinOp(
@@ -666,6 +641,43 @@ TEST(NNAPIDelegate, SqueezeWithAxisTest) {
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
}
+class TransposeSimpleModel : public SingleOpModelWithNNAPI {
+ public:
+ TransposeSimpleModel(std::initializer_list<int> input_shape,
+ std::initializer_list<int> perm_shape,
+ std::initializer_list<int> perm) {
+ input_ = AddInput(TensorType_FLOAT32);
+ perm_ = AddConstInput(TensorType_INT32, perm, perm_shape);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions,
+ CreateTransposeOptions(builder_).Union());
+ BuildInterpreter({input_shape, perm_shape});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor<float>(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input_;
+ int perm_;
+ int output_;
+};
+
+TEST(NNAPIDelegate, TransposeSimpleTest) {
+ TransposeSimpleModel m({2, 3, 4}, {3}, {2, 0, 1});
+ m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3}));
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21,
+ 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index dcd17bbeab..49d00a66ba 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -42,6 +42,7 @@ counterparts:
*as long as the input tensor is 4D (1 batch + 2 spatial + 1 other) and the
crops attribute is not used*
* [tf.exp](https://www.tensorflow.org/api_docs/python/tf/exp)
+* [tf.fake_quant*](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_args)
* [tf.matmul](https://www.tensorflow.org/api_docs/python/tf/matmul) - *as long
as the second argument is constant and transposition is not used*
* [tf.nn.avg_pool](https://www.tensorflow.org/api_docs/python/tf/nn/avg_pool)
@@ -790,6 +791,30 @@ Outputs {
}
```
+**ARG_MAX**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: A tensor of indices of maximum values.
+}
+```
+
+**ARG_MIN**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: A tensor of indices of minium values.
+}
+```
+
And these are TensorFlow Lite operations that are present but not ready for
custom models yet:
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 521216a4f1..0641a08636 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -441,6 +441,13 @@ TfLiteStatus Interpreter::AllocateTensors() {
TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
state_ = kStateInvokable;
+
+ // Reset the variable tensors to zero after (re)allocating the tensors.
+ // Developers shouldn't rely on the side effect of this function to reset
+ // variable tesnsors. They should call `ResetVariableTensorsToZero` directly
+ // instead.
+ ResetVariableTensorsToZero();
+
return kTfLiteOk;
}
@@ -565,6 +572,8 @@ TfLiteStatus Interpreter::PrepareOpsStartingAt(
nodes_and_registration_[node_index].second;
EnsureTensorsVectorCapacity();
if (OpPrepare(registration, &node) == kTfLiteError) {
+ context_.ReportError(&context_, "Node %d failed to prepare.\n",
+ node_index);
return kTfLiteError;
}
@@ -665,6 +674,8 @@ TfLiteStatus Interpreter::Invoke() {
EnsureTensorsVectorCapacity();
tensor_resized_since_op_invoke_ = false;
if (OpInvoke(registration, &node) == kTfLiteError) {
+ context_.ReportError(&context_, "Node %d failed to invoke.\n",
+ node_index);
status = kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 4fa97512fc..10119903fe 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -57,6 +57,22 @@ TEST(BasicInterpreter, InvokeInvalidModel) {
ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
}
+TEST(BasicInterpreter, TestAllocateTensorsResetVariableTensors) {
+ Interpreter interpreter;
+ int tensor_index;
+ ASSERT_EQ(interpreter.AddTensors(1, &tensor_index), kTfLiteOk);
+ constexpr int kTensorSize = 16;
+ interpreter.SetTensorParametersReadWrite(tensor_index, kTfLiteFloat32, "",
+ {kTensorSize}, {}, true);
+ interpreter.SetVariables({tensor_index});
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ TfLiteTensor* tensor = interpreter.tensor(tensor_index);
+ // Ensure that variable tensors are reset to zero.
+ for (int i = 0; i < kTensorSize; ++i) {
+ ASSERT_EQ(tensor->data.f[i], 0.0f);
+ }
+}
+
// Test size accessor functions.
TEST(BasicInterpreter, TestSizeFunctions) {
Interpreter interpreter;
diff --git a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
index 56f3e7604a..1587c3c56f 100644
--- a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
+++ b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
@@ -127,12 +127,8 @@ public final class OvicClassifierTest {
try {
testResult = classifier.classifyByteBuffer(testImage);
fail();
- } catch (RuntimeException e) {
- assertThat(e)
- .hasMessageThat()
- .contains(
- "Failed to get input dimensions. 0-th input should have 49152 bytes, "
- + "but found 150528 bytes.");
+ } catch (IllegalArgumentException e) {
+ // Success.
}
}
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
index 75334cd96e..94a1ec65d6 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
@@ -27,10 +27,7 @@ enum DataType {
UINT8(3),
/** 64-bit signed integer. */
- INT64(4),
-
- /** A {@link ByteBuffer}. */
- BYTEBUFFER(999);
+ INT64(4);
private final int value;
@@ -69,8 +66,6 @@ enum DataType {
return 1;
case INT64:
return 8;
- case BYTEBUFFER:
- return 1;
}
throw new IllegalArgumentException(
"DataType error: DataType " + this + " is not supported yet");
@@ -87,8 +82,6 @@ enum DataType {
return "byte";
case INT64:
return "long";
- case BYTEBUFFER:
- return "ByteBuffer";
}
throw new IllegalArgumentException(
"DataType error: DataType " + this + " is not supported yet");
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index 589fd6426f..7002f82677 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -165,20 +165,7 @@ public final class Interpreter implements AutoCloseable {
if (wrapper == null) {
throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
}
- Tensor[] tensors = wrapper.run(inputs);
- if (outputs == null || tensors == null || outputs.size() > tensors.length) {
- throw new IllegalArgumentException("Output error: Outputs do not match with model outputs.");
- }
- final int size = tensors.length;
- for (Integer idx : outputs.keySet()) {
- if (idx == null || idx < 0 || idx >= size) {
- throw new IllegalArgumentException(
- String.format(
- "Output error: Invalid index of output %d (should be in range [0, %d))",
- idx, size));
- }
- tensors[idx].copyTo(outputs.get(idx));
- }
+ wrapper.run(inputs, outputs);
}
/**
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index 80de88b6a1..767a220f8c 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -15,10 +15,10 @@ limitations under the License.
package org.tensorflow.lite;
-import java.lang.reflect.Array;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
@@ -40,6 +40,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
modelHandle = createModel(modelPath, errorHandle);
interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
isMemoryAllocated = true;
+ inputTensors = new Tensor[getInputCount(interpreterHandle)];
+ outputTensors = new Tensor[getOutputCount(interpreterHandle)];
}
/**
@@ -72,6 +74,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
isMemoryAllocated = true;
+ inputTensors = new Tensor[getInputCount(interpreterHandle)];
+ outputTensors = new Tensor[getOutputCount(interpreterHandle)];
}
/** Releases resources associated with this {@code NativeInterpreterWrapper}. */
@@ -85,75 +89,63 @@ final class NativeInterpreterWrapper implements AutoCloseable {
inputsIndexes = null;
outputsIndexes = null;
isMemoryAllocated = false;
+ Arrays.fill(inputTensors, null);
+ Arrays.fill(outputTensors, null);
}
/** Sets inputs, runs model inference and returns outputs. */
- Tensor[] run(Object[] inputs) {
+ void run(Object[] inputs, Map<Integer, Object> outputs) {
+ inferenceDurationNanoseconds = -1;
if (inputs == null || inputs.length == 0) {
throw new IllegalArgumentException("Input error: Inputs should not be null or empty.");
}
- int[] dataTypes = new int[inputs.length];
- Object[] sizes = new Object[inputs.length];
- int[] numsOfBytes = new int[inputs.length];
+ if (outputs == null || outputs.isEmpty()) {
+ throw new IllegalArgumentException("Input error: Outputs should not be null or empty.");
+ }
+
+ // TODO(b/80431971): Remove implicit resize after deprecating multi-dimensional array inputs.
+ // Rather than forcing an immediate resize + allocation if an input's shape differs, we first
+ // flush all resizes, avoiding redundant allocations.
for (int i = 0; i < inputs.length; ++i) {
- DataType dataType = dataTypeOf(inputs[i]);
- dataTypes[i] = dataType.getNumber();
- if (dataType == DataType.BYTEBUFFER) {
- ByteBuffer buffer = (ByteBuffer) inputs[i];
- if (buffer == null || !buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()) {
- throw new IllegalArgumentException(
- "Input error: ByteBuffer should be a direct ByteBuffer that uses "
- + "ByteOrder.nativeOrder().");
- }
- numsOfBytes[i] = buffer.limit();
- sizes[i] = getInputDims(interpreterHandle, i, numsOfBytes[i]);
- } else if (isNonEmptyArray(inputs[i])) {
- int[] dims = shapeOf(inputs[i]);
- sizes[i] = dims;
- numsOfBytes[i] = dataType.elemByteSize() * numElements(dims);
- } else {
- throw new IllegalArgumentException(
- String.format(
- "Input error: %d-th element of the %d inputs is not an array or a ByteBuffer.",
- i, inputs.length));
+ Tensor tensor = getInputTensor(i);
+ int[] newShape = tensor.getInputShapeIfDifferent(inputs[i]);
+ if (newShape != null) {
+ resizeInput(i, newShape);
}
}
- inferenceDurationNanoseconds = -1;
- long[] outputsHandles =
- run(
- interpreterHandle,
- errorHandle,
- sizes,
- dataTypes,
- numsOfBytes,
- inputs,
- this,
- isMemoryAllocated);
- if (outputsHandles == null || outputsHandles.length == 0) {
- throw new IllegalStateException("Internal error: Interpreter has no outputs.");
+
+ if (!isMemoryAllocated) {
+ allocateTensors(interpreterHandle, errorHandle);
+ isMemoryAllocated = true;
+ // Allocation can trigger dynamic resizing of output tensors, so clear the
+ // output tensor cache.
+ Arrays.fill(outputTensors, null);
}
- isMemoryAllocated = true;
- Tensor[] outputs = new Tensor[outputsHandles.length];
- for (int i = 0; i < outputsHandles.length; ++i) {
- outputs[i] = Tensor.fromHandle(outputsHandles[i]);
+
+ for (int i = 0; i < inputs.length; ++i) {
+ getInputTensor(i).setTo(inputs[i]);
+ }
+
+ long inferenceStartNanos = System.nanoTime();
+ run(interpreterHandle, errorHandle);
+ long inferenceDurationNanoseconds = System.nanoTime() - inferenceStartNanos;
+
+ for (Map.Entry<Integer, Object> output : outputs.entrySet()) {
+ getOutputTensor(output.getKey()).copyTo(output.getValue());
}
- return outputs;
+
+ // Only set if the entire operation succeeds.
+ this.inferenceDurationNanoseconds = inferenceDurationNanoseconds;
}
- private static native long[] run(
- long interpreterHandle,
- long errorHandle,
- Object[] sizes,
- int[] dtypes,
- int[] numsOfBytes,
- Object[] values,
- NativeInterpreterWrapper wrapper,
- boolean memoryAllocated);
+ private static native boolean run(long interpreterHandle, long errorHandle);
/** Resizes dimensions of a specific input. */
void resizeInput(int idx, int[] dims) {
if (resizeInput(interpreterHandle, errorHandle, idx, dims)) {
isMemoryAllocated = false;
+ // Resizing will invalidate the Tensor's shape, so invalidate the Tensor handle.
+ inputTensors[idx] = null;
}
}
@@ -212,78 +204,6 @@ final class NativeInterpreterWrapper implements AutoCloseable {
}
}
- static int numElements(int[] shape) {
- if (shape == null) {
- return 0;
- }
- int n = 1;
- for (int i = 0; i < shape.length; i++) {
- n *= shape[i];
- }
- return n;
- }
-
- static boolean isNonEmptyArray(Object o) {
- return (o != null && o.getClass().isArray() && Array.getLength(o) != 0);
- }
-
- /** Returns the type of the data. */
- static DataType dataTypeOf(Object o) {
- if (o != null) {
- Class<?> c = o.getClass();
- while (c.isArray()) {
- c = c.getComponentType();
- }
- if (float.class.equals(c)) {
- return DataType.FLOAT32;
- } else if (int.class.equals(c)) {
- return DataType.INT32;
- } else if (byte.class.equals(c)) {
- return DataType.UINT8;
- } else if (long.class.equals(c)) {
- return DataType.INT64;
- } else if (ByteBuffer.class.isInstance(o)) {
- return DataType.BYTEBUFFER;
- }
- }
- throw new IllegalArgumentException(
- "DataType error: cannot resolve DataType of " + o.getClass().getName());
- }
-
- /** Returns the shape of an object as an int array. */
- static int[] shapeOf(Object o) {
- int size = numDimensions(o);
- int[] dimensions = new int[size];
- fillShape(o, 0, dimensions);
- return dimensions;
- }
-
- static int numDimensions(Object o) {
- if (o == null || !o.getClass().isArray()) {
- return 0;
- }
- if (Array.getLength(o) == 0) {
- throw new IllegalArgumentException("Array lengths cannot be 0.");
- }
- return 1 + numDimensions(Array.get(o, 0));
- }
-
- static void fillShape(Object o, int dim, int[] shape) {
- if (shape == null || dim == shape.length) {
- return;
- }
- final int len = Array.getLength(o);
- if (shape[dim] == 0) {
- shape[dim] = len;
- } else if (shape[dim] != len) {
- throw new IllegalArgumentException(
- String.format("Mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim));
- }
- for (int i = 0; i < len; ++i) {
- fillShape(Array.get(o, i), dim + 1, shape);
- }
- }
-
/**
* Gets the last inference duration in nanoseconds. It returns null if there is no previous
* inference run or the last inference run failed.
@@ -293,40 +213,55 @@ final class NativeInterpreterWrapper implements AutoCloseable {
}
/**
- * Gets the dimensions of an input. It throws IllegalArgumentException if input index is invalid.
+ * Gets the quantization zero point of an output.
+ *
+ * @throws IllegalArgumentException if the output index is invalid.
*/
- int[] getInputDims(int index) {
- return getInputDims(interpreterHandle, index, -1);
+ int getOutputQuantizationZeroPoint(int index) {
+ return getOutputQuantizationZeroPoint(interpreterHandle, index);
}
/**
- * Gets the dimensions of an input. If numBytes >= 0, it will check whether num of bytes match the
- * input.
+ * Gets the quantization scale of an output.
+ *
+ * @throws IllegalArgumentException if the output index is invalid.
*/
- private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes);
-
- /** Gets the type of an output. It throws IllegalArgumentException if output index is invalid. */
- String getOutputDataType(int index) {
- int type = getOutputDataType(interpreterHandle, index);
- return DataType.fromNumber(type).toStringName();
+ float getOutputQuantizationScale(int index) {
+ return getOutputQuantizationScale(interpreterHandle, index);
}
/**
- * Gets the quantization zero point of an output.
+ * Gets the input {@link Tensor} for the provided input index.
*
- * @throws IllegalArgumentExeption if the output index is invalid.
+ * @throws IllegalArgumentException if the input index is invalid.
*/
- int getOutputQuantizationZeroPoint(int index) {
- return getOutputQuantizationZeroPoint(interpreterHandle, index);
+ Tensor getInputTensor(int index) {
+ if (index < 0 || index >= inputTensors.length) {
+ throw new IllegalArgumentException("Invalid input Tensor index: " + index);
+ }
+ Tensor inputTensor = inputTensors[index];
+ if (inputTensor == null) {
+ inputTensor =
+ inputTensors[index] = Tensor.fromHandle(getInputTensor(interpreterHandle, index));
+ }
+ return inputTensor;
}
/**
- * Gets the quantization scale of an output.
+ * Gets the output {@link Tensor} for the provided output index.
*
- * @throws IllegalArgumentExeption if the output index is invalid.
+ * @throws IllegalArgumentException if the output index is invalid.
*/
- float getOutputQuantizationScale(int index) {
- return getOutputQuantizationScale(interpreterHandle, index);
+ Tensor getOutputTensor(int index) {
+ if (index < 0 || index >= outputTensors.length) {
+ throw new IllegalArgumentException("Invalid output Tensor index: " + index);
+ }
+ Tensor outputTensor = outputTensors[index];
+ if (outputTensor == null) {
+ outputTensor =
+ outputTensors[index] = Tensor.fromHandle(getOutputTensor(interpreterHandle, index));
+ }
+ return outputTensor;
}
private static native int getOutputDataType(long interpreterHandle, int outputIdx);
@@ -343,18 +278,30 @@ final class NativeInterpreterWrapper implements AutoCloseable {
private long modelHandle;
- private int inputSize;
-
private long inferenceDurationNanoseconds = -1;
private ByteBuffer modelByteBuffer;
+ // Lazily constructed maps of input and output names to input and output Tensor indexes.
private Map<String, Integer> inputsIndexes;
-
private Map<String, Integer> outputsIndexes;
+ // Lazily constructed and populated arrays of input and output Tensor wrappers.
+ private final Tensor[] inputTensors;
+ private final Tensor[] outputTensors;
+
private boolean isMemoryAllocated = false;
+ private static native long allocateTensors(long interpreterHandle, long errorHandle);
+
+ private static native long getInputTensor(long interpreterHandle, int inputIdx);
+
+ private static native long getOutputTensor(long interpreterHandle, int outputIdx);
+
+ private static native int getInputCount(long interpreterHandle);
+
+ private static native int getOutputCount(long interpreterHandle);
+
private static native String[] getInputNames(long interpreterHandle);
private static native String[] getOutputNames(long interpreterHandle);
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
index b2a3e04c55..2403570c52 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
@@ -15,6 +15,7 @@ limitations under the License.
package org.tensorflow.lite;
+import java.lang.reflect.Array;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
@@ -31,43 +32,179 @@ final class Tensor {
return new Tensor(nativeHandle);
}
+ /** Returns the {@link DataType} of elements stored in the Tensor. */
+ public DataType dataType() {
+ return dtype;
+ }
+
+ /** Returns the size, in bytes, of the tensor data. */
+ public int numBytes() {
+ return numBytes(nativeHandle);
+ }
+
+ /**
+ * Returns the <a href="https://www.tensorflow.org/resources/dims_types.html#shape">shape</a> of
+ * the Tensor, i.e., the sizes of each dimension.
+ *
+ * @return an array where the i-th element is the size of the i-th dimension of the tensor.
+ */
+ public int[] shape() {
+ return shapeCopy;
+ }
+
+ /**
+ * Copies the contents of the provided {@code src} object to the Tensor.
+ *
+ * <p>The {@code src} should either be a (multi-dimensional) array with a shape matching that of
+ * this tensor, or a {@link ByteByffer} of compatible primitive type with a matching flat size.
+ *
+ * @throws IllegalArgumentException if the tensor is a scalar or if {@code src} is not compatible
+ * with the tensor (for example, mismatched data types or shapes).
+ */
+ void setTo(Object src) {
+ throwExceptionIfTypeIsIncompatible(src);
+ if (isByteBuffer(src)) {
+ ByteBuffer srcBuffer = (ByteBuffer) src;
+ // For direct ByteBuffer instances we support zero-copy. Note that this assumes the caller
+ // retains ownership of the source buffer until inference has completed.
+ if (srcBuffer.isDirect() && srcBuffer.order() == ByteOrder.nativeOrder()) {
+ writeDirectBuffer(nativeHandle, srcBuffer);
+ } else {
+ buffer().put(srcBuffer);
+ }
+ return;
+ }
+ writeMultiDimensionalArray(nativeHandle, src);
+ }
+
/**
* Copies the contents of the tensor to {@code dst} and returns {@code dst}.
*
* @param dst the destination buffer, either an explicitly-typed array or a {@link ByteBuffer}.
* @throws IllegalArgumentException if {@code dst} is not compatible with the tensor (for example,
* mismatched data types or shapes).
- * @throws BufferOverflowException If {@code dst} is a ByteBuffer with insufficient space for the
- * data in this tensor.
*/
- <T> T copyTo(T dst) {
+ Object copyTo(Object dst) {
+ throwExceptionIfTypeIsIncompatible(dst);
if (dst instanceof ByteBuffer) {
ByteBuffer dstByteBuffer = (ByteBuffer) dst;
dstByteBuffer.put(buffer());
return dst;
}
- if (NativeInterpreterWrapper.dataTypeOf(dst) != dtype) {
+ readMultiDimensionalArray(nativeHandle, dst);
+ return dst;
+ }
+
+ /** Returns the provided buffer's shape if specified and different from this Tensor's shape. */
+ // TODO(b/80431971): Remove this method after deprecating multi-dimensional array inputs.
+ int[] getInputShapeIfDifferent(Object input) {
+ // Implicit resizes based on ByteBuffer capacity isn't supported, so short-circuit that path.
+ // The ByteBuffer's size will be validated against this Tensor's size in {@link #setTo(Object)}.
+ if (isByteBuffer(input)) {
+ return null;
+ }
+ int[] inputShape = shapeOf(input);
+ if (Arrays.equals(shapeCopy, inputShape)) {
+ return null;
+ }
+ return inputShape;
+ }
+
+ /** Returns the type of the data. */
+ static DataType dataTypeOf(Object o) {
+ if (o != null) {
+ Class<?> c = o.getClass();
+ while (c.isArray()) {
+ c = c.getComponentType();
+ }
+ if (float.class.equals(c)) {
+ return DataType.FLOAT32;
+ } else if (int.class.equals(c)) {
+ return DataType.INT32;
+ } else if (byte.class.equals(c)) {
+ return DataType.UINT8;
+ } else if (long.class.equals(c)) {
+ return DataType.INT64;
+ }
+ }
+ throw new IllegalArgumentException(
+ "DataType error: cannot resolve DataType of " + o.getClass().getName());
+ }
+
+ /** Returns the shape of an object as an int array. */
+ static int[] shapeOf(Object o) {
+ int size = numDimensions(o);
+ int[] dimensions = new int[size];
+ fillShape(o, 0, dimensions);
+ return dimensions;
+ }
+
+ /** Returns the number of dimensions of a multi-dimensional array, otherwise 0. */
+ static int numDimensions(Object o) {
+ if (o == null || !o.getClass().isArray()) {
+ return 0;
+ }
+ if (Array.getLength(o) == 0) {
+ throw new IllegalArgumentException("Array lengths cannot be 0.");
+ }
+ return 1 + numDimensions(Array.get(o, 0));
+ }
+
+ /** Recursively populates the shape dimensions for a given (multi-dimensional) array. */
+ static void fillShape(Object o, int dim, int[] shape) {
+ if (shape == null || dim == shape.length) {
+ return;
+ }
+ final int len = Array.getLength(o);
+ if (shape[dim] == 0) {
+ shape[dim] = len;
+ } else if (shape[dim] != len) {
+ throw new IllegalArgumentException(
+ String.format("Mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim));
+ }
+ for (int i = 0; i < len; ++i) {
+ fillShape(Array.get(o, i), dim + 1, shape);
+ }
+ }
+
+ private void throwExceptionIfTypeIsIncompatible(Object o) {
+ if (isByteBuffer(o)) {
+ ByteBuffer oBuffer = (ByteBuffer) o;
+ if (oBuffer.capacity() != numBytes()) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Cannot convert between a TensorFlowLite buffer with %d bytes and a "
+ + "ByteBuffer with %d bytes.",
+ numBytes(), oBuffer.capacity()));
+ }
+ return;
+ }
+ DataType oType = dataTypeOf(o);
+ if (oType != dtype) {
throw new IllegalArgumentException(
String.format(
- "Output error: Cannot convert an TensorFlowLite tensor with type %s to a Java "
- + "object of type %s (which is compatible with the TensorFlowLite type %s)",
- dtype, dst.getClass().getName(), NativeInterpreterWrapper.dataTypeOf(dst)));
+ "Cannot convert between a TensorFlowLite tensor with type %s and a Java "
+ + "object of type %s (which is compatible with the TensorFlowLite type %s).",
+ dtype, o.getClass().getName(), oType));
}
- int[] dstShape = NativeInterpreterWrapper.shapeOf(dst);
- if (!Arrays.equals(dstShape, shapeCopy)) {
+
+ int[] oShape = shapeOf(o);
+ if (!Arrays.equals(oShape, shapeCopy)) {
throw new IllegalArgumentException(
String.format(
- "Output error: Shape of output target %s does not match with the shape of the "
- + "Tensor %s.",
- Arrays.toString(dstShape), Arrays.toString(shapeCopy)));
+ "Cannot copy between a TensorFlowLite tensor with shape %s and a Java object "
+ + "with shape %s.",
+ Arrays.toString(shapeCopy), Arrays.toString(oShape)));
}
- readMultiDimensionalArray(nativeHandle, dst);
- return dst;
}
- final long nativeHandle;
- final DataType dtype;
- final int[] shapeCopy;
+ private static boolean isByteBuffer(Object o) {
+ return o instanceof ByteBuffer;
+ }
+
+ private final long nativeHandle;
+ private final DataType dtype;
+ private final int[] shapeCopy;
private Tensor(long nativeHandle) {
this.nativeHandle = nativeHandle;
@@ -81,11 +218,17 @@ final class Tensor {
private static native ByteBuffer buffer(long handle);
+ private static native void writeDirectBuffer(long handle, ByteBuffer src);
+
private static native int dtype(long handle);
private static native int[] shape(long handle);
- private static native void readMultiDimensionalArray(long handle, Object value);
+ private static native int numBytes(long handle);
+
+ private static native void readMultiDimensionalArray(long handle, Object dst);
+
+ private static native void writeMultiDimensionalArray(long handle, Object src);
static {
TensorFlowLite.init();
diff --git a/tensorflow/contrib/lite/java/src/main/native/BUILD b/tensorflow/contrib/lite/java/src/main/native/BUILD
index 4399ed2025..4b4e1c21d8 100644
--- a/tensorflow/contrib/lite/java/src/main/native/BUILD
+++ b/tensorflow/contrib/lite/java/src/main/native/BUILD
@@ -11,7 +11,6 @@ licenses(["notice"]) # Apache 2.0
cc_library(
name = "native_framework_only",
srcs = [
- "duration_utils_jni.cc",
"exception_jni.cc",
"nativeinterpreterwrapper_jni.cc",
"tensor_jni.cc",
diff --git a/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc b/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc
deleted file mode 100644
index 0e08a04370..0000000000
--- a/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc
+++ /dev/null
@@ -1,38 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <jni.h>
-#include <time.h>
-
-namespace tflite {
-
-// Gets the elapsed wall-clock timespec.
-timespec getCurrentTime() {
- timespec time;
- clock_gettime(CLOCK_MONOTONIC, &time);
- return time;
-}
-
-// Computes the time diff from two timespecs. Returns '-1' if 'stop' is earlier
-// than 'start'.
-jlong timespec_diff_nanoseconds(struct timespec* start, struct timespec* stop) {
- jlong result = stop->tv_sec - start->tv_sec;
- if (result < 0) return -1;
- result = 1000000000 * result + (stop->tv_nsec - start->tv_nsec);
- if (result < 0) return -1;
- return result;
-}
-
-} // namespace tflite
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index 31f7b58fbc..e2c1edd9af 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -16,9 +16,6 @@ limitations under the License.
#include "tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h"
namespace {
-const int kByteBufferValue = 999;
-const int kBufferSize = 256;
-
tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) {
if (handle == 0) {
throwException(env, kIllegalArgumentException,
@@ -62,22 +59,6 @@ std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) {
return outputs;
}
-bool isByteBuffer(jint data_type) { return data_type == kByteBufferValue; }
-
-TfLiteType resolveDataType(jint data_type) {
- switch (data_type) {
- case 1:
- return kTfLiteFloat32;
- case 2:
- return kTfLiteInt32;
- case 3:
- return kTfLiteUInt8;
- case 4:
- return kTfLiteInt64;
- default:
- return kTfLiteNoType;
- }
-}
int getDataType(TfLiteType data_type) {
switch (data_type) {
@@ -108,64 +89,6 @@ void printDims(char* buffer, int max_size, int* dims, int num_dims) {
}
}
-TfLiteStatus checkInputs(JNIEnv* env, tflite::Interpreter* interpreter,
- const int input_size, jintArray data_types,
- jintArray nums_of_bytes, jobjectArray values,
- jobjectArray sizes) {
- if (input_size != interpreter->inputs().size()) {
- throwException(env, kIllegalArgumentException,
- "Input error: Expected num of inputs is %d but got %d",
- interpreter->inputs().size(), input_size);
- return kTfLiteError;
- }
- if (input_size != env->GetArrayLength(data_types) ||
- input_size != env->GetArrayLength(nums_of_bytes) ||
- input_size != env->GetArrayLength(values)) {
- throwException(env, kIllegalArgumentException,
- "Internal error: Arrays in arguments should be of the same "
- "length, but got %d sizes, %d data_types, %d nums_of_bytes, "
- "and %d values",
- input_size, env->GetArrayLength(data_types),
- env->GetArrayLength(nums_of_bytes),
- env->GetArrayLength(values));
- return kTfLiteError;
- }
- for (int i = 0; i < input_size; ++i) {
- int input_idx = interpreter->inputs()[i];
- TfLiteTensor* target = interpreter->tensor(input_idx);
- jintArray dims =
- static_cast<jintArray>(env->GetObjectArrayElement(sizes, i));
- int num_dims = static_cast<int>(env->GetArrayLength(dims));
- if (target->dims->size != num_dims) {
- throwException(env, kIllegalArgumentException,
- "Input error: %d-th input should have %d dimensions, but "
- "found %d dimensions",
- i, target->dims->size, num_dims);
- return kTfLiteError;
- }
- jint* ptr = env->GetIntArrayElements(dims, nullptr);
- for (int j = 1; j < num_dims; ++j) {
- if (target->dims->data[j] != ptr[j]) {
- std::unique_ptr<char[]> expected_dims(new char[kBufferSize]);
- std::unique_ptr<char[]> obtained_dims(new char[kBufferSize]);
- printDims(expected_dims.get(), kBufferSize, target->dims->data,
- num_dims);
- printDims(obtained_dims.get(), kBufferSize, ptr, num_dims);
- throwException(env, kIllegalArgumentException,
- "Input error: %d-th input dimension should be [%s], but "
- "found [%s]",
- i, expected_dims.get(), obtained_dims.get());
- env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT);
- return kTfLiteError;
- }
- }
- env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT);
- env->DeleteLocalRef(dims);
- if (env->ExceptionCheck()) return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
// Checks whether there is any difference between dimensions of a tensor and a
// given dimensions. Returns true if there is difference, else false.
bool areDimsDifferent(JNIEnv* env, TfLiteTensor* tensor, jintArray dims) {
@@ -188,74 +111,6 @@ bool areDimsDifferent(JNIEnv* env, TfLiteTensor* tensor, jintArray dims) {
return false;
}
-bool areInputDimensionsTheSame(JNIEnv* env, tflite::Interpreter* interpreter,
- int input_size, jobjectArray sizes) {
- if (interpreter->inputs().size() != input_size) {
- return false;
- }
- for (int i = 0; i < input_size; ++i) {
- int input_idx = interpreter->inputs()[i];
- jintArray dims =
- static_cast<jintArray>(env->GetObjectArrayElement(sizes, i));
- TfLiteTensor* target = interpreter->tensor(input_idx);
- if (areDimsDifferent(env, target, dims)) return false;
- env->DeleteLocalRef(dims);
- if (env->ExceptionCheck()) return false;
- }
- return true;
-}
-
-TfLiteStatus resizeInputs(JNIEnv* env, tflite::Interpreter* interpreter,
- int input_size, jobjectArray sizes) {
- for (int i = 0; i < input_size; ++i) {
- int input_idx = interpreter->inputs()[i];
- jintArray dims =
- static_cast<jintArray>(env->GetObjectArrayElement(sizes, i));
- TfLiteStatus status = interpreter->ResizeInputTensor(
- input_idx, convertJIntArrayToVector(env, dims));
- if (status != kTfLiteOk) {
- return status;
- }
- env->DeleteLocalRef(dims);
- if (env->ExceptionCheck()) return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus setInputs(JNIEnv* env, tflite::Interpreter* interpreter,
- int input_size, jintArray data_types,
- jintArray nums_of_bytes, jobjectArray values) {
- jint* data_type = env->GetIntArrayElements(data_types, nullptr);
- jint* num_bytes = env->GetIntArrayElements(nums_of_bytes, nullptr);
- for (int i = 0; i < input_size; ++i) {
- int input_idx = interpreter->inputs()[i];
- TfLiteTensor* target = interpreter->tensor(input_idx);
- jobject value = env->GetObjectArrayElement(values, i);
- bool is_byte_buffer = isByteBuffer(data_type[i]);
- if (is_byte_buffer) {
- writeByteBuffer(env, value, &(target->data.raw),
- static_cast<int>(num_bytes[i]));
- } else {
- TfLiteType type = resolveDataType(data_type[i]);
- if (type != target->type) {
- throwException(env, kIllegalArgumentException,
- "Input error: DataType (%d) of input data does not "
- "match with the DataType (%d) of model inputs.",
- type, target->type);
- return kTfLiteError;
- }
- writeMultiDimensionalArray(env, value, target->type, target->dims->size,
- &(target->data.raw),
- static_cast<int>(num_bytes[i]));
- }
- env->DeleteLocalRef(value);
- if (env->ExceptionCheck()) return kTfLiteError;
- }
- env->ReleaseIntArrayElements(data_types, data_type, JNI_ABORT);
- env->ReleaseIntArrayElements(nums_of_bytes, num_bytes, JNI_ABORT);
- return kTfLiteOk;
-}
-
// TODO(yichengfan): evaluate the benefit to use tflite verifier.
bool VerifyModel(const void* buf, size_t len) {
flatbuffers::Verifier verifier(static_cast<const uint8_t*>(buf), len);
@@ -287,6 +142,63 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env,
return names;
}
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
+ JNIEnv* env, jclass clazz, jlong handle, jlong error_handle) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return;
+ BufferErrorReporter* error_reporter =
+ convertLongToErrorReporter(env, error_handle);
+ if (error_reporter == nullptr) return;
+
+ if (interpreter->AllocateTensors() != kTfLiteOk) {
+ throwException(env, kNullPointerException,
+ "Internal error: Cannot allocate memory for the interpreter:"
+ " %s",
+ error_reporter->CachedErrorMessage());
+ }
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jint index) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return 0;
+ return reinterpret_cast<jlong>(
+ interpreter->tensor(interpreter->inputs()[index]));
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jint index) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return 0;
+ return reinterpret_cast<jlong>(
+ interpreter->tensor(interpreter->outputs()[index]));
+}
+
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return 0;
+ return static_cast<jint>(interpreter->inputs().size());
+}
+
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return 0;
+ return static_cast<jint>(interpreter->outputs().size());
+}
+
JNIEXPORT jobjectArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
jclass clazz,
@@ -434,114 +346,21 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
}
// Sets inputs, runs inference, and returns outputs as long handles.
-JNIEXPORT jlongArray JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
- JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
- jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes,
- jobjectArray values, jobject wrapper, jboolean memory_allocated) {
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) {
tflite::Interpreter* interpreter =
convertLongToInterpreter(env, interpreter_handle);
- if (interpreter == nullptr) return nullptr;
+ if (interpreter == nullptr) return;
BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
- if (error_reporter == nullptr) return nullptr;
- const int input_size = env->GetArrayLength(sizes);
- // validates inputs
- TfLiteStatus status = checkInputs(env, interpreter, input_size, data_types,
- nums_of_bytes, values, sizes);
- if (status != kTfLiteOk) return nullptr;
- if (!memory_allocated ||
- !areInputDimensionsTheSame(env, interpreter, input_size, sizes)) {
- // resizes inputs
- status = resizeInputs(env, interpreter, input_size, sizes);
- if (status != kTfLiteOk) {
- throwException(env, kNullPointerException,
- "Internal error: Can not resize the input: %s",
- error_reporter->CachedErrorMessage());
- return nullptr;
- }
- // allocates memory
- status = interpreter->AllocateTensors();
- if (status != kTfLiteOk) {
- throwException(env, kNullPointerException,
- "Internal error: Can not allocate memory for the given "
- "inputs: %s",
- error_reporter->CachedErrorMessage());
- return nullptr;
- }
- }
- // sets inputs
- status = setInputs(env, interpreter, input_size, data_types, nums_of_bytes,
- values);
- if (status != kTfLiteOk) return nullptr;
- timespec beforeInference = ::tflite::getCurrentTime();
- // runs inference
+ if (error_reporter == nullptr) return;
+
if (interpreter->Invoke() != kTfLiteOk) {
throwException(env, kIllegalArgumentException,
"Internal error: Failed to run on the given Interpreter: %s",
error_reporter->CachedErrorMessage());
- return nullptr;
- }
- timespec afterInference = ::tflite::getCurrentTime();
- jclass wrapper_clazz = env->GetObjectClass(wrapper);
- jfieldID fid =
- env->GetFieldID(wrapper_clazz, "inferenceDurationNanoseconds", "J");
- if (env->ExceptionCheck()) {
- env->ExceptionClear();
- } else if (fid != nullptr) {
- env->SetLongField(
- wrapper, fid,
- ::tflite::timespec_diff_nanoseconds(&beforeInference, &afterInference));
- }
- // returns outputs
- const std::vector<int>& results = interpreter->outputs();
- if (results.empty()) {
- throwException(
- env, kIllegalArgumentException,
- "Internal error: The Interpreter does not have any outputs.");
- return nullptr;
- }
- jlongArray outputs = env->NewLongArray(results.size());
- size_t size = results.size();
- for (int i = 0; i < size; ++i) {
- TfLiteTensor* source = interpreter->tensor(results[i]);
- jlong output = reinterpret_cast<jlong>(source);
- env->SetLongArrayRegion(outputs, i, 1, &output);
- }
- return outputs;
-}
-
-JNIEXPORT jintArray JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims(
- JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes) {
- tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
- if (interpreter == nullptr) return nullptr;
- const int idx = static_cast<int>(input_idx);
- if (input_idx < 0 || input_idx >= interpreter->inputs().size()) {
- throwException(env, kIllegalArgumentException,
- "Input error: Out of range: Failed to get %d-th input out of"
- " %d inputs",
- input_idx, interpreter->inputs().size());
- return nullptr;
- }
- TfLiteTensor* target = interpreter->tensor(interpreter->inputs()[idx]);
- int size = target->dims->size;
- if (num_bytes >= 0) { // verifies num of bytes matches if num_bytes if valid.
- int expected_num_bytes = elementByteSize(target->type);
- for (int i = 0; i < size; ++i) {
- expected_num_bytes *= target->dims->data[i];
- }
- if (num_bytes != expected_num_bytes) {
- throwException(env, kIllegalArgumentException,
- "Input error: Failed to get input dimensions. %d-th input "
- "should have %d bytes, but found %d bytes.",
- idx, expected_num_bytes, num_bytes);
- return nullptr;
- }
+ return;
}
- jintArray outputs = env->NewIntArray(size);
- env->SetIntArrayRegion(outputs, 0, size, &(target->dims->data[0]));
- return outputs;
}
JNIEXPORT jint JNICALL
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
index 128ece4981..618fba480e 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
@@ -29,9 +29,6 @@ limitations under the License.
namespace tflite {
// This is to be provided at link-time by a library.
extern std::unique_ptr<OpResolver> CreateOpResolver();
-extern timespec getCurrentTime();
-extern jlong timespec_diff_nanoseconds(struct timespec* start,
- struct timespec* stop);
} // namespace tflite
#ifdef __cplusplus
@@ -40,6 +37,57 @@ extern "C" {
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: allocateTensors
+ * Signature: (JJ)V
+ */
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
+ JNIEnv* env, jclass clazz, jlong handle, jlong error_handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: getInputTensor
+ * Signature: (JI)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jint index);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: getOutputTensor
+ * Signature: (JI)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jint index);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: getInputCount
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: getOutputCount
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
* Method:
* Signature: (J)[Ljava/lang/Object;
*/
@@ -118,28 +166,11 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
- * Method:
- * Signature:
- * (JJ[Ljava/lang/Object;[I[I[Ljava/lang/Object;Ljava/lang/Object;Z)[J
- */
-JNIEXPORT jlongArray JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
- JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
- jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes,
- jobjectArray values, jobject wrapper, jboolean memory_allocated);
-
-/*
- * Class: org_tensorflow_lite_NativeInterpreterWrapper
- * Method:
- * Signature: (JII)[I
- *
- * Gets input dimensions. If num_bytes is non-negative, it will check whether
- * num_bytes matches num of bytes required by the input, and return null and
- * throw IllegalArgumentException if not.
+ * Method: run
+ * Signature: (JJ)V
*/
-JNIEXPORT jintArray JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims(
- JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes);
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle);
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
index 08b4d04280..7ff96a3172 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
@@ -29,6 +29,35 @@ TfLiteTensor* convertLongToTensor(JNIEnv* env, jlong handle) {
return reinterpret_cast<TfLiteTensor*>(handle);
}
+size_t elementByteSize(TfLiteType data_type) {
+ // The code in this file makes the assumption that the
+ // TensorFlow TF_DataTypes and the Java primitive types
+ // have the same byte sizes. Validate that:
+ switch (data_type) {
+ case kTfLiteFloat32:
+ static_assert(sizeof(jfloat) == 4,
+ "Interal error: Java float not compatible with "
+ "kTfLiteFloat");
+ return 4;
+ case kTfLiteInt32:
+ static_assert(sizeof(jint) == 4,
+ "Interal error: Java int not compatible with kTfLiteInt");
+ return 4;
+ case kTfLiteUInt8:
+ static_assert(sizeof(jbyte) == 1,
+ "Interal error: Java byte not compatible with "
+ "kTfLiteUInt8");
+ return 1;
+ case kTfLiteInt64:
+ static_assert(sizeof(jlong) == 8,
+ "Interal error: Java long not compatible with "
+ "kTfLiteInt64");
+ return 8;
+ default:
+ return 0;
+ }
+}
+
size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type,
void* dst, size_t dst_size) {
jarray array = static_cast<jarray>(object);
@@ -141,48 +170,6 @@ size_t readMultiDimensionalArray(JNIEnv* env, TfLiteType data_type, char* src,
}
}
-} // namespace
-
-size_t elementByteSize(TfLiteType data_type) {
- // The code in this file makes the assumption that the
- // TensorFlow TF_DataTypes and the Java primitive types
- // have the same byte sizes. Validate that:
- switch (data_type) {
- case kTfLiteFloat32:
- static_assert(sizeof(jfloat) == 4,
- "Interal error: Java float not compatible with "
- "kTfLiteFloat");
- return 4;
- case kTfLiteInt32:
- static_assert(sizeof(jint) == 4,
- "Interal error: Java int not compatible with kTfLiteInt");
- return 4;
- case kTfLiteUInt8:
- static_assert(sizeof(jbyte) == 1,
- "Interal error: Java byte not compatible with "
- "kTfLiteUInt8");
- return 1;
- case kTfLiteInt64:
- static_assert(sizeof(jlong) == 8,
- "Interal error: Java long not compatible with "
- "kTfLiteInt64");
- return 8;
- default:
- return 0;
- }
-}
-
-size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size) {
- char* buf = static_cast<char*>(env->GetDirectBufferAddress(object));
- if (!buf) {
- throwException(env, kIllegalArgumentException,
- "Input ByteBuffer is not a direct buffer");
- return 0;
- }
- *dst = buf;
- return dst_size;
-}
-
size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
int dims_left, char** dst, int dst_size) {
if (dims_left <= 1) {
@@ -203,16 +190,37 @@ size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
}
}
+} // namespace
+
JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env,
jclass clazz,
jlong handle) {
TfLiteTensor* tensor = convertLongToTensor(env, handle);
if (tensor == nullptr) return nullptr;
-
+ if (tensor->data.raw == nullptr) {
+ throwException(env, kIllegalArgumentException,
+ "Internal error: Tensor hasn't been allocated.");
+ return nullptr;
+ }
return env->NewDirectByteBuffer(static_cast<void*>(tensor->data.raw),
static_cast<jlong>(tensor->bytes));
}
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer(
+ JNIEnv* env, jclass clazz, jlong handle, jobject src) {
+ TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ if (tensor == nullptr) return;
+
+ char* src_data_raw = static_cast<char*>(env->GetDirectBufferAddress(src));
+ if (!src_data_raw) {
+ throwException(env, kIllegalArgumentException,
+ "Input ByteBuffer is not a direct buffer");
+ return;
+ }
+
+ tensor->data.raw = src_data_raw;
+}
+
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
jclass clazz,
@@ -230,6 +238,27 @@ Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
num_dims, static_cast<jarray>(value));
}
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jobject src) {
+ TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ if (tensor == nullptr) return;
+ if (tensor->data.raw == nullptr) {
+ throwException(env, kIllegalArgumentException,
+ "Internal error: Target Tensor hasn't been allocated.");
+ return;
+ }
+ if (tensor->dims->size == 0) {
+ throwException(env, kIllegalArgumentException,
+ "Internal error: Cannot copy empty/scalar Tensors.");
+ return;
+ }
+ writeMultiDimensionalArray(env, src, tensor->type, tensor->dims->size,
+ &tensor->data.raw, tensor->bytes);
+}
+
JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
jclass clazz,
jlong handle) {
@@ -247,3 +276,11 @@ Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) {
env->SetIntArrayRegion(result, 0, num_dims, tensor->dims->data);
return result;
}
+
+JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ const TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ if (tensor == nullptr) return 0;
+ return static_cast<jint>(tensor->bytes);
+}
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
index 9ba95d9ac4..06e2546af8 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
@@ -34,6 +34,14 @@ JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env,
/*
* Class: org_tensorflow_lite_Tensor
+ * Method: writeDirectBuffer
+ * Signature: (JLjava/nio/ByteBuffer;)
+ */
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer(
+ JNIEnv* env, jclass clazz, jlong handle, jobject src);
+
+/*
+ * Class: org_tensorflow_lite_Tensor
* Method: dtype
* Signature: (J)I
*/
@@ -52,6 +60,15 @@ JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env,
/*
* Class: org_tensorflow_lite_Tensor
+ * Method: numBytes
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_Tensor
* Method: readMultiDimensionalArray
* Signature: (JLjava/lang/Object;)
*/
@@ -59,23 +76,18 @@ JNIEXPORT void JNICALL
Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
jclass clazz,
jlong handle,
- jobject value);
+ jobject dst);
/*
- * Finds the size of each data type.
- */
-size_t elementByteSize(TfLiteType data_type);
-
-/*
- * Writes data of a ByteBuffer into dest.
- */
-size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size);
-
-/*
- * Writes a multi-dimensional array into dest.
+ * Class: org_tensorflow_lite_Tensor
+ * Method: writeMultidimensionalArray
+ * Signature: (JLjava/lang/Object;)
*/
-size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
- int dims_left, char** dst, int dst_size);
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jobject src);
#ifdef __cplusplus
} // extern "C"
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index 42096ef9a3..d66a73db94 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -221,7 +221,9 @@ public final class InterpreterTest {
assertThat(e)
.hasMessageThat()
.contains(
- "DataType (2) of input data does not match with the DataType (1) of model inputs.");
+ "Cannot convert between a TensorFlowLite tensor with type "
+ + "FLOAT32 and a Java object of type [[[[I (which is compatible with the"
+ + " TensorFlowLite type INT32)");
}
interpreter.close();
}
@@ -241,8 +243,8 @@ public final class InterpreterTest {
assertThat(e)
.hasMessageThat()
.contains(
- "Cannot convert an TensorFlowLite tensor with type "
- + "FLOAT32 to a Java object of type [[[[I (which is compatible with the"
+ "Cannot convert between a TensorFlowLite tensor with type "
+ + "FLOAT32 and a Java object of type [[[[I (which is compatible with the"
+ " TensorFlowLite type INT32)");
}
interpreter.close();
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
index 029e5853e2..9c4a5acd79 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
@@ -20,6 +20,8 @@ import static org.junit.Assert.fail;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+import java.util.HashMap;
+import java.util.Map;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -101,10 +103,10 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
float[][][][] parsedOutputs = new float[2][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, -19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
@@ -119,11 +121,11 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs).hasLength(1);
ByteBuffer parsedOutput =
ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
- outputs[0].copyTo(parsedOutput);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutput);
+ wrapper.run(inputs, outputs);
float[] outputOneD = {
parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
};
@@ -140,17 +142,16 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
float[][][][] parsedOutputs = new float[2][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, -19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
- outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
parsedOutputs = new float[2][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
outputOneD = parsedOutputs[0][0][0];
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
wrapper.close();
@@ -164,10 +165,10 @@ public final class NativeInterpreterWrapperTest {
int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
int[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
int[][][][] parsedOutputs = new int[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
int[] outputOneD = parsedOutputs[0][0][0];
int[] expected = {3, 7, -4, 3, 7, -4, 3, 7, -4, 3, 7, -4};
assertThat(outputOneD).isEqualTo(expected);
@@ -182,10 +183,10 @@ public final class NativeInterpreterWrapperTest {
long[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
long[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
long[][][][] parsedOutputs = new long[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
long[] outputOneD = parsedOutputs[0][0][0];
long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L,
-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L};
@@ -203,10 +204,10 @@ public final class NativeInterpreterWrapperTest {
Object[] inputs = {fourD};
int[] inputDims = {2, 8, 8, 3};
wrapper.resizeInput(0, inputDims);
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
byte[][][][] parsedOutputs = new byte[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
byte[] outputOneD = parsedOutputs[0][0][0];
byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0};
@@ -229,13 +230,14 @@ public final class NativeInterpreterWrapperTest {
}
}
}
+ bbuf.rewind();
Object[] inputs = {bbuf};
int[] inputDims = {2, 8, 8, 3};
wrapper.resizeInput(0, inputDims);
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
byte[][][][] parsedOutputs = new byte[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
byte[] outputOneD = parsedOutputs[0][0][0];
byte[] expected = {
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
@@ -261,21 +263,22 @@ public final class NativeInterpreterWrapperTest {
}
}
Object[] inputs = {bbuf};
+ float[][][][] parsedOutputs = new float[4][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
- "Failed to get input dimensions. 0-th input should have 768 bytes, but found 3072 bytes");
+ "Cannot convert between a TensorFlowLite buffer with 768 bytes and a "
+ + "ByteBuffer with 3072 bytes.");
}
int[] inputDims = {4, 8, 8, 3};
wrapper.resizeInput(0, inputDims);
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
- float[][][][] parsedOutputs = new float[4][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ wrapper.run(inputs, outputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, -19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
@@ -288,14 +291,18 @@ public final class NativeInterpreterWrapperTest {
ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 7 * 8 * 3);
bbuf.order(ByteOrder.nativeOrder());
Object[] inputs = {bbuf};
+ Map<Integer, Object> outputs = new HashMap<>();
+ ByteBuffer parsedOutput = ByteBuffer.allocateDirect(2 * 7 * 8 * 3);
+ outputs.put(0, parsedOutput);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
- "Failed to get input dimensions. 0-th input should have 192 bytes, but found 336 bytes.");
+ "Cannot convert between a TensorFlowLite buffer with 192 bytes and a "
+ + "ByteBuffer with 336 bytes.");
}
wrapper.close();
}
@@ -308,14 +315,18 @@ public final class NativeInterpreterWrapperTest {
int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
int[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ int[][][][] parsedOutputs = new int[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
- "DataType (2) of input data does not match with the DataType (1) of model inputs.");
+ "Cannot convert between a TensorFlowLite tensor with type FLOAT32 and a Java object "
+ + "of type [[[[I (which is compatible with the TensorFlowLite type INT32)");
}
wrapper.close();
}
@@ -329,8 +340,11 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("Invalid handle to Interpreter.");
@@ -342,7 +356,7 @@ public final class NativeInterpreterWrapperTest {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
try {
Object[] inputs = {};
- wrapper.run(inputs);
+ wrapper.run(inputs, null);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("Inputs should not be null or empty.");
@@ -358,11 +372,14 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD, fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Expected num of inputs is 1 but got 2");
+ assertThat(e).hasMessageThat().contains("Invalid input Tensor index: 1");
}
wrapper.close();
}
@@ -374,13 +391,18 @@ public final class NativeInterpreterWrapperTest {
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
Object[] inputs = {threeD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
- .contains("0-th input should have 4 dimensions, but found 3 dimensions");
+ .contains(
+ "Cannot copy between a TensorFlowLite tensor with shape [8, 7, 3] and a "
+ + "Java object with shape [2, 8, 8, 3].");
}
wrapper.close();
}
@@ -393,92 +415,23 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
- .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]");
+ .contains(
+ "Cannot copy between a TensorFlowLite tensor with shape [2, 8, 7, 3] and a "
+ + "Java object with shape [2, 8, 8, 3].");
}
wrapper.close();
}
@Test
- public void testNumElements() {
- int[] shape = {2, 3, 4};
- int num = NativeInterpreterWrapper.numElements(shape);
- assertThat(num).isEqualTo(24);
- shape = null;
- num = NativeInterpreterWrapper.numElements(shape);
- assertThat(num).isEqualTo(0);
- }
-
- @Test
- public void testIsNonEmtpyArray() {
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(null)).isFalse();
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(3.2)).isFalse();
- int[] emptyArray = {};
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(emptyArray)).isFalse();
- int[] validArray = {9, 5, 2, 1};
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(validArray)).isTrue();
- }
-
- @Test
- public void testDataTypeOf() {
- float[] testEmtpyArray = {};
- DataType dataType = NativeInterpreterWrapper.dataTypeOf(testEmtpyArray);
- assertThat(dataType).isEqualTo(DataType.FLOAT32);
- float[] testFloatArray = {0.783f, 0.251f};
- dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray);
- assertThat(dataType).isEqualTo(DataType.FLOAT32);
- float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray};
- dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray);
- assertThat(dataType).isEqualTo(DataType.FLOAT32);
- try {
- double[] testDoubleArray = {0.783, 0.251};
- NativeInterpreterWrapper.dataTypeOf(testDoubleArray);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("cannot resolve DataType of");
- }
- try {
- Float[] testBoxedArray = {0.783f, 0.251f};
- NativeInterpreterWrapper.dataTypeOf(testBoxedArray);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;");
- }
- }
-
- @Test
- public void testNumDimensions() {
- int scalar = 1;
- assertThat(NativeInterpreterWrapper.numDimensions(scalar)).isEqualTo(0);
- int[][] array = {{2, 4}, {1, 9}};
- assertThat(NativeInterpreterWrapper.numDimensions(array)).isEqualTo(2);
- try {
- int[] emptyArray = {};
- NativeInterpreterWrapper.numDimensions(emptyArray);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Array lengths cannot be 0.");
- }
- }
-
- @Test
- public void testFillShape() {
- int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}};
- int num = NativeInterpreterWrapper.numDimensions(array);
- int[] shape = new int[num];
- NativeInterpreterWrapper.fillShape(array, 0, shape);
- assertThat(num).isEqualTo(3);
- assertThat(shape[0]).isEqualTo(2);
- assertThat(shape[1]).isEqualTo(3);
- assertThat(shape[2]).isEqualTo(1);
- }
-
- @Test
public void testGetInferenceLatency() {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
float[] oneD = {1.23f, 6.54f, 7.81f};
@@ -486,8 +439,10 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isGreaterThan(0L);
wrapper.close();
}
@@ -507,13 +462,14 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
- assertThat(e)
- .hasMessageThat()
- .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]");
+ // Expected.
}
assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isNull();
wrapper.close();
@@ -523,41 +479,7 @@ public final class NativeInterpreterWrapperTest {
public void testGetInputDims() {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
int[] expectedDims = {1, 8, 8, 3};
- assertThat(wrapper.getInputDims(0)).isEqualTo(expectedDims);
- wrapper.close();
- }
-
- @Test
- public void testGetInputDimsOutOfRange() {
- NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
- try {
- wrapper.getInputDims(-1);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Out of range");
- }
- try {
- wrapper.getInputDims(1);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Out of range");
- }
- wrapper.close();
- }
-
- @Test
- public void testGetOutputDataType() {
- NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("float");
- wrapper.close();
- wrapper = new NativeInterpreterWrapper(LONG_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("long");
- wrapper.close();
- wrapper = new NativeInterpreterWrapper(INT_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("int");
- wrapper.close();
- wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("byte");
+ assertThat(wrapper.getInputTensor(0).shape()).isEqualTo(expectedDims);
wrapper.close();
}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
index dd9d37eeda..71ef044943 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
@@ -18,9 +18,10 @@ package org.tensorflow.lite;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
-import java.nio.BufferOverflowException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+import java.util.HashMap;
+import java.util.Map;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -35,7 +36,7 @@ public final class TensorTest {
"tensorflow/contrib/lite/java/src/testdata/add.bin";
private NativeInterpreterWrapper wrapper;
- private long nativeHandle;
+ private Tensor tensor;
@Before
public void setUp() {
@@ -45,8 +46,10 @@ public final class TensorTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- nativeHandle = outputs[0].nativeHandle;
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, new float[2][8][8][3]);
+ wrapper.run(inputs, outputs);
+ tensor = wrapper.getOutputTensor(0);
}
@After
@@ -55,17 +58,16 @@ public final class TensorTest {
}
@Test
- public void testFromHandle() throws Exception {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
+ public void testBasic() throws Exception {
assertThat(tensor).isNotNull();
int[] expectedShape = {2, 8, 8, 3};
- assertThat(tensor.shapeCopy).isEqualTo(expectedShape);
- assertThat(tensor.dtype).isEqualTo(DataType.FLOAT32);
+ assertThat(tensor.shape()).isEqualTo(expectedShape);
+ assertThat(tensor.dataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(tensor.numBytes()).isEqualTo(2 * 8 * 8 * 3 * 4);
}
@Test
public void testCopyTo() {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
float[][][][] parsedOutputs = new float[2][8][8][3];
tensor.copyTo(parsedOutputs);
float[] outputOneD = parsedOutputs[0][0][0];
@@ -75,7 +77,6 @@ public final class TensorTest {
@Test
public void testCopyToByteBuffer() {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
ByteBuffer parsedOutput =
ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
tensor.copyTo(parsedOutput);
@@ -89,19 +90,17 @@ public final class TensorTest {
@Test
public void testCopyToInvalidByteBuffer() {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
ByteBuffer parsedOutput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());
try {
tensor.copyTo(parsedOutput);
fail();
- } catch (BufferOverflowException e) {
+ } catch (IllegalArgumentException e) {
// Expected.
}
}
@Test
public void testCopyToWrongType() {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
int[][][][] parsedOutputs = new int[2][8][8][3];
try {
tensor.copyTo(parsedOutputs);
@@ -110,15 +109,13 @@ public final class TensorTest {
assertThat(e)
.hasMessageThat()
.contains(
- "Cannot convert an TensorFlowLite tensor with type "
- + "FLOAT32 to a Java object of type [[[[I (which is compatible with the TensorFlowLite "
- + "type INT32)");
+ "Cannot convert between a TensorFlowLite tensor with type FLOAT32 and a Java object "
+ + "of type [[[[I (which is compatible with the TensorFlowLite type INT32)");
}
}
@Test
public void testCopyToWrongShape() {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
float[][][][] parsedOutputs = new float[1][8][8][3];
try {
tensor.copyTo(parsedOutputs);
@@ -127,8 +124,104 @@ public final class TensorTest {
assertThat(e)
.hasMessageThat()
.contains(
- "Shape of output target [1, 8, 8, 3] does not match "
- + "with the shape of the Tensor [2, 8, 8, 3].");
+ "Cannot copy between a TensorFlowLite tensor with shape [2, 8, 8, 3] "
+ + "and a Java object with shape [1, 8, 8, 3].");
+ }
+ }
+
+ @Test
+ public void testSetTo() {
+ float[][][][] input = new float[2][8][8][3];
+ float[][][][] output = new float[2][8][8][3];
+ ByteBuffer inputByteBuffer =
+ ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
+
+ input[0][0][0][0] = 2.0f;
+ tensor.setTo(input);
+ tensor.copyTo(output);
+ assertThat(output[0][0][0][0]).isEqualTo(2.0f);
+
+ inputByteBuffer.putFloat(0, 3.0f);
+ tensor.setTo(inputByteBuffer);
+ tensor.copyTo(output);
+ assertThat(output[0][0][0][0]).isEqualTo(3.0f);
+ }
+
+ @Test
+ public void testSetToInvalidByteBuffer() {
+ ByteBuffer input = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());
+ try {
+ tensor.setTo(input);
+ fail();
+ } catch (IllegalArgumentException e) {
+ // Success.
+ }
+ }
+
+ @Test
+ public void testGetInputShapeIfDifferent() {
+ ByteBuffer bytBufferInput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());
+ assertThat(tensor.getInputShapeIfDifferent(bytBufferInput)).isNull();
+
+ float[][][][] sameShapeInput = new float[2][8][8][3];
+ assertThat(tensor.getInputShapeIfDifferent(sameShapeInput)).isNull();
+
+ float[][][][] differentShapeInput = new float[1][8][8][3];
+ assertThat(tensor.getInputShapeIfDifferent(differentShapeInput))
+ .isEqualTo(new int[] {1, 8, 8, 3});
+ }
+
+ @Test
+ public void testDataTypeOf() {
+ float[] testEmptyArray = {};
+ DataType dataType = Tensor.dataTypeOf(testEmptyArray);
+ assertThat(dataType).isEqualTo(DataType.FLOAT32);
+ float[] testFloatArray = {0.783f, 0.251f};
+ dataType = Tensor.dataTypeOf(testFloatArray);
+ assertThat(dataType).isEqualTo(DataType.FLOAT32);
+ float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray};
+ dataType = Tensor.dataTypeOf(testFloatArray);
+ assertThat(dataType).isEqualTo(DataType.FLOAT32);
+ try {
+ double[] testDoubleArray = {0.783, 0.251};
+ Tensor.dataTypeOf(testDoubleArray);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("cannot resolve DataType of");
+ }
+ try {
+ Float[] testBoxedArray = {0.783f, 0.251f};
+ Tensor.dataTypeOf(testBoxedArray);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;");
}
}
+
+ @Test
+ public void testNumDimensions() {
+ int scalar = 1;
+ assertThat(Tensor.numDimensions(scalar)).isEqualTo(0);
+ int[][] array = {{2, 4}, {1, 9}};
+ assertThat(Tensor.numDimensions(array)).isEqualTo(2);
+ try {
+ int[] emptyArray = {};
+ Tensor.numDimensions(emptyArray);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("Array lengths cannot be 0.");
+ }
+ }
+
+ @Test
+ public void testFillShape() {
+ int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}};
+ int num = Tensor.numDimensions(array);
+ int[] shape = new int[num];
+ Tensor.fillShape(array, 0, shape);
+ assertThat(num).isEqualTo(3);
+ assertThat(shape[0]).isEqualTo(2);
+ assertThat(shape[1]).isEqualTo(3);
+ assertThat(shape[2]).isEqualTo(1);
+ }
}
diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
index 3aef0c3bb6..c23521c077 100644
--- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
+++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
@@ -58,7 +58,7 @@ public class TestHelper {
*/
public static int[] getInputDims(Interpreter interpreter, int index) {
if (interpreter != null && interpreter.wrapper != null) {
- return interpreter.wrapper.getInputDims(index);
+ return interpreter.wrapper.getInputTensor(index).shape();
} else {
throw new IllegalArgumentException(
"Interpreter has not initialized;" + " Failed to get input dimensions.");
@@ -77,7 +77,7 @@ public class TestHelper {
*/
public static String getOutputDataType(Interpreter interpreter, int index) {
if (interpreter != null && interpreter.wrapper != null) {
- return interpreter.wrapper.getOutputDataType(index);
+ return interpreter.wrapper.getOutputTensor(index).dataType().toStringName();
} else {
throw new IllegalArgumentException(
"Interpreter has not initialized;" + " Failed to get output data type.");
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index e749edc5ee..edce73989c 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -155,6 +155,7 @@ cc_library(
"embedding_lookup_sparse.cc",
"exp.cc",
"expand_dims.cc",
+ "fake_quant.cc",
"floor.cc",
"fully_connected.cc",
"gather.cc",
@@ -564,6 +565,19 @@ tf_cc_test(
)
tf_cc_test(
+ name = "fake_quant_test",
+ size = "small",
+ srcs = ["fake_quant_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
name = "maximum_minimum_test",
size = "small",
srcs = ["maximum_minimum_test.cc"],
diff --git a/tensorflow/contrib/lite/kernels/arg_min_max.cc b/tensorflow/contrib/lite/kernels/arg_min_max.cc
index 2e2ec94fab..4f30d09030 100644
--- a/tensorflow/contrib/lite/kernels/arg_min_max.cc
+++ b/tensorflow/contrib/lite/kernels/arg_min_max.cc
@@ -177,6 +177,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
return kTfLiteOk;
}
+TfLiteStatus ArgMinEval(TfLiteContext* context, TfLiteNode* node) {
+ return Eval(context, node, false);
+}
+
TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
return Eval(context, node, true);
}
@@ -189,6 +193,12 @@ TfLiteRegistration* Register_ARG_MAX() {
return &r;
}
+TfLiteRegistration* Register_ARG_MIN() {
+ static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
+ arg_min_max::ArgMinEval};
+ return &r;
+}
+
} // namespace builtin
} // namespace ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/arg_min_max_test.cc b/tensorflow/contrib/lite/kernels/arg_min_max_test.cc
index 31b15fe19a..90e5fdc532 100644
--- a/tensorflow/contrib/lite/kernels/arg_min_max_test.cc
+++ b/tensorflow/contrib/lite/kernels/arg_min_max_test.cc
@@ -24,16 +24,13 @@ namespace {
using ::testing::ElementsAreArray;
template <typename T>
-class ArgMaxOpModel : public SingleOpModel {
+class ArgBaseOpModel : public SingleOpModel {
public:
- ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
- TensorType output_type, TensorType index_output_type) {
+ ArgBaseOpModel(std::initializer_list<int> input_shape, TensorType input_type,
+ TensorType output_type, TensorType index_output_type) {
input_ = AddInput(input_type);
axis_ = AddInput(TensorType_INT32);
output_ = AddOutput(output_type);
- SetBuiltinOp(BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
- CreateArgMaxOptions(builder_, index_output_type).Union());
- BuildInterpreter({input_shape, {1, 1, 1, 1}});
}
int input() { return input_; }
@@ -42,12 +39,42 @@ class ArgMaxOpModel : public SingleOpModel {
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
- private:
+ protected:
int input_;
int axis_;
int output_;
};
+template <typename T>
+class ArgMaxOpModel : public ArgBaseOpModel<T> {
+ public:
+ ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
+ TensorType output_type, TensorType index_output_type)
+ : ArgBaseOpModel<T>(input_shape, input_type, output_type,
+ index_output_type) {
+ ArgBaseOpModel<T>::SetBuiltinOp(
+ BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
+ CreateArgMaxOptions(ArgBaseOpModel<T>::builder_, index_output_type)
+ .Union());
+ ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
+ }
+};
+
+template <typename T>
+class ArgMinOpModel : public ArgBaseOpModel<T> {
+ public:
+ ArgMinOpModel(std::initializer_list<int> input_shape, TensorType input_type,
+ TensorType output_type, TensorType index_output_type)
+ : ArgBaseOpModel<T>(input_shape, input_type, output_type,
+ index_output_type) {
+ ArgBaseOpModel<T>::SetBuiltinOp(
+ BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions,
+ CreateArgMinOptions(ArgBaseOpModel<T>::builder_, index_output_type)
+ .Union());
+ ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
+ }
+};
+
TEST(ArgMaxOpTest, GetMaxArgFloat) {
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32,
TensorType_INT32, TensorType_INT32);
@@ -96,6 +123,54 @@ TEST(ArgMaxOpTest, GetMaxArgOutput64) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1}));
}
+TEST(ArgMinOpTest, GetMinArgFloat) {
+ ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32,
+ TensorType_INT32, TensorType_INT32);
+ model.PopulateTensor<float>(model.input(), {0.1, 0.9, 0.7, 0.3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1}));
+}
+
+TEST(ArgMinOpTest, GetMinArgInt) {
+ ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_INT32, TensorType_INT32,
+ TensorType_INT32);
+ model.PopulateTensor<int>(model.input(), {1, 9, 7, 3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1}));
+}
+
+TEST(ArgMinOpTest, GetMinArgMulDimensions) {
+ ArgMinOpModel<int32_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
+ TensorType_INT32);
+ model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1}));
+}
+
+TEST(ArgMinOpTest, GetMinArgOutput64) {
+ ArgMinOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT64,
+ TensorType_INT64);
+ model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/fake_quant.cc b/tensorflow/contrib/lite/kernels/fake_quant.cc
new file mode 100644
index 0000000000..f8927a0799
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/fake_quant.cc
@@ -0,0 +1,81 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <vector>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace fake_quant {
+
+// This file has reference implementation of FakeQuant.
+enum KernelType {
+ kReference,
+};
+
+struct OpContext {
+ OpContext(TfLiteContext* context, TfLiteNode* node) {
+ input = GetInput(context, node, 0);
+ output = GetOutput(context, node, 0);
+ }
+ const TfLiteTensor* input;
+ TfLiteTensor* output;
+};
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ OpContext op_context(context, node);
+ TfLiteIntArray* output_dims = TfLiteIntArrayCopy(op_context.input->dims);
+ op_context.output->type = op_context.input->type;
+ return context->ResizeTensor(context, op_context.output, output_dims);
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ OpContext op_context(context, node);
+
+ const auto* params =
+ reinterpret_cast<TfLiteFakeQuantParams*>(node->builtin_data);
+
+ reference_ops::FakeQuant(GetTensorData<float>(op_context.input),
+ GetTensorDims(op_context.input), params->min,
+ params->max, params->num_bits,
+ GetTensorData<float>(op_context.output),
+ GetTensorDims(op_context.output));
+
+ return kTfLiteOk;
+}
+
+} // namespace fake_quant
+
+TfLiteRegistration* Register_FAKE_QUANT_REF() {
+ static TfLiteRegistration r = {nullptr, nullptr, fake_quant::Prepare,
+ fake_quant::Eval<fake_quant::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_FAKE_QUANT() { return Register_FAKE_QUANT_REF(); }
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/fake_quant_test.cc b/tensorflow/contrib/lite/kernels/fake_quant_test.cc
new file mode 100644
index 0000000000..11a02f7ed7
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/fake_quant_test.cc
@@ -0,0 +1,112 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class FakeQuantOpModel : public SingleOpModel {
+ public:
+ FakeQuantOpModel(const TensorData& input, const TensorType& output, float min,
+ float max, int num_bits) {
+ input_ = AddInput(input);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_FAKE_QUANT, BuiltinOptions_FakeQuantOptions,
+ CreateFakeQuantOptions(builder_, min, max, num_bits).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ template <class T>
+ void SetInput(std::initializer_list<T> data) {
+ PopulateTensor(input_, data);
+ }
+
+ template <class T>
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ protected:
+ int input_;
+ int output_;
+};
+
+TEST(FakeQuantOpTest, FloatPositiveRange8Test) {
+ std::initializer_list<float> data = {0.0, 1.0, 0.25,
+ 0.50, 0.4444444, 0.00001};
+ FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, 0.0f,
+ 1.0f, 8);
+ m.SetInput<float>(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
+ EXPECT_THAT(
+ m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({0, 1, 0.25098, 0.498039, 0.443137, 0})));
+}
+
+TEST(FakeQuantOpTest, FloatNegativeRange8Test) {
+ std::initializer_list<float> data = {0.0, -0.9, 0.25,
+ 0.50, 0.4444444, -0.00001};
+ FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, -0.9f,
+ 0.9f, 8);
+ m.SetInput<float>(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear(
+ {0, -0.896471, 0.247059, 0.501176, 0.444706, 0})));
+}
+
+TEST(FakeQuantOpTest, FloatPositiveRange16Test) {
+ std::initializer_list<float> data = {0.0, 1.0, 0.25,
+ 0.50, 0.4444444, 0.00001};
+ FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, 0.0f,
+ 1.0f, 16);
+ m.SetInput<float>(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear(
+ {0, 1, 0.250004, 0.500008, 0.44445, 1.5259e-05})));
+}
+
+TEST(FakeQuantOpTest, FloatNegativeRange16Test) {
+ std::initializer_list<float> data = {0.0, -0.9, 0.25,
+ 0.50, 0.4444444, -0.00001};
+ FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, -0.9f,
+ 0.9f, 16);
+ m.SetInput<float>(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear(
+ {0, -0.900014, 0.249998, 0.499995, 0.444431, 0})));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 5ba7e2af9b..c19f8e8a81 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -62,72 +62,35 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
sizeof(float32x4_t), (postamble_start >> 2) * sizeof(float32x4_t),
&aligned_vector_cache_free));
- const int kUnrollSize = 2;
for (int b = 0; b < n_batch; b++) {
float* result_in_batch = result + b * m_rows * result_stride;
const float* vector_in_batch = vector + b * m_cols;
-
- const float* matrix_ptr0 = matrix;
- // If there is only 1 row, we don't want to assign an illegal pointer.
- const float* matrix_ptr1 = nullptr;
- if (m_rows > 1) {
- matrix_ptr1 = matrix + m_cols;
- }
+ const float* matrix_row = matrix;
// Cache the vector.
for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
vector_cache_float32x4[c >> 2] = vld1q_f32(vector_in_batch + c);
}
- // Main matrix by vector multiplication loop, which handles two rows of
- // matrix by vector multiplication.
- for (int r = 0; r < (m_rows & ~(kUnrollSize - 1)); r += kUnrollSize) {
- float32x4_t acc0_32x4 = vmovq_n_f32(0.0);
- float32x4_t acc1_32x4 = vmovq_n_f32(0.0);
- for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
- float32x4_t temp = vector_cache_float32x4[c >> 2];
- // Load 4 float values from vector1 and vector2 and accumulator.
- float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c);
- float32x4_t v1_f32x4 = vld1q_f32(matrix_ptr1 + c);
- // Vector multiply-accumulate 4 float
- acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp);
- acc1_32x4 = vmlaq_f32(acc1_32x4, v1_f32x4, temp);
- }
- // Add the 4 intermediate sum values to get the final dot-prod value for
- // this column.
- *result_in_batch +=
- (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) +
- vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3));
- *(result_in_batch + result_stride) +=
- (vgetq_lane_f32(acc1_32x4, 0) + vgetq_lane_f32(acc1_32x4, 1) +
- vgetq_lane_f32(acc1_32x4, 2) + vgetq_lane_f32(acc1_32x4, 3));
- for (int c = postamble_start; c < m_cols; c++) {
- *result_in_batch += matrix_ptr0[c] * vector_in_batch[c];
- *(result_in_batch + result_stride) +=
- matrix_ptr1[c] * vector_in_batch[c];
- }
- matrix_ptr0 += kUnrollSize * m_cols;
- matrix_ptr1 += kUnrollSize * m_cols;
- result_in_batch += kUnrollSize * result_stride;
- }
- for (int r = (m_rows & ~(kUnrollSize - 1)); r < m_rows; r++) {
- float32x4_t acc0_32x4 = vmovq_n_f32(0.0);
+ // Main matrix by vector multiplication loop
+ for (int r = 0; r < m_rows; r++) {
+ float32x4_t acc_32x4 = vmovq_n_f32(0.0);
for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
float32x4_t temp = vector_cache_float32x4[c >> 2];
- // Load 4 float values from vector1 and vector2 and accumulator.
- float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c);
+ // Load 4 float values from vector and accumulator.
+ float32x4_t v_f32x4 = vld1q_f32(matrix_row + c);
// Vector multiply-accumulate 4 float
- acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp);
+ acc_32x4 = vmlaq_f32(acc_32x4, v_f32x4, temp);
}
// Add the 4 intermediate sum values to get the final dot-prod value for
// this column.
*result_in_batch +=
- (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) +
- vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3));
+ (vgetq_lane_f32(acc_32x4, 0) + vgetq_lane_f32(acc_32x4, 1) +
+ vgetq_lane_f32(acc_32x4, 2) + vgetq_lane_f32(acc_32x4, 3));
for (int c = postamble_start; c < m_cols; c++) {
- *result_in_batch += matrix_ptr0[c] * vector_in_batch[c];
+ *result_in_batch += matrix_row[c] * vector_in_batch[c];
}
- matrix_ptr0 += m_cols;
+ matrix_row += m_cols;
result_in_batch += result_stride;
}
}
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 0ca08cd8f3..1994e85ce3 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -82,6 +82,7 @@ TfLiteRegistration* Register_PRELU();
TfLiteRegistration* Register_MAXIMUM();
TfLiteRegistration* Register_MINIMUM();
TfLiteRegistration* Register_ARG_MAX();
+TfLiteRegistration* Register_ARG_MIN();
TfLiteRegistration* Register_GREATER();
TfLiteRegistration* Register_GREATER_EQUAL();
TfLiteRegistration* Register_LESS();
@@ -102,6 +103,7 @@ TfLiteRegistration* Register_SQRT();
TfLiteRegistration* Register_RSQRT();
TfLiteRegistration* Register_SHAPE();
TfLiteRegistration* Register_POW();
+TfLiteRegistration* Register_FAKE_QUANT();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -167,6 +169,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM());
AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM());
AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX());
+ AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN());
AddBuiltin(BuiltinOperator_GREATER, Register_GREATER());
AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL());
AddBuiltin(BuiltinOperator_LESS, Register_LESS());
@@ -187,6 +190,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT());
AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE());
AddBuiltin(BuiltinOperator_POW, Register_POW());
+ AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc
index 9b6cee3cb5..3cdb5db209 100644
--- a/tensorflow/contrib/lite/kernels/select.cc
+++ b/tensorflow/contrib/lite/kernels/select.cc
@@ -89,6 +89,9 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8: \
TF_LITE_SELECT(uint8_t, op); \
break; \
+ case kTfLiteInt16: \
+ TF_LITE_SELECT(int16_t, op); \
+ break; \
case kTfLiteInt32: \
TF_LITE_SELECT(int32_t, op); \
break; \
diff --git a/tensorflow/contrib/lite/kernels/select_test.cc b/tensorflow/contrib/lite/kernels/select_test.cc
index 4664b9acb4..5b2e61cd29 100644
--- a/tensorflow/contrib/lite/kernels/select_test.cc
+++ b/tensorflow/contrib/lite/kernels/select_test.cc
@@ -96,6 +96,19 @@ TEST(SelectOpTest, SelectUInt8) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
+TEST(SelectOpTest, SelectInt16) {
+ SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
+ TensorType_INT16);
+
+ model.PopulateTensor<bool>(model.input1(), {false, true, false, false});
+ model.PopulateTensor<int16_t>(model.input2(), {1, 2, 3, 4});
+ model.PopulateTensor<int16_t>(model.input3(), {5, 6, 7, 8});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput<int16_t>(), ElementsAreArray({5, 2, 7, 8}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
TEST(SelectOpTest, SelectInt32) {
SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
TensorType_INT32);
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 0bc717bfca..93b3df98f3 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -206,8 +206,9 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
} else if (builtin_code != BuiltinOperator_CUSTOM) {
registration = op_resolver_.FindOp(builtin_code, version);
if (registration == nullptr) {
- error_reporter_->Report("Didn't find op for builtin opcode '%s'\n",
- EnumNameBuiltinOperator(builtin_code));
+ error_reporter_->Report(
+ "Didn't find op for builtin opcode '%s' version '%d'\n",
+ EnumNameBuiltinOperator(builtin_code), version);
status = kTfLiteError;
}
} else if (!opcode->custom_code()) {
@@ -663,6 +664,15 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_ARG_MIN: {
+ auto* params = MallocPOD<TfLiteArgMinParams>();
+ if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) {
+ ConvertTensorType(schema_params->output_type(), &params->output_type,
+ error_reporter);
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
case BuiltinOperator_TRANSPOSE_CONV: {
TfLiteTransposeConvParams* params =
MallocPOD<TfLiteTransposeConvParams>();
@@ -699,6 +709,16 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
error_reporter->Report("DELEGATE op shouldn't exist in model.");
return kTfLiteError;
}
+ case BuiltinOperator_FAKE_QUANT: {
+ auto* params = MallocPOD<TfLiteFakeQuantParams>();
+ if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) {
+ params->min = schema_params->min();
+ params->max = schema_params->max();
+ params->num_bits = schema_params->num_bits();
+ }
+ *builtin_data = static_cast<void*>(params);
+ break;
+ }
// Below are the ops with no builtin_data strcture.
case BuiltinOperator_BATCH_TO_SPACE_ND:
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 905c0919cb..cc668485a4 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -548,6 +548,18 @@ TfLiteStatus AddOpsAndParams(
add_squeeze_params(node.builtin_data);
nn_op_type = ANEURALNETWORKS_SQUEEZE;
break;
+ case tflite::BuiltinOperator_TRANSPOSE:
+ // The permutation input tensor value dictates the output dimensions.
+ // TODO(b/110888333): Support dynamically-sized tensors in delegates.
+ if ((node.inputs->size > 1) &&
+ (interpreter->tensor(node.inputs->data[1])->allocation_type !=
+ kTfLiteMmapRo)) {
+ logError("NNAPI does not yet support dynamic tensors.");
+ return kTfLiteError;
+ }
+ nnapi_version = 11; // require NNAPI 1.1
+ nn_op_type = ANEURALNETWORKS_TRANSPOSE;
+ break;
case tflite::BuiltinOperator_CONCAT_EMBEDDINGS:
case tflite::BuiltinOperator_LSH_PROJECTION:
case tflite::BuiltinOperator_HASHTABLE_LOOKUP:
@@ -567,7 +579,6 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_SPACE_TO_BATCH_ND:
case tflite::BuiltinOperator_BATCH_TO_SPACE_ND:
case tflite::BuiltinOperator_TOPK_V2:
- case tflite::BuiltinOperator_TRANSPOSE:
case tflite::BuiltinOperator_SPLIT:
case tflite::BuiltinOperator_STRIDED_SLICE:
case tflite::BuiltinOperator_EXP:
@@ -579,6 +590,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_MAXIMUM:
case tflite::BuiltinOperator_MINIMUM:
case tflite::BuiltinOperator_ARG_MAX:
+ case tflite::BuiltinOperator_ARG_MIN:
case tflite::BuiltinOperator_GREATER:
case tflite::BuiltinOperator_GREATER_EQUAL:
case tflite::BuiltinOperator_LESS:
@@ -599,6 +611,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_RSQRT:
case tflite::BuiltinOperator_SHAPE:
case tflite::BuiltinOperator_POW:
+ case tflite::BuiltinOperator_FAKE_QUANT:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 5e6467f676..17ea26052d 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -160,6 +160,8 @@ enum BuiltinOperator : byte {
RSQRT = 76,
SHAPE = 77,
POW = 78,
+ ARG_MIN = 79,
+ FAKE_QUANT = 80,
}
// Options for the builtin operators.
@@ -220,6 +222,8 @@ union BuiltinOptions {
NotEqualOptions,
ShapeOptions,
PowOptions,
+ ArgMinOptions,
+ FakeQuantOptions,
}
enum Padding : byte { SAME, VALID }
@@ -469,6 +473,10 @@ table ArgMaxOptions {
output_type : TensorType;
}
+table ArgMinOptions {
+ output_type : TensorType;
+}
+
table GreaterOptions {
}
@@ -517,6 +525,12 @@ table ShapeOptions {
table PowOptions {
}
+table FakeQuantOptions {
+ min:float;
+ max:float;
+ num_bits:int;
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index fe0ff9a7a5..37489ebc68 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -157,6 +157,9 @@ struct TileOptionsT;
struct ArgMaxOptions;
struct ArgMaxOptionsT;
+struct ArgMinOptions;
+struct ArgMinOptionsT;
+
struct GreaterOptions;
struct GreaterOptionsT;
@@ -199,6 +202,9 @@ struct ShapeOptionsT;
struct PowOptions;
struct PowOptionsT;
+struct FakeQuantOptions;
+struct FakeQuantOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -343,11 +349,13 @@ enum BuiltinOperator {
BuiltinOperator_RSQRT = 76,
BuiltinOperator_SHAPE = 77,
BuiltinOperator_POW = 78,
+ BuiltinOperator_ARG_MIN = 79,
+ BuiltinOperator_FAKE_QUANT = 80,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_POW
+ BuiltinOperator_MAX = BuiltinOperator_FAKE_QUANT
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[78] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[80] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -426,7 +434,9 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[78] {
BuiltinOperator_SQRT,
BuiltinOperator_RSQRT,
BuiltinOperator_SHAPE,
- BuiltinOperator_POW
+ BuiltinOperator_POW,
+ BuiltinOperator_ARG_MIN,
+ BuiltinOperator_FAKE_QUANT
};
return values;
}
@@ -512,6 +522,8 @@ inline const char **EnumNamesBuiltinOperator() {
"RSQRT",
"SHAPE",
"POW",
+ "ARG_MIN",
+ "FAKE_QUANT",
nullptr
};
return names;
@@ -580,11 +592,13 @@ enum BuiltinOptions {
BuiltinOptions_NotEqualOptions = 54,
BuiltinOptions_ShapeOptions = 55,
BuiltinOptions_PowOptions = 56,
+ BuiltinOptions_ArgMinOptions = 57,
+ BuiltinOptions_FakeQuantOptions = 58,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_PowOptions
+ BuiltinOptions_MAX = BuiltinOptions_FakeQuantOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[57] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[59] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -642,7 +656,9 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[57] {
BuiltinOptions_EqualOptions,
BuiltinOptions_NotEqualOptions,
BuiltinOptions_ShapeOptions,
- BuiltinOptions_PowOptions
+ BuiltinOptions_PowOptions,
+ BuiltinOptions_ArgMinOptions,
+ BuiltinOptions_FakeQuantOptions
};
return values;
}
@@ -706,6 +722,8 @@ inline const char **EnumNamesBuiltinOptions() {
"NotEqualOptions",
"ShapeOptions",
"PowOptions",
+ "ArgMinOptions",
+ "FakeQuantOptions",
nullptr
};
return names;
@@ -944,6 +962,14 @@ template<> struct BuiltinOptionsTraits<PowOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_PowOptions;
};
+template<> struct BuiltinOptionsTraits<ArgMinOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_ArgMinOptions;
+};
+
+template<> struct BuiltinOptionsTraits<FakeQuantOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_FakeQuantOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1423,6 +1449,22 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_PowOptions ?
reinterpret_cast<const PowOptionsT *>(value) : nullptr;
}
+ ArgMinOptionsT *AsArgMinOptions() {
+ return type == BuiltinOptions_ArgMinOptions ?
+ reinterpret_cast<ArgMinOptionsT *>(value) : nullptr;
+ }
+ const ArgMinOptionsT *AsArgMinOptions() const {
+ return type == BuiltinOptions_ArgMinOptions ?
+ reinterpret_cast<const ArgMinOptionsT *>(value) : nullptr;
+ }
+ FakeQuantOptionsT *AsFakeQuantOptions() {
+ return type == BuiltinOptions_FakeQuantOptions ?
+ reinterpret_cast<FakeQuantOptionsT *>(value) : nullptr;
+ }
+ const FakeQuantOptionsT *AsFakeQuantOptions() const {
+ return type == BuiltinOptions_FakeQuantOptions ?
+ reinterpret_cast<const FakeQuantOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -4486,6 +4528,60 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(
flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct ArgMinOptionsT : public flatbuffers::NativeTable {
+ typedef ArgMinOptions TableType;
+ TensorType output_type;
+ ArgMinOptionsT()
+ : output_type(TensorType_FLOAT32) {
+ }
+};
+
+struct ArgMinOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ArgMinOptionsT NativeTableType;
+ enum {
+ VT_OUTPUT_TYPE = 4
+ };
+ TensorType output_type() const {
+ return static_cast<TensorType>(GetField<int8_t>(VT_OUTPUT_TYPE, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_OUTPUT_TYPE) &&
+ verifier.EndTable();
+ }
+ ArgMinOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ArgMinOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<ArgMinOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ArgMinOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_output_type(TensorType output_type) {
+ fbb_.AddElement<int8_t>(ArgMinOptions::VT_OUTPUT_TYPE, static_cast<int8_t>(output_type), 0);
+ }
+ explicit ArgMinOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ArgMinOptionsBuilder &operator=(const ArgMinOptionsBuilder &);
+ flatbuffers::Offset<ArgMinOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ArgMinOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ArgMinOptions> CreateArgMinOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ TensorType output_type = TensorType_FLOAT32) {
+ ArgMinOptionsBuilder builder_(_fbb);
+ builder_.add_output_type(output_type);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<ArgMinOptions> CreateArgMinOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct GreaterOptionsT : public flatbuffers::NativeTable {
typedef GreaterOptions TableType;
GreaterOptionsT() {
@@ -5112,6 +5208,84 @@ inline flatbuffers::Offset<PowOptions> CreatePowOptions(
flatbuffers::Offset<PowOptions> CreatePowOptions(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct FakeQuantOptionsT : public flatbuffers::NativeTable {
+ typedef FakeQuantOptions TableType;
+ float min;
+ float max;
+ int32_t num_bits;
+ FakeQuantOptionsT()
+ : min(0.0f),
+ max(0.0f),
+ num_bits(0) {
+ }
+};
+
+struct FakeQuantOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef FakeQuantOptionsT NativeTableType;
+ enum {
+ VT_MIN = 4,
+ VT_MAX = 6,
+ VT_NUM_BITS = 8
+ };
+ float min() const {
+ return GetField<float>(VT_MIN, 0.0f);
+ }
+ float max() const {
+ return GetField<float>(VT_MAX, 0.0f);
+ }
+ int32_t num_bits() const {
+ return GetField<int32_t>(VT_NUM_BITS, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<float>(verifier, VT_MIN) &&
+ VerifyField<float>(verifier, VT_MAX) &&
+ VerifyField<int32_t>(verifier, VT_NUM_BITS) &&
+ verifier.EndTable();
+ }
+ FakeQuantOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(FakeQuantOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<FakeQuantOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct FakeQuantOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_min(float min) {
+ fbb_.AddElement<float>(FakeQuantOptions::VT_MIN, min, 0.0f);
+ }
+ void add_max(float max) {
+ fbb_.AddElement<float>(FakeQuantOptions::VT_MAX, max, 0.0f);
+ }
+ void add_num_bits(int32_t num_bits) {
+ fbb_.AddElement<int32_t>(FakeQuantOptions::VT_NUM_BITS, num_bits, 0);
+ }
+ explicit FakeQuantOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ FakeQuantOptionsBuilder &operator=(const FakeQuantOptionsBuilder &);
+ flatbuffers::Offset<FakeQuantOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<FakeQuantOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<FakeQuantOptions> CreateFakeQuantOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ float min = 0.0f,
+ float max = 0.0f,
+ int32_t num_bits = 0) {
+ FakeQuantOptionsBuilder builder_(_fbb);
+ builder_.add_num_bits(num_bits);
+ builder_.add_max(max);
+ builder_.add_min(min);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<FakeQuantOptions> CreateFakeQuantOptions(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -5413,6 +5587,12 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const PowOptions *builtin_options_as_PowOptions() const {
return builtin_options_type() == BuiltinOptions_PowOptions ? static_cast<const PowOptions *>(builtin_options()) : nullptr;
}
+ const ArgMinOptions *builtin_options_as_ArgMinOptions() const {
+ return builtin_options_type() == BuiltinOptions_ArgMinOptions ? static_cast<const ArgMinOptions *>(builtin_options()) : nullptr;
+ }
+ const FakeQuantOptions *builtin_options_as_FakeQuantOptions() const {
+ return builtin_options_type() == BuiltinOptions_FakeQuantOptions ? static_cast<const FakeQuantOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -5668,6 +5848,14 @@ template<> inline const PowOptions *Operator::builtin_options_as<PowOptions>() c
return builtin_options_as_PowOptions();
}
+template<> inline const ArgMinOptions *Operator::builtin_options_as<ArgMinOptions>() const {
+ return builtin_options_as_ArgMinOptions();
+}
+
+template<> inline const FakeQuantOptions *Operator::builtin_options_as<FakeQuantOptions>() const {
+ return builtin_options_as_FakeQuantOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -7333,6 +7521,32 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatB
_output_type);
}
+inline ArgMinOptionsT *ArgMinOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new ArgMinOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void ArgMinOptions::UnPackTo(ArgMinOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = output_type(); _o->output_type = _e; };
+}
+
+inline flatbuffers::Offset<ArgMinOptions> ArgMinOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateArgMinOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<ArgMinOptions> CreateArgMinOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ArgMinOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _output_type = _o->output_type;
+ return tflite::CreateArgMinOptions(
+ _fbb,
+ _output_type);
+}
+
inline GreaterOptionsT *GreaterOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new GreaterOptionsT();
UnPackTo(_o, _resolver);
@@ -7670,6 +7884,38 @@ inline flatbuffers::Offset<PowOptions> CreatePowOptions(flatbuffers::FlatBufferB
_fbb);
}
+inline FakeQuantOptionsT *FakeQuantOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new FakeQuantOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void FakeQuantOptions::UnPackTo(FakeQuantOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = min(); _o->min = _e; };
+ { auto _e = max(); _o->max = _e; };
+ { auto _e = num_bits(); _o->num_bits = _e; };
+}
+
+inline flatbuffers::Offset<FakeQuantOptions> FakeQuantOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateFakeQuantOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<FakeQuantOptions> CreateFakeQuantOptions(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FakeQuantOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _min = _o->min;
+ auto _max = _o->max;
+ auto _num_bits = _o->num_bits;
+ return tflite::CreateFakeQuantOptions(
+ _fbb,
+ _min,
+ _max,
+ _num_bits);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -8083,6 +8329,14 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const PowOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_ArgMinOptions: {
+ auto ptr = reinterpret_cast<const ArgMinOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_FakeQuantOptions: {
+ auto ptr = reinterpret_cast<const FakeQuantOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -8325,6 +8579,14 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const PowOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_ArgMinOptions: {
+ auto ptr = reinterpret_cast<const ArgMinOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_FakeQuantOptions: {
+ auto ptr = reinterpret_cast<const FakeQuantOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -8555,6 +8817,14 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const PowOptionsT *>(value);
return CreatePowOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_ArgMinOptions: {
+ auto ptr = reinterpret_cast<const ArgMinOptionsT *>(value);
+ return CreateArgMinOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_FakeQuantOptions: {
+ auto ptr = reinterpret_cast<const FakeQuantOptionsT *>(value);
+ return CreateFakeQuantOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -8785,6 +9055,14 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new PowOptionsT(*reinterpret_cast<PowOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_ArgMinOptions: {
+ value = new ArgMinOptionsT(*reinterpret_cast<ArgMinOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_FakeQuantOptions: {
+ value = new FakeQuantOptionsT(*reinterpret_cast<FakeQuantOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -9072,6 +9350,16 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_ArgMinOptions: {
+ auto ptr = reinterpret_cast<ArgMinOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_FakeQuantOptions: {
+ auto ptr = reinterpret_cast<FakeQuantOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 50237ed792..1093bd2cbe 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -678,6 +678,55 @@ def make_relu6_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_prelu_tests(zip_path):
+ """Make a set of tests to do PReLU."""
+
+ test_parameters = [{
+ # The canonical case for image processing is having a 4D `input` (NHWC)
+ # and `shared_axes`=[1, 2], so the alpha parameter is per channel.
+ "input_shape": [[1, 10, 10, 3], [3, 3, 3, 3]],
+ "shared_axes": [[1, 2], [1]],
+ }]
+
+ def build_graph(parameters):
+ """Build the graph for the test case."""
+
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ prelu = tf.keras.layers.PReLU(shared_axes=parameters["shared_axes"])
+ out = prelu(input_tensor)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ """Build the inputs for the test case."""
+
+ input_shape = parameters["input_shape"]
+ input_values = create_tensor_data(
+ np.float32, input_shape, min_value=-10, max_value=10)
+ shared_axes = parameters["shared_axes"]
+
+ alpha_shape = []
+ for dim in range(1, len(input_shape)):
+ alpha_shape.append(1 if dim in shared_axes else input_shape[dim])
+
+ alpha_values = create_tensor_data(np.float32, alpha_shape)
+
+ # There should be only 1 trainable variable tensor.
+ variables = tf.all_variables()
+ assert len(variables) == 1
+ sess.run(variables[0].assign(alpha_values))
+
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(
+ zip_path,
+ test_parameters,
+ build_graph,
+ build_inputs,
+ use_frozen_graph=True)
+
+
# This function tests various TensorFLow functions that generates Const op,
# including `tf.ones`, `tf.zeros` and random functions.
def make_constant_tests(zip_path):
@@ -2175,7 +2224,7 @@ def make_topk_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
-def make_arg_max_tests(zip_path):
+def make_arg_min_max_tests(zip_path):
"""Make a set of tests to do arg_max."""
test_parameters = [{
@@ -2183,6 +2232,7 @@ def make_arg_max_tests(zip_path):
"input_shape": [[1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5], [10]],
"output_type": [tf.int32, tf.int64],
"axis_is_last_dim": [True, False],
+ "is_arg_max": [True],
}]
def build_graph(parameters):
@@ -2195,7 +2245,10 @@ def make_arg_max_tests(zip_path):
axis = len(parameters["input_shape"]) - 1
else:
axis = random.randint(0, max(len(parameters["input_shape"]) - 2, 0))
- out = tf.arg_max(input_value, axis, output_type=parameters["output_type"])
+ if parameters["is_arg_max"]:
+ out = tf.arg_max(input_value, axis, output_type=parameters["output_type"])
+ else:
+ out = tf.arg_min(input_value, axis, output_type=parameters["output_type"])
return [input_value], [out]
def build_inputs(parameters, sess, inputs, outputs):
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index c4e20312d8..5bc6b53416 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -97,11 +97,12 @@ std::map<string, string> kBrokenTests = {
{R"(^\/gather.*axis=1)", "76910444"},
// No support for arbitrary dimensions in ArgMax.
- {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.,.,.\])",
+ {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.,.,.\])",
"77546240"},
- {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.,.\])",
+ {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.,.\])",
+ "77546240"},
+ {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.\])",
"77546240"},
- {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.\])", "77546240"},
};
// Allows test data to be unzipped into a temporary directory and makes
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 6be6b25f93..a08cdbfba6 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1135,6 +1135,22 @@ void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op,
GetTensorFlowDataType(model, src_op.outputs[0]));
}
+void ConvertArgMinOperator(const Model& model, const ArgMinOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* argmin_op = tensorflow_graph->add_node();
+ argmin_op->set_op("ArgMin");
+ argmin_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *argmin_op->add_input() = src_op.inputs[0];
+ *argmin_op->add_input() = src_op.inputs[1];
+ (*argmin_op->mutable_attr())["T"].set_type(
+ GetTensorFlowDataType(model, src_op.inputs[0]));
+ (*argmin_op->mutable_attr())["Tidx"].set_type(
+ GetTensorFlowDataType(model, src_op.inputs[1]));
+ (*argmin_op->mutable_attr())["output_type"].set_type(
+ GetTensorFlowDataType(model, src_op.outputs[0]));
+}
+
void ConvertTransposeOperator(const Model& model,
const TransposeOperator& src_op,
GraphDef* tensorflow_graph) {
@@ -1964,6 +1980,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kArgMax) {
ConvertArgMaxOperator(model, static_cast<const ArgMaxOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kArgMin) {
+ ConvertArgMinOperator(model, static_cast<const ArgMinOperator&>(src_op),
+ tensorflow_graph);
} else if (src_op.type == OperatorType::kTopK_V2) {
ConvertTopKV2Operator(model, static_cast<const TopKV2Operator&>(src_op),
tensorflow_graph);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc
index 30be4ac0aa..b90a156a0d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc
@@ -74,14 +74,30 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
const auto* relu_neg_input_op = GetOpWithOutput(*model, mul_op->inputs[1]);
if (relu_neg_input_op == nullptr ||
- relu_neg_input_op->type != OperatorType::kNeg ||
- relu_neg_input_op->fused_activation_function !=
- FusedActivationFunctionType::kRelu ||
relu_neg_input_op->inputs.size() != 1) {
return false;
}
- if (relu_input_op->inputs[0] != relu_neg_input_op->inputs[0]) {
+ const Operator* final_input_op;
+ if (relu_neg_input_op->type == OperatorType::kNeg &&
+ relu_neg_input_op->fused_activation_function ==
+ FusedActivationFunctionType::kRelu) {
+ // This detects a Neg op with fused Relu activation function.
+ final_input_op = relu_neg_input_op;
+ } else {
+ // This detects a Neg op followed by a separated Relu op.
+ const auto* neg_input_op =
+ GetOpWithOutput(*model, relu_neg_input_op->inputs[0]);
+ if (neg_input_op == nullptr || neg_input_op->inputs.size() != 1 ||
+ relu_neg_input_op->type != OperatorType::kRelu ||
+ relu_neg_input_op->fused_activation_function !=
+ FusedActivationFunctionType::kNone) {
+ return false;
+ }
+ final_input_op = neg_input_op;
+ }
+
+ if (relu_input_op->inputs[0] != final_input_op->inputs[0]) {
return false;
}
@@ -112,7 +128,6 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
// intermediate tensors aren't used by other ops, those will be removed by
// other graph transformation rules.
model->operators.erase(FindOp(*model, add_op));
-
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 00ab7cbaa9..670bcf64e7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -100,6 +100,13 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type;
break;
}
+ case OperatorType::kArgMin: {
+ // Data type of the ArgMin op is specified.
+ CHECK_EQ(op->outputs.size(), 1);
+ auto* argmin_op = static_cast<ArgMinOperator*>(op);
+ model->GetArray(op->outputs[0]).data_type = argmin_op->output_data_type;
+ break;
+ }
case OperatorType::kRange: {
auto* range_op = static_cast<RangeOperator*>(op);
// Output type of the Range op can be set via an attribute
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
index 0f2592d05f..53fc87da7b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
@@ -30,15 +30,9 @@ namespace {
bool ChangeArrayDataType(GraphTransformation* transformation, Array* array,
ArrayDataType new_data_type,
const MinMax* new_minmax) {
- // The code below assumes kInt16, see
- // GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>
- if (new_data_type != ArrayDataType::kInt16) {
- return false;
- }
-
- bool changed = false;
// Ensure the array ends up in the new type (if it hasn't yet been quantized).
- if ((array->final_data_type != new_data_type)) {
+ bool changed = false;
+ if (array->final_data_type != new_data_type) {
array->final_data_type = new_data_type;
changed = true;
}
@@ -75,8 +69,20 @@ bool ChangeArrayDataType(GraphTransformation* transformation, Array* array,
array_minmax.min = min;
array_minmax.max = max;
- GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(
- array_minmax, array->quantization_params.get());
+ switch (new_data_type) {
+ case ArrayDataType::kUint8:
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(
+ array_minmax, array->quantization_params.get());
+ break;
+ case ArrayDataType::kInt16:
+ GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(
+ array_minmax, array->quantization_params.get());
+ break;
+ default:
+ CHECK(false) << "Unsupported quantized data type: "
+ << ArrayDataTypeName(new_data_type);
+ return false;
+ }
// Directly change the type as the array was already quantized.
array->data_type = new_data_type;
@@ -95,6 +101,7 @@ bool ChangeArrayDataType(GraphTransformation* transformation, Array* array,
changed = true;
}
}
+
return changed;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 8eb0423283..4f95c57451 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1404,7 +1404,8 @@ void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
}
}
-void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) {
+template <typename Op>
+void ProcessArgMinMaxOperator(Model* model, Op* op) {
CHECK_EQ(op->inputs.size(), 2);
const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
@@ -1696,7 +1697,12 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
static_cast<StridedSliceOperator*>(op));
break;
case OperatorType::kArgMax:
- ProcessArgMaxOperator(model, static_cast<ArgMaxOperator*>(op));
+ ProcessArgMinMaxOperator<ArgMaxOperator>(
+ model, static_cast<ArgMaxOperator*>(op));
+ break;
+ case OperatorType::kArgMin:
+ ProcessArgMinMaxOperator<ArgMinOperator>(
+ model, static_cast<ArgMinOperator*>(op));
break;
case OperatorType::kUnsupported:
break;
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 5c32a39035..bc439a2feb 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1230,10 +1230,11 @@ tensorflow::Status ConvertGatherOperator(
return tensorflow::Status::OK();
}
-tensorflow::Status ConvertArgMaxOperator(
+template <typename Op, const char* op_name>
+tensorflow::Status ConvertArgMinMaxOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
- CHECK_EQ(node.op(), "ArgMax");
+ CHECK_EQ(node.op(), op_name);
TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
const auto axis_data_type =
HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
@@ -1242,7 +1243,7 @@ tensorflow::Status ConvertArgMaxOperator(
: DT_INT64;
CHECK(axis_data_type == DT_INT64 || axis_data_type == DT_INT32);
CHECK(output_type == DT_INT64 || output_type == DT_INT32);
- auto* op = new ArgMaxOperator;
+ auto* op = new Op;
op->output_data_type = ConvertDataType(output_type);
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));
@@ -1833,12 +1834,16 @@ using ConverterType = tensorflow::Status (*)(
Model* model);
using ConverterMapType = std::unordered_map<std::string, ConverterType>;
+constexpr char kArgMax[] = "ArgMax";
+constexpr char kArgMin[] = "ArgMin";
+
ConverterMapType GetTensorFlowNodeConverterMap() {
return std::unordered_map<std::string, ConverterType>({
{"Add", ConvertSimpleOperator<AddOperator, 2>},
{"AddN", ConvertSimpleOperator<AddNOperator>},
{"All", ConvertSimpleOperator<TensorFlowAllOperator>},
- {"ArgMax", ConvertArgMaxOperator},
+ {"ArgMax", ConvertArgMinMaxOperator<ArgMaxOperator, kArgMax>},
+ {"ArgMin", ConvertArgMinMaxOperator<ArgMinOperator, kArgMin>},
{"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>},
{"AvgPool", ConvertAvgPoolOperator},
{"BatchMatMul", ConvertBatchMatMulOperator},
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 3a1d243f87..8660464fdb 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -140,6 +140,7 @@ enum class OperatorType : uint8 {
kEqual,
kNotEqual,
kPow,
+ kArgMin,
};
// Helper to deal with TensorFlow arrays using a different ordering of
@@ -1528,6 +1529,17 @@ struct ArgMaxOperator : Operator {
ArrayDataType output_data_type = ArrayDataType::kInt64;
};
+// ArgMin operator. It returns the index of the minimum value along axis.
+//
+// Inputs:
+// inputs[0]: required: the input tensor
+//
+// TensorFlow equivalent: ArgMin
+struct ArgMinOperator : Operator {
+ ArgMinOperator() : Operator(OperatorType::kArgMin) {}
+ ArrayDataType output_data_type = ArrayDataType::kInt64;
+};
+
// ResizeBilinear operator. It resizes input images with bilinear interpolation.
// It does not support align_corners at the moment.
//
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index 1972246807..5ad307af14 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -336,17 +336,13 @@ void Export(
auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map,
&builder, &error_summary);
- const string fake_quant_operation_name = "FAKE_QUANT";
-
- if (error_summary.count(fake_quant_operation_name) != 0) {
- LOG(ERROR)
- << fake_quant_operation_name
- << " operation was not converted. If running quantized make sure you "
- "are passing --inference_type=QUANTIZED_UINT8 and values for "
- "--std_values and --mean_values.";
- // Remove the fake quant operation from the errors, since it shouldn't
- // be provided a custom implementation.
- error_summary.erase(fake_quant_operation_name);
+ for (const auto& op : model.operators) {
+ if (op->type == OperatorType::kFakeQuant) {
+ LOG(WARNING) << "FAKE_QUANT operation " << LogName(*op)
+ << " was not converted. If running quantized make sure you "
+ "are passing --inference_type=QUANTIZED_UINT8 and values "
+ "for --std_values and --mean_values.";
+ }
}
if (!allow_custom_ops && !error_summary.empty()) {
// Remove ExpandDims and ReorderAxes from unimplemented list unless they
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 7e55ae92bd..8377ba6a03 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -282,22 +282,24 @@ class DepthToSpace : public CustomOperator<DepthToSpaceOperator> {
int GetVersion(const Operator& op) const override { return 1; }
};
-class FakeQuant : public CustomOperator<FakeQuantOperator> {
+class FakeQuant
+ : public BuiltinOperator<FakeQuantOperator, ::tflite::FakeQuantOptions,
+ ::tflite::BuiltinOptions_FakeQuantOptions> {
public:
- using CustomOperator::CustomOperator;
- void WriteOptions(const TocoOperator& op,
- flexbuffers::Builder* fbb) const override {
- fbb->Float("min", op.minmax->min);
- fbb->Float("max", op.minmax->max);
- fbb->Int("num_bits", op.num_bits);
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateFakeQuantOptions(*builder, op.minmax->min,
+ op.minmax->max, op.num_bits);
}
- void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
auto* minmax = new MinMax;
- minmax->min = m["min"].AsFloat();
- minmax->max = m["max"].AsFloat();
+ minmax->min = options.min();
+ minmax->max = options.max();
op->minmax.reset(minmax);
- const auto& num_bits = m["num_bits"];
- op->num_bits = num_bits.IsInt() ? num_bits.AsInt32() : 8;
+ op->num_bits = options.num_bits();
}
int GetVersion(const Operator& op) const override { return 1; }
@@ -885,6 +887,25 @@ class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
int GetVersion(const Operator& op) const override { return 1; }
};
+class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
+ ::tflite::BuiltinOptions_ArgMinOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateArgMinOptions(
+ *builder, DataType::Serialize(op.output_data_type));
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->output_data_type = DataType::Deserialize(options.output_type());
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
class TransposeConv
: public BuiltinOperator<TransposeConvOperator,
::tflite::TransposeConvOptions,
@@ -1175,6 +1196,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
ops.emplace_back(
new ArgMax(::tflite::BuiltinOperator_ARG_MAX, OperatorType::kArgMax));
ops.emplace_back(
+ new ArgMin(::tflite::BuiltinOperator_ARG_MIN, OperatorType::kArgMin));
+ ops.emplace_back(
new Tile(::tflite::BuiltinOperator_TILE, OperatorType::kTile));
ops.emplace_back(new ExpandDims(::tflite::BuiltinOperator_EXPAND_DIMS,
OperatorType::kExpandDims));
@@ -1184,11 +1207,12 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
OperatorType::kSparseToDense));
ops.emplace_back(
new Shape(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape));
+ ops.emplace_back(new FakeQuant(::tflite::BuiltinOperator_FAKE_QUANT,
+ OperatorType::kFakeQuant));
// Custom Operators.
ops.emplace_back(
new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
- ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant));
ops.emplace_back(new TensorFlowUnsupported("TENSORFLOW_UNSUPPORTED",
OperatorType::kUnsupported));
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index 8b6808d3c7..ff2d35b1f5 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -416,6 +416,13 @@ TEST_F(OperatorTest, BuiltinArgMax) {
EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
}
+TEST_F(OperatorTest, BuiltinArgMin) {
+ ArgMinOperator op;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("ARG_MIN", OperatorType::kArgMin), op);
+ EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
+}
+
TEST_F(OperatorTest, BuiltinTransposeConv) {
TransposeConvOperator op;
op.stride_width = 123;
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 8abdb014e4..4ec74e351f 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -387,6 +387,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Mean)
HANDLE_OPERATORTYPENAME_CASE(Svdf)
HANDLE_OPERATORTYPENAME_CASE(ArgMax)
+ HANDLE_OPERATORTYPENAME_CASE(ArgMin)
HANDLE_OPERATORTYPENAME_CASE(TopK_V2)
HANDLE_OPERATORTYPENAME_CASE(Unsupported)
HANDLE_OPERATORTYPENAME_CASE(Exp)
diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD
index a3df37358f..d070018e83 100644
--- a/tensorflow/contrib/lite/tools/BUILD
+++ b/tensorflow/contrib/lite/tools/BUILD
@@ -14,6 +14,7 @@ py_binary(
srcs = ["visualize.py"],
data = [
"//tensorflow/contrib/lite/schema:schema.fbs",
+ "//tensorflow/python:platform",
"@flatbuffers//:flatc",
],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc
index 08648bcfe2..19b9a9c7ba 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc
@@ -98,10 +98,13 @@ void BenchmarkModel::LogFlags() {
<< "]";
}
+void BenchmarkModel::PrepareInputsAndOutputs() {}
+
Stat<int64_t> BenchmarkModel::Run(int num_times, RunType run_type) {
Stat<int64_t> run_stats;
TFLITE_LOG(INFO) << "Running benchmark for " << num_times << " iterations ";
for (int run = 0; run < num_times; run++) {
+ PrepareInputsAndOutputs();
listeners_.OnSingleRunStart(run_type);
int64_t start_us = profiling::time::NowMicros();
RunImpl();
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h
index 942e21f67a..3c7063b2d4 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h
@@ -150,6 +150,7 @@ class BenchmarkModel {
virtual std::vector<Flag> GetFlags();
virtual uint64_t ComputeInputBytes() = 0;
virtual tensorflow::Stat<int64_t> Run(int num_times, RunType run_type);
+ virtual void PrepareInputsAndOutputs();
virtual void RunImpl() = 0;
BenchmarkParams params_;
BenchmarkListeners listeners_;
diff --git a/tensorflow/contrib/lite/tools/visualize.py b/tensorflow/contrib/lite/tools/visualize.py
index f571dd59da..e07f899e4d 100644
--- a/tensorflow/contrib/lite/tools/visualize.py
+++ b/tensorflow/contrib/lite/tools/visualize.py
@@ -28,11 +28,24 @@ import json
import os
import sys
+from tensorflow.python.platform import resource_loader
+
# Schema to use for flatbuffers
_SCHEMA = "third_party/tensorflow/contrib/lite/schema/schema.fbs"
-# Where the binary will be once built in for the flatc converter
-_BINARY = "third_party/flatbuffers/flatc"
+# TODO(angerson): fix later when rules are simplified..
+_SCHEMA = resource_loader.get_path_to_datafile("../schema/schema.fbs")
+_BINARY = resource_loader.get_path_to_datafile("../../../../flatbuffers/flatc")
+# Account for different package positioning internal vs. external.
+if not os.path.exists(_BINARY):
+ _BINARY = resource_loader.get_path_to_datafile(
+ "../../../../../flatbuffers/flatc")
+
+if not os.path.exists(_SCHEMA):
+ raise RuntimeError("Sorry, schema file cannot be found at %r" % _SCHEMA)
+if not os.path.exists(_BINARY):
+ raise RuntimeError("Sorry, flatc is not available at %r" % _BINARY)
+
# A CSS description for making the visualizer
_CSS = """