aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/autograph/converters/logical_expressions.py4
-rw-r--r--tensorflow/contrib/autograph/converters/logical_expressions_test.py9
-rw-r--r--tensorflow/contrib/autograph/impl/api_test.py3
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/anf.py10
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py40
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/live_values.py7
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator.py43
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/model.py8
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py9
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py45
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py104
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py22
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/BUILD1
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py45
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py56
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py13
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py144
-rw-r--r--tensorflow/contrib/data/python/ops/map_defun.py2
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py3
-rw-r--r--tensorflow/contrib/estimator/BUILD1
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn.py14
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn_test.py41
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc2
-rw-r--r--tensorflow/contrib/learn/BUILD1
-rw-r--r--tensorflow/contrib/lite/BUILD46
-rw-r--r--tensorflow/contrib/lite/allocation.cc4
-rw-r--r--tensorflow/contrib/lite/allocation.h4
-rw-r--r--tensorflow/contrib/lite/arena_planner.h6
-rw-r--r--tensorflow/contrib/lite/build_def.bzl21
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h292
-rw-r--r--tensorflow/contrib/lite/c/BUILD39
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data.h298
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data_test.cc83
-rw-r--r--tensorflow/contrib/lite/c/c_api_internal.c (renamed from tensorflow/contrib/lite/context.c)6
-rw-r--r--tensorflow/contrib/lite/c/c_api_internal.h491
-rw-r--r--tensorflow/contrib/lite/c/c_api_internal_test.cc (renamed from tensorflow/contrib/lite/context_test.cc)10
-rw-r--r--tensorflow/contrib/lite/context.h478
-rw-r--r--tensorflow/contrib/lite/context_util.h2
-rw-r--r--tensorflow/contrib/lite/core/api/BUILD57
-rw-r--r--tensorflow/contrib/lite/core/api/error_reporter.cc38
-rw-r--r--tensorflow/contrib/lite/core/api/error_reporter.h45
-rw-r--r--tensorflow/contrib/lite/core/api/error_reporter_test.cc49
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc622
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.h48
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc104
-rw-r--r--tensorflow/contrib/lite/core/api/op_resolver.cc60
-rw-r--r--tensorflow/contrib/lite/core/api/op_resolver.h47
-rw-r--r--tensorflow/contrib/lite/core/api/op_resolver_test.cc197
-rw-r--r--tensorflow/contrib/lite/delegates/eager/BUILD5
-rw-r--r--tensorflow/contrib/lite/delegates/eager/buffer_map.h2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate.h2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/kernel.cc2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/kernel.h2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util.h2
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/BUILD2
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc2
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h2
-rw-r--r--tensorflow/contrib/lite/error_reporter.h38
-rw-r--r--tensorflow/contrib/lite/experimental/c/BUILD1
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api.cc12
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api.h13
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_internal.h6
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_test.cc4
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/BUILD3
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc2
-rw-r--r--tensorflow/contrib/lite/graph_info.h2
-rw-r--r--tensorflow/contrib/lite/interpreter.cc4
-rw-r--r--tensorflow/contrib/lite/interpreter.h5
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc2
-rw-r--r--tensorflow/contrib/lite/java/ovic/BUILD3
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h8
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.h2
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD44
-rw-r--r--tensorflow/contrib/lite/kernels/activation_functor.h2
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/add.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/arg_min_max.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/audio_spectrogram.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/batch_to_space_nd.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/cast.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/concatenation.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/dequantize.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/detection_postprocess.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/div.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/eigen_support.h2
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/exp.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/expand_dims.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/expand_dims_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/fake_quant.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/floor.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/floor_div.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/gather.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/gather_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/gemm_support.h2
-rw-r--r--tensorflow/contrib/lite/kernels/hashtable_lookup.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD46
-rw-r--r--tensorflow/contrib/lite/kernels/internal/common.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h129
-rw-r--r--tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h92
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor.h111
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h135
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.h5
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/local_response_norm.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/logical.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/lsh_projection.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/maximum_minimum.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/mfcc.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/mul.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/neg.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/one_hot.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/pack.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/pad.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/padding.h2
-rw-r--r--tensorflow/contrib/lite/kernels/pooling.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/pow.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/reduce.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/register.h3
-rw-r--r--tensorflow/contrib/lite/kernels/reshape.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/select.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/shape.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/skip_gram.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/slice.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_batch_nd.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_depth.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_to_dense.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/split.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/squeeze.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/sub.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/svdf.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/tile.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/tile_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/topk_v2.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/topk_v2_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/transpose.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/unpack.cc4
-rw-r--r--tensorflow/contrib/lite/memory_planner.h2
-rw-r--r--tensorflow/contrib/lite/mmap_allocation.cc2
-rw-r--r--tensorflow/contrib/lite/model.cc636
-rw-r--r--tensorflow/contrib/lite/model.h5
-rw-r--r--tensorflow/contrib/lite/model_test.cc2
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver.cc (renamed from tensorflow/contrib/lite/op_resolver.cc)3
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver.h79
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver_test.cc (renamed from tensorflow/contrib/lite/op_resolver_test.cc)2
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc4
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.h4
-rw-r--r--tensorflow/contrib/lite/op_resolver.h78
-rw-r--r--tensorflow/contrib/lite/simple_memory_arena.h2
-rw-r--r--tensorflow/contrib/lite/stderr_reporter.cc (renamed from tensorflow/contrib/lite/error_reporter.cc)22
-rw-r--r--tensorflow/contrib/lite/stderr_reporter.h34
-rw-r--r--tensorflow/contrib/lite/string_util.cc2
-rw-r--r--tensorflow/contrib/lite/string_util.h2
-rw-r--r--tensorflow/contrib/lite/string_util_test.cc2
-rw-r--r--tensorflow/contrib/lite/testing/BUILD2
-rw-r--r--tensorflow/contrib/lite/testing/util.h2
-rw-r--r--tensorflow/contrib/lite/toco/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc19
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h5
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD1
-rw-r--r--tensorflow/contrib/lite/tools/make/Makefile108
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights.cc1
-rw-r--r--tensorflow/contrib/lite/tutorials/BUILD20
-rw-r--r--tensorflow/contrib/lite/tutorials/dataset.py122
-rw-r--r--tensorflow/contrib/lite/tutorials/mnist_tflite.py87
-rw-r--r--tensorflow/contrib/lite/util.h2
-rw-r--r--tensorflow/contrib/lite/util_test.cc2
-rw-r--r--tensorflow/contrib/makefile/proto_text_cc_files.txt1
-rw-r--r--tensorflow/contrib/opt/BUILD2
-rw-r--r--tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py91
-rw-r--r--tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py240
-rw-r--r--tensorflow/contrib/quantize/BUILD3
-rw-r--r--tensorflow/contrib/quantize/python/common.py26
-rw-r--r--tensorflow/contrib/quantize/python/common_test.py25
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms.py25
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py5
-rw-r--r--tensorflow/contrib/rnn/BUILD8
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py20
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py6
-rw-r--r--tensorflow/contrib/tpu/__init__.py1
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py10
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py15
-rw-r--r--tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc2
213 files changed, 4072 insertions, 2323 deletions
diff --git a/tensorflow/contrib/autograph/converters/logical_expressions.py b/tensorflow/contrib/autograph/converters/logical_expressions.py
index 16eb1f0e3f..41c3424fa3 100644
--- a/tensorflow/contrib/autograph/converters/logical_expressions.py
+++ b/tensorflow/contrib/autograph/converters/logical_expressions.py
@@ -57,8 +57,8 @@ class LogicalExpressionTransformer(converter.Base):
gast.NotEq: 'tf.not_equal',
gast.Or: 'tf.logical_or',
gast.USub: 'tf.negative',
- gast.Is: 'autograph_utils.dynamic_is',
- gast.IsNot: 'autograph_utils.dynamic_is_not'
+ gast.Is: 'ag__.utils.dynamic_is',
+ gast.IsNot: 'ag__.utils.dynamic_is_not'
}
def _expect_simple_symbol(self, operand):
diff --git a/tensorflow/contrib/autograph/converters/logical_expressions_test.py b/tensorflow/contrib/autograph/converters/logical_expressions_test.py
index 8f9eee7081..409a73afba 100644
--- a/tensorflow/contrib/autograph/converters/logical_expressions_test.py
+++ b/tensorflow/contrib/autograph/converters/logical_expressions_test.py
@@ -47,6 +47,15 @@ class GradientsFunctionTest(converter_testing.TestCase):
with self.cached_session() as sess:
self.assertTrue(sess.run(result.test_fn(True, False, True)))
+ def test_ag_utils_lookup(self):
+ def test_fn(a, b):
+ return a is b or a is not b
+
+ with self.converted(test_fn, logical_expressions, {}, math_ops.logical_or
+ ) as result:
+ with self.cached_session() as sess:
+ self.assertTrue(sess.run(result.test_fn(True, False)))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py
index 803fde9089..a4c6fed265 100644
--- a/tensorflow/contrib/autograph/impl/api_test.py
+++ b/tensorflow/contrib/autograph/impl/api_test.py
@@ -38,9 +38,6 @@ class ApiTest(test.TestCase):
def setUp(self):
config.COMPILED_IMPORT_STATEMENTS = (
'from __future__ import print_function',
- 'from tensorflow.contrib.autograph import utils'
- ' as autograph_utils',
- 'tf = autograph_utils.fake_tf()',
)
def test_decorator_recurses(self):
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
index e42f679cfe..d77c15915b 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
@@ -394,10 +394,16 @@ class AnfTransformer(transformer.Base):
# just recur.
def visit_List(self, node):
- return self._visit_strict_expression(node)
+ node = self.generic_visit(node)
+ if not isinstance(node.ctx, gast.Store):
+ self._ensure_fields_trivial(node)
+ return node
def visit_Tuple(self, node):
- return self._visit_strict_expression(node)
+ node = self.generic_visit(node)
+ if not isinstance(node.ctx, gast.Store):
+ self._ensure_fields_trivial(node)
+ return node
def transform(node, entity_info, gensym_source=None):
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
index 951974820c..1ffd4bbe55 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
@@ -165,6 +165,46 @@ class AnfTransformerTest(test.TestCase):
self.assert_body_anfs_as_expected(expected_result, test_function)
+ def test_nested_multi_value_assign(self):
+
+ def test_function(a, b, c):
+ x, y = a, a + b
+ (z, y), x = (c, y + b), x + a
+ return z, (y, x)
+
+ def expected_result(a, b, c):
+ tmp_1001 = a + b
+ x, y = a, tmp_1001
+ tmp_1002 = y + b
+ tmp_1003 = (c, tmp_1002)
+ tmp_1004 = x + a
+ (z, y), x = tmp_1003, tmp_1004
+ tmp_1005 = y, x
+ tmp_1006 = z, tmp_1005
+ return tmp_1006
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_deeply_nested_multi_value_assign(self):
+
+ def test_function(a):
+ [([(b, c), [d, e]], (f, g)), [(h, i, j), k]] = a
+ return [([(b, c), [d, e]], (f, g)), [(h, i, j), k]]
+
+ def expected_result(a):
+ [([(b, c), [d, e]], (f, g)), [(h, i, j), k]] = a
+ tmp_1001 = b, c
+ tmp_1002 = [d, e]
+ tmp_1003 = [tmp_1001, tmp_1002]
+ tmp_1004 = f, g
+ tmp_1005 = h, i, j
+ tmp_1006 = tmp_1003, tmp_1004
+ tmp_1007 = [tmp_1005, k]
+ tmp_1008 = [tmp_1006, tmp_1007]
+ return tmp_1008
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
def test_local_definition_and_binary_compare(self):
def test_function():
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
index 2d8f922a45..e7baa244b2 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
@@ -29,6 +29,11 @@ from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+# TODO(aqj): Do we need this? Do other builtins fail in similar ways
+# See b/114389775 for a related bug in pyct
+# These symbols are legal in Python, but don't appear in the namespace.
+_special_symbols = {'range': range}
+
class LiveValueResolver(transformer.Base):
"""Annotates nodes with live values."""
@@ -66,6 +71,8 @@ class LiveValueResolver(transformer.Base):
# If the symbol value is for example a primitive, then it will not
# have a name.
pass
+ elif node.id in _special_symbols:
+ anno.setanno(node, 'live_val', _special_symbols[node.id])
else:
pass
# TODO(mdan): Should we raise an error here?
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index 870ce2442b..4c7a538b38 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -52,7 +52,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
center_bias=True,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeClassifier estimator instance.
Args:
@@ -94,6 +95,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
Raises:
ValueError: If learner_config is not valid.
@@ -134,7 +136,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
'logits_modifier_function': logits_modifier_function,
'use_core_libs': use_core_libs,
'output_leaf_index': output_leaf_index,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -159,7 +162,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
center_bias=True,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeRegressor estimator instance.
Args:
@@ -201,6 +205,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
"""
head = head_lib.regression_head(
label_name=label_name,
@@ -224,7 +229,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
'center_bias': center_bias,
'use_core_libs': use_core_libs,
'output_leaf_index': False,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -251,7 +257,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
center_bias=True,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeEstimator estimator instance.
Args:
@@ -289,6 +296,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
"""
super(GradientBoostedDecisionTreeEstimator, self).__init__(
model_fn=model.model_builder,
@@ -303,7 +311,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
'center_bias': center_bias,
'use_core_libs': use_core_libs,
'output_leaf_index': False,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -329,7 +338,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
center_bias=False,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeRanker instance.
This is an estimator that can be trained off the pairwise data and can be
@@ -377,6 +387,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
+
Raises:
ValueError: If learner_config is not valid.
"""
@@ -395,7 +407,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
'use_core_libs': use_core_libs,
'output_leaf_index': output_leaf_index,
'ranking_model_pair_keys': ranking_model_pair_keys,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -444,7 +457,8 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
feature_engineering_fn=None,
logits_modifier_function=None,
center_bias=True,
- output_leaf_index=False):
+ output_leaf_index=False,
+ num_quantiles=100):
"""Initializes a core version of GradientBoostedDecisionTreeEstimator.
Args:
@@ -474,6 +488,7 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
for example_prediction_result in result_dict:
# access leaf index list by example_prediction_result["leaf_index"]
# which contains one leaf index per tree
+ num_quantiles: Number of quantiles to build for numeric feature values.
"""
def _model_fn(features, labels, mode, config):
@@ -493,7 +508,8 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
'logits_modifier_function': logits_modifier_function,
'use_core_libs': True,
'output_leaf_index': output_leaf_index,
- 'override_global_step_value': None
+ 'override_global_step_value': None,
+ 'num_quantiles': num_quantiles,
},
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
@@ -517,7 +533,8 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
label_keys=None,
logits_modifier_function=None,
center_bias=False,
- output_leaf_index=False):
+ output_leaf_index=False,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeRanker instance.
This is an estimator that can be trained off the pairwise data and can be
@@ -552,6 +569,7 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
for result_dict in result_iter:
# access leaf index list by result_dict["leaf_index"]
# which contains one leaf index per tree
+ num_quantiles: Number of quantiles to build for numeric feature values.
Raises:
ValueError: If learner_config is not valid.
@@ -576,7 +594,8 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
'use_core_libs': True,
'output_leaf_index': output_leaf_index,
'ranking_model_pair_keys': ranking_model_pair_keys,
- 'override_global_step_value': None
+ 'override_global_step_value': None,
+ 'num_quantiles': num_quantiles,
},
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index 04b46c3483..a6e422847d 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -81,6 +81,7 @@ def model_builder(features,
logits_modifier_function = params["logits_modifier_function"]
output_leaf_index = params["output_leaf_index"]
override_global_step_value = params.get("override_global_step_value", None)
+ num_quantiles = params["num_quantiles"]
if features is None:
raise ValueError("At least one feature must be specified.")
@@ -116,7 +117,8 @@ def model_builder(features,
logits_dimension=head.logits_dimension,
features=training_features,
use_core_columns=use_core_libs,
- output_leaf_index=output_leaf_index)
+ output_leaf_index=output_leaf_index,
+ num_quantiles=num_quantiles)
with ops.name_scope("gbdt", "gbdt_optimizer"):
predictions_dict = gbdt_model.predict(mode)
logits = predictions_dict["predictions"]
@@ -237,6 +239,7 @@ def ranking_model_builder(features,
output_leaf_index = params["output_leaf_index"]
ranking_model_pair_keys = params["ranking_model_pair_keys"]
override_global_step_value = params.get("override_global_step_value", None)
+ num_quantiles = params["num_quantiles"]
if features is None:
raise ValueError("At least one feature must be specified.")
@@ -299,7 +302,8 @@ def ranking_model_builder(features,
logits_dimension=head.logits_dimension,
features=main_features,
use_core_columns=use_core_libs,
- output_leaf_index=output_leaf_index)
+ output_leaf_index=output_leaf_index,
+ num_quantiles=num_quantiles)
with ops.name_scope("gbdt", "gbdt_optimizer"):
# Logits for inference.
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index b008c6e534..c7eb2493a8 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -304,7 +304,8 @@ class GradientBoostedDecisionTreeModel(object):
feature_columns=None,
use_core_columns=False,
output_leaf_index=False,
- output_leaf_index_modes=None):
+ output_leaf_index_modes=None,
+ num_quantiles=100):
"""Construct a new GradientBoostedDecisionTreeModel function.
Args:
@@ -327,6 +328,7 @@ class GradientBoostedDecisionTreeModel(object):
output_leaf_index_modes: A list of modes from (TRAIN, EVAL, INFER) which
dictates when leaf indices will be outputted. By default, leaf indices
are only outputted in INFER mode.
+ num_quantiles: Number of quantiles to build for numeric feature values.
Raises:
ValueError: if inputs are not valid.
@@ -399,6 +401,7 @@ class GradientBoostedDecisionTreeModel(object):
self._learner_config = learner_config
self._feature_columns = feature_columns
self._learner_config_serialized = learner_config.SerializeToString()
+ self._num_quantiles = num_quantiles
self._max_tree_depth = variables.Variable(
initial_value=self._learner_config.constraints.max_tree_depth)
self._attempted_trees = variables.Variable(
@@ -689,8 +692,8 @@ class GradientBoostedDecisionTreeModel(object):
loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction)
weak_learner_type = constant_op.constant(
self._learner_config.weak_learner_type)
- epsilon = 0.01
- num_quantiles = 100
+ num_quantiles = self._num_quantiles
+ epsilon = 1.0 / num_quantiles
strategy_tensor = constant_op.constant(strategy)
with ops.device(self._get_replica_device_setter(worker_device)):
# Create handlers for dense float columns
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index 1ab150d74a..1056894f18 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -229,6 +229,10 @@ class TPUClusterResolver(ClusterResolver):
def get_master(self):
return self.master()
+ def get_job_name(self):
+ if self._shouldResolve():
+ return self._job_name
+
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest TPU information.
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 34f594f741..b9320e5fef 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -279,7 +279,9 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 9d8e955245..67242fecfe 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -428,10 +428,10 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list())
@parameterized.named_parameters(
- ("default", None, None),
- ("sequential_calls", 1, None),
- ("parallel_calls", 2, None),
- ("parallel_batches", None, 10),
+ ("Default", None, None),
+ ("SequentialCalls", 1, None),
+ ("ParallelCalls", 2, None),
+ ("ParallelBatches", None, 10),
)
def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
"""Test a dataset that maps a TF function across its input elements."""
@@ -505,8 +505,8 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
sess.run(init_op, feed_dict={count: 14, batch_size: 0})
@parameterized.named_parameters(
- ("even", False),
- ("uneven", True),
+ ("Even", False),
+ ("Uneven", True),
)
def testMapAndBatchPartialBatch(self, drop_remainder):
iterator = (
@@ -663,7 +663,14 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
for _ in range(3):
sess.run(get_next)
- @parameterized.parameters(0, 5, 10, 90, 95, 99)
+ @parameterized.named_parameters(
+ ("1", 0),
+ ("2", 5),
+ ("3", 10),
+ ("4", 90),
+ ("5", 95),
+ ("6", 99),
+ )
def testMapAndBatchOutOfRangeError(self, threshold):
def raising_py_fn(i):
@@ -689,18 +696,18 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- @parameterized.parameters(
- (False, dtypes.bool),
- (-42, dtypes.int8),
- (-42, dtypes.int16),
- (-42, dtypes.int32),
- (-42, dtypes.int64),
- (42, dtypes.uint8),
- (42, dtypes.uint16),
- (42.0, dtypes.float16),
- (42.0, dtypes.float32),
- (42.0, dtypes.float64),
- (b"hello", dtypes.string),
+ @parameterized.named_parameters(
+ ("1", False, dtypes.bool),
+ ("2", -42, dtypes.int8),
+ ("3", -42, dtypes.int16),
+ ("4", -42, dtypes.int32),
+ ("5", -42, dtypes.int64),
+ ("6", 42, dtypes.uint8),
+ ("7", 42, dtypes.uint16),
+ ("8", 42.0, dtypes.float16),
+ ("9", 42.0, dtypes.float32),
+ ("10", 42.0, dtypes.float64),
+ ("11", b"hello", dtypes.string),
)
def testMapAndBatchTypes(self, element, dtype):
def gen():
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
index 091eb5ce37..61567bc8d7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
@@ -17,7 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import time
+
from tensorflow.contrib.data.python.ops import map_defun
+from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -25,10 +28,10 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-
class MapDefunTest(test.TestCase):
def testMapDefunSimple(self):
@@ -146,6 +149,105 @@ class MapDefunTest(test.TestCase):
r"indices = 10 is not in \[0, 5\)"):
self.evaluate(map_defun_op)
+ def testMapDefunWithUnspecifiedOutputShape(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ res = x * 2 + 3
+ return (res, res + 1, res + 2)
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems],
+ [dtypes.int32, dtypes.int32, dtypes.int32],
+ [None, (None,), (2,)])
+ expected = elems * 2 + 3
+ self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected))
+ self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1))
+ self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
+
+ def testMapDefunWithDifferentOutputShapeEachRun(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2 + 3
+
+ elems = array_ops.placeholder(dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0]
+ with session.Session() as sess:
+ self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3])
+ self.assertAllEqual(
+ sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
+
+ def testMapDefunWithWrongOutputShape(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2 + 3
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0]
+ with self.assertRaises(errors.InvalidArgumentError):
+ self.evaluate(r)
+
+ def testMapDefunWithInvalidInput(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2
+
+ c = constant_op.constant(2)
+ with self.assertRaises(ValueError):
+ # Fails at graph construction time for inputs with known shapes.
+ r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0]
+ p = array_ops.placeholder(dtypes.int32)
+ r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0]
+ with session.Session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(r, feed_dict={p: 0})
+
+
+class MapDefunBenchmark(test.Benchmark):
+
+ def _run(self, op, name=None, num_iters=3000):
+ with session.Session() as sess:
+ # Warm up the session
+ for _ in range(5):
+ sess.run(op)
+ start = time.time()
+ for _ in range(num_iters):
+ sess.run(op)
+ end = time.time()
+ mean_us = (end - start) * 1e6 / num_iters
+ self.report_benchmark(
+ name=name,
+ iters=num_iters,
+ wall_time=mean_us,
+ extras={"examples_per_sec": num_iters / (end - start)})
+
+ def benchmarkDefunVsMapFn(self):
+ """Benchmarks to compare the performance of MapDefun vs tf.map_fn."""
+
+ @function.Defun(dtypes.int32)
+ def defun(x):
+ return array_ops.identity(x)
+
+ def map_fn(x):
+ return array_ops.identity(x)
+
+ base = math_ops.range(100)
+ for input_size in [10, 100, 1000, 10000]:
+ num_iters = 100000 // input_size
+ map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()])
+ map_fn_op = functional_ops.map_fn(map_fn, base)
+
+ self._run(
+ map_defun_op,
+ "benchmarkMapDefun_size_%d" % input_size,
+ num_iters=num_iters)
+ self._run(
+ map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
index 586b4bee5f..6a7ef877f9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -44,22 +44,22 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
for i, fun1 in enumerate(functions):
for j, fun2 in enumerate(functions):
tests.append((
- "test_{}_{}".format(i, j),
+ "Test{}{}".format(i, j),
[fun1, fun2],
))
for k, fun3 in enumerate(functions):
tests.append((
- "test_{}_{}_{}".format(i, j, k),
+ "Test{}{}{}".format(i, j, k),
[fun1, fun2, fun3],
))
swap = lambda x, n: (n, x)
tests.append((
- "swap1",
+ "Swap1",
[lambda x: (x, 42), swap],
))
tests.append((
- "swap2",
+ "Swap2",
[lambda x: (x, 42), swap, swap],
))
return tuple(tests)
@@ -109,13 +109,13 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
for x, fun in enumerate(functions):
for y, predicate in enumerate(filters):
- tests.append(("mixed_{}_{}".format(x, y), fun, predicate))
+ tests.append(("Mixed{}{}".format(x, y), fun, predicate))
# Multi output
- tests.append(("multiOne", lambda x: (x, x),
+ tests.append(("Multi1", lambda x: (x, x),
lambda x, y: constant_op.constant(True)))
tests.append(
- ("multiTwo", lambda x: (x, 2),
+ ("Multi2", lambda x: (x, 2),
lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)))
return tuple(tests)
@@ -172,17 +172,17 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
identity = lambda x: x
for x, predicate_1 in enumerate(filters):
for y, predicate_2 in enumerate(filters):
- tests.append(("mixed_{}_{}".format(x, y), identity,
+ tests.append(("Mixed{}{}".format(x, y), identity,
[predicate_1, predicate_2]))
for z, predicate_3 in enumerate(filters):
- tests.append(("mixed_{}_{}_{}".format(x, y, z), identity,
+ tests.append(("Mixed{}{}{}".format(x, y, z), identity,
[predicate_1, predicate_2, predicate_3]))
take_all_multiple = lambda x, y: constant_op.constant(True)
# Multi output
- tests.append(("multiOne", lambda x: (x, x),
+ tests.append(("Multi1", lambda x: (x, x),
[take_all_multiple, take_all_multiple]))
- tests.append(("multiTwo", lambda x: (x, 2), [
+ tests.append(("Multi2", lambda x: (x, 2), [
take_all_multiple,
lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
]))
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
index 4881f63ab9..aa89674c6e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
@@ -210,6 +210,7 @@ py_test(
"//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
index ac3892fe81..243f6405a1 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
@@ -27,42 +28,38 @@ from tensorflow.python.platform import test
class InterleaveDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
+ dataset_serialization_test_base.DatasetSerializationTestBase,
+ parameterized.TestCase):
- def _build_iterator_graph(self, input_values, cycle_length, block_length):
+ def _build_iterator_graph(self, input_values, cycle_length, block_length,
+ num_parallel_calls):
repeat_count = 2
return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
repeat_count).interleave(
lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
- cycle_length, block_length)
+ cycle_length, block_length, num_parallel_calls)
- def testSerializationCore(self):
+ @parameterized.named_parameters(
+ ("1", 2, 3, None),
+ ("2", 2, 3, 1),
+ ("3", 2, 3, 2),
+ ("4", 1, 3, None),
+ ("5", 1, 3, 1),
+ ("6", 2, 1, None),
+ ("7", 2, 1, 1),
+ ("8", 2, 1, 2),
+ )
+ def testSerializationCore(self, cycle_length, block_length,
+ num_parallel_calls):
input_values = np.array([4, 5, 6], dtype=np.int64)
num_outputs = np.sum(input_values) * 2
- # cycle_length > 1, block_length > 1
- cycle_length = 2
- block_length = 3
# pylint: disable=g-long-lambda
self.run_core_tests(
lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
+ input_values, cycle_length, block_length, num_parallel_calls),
lambda: self._build_iterator_graph(
- input_values, cycle_length * 2, block_length * 1),
+ input_values, cycle_length * 2, block_length, num_parallel_calls),
num_outputs)
- # cycle_length = 1
- cycle_length = 1
- block_length = 3
- self.run_core_tests(
- lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
- None, num_outputs)
- # block_length = 1
- cycle_length = 2
- block_length = 1
- self.run_core_tests(
- lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
- None, num_outputs)
# pylint: enable=g-long-lambda
def testSparseCore(self):
@@ -82,5 +79,5 @@ class InterleaveDatasetSerializationTest(
self.run_core_tests(_build_dataset, None, 20)
-if __name__ == '__main__':
+if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
index 8b2f846494..6b3e8e9f6e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
@@ -32,18 +32,18 @@ from tensorflow.python.platform import test
class SlideDatasetTest(test.TestCase, parameterized.TestCase):
- @parameterized.parameters(
- (20, 14, 7, 1),
- (20, 17, 9, 1),
- (20, 14, 14, 1),
- (20, 10, 14, 1),
- (20, 14, 19, 1),
- (20, 4, 1, 2),
- (20, 2, 1, 6),
- (20, 4, 7, 2),
- (20, 2, 7, 6),
- (1, 10, 4, 1),
- (0, 10, 4, 1),
+ @parameterized.named_parameters(
+ ("1", 20, 14, 7, 1),
+ ("2", 20, 17, 9, 1),
+ ("3", 20, 14, 14, 1),
+ ("4", 20, 10, 14, 1),
+ ("5", 20, 14, 19, 1),
+ ("6", 20, 4, 1, 2),
+ ("7", 20, 2, 1, 6),
+ ("8", 20, 4, 7, 2),
+ ("9", 20, 2, 7, 6),
+ ("10", 1, 10, 4, 1),
+ ("11", 0, 10, 4, 1),
)
def testSlideDataset(self, count, window_size, window_shift, window_stride):
"""Tests a dataset that slides a window its input elements."""
@@ -96,18 +96,18 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- @parameterized.parameters(
- (20, 14, 7, 1),
- (20, 17, 9, 1),
- (20, 14, 14, 1),
- (20, 10, 14, 1),
- (20, 14, 19, 1),
- (20, 4, 1, 2),
- (20, 2, 1, 6),
- (20, 4, 7, 2),
- (20, 2, 7, 6),
- (1, 10, 4, 1),
- (0, 10, 4, 1),
+ @parameterized.named_parameters(
+ ("1", 20, 14, 7, 1),
+ ("2", 20, 17, 9, 1),
+ ("3", 20, 14, 14, 1),
+ ("4", 20, 10, 14, 1),
+ ("5", 20, 14, 19, 1),
+ ("6", 20, 4, 1, 2),
+ ("7", 20, 2, 1, 6),
+ ("8", 20, 4, 7, 2),
+ ("9", 20, 2, 7, 6),
+ ("10", 1, 10, 4, 1),
+ ("11", 0, 10, 4, 1),
)
def testSlideDatasetDeprecated(self, count, window_size, stride,
window_stride):
@@ -160,10 +160,10 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- @parameterized.parameters(
- (14, 0, 3, 1),
- (14, 3, 0, 1),
- (14, 3, 3, 0),
+ @parameterized.named_parameters(
+ ("1", 14, 0, 3, 1),
+ ("2", 14, 3, 0, 1),
+ ("3", 14, 3, 3, 0),
)
def testSlideDatasetInvalid(self, count, window_size, window_shift,
window_stride):
diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
index 0486e2bce2..4b08ec759d 100644
--- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
@@ -33,8 +33,17 @@ from tensorflow.python.platform import test
class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase):
- @parameterized.parameters((1, None), (2, None), (4, None), (8, None),
- (16, None), (4, -1), (4, 0), (4, 1), (4, 4))
+ @parameterized.named_parameters(
+ ("1", 1, None),
+ ("2", 2, None),
+ ("3", 4, None),
+ ("4", 8, None),
+ ("5", 16, None),
+ ("6", 4, -1),
+ ("7", 4, 0),
+ ("8", 4, 1),
+ ("9", 4, 4),
+ )
def testNumThreads(self, num_threads, max_intra_op_parallelism):
def get_thread_id(_):
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
index 33d95d6754..ff4d9b3260 100644
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
@@ -64,15 +64,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
else:
self.assertEqual(xs, ys)
- @parameterized.parameters(
- (None, np.int32([]), dtypes.bool),
- (None, np.int32([]), dtypes.int32),
- (None, np.int32([]), dtypes.float32),
- (None, np.int32([]), dtypes.string),
- (None, np.int32([2]), dtypes.int32),
- (None, np.int32([2, 2]), dtypes.int32),
- ((None, None, None), np.int32([]), dtypes.int32),
- ((None, (None, None)), np.int32([]), dtypes.int32),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([]), dtypes.bool),
+ ("2", None, np.int32([]), dtypes.int32),
+ ("3", None, np.int32([]), dtypes.float32),
+ ("4", None, np.int32([]), dtypes.string),
+ ("5", None, np.int32([2]), dtypes.int32),
+ ("6", None, np.int32([2, 2]), dtypes.int32),
+ ("7", (None, None, None), np.int32([]), dtypes.int32),
+ ("8", (None, (None, None)), np.int32([]), dtypes.int32),
)
def testWindowDatasetFlatMap(self, structure, shape, dtype):
"""Tests windowing by chaining it with flat map.
@@ -97,15 +97,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (None, np.int32([]), dtypes.bool),
- (None, np.int32([]), dtypes.int32),
- (None, np.int32([]), dtypes.float32),
- (None, np.int32([]), dtypes.string),
- (None, np.int32([2]), dtypes.int32),
- (None, np.int32([2, 2]), dtypes.int32),
- ((None, None, None), np.int32([]), dtypes.int32),
- ((None, (None, None)), np.int32([]), dtypes.int32),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([]), dtypes.bool),
+ ("2", None, np.int32([]), dtypes.int32),
+ ("3", None, np.int32([]), dtypes.float32),
+ ("4", None, np.int32([]), dtypes.string),
+ ("5", None, np.int32([2]), dtypes.int32),
+ ("6", None, np.int32([2, 2]), dtypes.int32),
+ ("7", (None, None, None), np.int32([]), dtypes.int32),
+ ("8", (None, (None, None)), np.int32([]), dtypes.int32),
)
def testWindowDatasetBatchDense(self, structure, shape, dtype):
"""Tests batching of dense tensor windows.
@@ -135,10 +135,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([]),),
- (np.int32([1]),),
- (np.int32([1, 2, 3]),),
+ @parameterized.named_parameters(
+ ("1", np.int32([])),
+ ("2", np.int32([1])),
+ ("3", np.int32([1, 2, 3])),
)
def testWindowDatasetBatchDenseDynamicShape(self, shape):
"""Tests batching of dynamically shaped dense tensor windows.
@@ -203,15 +203,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
for substructure in structure
])
- @parameterized.parameters(
- (None, np.int32([]), dtypes.bool),
- (None, np.int32([]), dtypes.int32),
- (None, np.int32([]), dtypes.float32),
- (None, np.int32([]), dtypes.string),
- (None, np.int32([2]), dtypes.int32),
- (None, np.int32([2, 2]), dtypes.int32),
- ((None, None, None), np.int32([]), dtypes.int32),
- ((None, (None, None)), np.int32([]), dtypes.int32),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([]), dtypes.bool),
+ ("2", None, np.int32([]), dtypes.int32),
+ ("3", None, np.int32([]), dtypes.float32),
+ ("4", None, np.int32([]), dtypes.string),
+ ("5", None, np.int32([2]), dtypes.int32),
+ ("6", None, np.int32([2, 2]), dtypes.int32),
+ ("7", (None, None, None), np.int32([]), dtypes.int32),
+ ("8", (None, (None, None)), np.int32([]), dtypes.int32),
)
def testWindowDatasetBatchSparse(self, structure, shape, dtype):
"""Tests batching of sparse tensor windows.
@@ -243,10 +243,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([]),),
- (np.int32([1]),),
- (np.int32([1, 2, 3]),),
+ @parameterized.named_parameters(
+ ("1", np.int32([])),
+ ("2", np.int32([1])),
+ ("3", np.int32([1, 2, 3])),
)
def testWindowDatasetBatchSparseDynamicShape(self, shape):
"""Tests batching of dynamically shaped sparse tensor windows.
@@ -284,17 +284,18 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
for substructure in structure
]))
- @parameterized.parameters(
- (None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.string, [-1]),
- (None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
- (None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]),
- ((None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ((None, (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]),
+ ("2", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("3", None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]),
+ ("4", None, np.int32([[1], [2], [3]]), dtypes.string, [-1]),
+ ("5", None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
+ ("6", None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]),
+ ("7", (None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("8", (None,
+ (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("9", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("10", None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])),
)
def testWindowDatasetPaddedBatchDense(self, structure, shapes, dtype,
padded_shape):
@@ -329,10 +330,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([[1], [2], [3]]), [-1]),
- (np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
- (np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
+ @parameterized.named_parameters(
+ ("1", np.int32([[1], [2], [3]]), [-1]),
+ ("2", np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
+ ("3", np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
)
def testWindowDatasetPaddedBatchDenseDynamicShape(self, shapes, padded_shape):
"""Tests padded batching of dynamically shaped dense tensor windows.
@@ -361,9 +362,9 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([[1]]), np.int32([0])),
- (np.int32([[10], [20]]), np.int32([15])),
+ @parameterized.named_parameters(
+ ("1", np.int32([[1]]), np.int32([0])),
+ ("2", np.int32([[10], [20]]), np.int32([15])),
)
def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape):
"""Tests invalid padded batching of dense tensor windows.
@@ -420,17 +421,18 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
for substructure in structure
])
- @parameterized.parameters(
- (None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.string, [-1]),
- (None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
- (None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]),
- ((None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ((None, (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])),
+ @parameterized.named_parameters(
+ ("1", None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]),
+ ("2", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("3", None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]),
+ ("4", None, np.int64([[1], [2], [3]]), dtypes.string, [-1]),
+ ("5", None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
+ ("6", None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]),
+ ("7", (None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("8", (None,
+ (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("9", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("10", None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])),
)
def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype,
padded_shape):
@@ -463,10 +465,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int64([[1], [2], [3]]), [-1]),
- (np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
- (np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
+ @parameterized.named_parameters(
+ ("1", np.int64([[1], [2], [3]]), [-1]),
+ ("2", np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
+ ("3", np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
)
def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes,
padded_shape):
@@ -495,9 +497,9 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int64([[1]]), [0]),
- (np.int64([[10], [20]]), [15]),
+ @parameterized.named_parameters(
+ ("1", np.int64([[1]]), [0]),
+ ("2", np.int64([[10], [20]]), [15]),
)
def testWindowDatasetPaddedBatchSparseInvalid(self, shapes, padded_shape):
"""Tests invalid padded batching of sparse tensor windows.
diff --git a/tensorflow/contrib/data/python/ops/map_defun.py b/tensorflow/contrib/data/python/ops/map_defun.py
index 54d5cd6da0..3d0d0993c9 100644
--- a/tensorflow/contrib/data/python/ops/map_defun.py
+++ b/tensorflow/contrib/data/python/ops/map_defun.py
@@ -53,6 +53,4 @@ def map_defun(fn, elems, output_dtypes, output_shapes):
elems = [ops.convert_to_tensor(e) for e in elems]
output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes]
- if not all(s.is_fully_defined() for s in output_shapes):
- raise ValueError("All fn output shapes must be fully defined.")
return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn)
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index d39fd57294..3cee3e37a7 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -446,8 +446,7 @@ class TestWithDistributionStrategy(test.TestCase):
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
- with self.assertRaisesRegexp(ValueError,
- 'expected input to have 2 dimensions'):
+ with self.assertRaisesRegexp(ValueError, 'expected input to have shape'):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
# Wrong input shape
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 4fb70ec685..6ba83976fc 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -310,7 +310,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
def get_host_cpu_device(self, host_id):
if self._tpu_cluster_resolver.get_master() in ('', 'local'):
return '/replica:0/task:0/device:CPU:0'
- return '/job:tpu_worker/task:%d/device:CPU:0' % (host_id,)
+ job_name = self._tpu_cluster_resolver.get_job_name() or 'tpu_worker'
+ return '/job:%s/task:%d/device:CPU:0' % (job_name, host_id)
def configure(self,
session_config=None,
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 77f62df99d..437b3d965d 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -446,6 +446,7 @@ py_library(
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:head",
"//tensorflow/python/estimator:optimizers",
+ "//tensorflow/python/ops/losses",
"@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py
index 7c49cd00d1..98660bb731 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn.py
@@ -37,6 +37,7 @@ from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import training_util
@@ -405,6 +406,7 @@ class RNNClassifier(estimator.Estimator):
weight_column=None,
label_vocabulary=None,
optimizer='Adagrad',
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
input_layer_partitioner=None,
config=None):
"""Initializes a `RNNClassifier` instance.
@@ -454,6 +456,8 @@ class RNNClassifier(estimator.Estimator):
string.
optimizer: An instance of `tf.Optimizer` or string specifying optimizer
type. Defaults to Adagrad optimizer.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
+ to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.
input_layer_partitioner: Optional. Partitioner for input layer. Defaults
to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
config: `RunConfig` object to configure the runtime settings.
@@ -467,11 +471,15 @@ class RNNClassifier(estimator.Estimator):
if n_classes == 2:
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
weight_column=weight_column,
- label_vocabulary=label_vocabulary)
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
else:
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
- n_classes, weight_column=weight_column,
- label_vocabulary=label_vocabulary)
+ n_classes,
+ weight_column=weight_column,
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
+
def _model_fn(features, labels, mode, config):
return _rnn_model_fn(
features=features,
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
index 959b40371a..1aebed348d 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
@@ -713,7 +713,7 @@ class RNNClassifierTrainingTest(test.TestCase):
# Uses same checkpoint and examples as testBinaryClassEvaluationMetrics.
# See that test for loss calculation.
- mock_optimizer = self._mock_optimizer(expected_loss=1.119661)
+ mock_optimizer = self._mock_optimizer(expected_loss=0.559831)
sequence_feature_columns = [
seq_fc.sequence_numeric_column('price', shape=(1,))]
@@ -748,7 +748,7 @@ class RNNClassifierTrainingTest(test.TestCase):
# Uses same checkpoint and examples as testMultiClassEvaluationMetrics.
# See that test for loss calculation.
- mock_optimizer = self._mock_optimizer(expected_loss=2.662932)
+ mock_optimizer = self._mock_optimizer(expected_loss=1.331465)
sequence_feature_columns = [
seq_fc.sequence_numeric_column('price', shape=(1,))]
@@ -812,20 +812,32 @@ class RNNClassifierEvaluationTest(test.TestCase):
# probability = exp(logits) / (1 + exp(logits)) = [[0.353593], [0.504930]]
# loss = -label * ln(p) - (1 - label) * ln(1 - p)
# = [[0.436326], [0.683335]]
+ # sum_over_batch_size = (0.436326 + 0.683335)/2
expected_metrics = {
- ops.GraphKeys.GLOBAL_STEP: global_step,
- metric_keys.MetricKeys.LOSS: 1.119661,
- metric_keys.MetricKeys.LOSS_MEAN: 0.559831,
- metric_keys.MetricKeys.ACCURACY: 1.0,
- metric_keys.MetricKeys.PREDICTION_MEAN: 0.429262,
- metric_keys.MetricKeys.LABEL_MEAN: 0.5,
- metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,
+ ops.GraphKeys.GLOBAL_STEP:
+ global_step,
+ metric_keys.MetricKeys.LOSS:
+ 0.559831,
+ metric_keys.MetricKeys.LOSS_MEAN:
+ 0.559831,
+ metric_keys.MetricKeys.ACCURACY:
+ 1.0,
+ metric_keys.MetricKeys.PREDICTION_MEAN:
+ 0.429262,
+ metric_keys.MetricKeys.LABEL_MEAN:
+ 0.5,
+ metric_keys.MetricKeys.ACCURACY_BASELINE:
+ 0.5,
# With default threshold of 0.5, the model is a perfect classifier.
- metric_keys.MetricKeys.RECALL: 1.0,
- metric_keys.MetricKeys.PRECISION: 1.0,
+ metric_keys.MetricKeys.RECALL:
+ 1.0,
+ metric_keys.MetricKeys.PRECISION:
+ 1.0,
# Positive example is scored above negative, so AUC = 1.0.
- metric_keys.MetricKeys.AUC: 1.0,
- metric_keys.MetricKeys.AUC_PR: 1.0,
+ metric_keys.MetricKeys.AUC:
+ 1.0,
+ metric_keys.MetricKeys.AUC_PR:
+ 1.0,
}
self.assertAllClose(
sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics))
@@ -871,9 +883,10 @@ class RNNClassifierEvaluationTest(test.TestCase):
# [0.059494, 0.572639, 0.367866]]
# loss = -1. * log(softmax[label])
# = [[2.105432], [0.557500]]
+ # sum_over_batch_size = (2.105432 + 0.557500)/2
expected_metrics = {
ops.GraphKeys.GLOBAL_STEP: global_step,
- metric_keys.MetricKeys.LOSS: 2.662932,
+ metric_keys.MetricKeys.LOSS: 1.331465,
metric_keys.MetricKeys.LOSS_MEAN: 1.331466,
metric_keys.MetricKeys.ACCURACY: 0.5,
}
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index 0ccb4583ab..716bb87e38 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -174,7 +174,7 @@ class FusedConv2DBiasActivationOp : public OpKernel {
// Input bias is a 1-D tensor, with size matching output depth.
const Tensor& bias = context->input(kBias);
- OP_REQUIRES_OK(context, CheckShape(bias, "conv_input"));
+ OP_REQUIRES_OK(context, CheckShape(bias, "bias"));
const Tensor& conv_input_scale_tensor = context->input(kConvInputScale);
const Tensor& side_input_scale_tensor = context->input(kSideInputScale);
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 418b0cf392..61185f65a9 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -403,6 +403,7 @@ py_test(
srcs = ["python/learn/estimators/dnn_test.py"],
shard_count = 4,
srcs_version = "PY2AND3",
+ tags = ["notap"],
deps = [
":learn",
"//tensorflow/contrib/layers:layers_py",
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index 0091587bf7..f320b53d94 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -36,10 +36,10 @@ cc_library(
srcs = ["arena_planner.cc"],
hdrs = ["arena_planner.h"],
deps = [
- ":context",
":graph_info",
":memory_planner",
":simple_memory_arena",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -54,6 +54,7 @@ cc_test(
deps = [
":arena_planner",
"//tensorflow/contrib/lite/testing:util",
+ "//tensorflow/core:framework",
"//tensorflow/core:lib",
"@com_google_googletest//:gtest",
],
@@ -63,27 +64,27 @@ cc_test(
# TODO(aselle): Resolve problems preventing C99 usage.
cc_library(
name = "context",
- srcs = ["context.c"],
hdrs = ["context.h"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
name = "graph_info",
hdrs = ["graph_info.h"],
- deps = [":context"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
name = "memory_planner",
hdrs = ["memory_planner.h"],
- deps = [":context"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
name = "simple_memory_arena",
srcs = ["simple_memory_arena.cc"],
hdrs = ["simple_memory_arena.h"],
- deps = [":context"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
@@ -91,7 +92,7 @@ cc_library(
hdrs = [
"builtin_op_data.h",
],
- deps = [":context"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
@@ -121,12 +122,12 @@ cc_library(
name = "framework",
srcs = [
"allocation.cc",
- "error_reporter.cc",
"graph_info.cc",
"interpreter.cc",
"model.cc",
- "op_resolver.cc",
+ "mutable_op_resolver.cc",
"optional_debug_tools.cc",
+ "stderr_reporter.cc",
] + select({
"//tensorflow:android": [
"nnapi_delegate.cc",
@@ -149,9 +150,11 @@ cc_library(
"graph_info.h",
"interpreter.h",
"model.h",
+ "mutable_op_resolver.h",
"nnapi_delegate.h",
"op_resolver.h",
"optional_debug_tools.h",
+ "stderr_reporter.h",
],
copts = tflite_copts(),
linkopts = [
@@ -164,14 +167,14 @@ cc_library(
}),
deps = [
":arena_planner",
- ":builtin_op_data",
- ":context",
":graph_info",
":memory_planner",
":schema_fbs_version",
":simple_memory_arena",
":string",
":util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/core/api",
"//tensorflow/contrib/lite/kernels:eigen_support",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/nnapi:nnapi_lib",
@@ -210,6 +213,8 @@ cc_test(
deps = [
":framework",
":string_util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/core/api",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/kernels:kernel_util",
"//tensorflow/contrib/lite/kernels/internal:tensor_utils",
@@ -259,6 +264,8 @@ cc_test(
],
deps = [
":framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/core/api",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
],
@@ -266,9 +273,9 @@ cc_test(
# Test OpResolver.
cc_test(
- name = "op_resolver_test",
+ name = "mutable_op_resolver_test",
size = "small",
- srcs = ["op_resolver_test.cc"],
+ srcs = ["mutable_op_resolver_test.cc"],
tags = ["no_oss"],
deps = [
":framework",
@@ -277,24 +284,12 @@ cc_test(
],
)
-# Test the C extension API code.
-cc_test(
- name = "context_test",
- size = "small",
- srcs = ["context_test.cc"],
- deps = [
- ":framework",
- "//tensorflow/contrib/lite/testing:util",
- "@com_google_googletest//:gtest",
- ],
-)
-
cc_library(
name = "util",
srcs = ["util.cc"],
hdrs = ["util.h"],
deps = [
- ":context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -304,7 +299,6 @@ cc_test(
srcs = ["util_test.cc"],
tags = ["no_oss"],
deps = [
- ":context",
":util",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc
index 8946261814..21cb1832a7 100644
--- a/tensorflow/contrib/lite/allocation.cc
+++ b/tensorflow/contrib/lite/allocation.cc
@@ -23,8 +23,8 @@ limitations under the License.
#include <cstring>
#include <utility>
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h
index 121f3d2646..182bc0977f 100644
--- a/tensorflow/contrib/lite/allocation.h
+++ b/tensorflow/contrib/lite/allocation.h
@@ -20,8 +20,8 @@ limitations under the License.
#include <cstdio>
#include <cstdlib>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/simple_memory_arena.h"
#include "tensorflow/contrib/lite/string.h"
diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h
index 55003cf4e9..382577045b 100644
--- a/tensorflow/contrib/lite/arena_planner.h
+++ b/tensorflow/contrib/lite/arena_planner.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/graph_info.h"
#include "tensorflow/contrib/lite/memory_planner.h"
#include "tensorflow/contrib/lite/simple_memory_arena.h"
@@ -37,8 +37,8 @@ struct AllocationInfo;
// each tensor needs to be allocated and deallocated, and preallocates all the
// necessary memory (the PlanAllocations phase). It then assigns portions of
// this memory buffer to each tensor (the ExecuteAllocations phase). Tensors may
-// share some of the buffer if a tensor B is to be allocated after another tensor
-// A has been deallocated.
+// share some of the buffer if a tensor B is to be allocated after another
+// tensor A has been deallocated.
//
// If dynamic tensors are used the planning steps can be repeated during model
// execution. Since dynamic tensors don't have sizes until after the
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 0246e7fa30..9317e2bb6e 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -49,6 +49,9 @@ def tflite_linkopts_unstripped():
Returns:
a select object with proper linkopts
"""
+
+ # In case you wonder why there's no --icf is because the gains were
+ # negligible, and created potential compatibility problems.
return select({
"//tensorflow:android": [
"-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj.
@@ -56,13 +59,7 @@ def tflite_linkopts_unstripped():
"-Wl,--gc-sections", # Eliminate unused code and data.
"-Wl,--as-needed", # Don't link unused libs.
],
- "//tensorflow:darwin": [],
- "//tensorflow:ios": [],
- "//tensorflow/contrib/lite:mips": [],
- "//tensorflow/contrib/lite:mips64": [],
- "//conditions:default": [
- "-Wl,--icf=all", # Identical code folding.
- ],
+ "//conditions:default": [],
})
def tflite_jni_linkopts_unstripped():
@@ -74,17 +71,15 @@ def tflite_jni_linkopts_unstripped():
Returns:
a select object with proper linkopts
"""
+
+ # In case you wonder why there's no --icf is because the gains were
+ # negligible, and created potential compatibility problems.
return select({
"//tensorflow:android": [
"-Wl,--gc-sections", # Eliminate unused code and data.
"-Wl,--as-needed", # Don't link unused libs.
],
- "//tensorflow:darwin": [],
- "//tensorflow/contrib/lite:mips": [],
- "//tensorflow/contrib/lite:mips64": [],
- "//conditions:default": [
- "-Wl,--icf=all", # Identical code folding.
- ],
+ "//conditions:default": [],
})
def tflite_linkopts():
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index aecd71910c..30901bd0fa 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -12,297 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// Compatibility shim for new location of interface definitions.
+
#ifndef TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
-#include <stdint.h>
-
-#include "tensorflow/contrib/lite/context.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif // __cplusplus
-
-// TODO(aselle): Consider using "if this then that" for testing.
-
-// Useful placeholder to put in otherwise empty structs to avoid size warnings.
-typedef struct {
- char dummy_;
-} EmptyStructPlaceholder;
-
-// Possible padding types (for convolutions)
-typedef enum {
- kTfLitePaddingUnknown = 0,
- kTfLitePaddingSame,
- kTfLitePaddingValid,
-} TfLitePadding;
-
-typedef struct {
- int width;
- int height;
-} TfLitePaddingValues;
-
-// Possible fused activation functions.
-// TODO(aselle): rename to TfLiteActivation
-typedef enum {
- kTfLiteActNone = 0,
- kTfLiteActRelu,
- kTfLiteActRelu1,
- kTfLiteActRelu6,
- kTfLiteActTanh,
- kTfLiteActSignBit,
- kTfLiteActSigmoid,
-} TfLiteFusedActivation;
-
-typedef struct {
- TfLitePadding padding;
- int stride_width;
- int stride_height;
- int dilation_width_factor;
- int dilation_height_factor;
- TfLiteFusedActivation activation;
-} TfLiteConvParams;
-
-typedef struct {
- TfLitePadding padding;
- int stride_width;
- int stride_height;
- int filter_width;
- int filter_height;
- TfLiteFusedActivation activation;
- struct {
- TfLitePaddingValues padding;
- } computed;
-} TfLitePoolParams;
-
-typedef struct {
- TfLitePadding padding;
- int stride_width;
- int stride_height;
- int depth_multiplier;
- TfLiteFusedActivation activation;
-} TfLiteDepthwiseConvParams;
-
-typedef struct {
- int rank;
- TfLiteFusedActivation activation;
-} TfLiteSVDFParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteRNNParams;
-
-typedef struct {
- bool time_major;
- TfLiteFusedActivation activation;
-} TfLiteSequenceRNNParams;
-
-typedef enum {
- kTfLiteFullyConnectedWeightsFormatDefault = 0,
- kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1,
-} TfLiteFullyConnectedWeightsFormat;
-
-typedef struct {
- // Parameters for FullyConnected version 1 or above.
- TfLiteFusedActivation activation;
-
- // Parameters for FullyConnected version 2 or above.
- TfLiteFullyConnectedWeightsFormat weights_format;
-} TfLiteFullyConnectedParams;
-
-typedef enum {
- kTfLiteLshProjectionUnknown = 0,
- kTfLiteLshProjectionSparse = 1,
- kTfLiteLshProjectionDense = 2,
-} TfLiteLSHProjectionType;
-
-typedef struct {
- TfLiteLSHProjectionType type;
-} TfLiteLSHProjectionParams;
-
-typedef struct {
- float beta;
-} TfLiteSoftmaxParams;
-
-typedef struct {
- int axis;
- TfLiteFusedActivation activation;
-} TfLiteConcatenationParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteAddParams;
-
-typedef struct {
- EmptyStructPlaceholder placeholder_;
-} TfLiteSpaceToBatchNDParams;
-
-typedef struct {
- EmptyStructPlaceholder placeholder_;
-} TfLiteBatchToSpaceNDParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteMulParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteSubParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteDivParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteL2NormParams;
-
-typedef struct {
- int radius;
- float bias;
- float alpha;
- float beta;
-} TfLiteLocalResponseNormParams;
-
-typedef enum {
- kTfLiteLSTMFullKernel = 0,
- kTfLiteLSTMBasicKernel
-} TfLiteLSTMKernelType;
-
-typedef struct {
- // Parameters for LSTM version 1.
- TfLiteFusedActivation activation;
- float cell_clip;
- float proj_clip;
-
- // Parameters for LSTM version 2.
- // kTfLiteLSTMBasicKernel is only supported in version 2 or above.
- TfLiteLSTMKernelType kernel_type;
-} TfLiteLSTMParams;
-
-typedef struct {
- bool align_corners;
-} TfLiteResizeBilinearParams;
-
-typedef struct {
- EmptyStructPlaceholder placeholder_;
-} TfLitePadParams;
-
-typedef struct {
- EmptyStructPlaceholder placeholder_;
-} TfLitePadV2Params;
-
-typedef struct {
- // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
- // For now we will fix the maximum possible number of dimensions.
- int shape[8];
- int num_dimensions;
-} TfLiteReshapeParams;
-
-typedef struct {
- int ngram_size;
- int max_skip_size;
- bool include_all_ngrams;
-} TfLiteSkipGramParams;
-
-typedef struct {
- int block_size;
-} TfLiteSpaceToDepthParams;
-
-typedef struct {
- TfLiteType in_data_type;
- TfLiteType out_data_type;
-} TfLiteCastParams;
-
-typedef enum {
- kTfLiteCombinerTypeSum = 0,
- kTfLiteCombinerTypeMean = 1,
- kTfLiteCombinerTypeSqrtn = 2,
-} TfLiteCombinerType;
-
-typedef struct {
- TfLiteCombinerType combiner;
-} TfLiteEmbeddingLookupSparseParams;
-
-typedef struct {
- int axis;
-} TfLiteGatherParams;
-
-typedef struct {
- EmptyStructPlaceholder placeholder_;
-} TfLiteTransposeParams;
-
-typedef struct {
- bool keep_dims;
-} TfLiteReducerParams;
-
-typedef struct {
- int num_splits;
-} TfLiteSplitParams;
-
-typedef struct {
- // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
- // For now we will fix the maximum possible number of dimensions.
- int squeeze_dims[8];
- int num_squeeze_dims;
-} TfLiteSqueezeParams;
-
-typedef struct {
- int begin_mask;
- int end_mask;
- int ellipsis_mask;
- int new_axis_mask;
- int shrink_axis_mask;
-} TfLiteStridedSliceParams;
-
-typedef struct {
- TfLiteType output_type;
-} TfLiteArgMaxParams;
-
-typedef struct {
- TfLiteType output_type;
-} TfLiteArgMinParams;
-
-typedef struct {
- TfLitePadding padding;
- int stride_width;
- int stride_height;
-} TfLiteTransposeConvParams;
-
-typedef struct {
- bool validate_indices;
-} TfLiteSparseToDenseParams;
-
-typedef struct {
- TfLiteType out_type;
-} TfLiteShapeParams;
-
-typedef struct {
- // Parameters supported by version 1:
- float min;
- float max;
- int num_bits;
-
- // Parameters supported by version 2:
- bool narrow_range;
-} TfLiteFakeQuantParams;
-
-typedef struct {
- int values_count;
- int axis;
-} TfLitePackParams;
-
-typedef struct {
- int axis;
-} TfLiteOneHotParams;
-
-typedef struct {
- int num;
- int axis;
-} TfLiteUnpackParams;
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
diff --git a/tensorflow/contrib/lite/c/BUILD b/tensorflow/contrib/lite/c/BUILD
new file mode 100644
index 0000000000..663eb63cad
--- /dev/null
+++ b/tensorflow/contrib/lite/c/BUILD
@@ -0,0 +1,39 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+cc_library(
+ name = "c_api_internal",
+ srcs = ["c_api_internal.c"],
+ hdrs = [
+ "builtin_op_data.h",
+ "c_api_internal.h",
+ ],
+ visibility = [
+ "//tensorflow/contrib/lite:__subpackages__",
+ ],
+)
+
+# Test the C extension API code.
+cc_test(
+ name = "c_api_internal_test",
+ size = "small",
+ srcs = ["c_api_internal_test.cc"],
+ deps = [
+ ":c_api_internal",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_test(
+ name = "builtin_op_data_test",
+ size = "small",
+ srcs = ["builtin_op_data_test.cc"],
+ copts = ["-Wno-unused-variable"],
+ deps = [
+ ":c_api_internal",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h
new file mode 100644
index 0000000000..fa43e6a024
--- /dev/null
+++ b/tensorflow/contrib/lite/c/builtin_op_data.h
@@ -0,0 +1,298 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_
+#define TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_
+
+#include <stdint.h>
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// TODO(aselle): Consider using "if this then that" for testing.
+
+// Possible padding types (for convolutions)
+typedef enum {
+ kTfLitePaddingUnknown = 0,
+ kTfLitePaddingSame,
+ kTfLitePaddingValid,
+} TfLitePadding;
+
+typedef struct {
+ int width;
+ int height;
+} TfLitePaddingValues;
+
+// Possible fused activation functions.
+// TODO(aselle): rename to TfLiteActivation
+typedef enum {
+ kTfLiteActNone = 0,
+ kTfLiteActRelu,
+ kTfLiteActRelu1,
+ kTfLiteActRelu6,
+ kTfLiteActTanh,
+ kTfLiteActSignBit,
+ kTfLiteActSigmoid,
+} TfLiteFusedActivation;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ int dilation_width_factor;
+ int dilation_height_factor;
+ TfLiteFusedActivation activation;
+} TfLiteConvParams;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ int filter_width;
+ int filter_height;
+ TfLiteFusedActivation activation;
+ struct {
+ TfLitePaddingValues padding;
+ } computed;
+} TfLitePoolParams;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ int depth_multiplier;
+ TfLiteFusedActivation activation;
+} TfLiteDepthwiseConvParams;
+
+typedef struct {
+ int rank;
+ TfLiteFusedActivation activation;
+} TfLiteSVDFParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteRNNParams;
+
+typedef struct {
+ bool time_major;
+ TfLiteFusedActivation activation;
+} TfLiteSequenceRNNParams;
+
+typedef enum {
+ kTfLiteFullyConnectedWeightsFormatDefault = 0,
+ kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1,
+} TfLiteFullyConnectedWeightsFormat;
+
+typedef struct {
+ // Parameters for FullyConnected version 1 or above.
+ TfLiteFusedActivation activation;
+
+ // Parameters for FullyConnected version 2 or above.
+ TfLiteFullyConnectedWeightsFormat weights_format;
+} TfLiteFullyConnectedParams;
+
+typedef enum {
+ kTfLiteLshProjectionUnknown = 0,
+ kTfLiteLshProjectionSparse = 1,
+ kTfLiteLshProjectionDense = 2,
+} TfLiteLSHProjectionType;
+
+typedef struct {
+ TfLiteLSHProjectionType type;
+} TfLiteLSHProjectionParams;
+
+typedef struct {
+ float beta;
+} TfLiteSoftmaxParams;
+
+typedef struct {
+ int axis;
+ TfLiteFusedActivation activation;
+} TfLiteConcatenationParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteAddParams;
+
+typedef struct {
+} TfLiteSpaceToBatchNDParams;
+
+typedef struct {
+} TfLiteBatchToSpaceNDParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteMulParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteSubParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteDivParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteL2NormParams;
+
+typedef struct {
+ int radius;
+ float bias;
+ float alpha;
+ float beta;
+} TfLiteLocalResponseNormParams;
+
+typedef enum {
+ kTfLiteLSTMFullKernel = 0,
+ kTfLiteLSTMBasicKernel
+} TfLiteLSTMKernelType;
+
+typedef struct {
+ // Parameters for LSTM version 1.
+ TfLiteFusedActivation activation;
+ float cell_clip;
+ float proj_clip;
+
+ // Parameters for LSTM version 2.
+ // kTfLiteLSTMBasicKernel is only supported in version 2 or above.
+ TfLiteLSTMKernelType kernel_type;
+} TfLiteLSTMParams;
+
+typedef struct {
+ bool align_corners;
+} TfLiteResizeBilinearParams;
+
+typedef struct {
+} TfLitePadParams;
+
+typedef struct {
+} TfLitePadV2Params;
+
+typedef struct {
+ // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
+ // For now we will fix the maximum possible number of dimensions.
+ int shape[8];
+ int num_dimensions;
+} TfLiteReshapeParams;
+
+typedef struct {
+ int ngram_size;
+ int max_skip_size;
+ bool include_all_ngrams;
+} TfLiteSkipGramParams;
+
+typedef struct {
+ int block_size;
+} TfLiteSpaceToDepthParams;
+
+typedef struct {
+ TfLiteType in_data_type;
+ TfLiteType out_data_type;
+} TfLiteCastParams;
+
+typedef enum {
+ kTfLiteCombinerTypeSum = 0,
+ kTfLiteCombinerTypeMean = 1,
+ kTfLiteCombinerTypeSqrtn = 2,
+} TfLiteCombinerType;
+
+typedef struct {
+ TfLiteCombinerType combiner;
+} TfLiteEmbeddingLookupSparseParams;
+
+typedef struct {
+ int axis;
+} TfLiteGatherParams;
+
+typedef struct {
+} TfLiteTransposeParams;
+
+typedef struct {
+ bool keep_dims;
+} TfLiteReducerParams;
+
+typedef struct {
+ int num_splits;
+} TfLiteSplitParams;
+
+typedef struct {
+ // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
+ // For now we will fix the maximum possible number of dimensions.
+ int squeeze_dims[8];
+ int num_squeeze_dims;
+} TfLiteSqueezeParams;
+
+typedef struct {
+ int begin_mask;
+ int end_mask;
+ int ellipsis_mask;
+ int new_axis_mask;
+ int shrink_axis_mask;
+} TfLiteStridedSliceParams;
+
+typedef struct {
+ TfLiteType output_type;
+} TfLiteArgMaxParams;
+
+typedef struct {
+ TfLiteType output_type;
+} TfLiteArgMinParams;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+} TfLiteTransposeConvParams;
+
+typedef struct {
+ bool validate_indices;
+} TfLiteSparseToDenseParams;
+
+typedef struct {
+ TfLiteType out_type;
+} TfLiteShapeParams;
+
+typedef struct {
+ // Parameters supported by version 1:
+ float min;
+ float max;
+ int num_bits;
+
+ // Parameters supported by version 2:
+ bool narrow_range;
+} TfLiteFakeQuantParams;
+
+typedef struct {
+ int values_count;
+ int axis;
+} TfLitePackParams;
+
+typedef struct {
+ int axis;
+} TfLiteOneHotParams;
+
+typedef struct {
+ int num;
+ int axis;
+} TfLiteUnpackParams;
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_
diff --git a/tensorflow/contrib/lite/c/builtin_op_data_test.cc b/tensorflow/contrib/lite/c/builtin_op_data_test.cc
new file mode 100644
index 0000000000..4d0ba75e68
--- /dev/null
+++ b/tensorflow/contrib/lite/c/builtin_op_data_test.cc
@@ -0,0 +1,83 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include <gtest/gtest.h>
+
+namespace tflite {
+
+// Builtin op data is just a set of data definitions, so the only meaningful
+// test we can run is whether we can create the structs we expect to find.
+// Testing each struct's members might be possible, but it seems unnecessary
+// until we've locked down the API. The build rule has copts set to ignore the
+// unused variable warning, since this is just a compilation test.
+TEST(IntArray, CanCompileStructs) {
+ TfLitePadding padding = kTfLitePaddingSame;
+ TfLitePaddingValues padding_values;
+ TfLiteFusedActivation fused_activation = kTfLiteActRelu;
+ TfLiteConvParams conv_params;
+ TfLitePoolParams pool_params;
+ TfLiteDepthwiseConvParams depthwise_conv_params;
+ TfLiteSVDFParams svdf_params;
+ TfLiteRNNParams rnn_params;
+ TfLiteSequenceRNNParams sequence_rnn_params;
+ TfLiteFullyConnectedWeightsFormat fully_connected_weights_format =
+ kTfLiteFullyConnectedWeightsFormatDefault;
+ TfLiteFullyConnectedParams fully_connected_params;
+ TfLiteLSHProjectionType projection_type = kTfLiteLshProjectionDense;
+ TfLiteLSHProjectionParams projection_params;
+ TfLiteSoftmaxParams softmax_params;
+ TfLiteConcatenationParams concatenation_params;
+ TfLiteAddParams add_params;
+ TfLiteSpaceToBatchNDParams space_to_batch_nd_params;
+ TfLiteBatchToSpaceNDParams batch_to_space_nd_params;
+ TfLiteMulParams mul_params;
+ TfLiteSubParams sub_params;
+ TfLiteDivParams div_params;
+ TfLiteL2NormParams l2_norm_params;
+ TfLiteLocalResponseNormParams local_response_norm_params;
+ TfLiteLSTMKernelType lstm_kernel_type = kTfLiteLSTMBasicKernel;
+ TfLiteLSTMParams lstm_params;
+ TfLiteResizeBilinearParams resize_bilinear_params;
+ TfLitePadParams pad_params;
+ TfLitePadV2Params pad_v2_params;
+ TfLiteReshapeParams reshape_params;
+ TfLiteSkipGramParams skip_gram_params;
+ TfLiteSpaceToDepthParams space_to_depth_params;
+ TfLiteCastParams cast_params;
+ TfLiteCombinerType combiner_type = kTfLiteCombinerTypeSqrtn;
+ TfLiteEmbeddingLookupSparseParams lookup_sparse_params;
+ TfLiteGatherParams gather_params;
+ TfLiteTransposeParams transpose_params;
+ TfLiteReducerParams reducer_params;
+ TfLiteSplitParams split_params;
+ TfLiteSqueezeParams squeeze_params;
+ TfLiteStridedSliceParams strided_slice_params;
+ TfLiteArgMaxParams arg_max_params;
+ TfLiteArgMinParams arg_min_params;
+ TfLiteTransposeConvParams transpose_conv_params;
+ TfLiteSparseToDenseParams sparse_to_dense_params;
+ TfLiteShapeParams shape_params;
+ TfLiteFakeQuantParams fake_quant_params;
+ TfLitePackParams pack_params;
+ TfLiteOneHotParams one_hot_params;
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/c/c_api_internal.c
index 7f2aa316f4..1846bad4b7 100644
--- a/tensorflow/contrib/lite/context.c
+++ b/tensorflow/contrib/lite/c/c_api_internal.c
@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include <stdio.h>
+#include <stdlib.h>
#include <string.h>
int TfLiteIntArrayGetSizeInBytes(int size) {
@@ -76,7 +77,8 @@ void TfLiteTensorFree(TfLiteTensor* t) {
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
- const void* allocation, bool is_variable, TfLiteTensor* tensor) {
+ const void* allocation, bool is_variable,
+ TfLiteTensor* tensor) {
TfLiteTensorFree(tensor);
tensor->type = type;
tensor->name = name;
diff --git a/tensorflow/contrib/lite/c/c_api_internal.h b/tensorflow/contrib/lite/c/c_api_internal.h
new file mode 100644
index 0000000000..48df68a654
--- /dev/null
+++ b/tensorflow/contrib/lite/c/c_api_internal.h
@@ -0,0 +1,491 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This file defines a C API for implementing operations in tflite.
+// These operations can be defined using c++ but the interface between
+// the interpreter and the operations are C.
+//
+// Summary of abstractions
+// TF_LITE_ENSURE - Self-sufficient error checking
+// TfLiteStatus - Status reporting
+// TfLiteIntArray - stores tensor shapes (dims),
+// TfLiteContext - allows an op to access the tensors
+// TfLiteTensor - tensor (a multidimensional array)
+// TfLiteNode - a single node or operation
+// TfLiteRegistration - the implementation of a conceptual operation.
+//
+// Some abstractions in this file are created and managed by Interpreter.
+#ifndef TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_
+#define TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_
+
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
+
+// The list of external context types known to TF Lite. This list exists solely
+// to avoid conflicts and to ensure ops can share the external contexts they
+// need. Access to the external contexts is controled by one of the
+// corresponding support files.
+typedef enum {
+ kTfLiteEigenContext = 0, // include eigen_support.h to use.
+ kTfLiteGemmLowpContext = 1, // include gemm_support.h to use.
+ kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support.
+ kTfLiteMaxExternalContexts = 3
+} TfLiteExternalContextType;
+
+// An external context is a collection of information unrelated to the TF Lite
+// framework, but useful to a subset of the ops. TF Lite knows very little
+// about about the actual contexts, but it keeps a list of them, and is able to
+// refresh them if configurations like the number of recommended threads
+// change.
+typedef struct {
+ TfLiteExternalContextType type;
+ TfLiteStatus (*Refresh)(struct TfLiteContext* context);
+} TfLiteExternalContext;
+
+// Forward declare so GetNode can use this is in Context.
+typedef struct _TfLiteRegistration TfLiteRegistration;
+typedef struct _TfLiteDelegate TfLiteDelegate;
+
+#define kOptionalTensor (-1)
+
+// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
+// indices
+typedef struct {
+ int size;
+// gcc 6.1+ have a bug where flexible members aren't properly handled
+// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
+#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
+ __GNUC_MINOR__ >= 1
+ int data[0];
+#else
+ int data[];
+#endif
+} TfLiteIntArray;
+
+// Given the size (number of elements) in a TfLiteIntArray, calculate its size
+// in bytes.
+int TfLiteIntArrayGetSizeInBytes(int size);
+
+// Create a array of a given `size` (uninitialized entries).
+// This returns a pointer, that you must free using TfLiteIntArrayFree().
+TfLiteIntArray* TfLiteIntArrayCreate(int size);
+
+// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise.
+int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b);
+
+// Create a copy of an array passed as `src`.
+// You are expected to free memory with TfLiteIntArrayFree
+TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src);
+
+// Free memory of array `v`.
+void TfLiteIntArrayFree(TfLiteIntArray* v);
+
+// Since we must not depend on any libraries, define a minimal subset of
+// error macros while avoiding names that have pre-conceived meanings like
+// assert and check.
+
+// Check whether value is true, and if not return kTfLiteError from
+// the current function (and report the error string msg).
+#define TF_LITE_ENSURE_MSG(context, value, msg) \
+ do { \
+ if (!(value)) { \
+ (context)->ReportError((context), __FILE__ " " msg); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+// Check whether the value `a` is true, and if not return kTfLiteError from
+// the current function, while also reporting the location of the error.
+#define TF_LITE_ENSURE(context, a) \
+ do { \
+ if (!(a)) { \
+ (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \
+ __LINE__, #a); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+#define TF_LITE_ENSURE_STATUS(a) \
+ do { \
+ if ((a) != kTfLiteOk) { \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+// Check whether the value `a == b` is true, and if not return kTfLiteError from
+// the current function, while also reporting the location of the error.
+// `a` and `b` may be evaluated more than once, so no side effects or
+// extremely expensive computations should be done.
+#define TF_LITE_ENSURE_EQ(context, a, b) \
+ do { \
+ if ((a) != (b)) { \
+ (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \
+ __LINE__, #a, #b, (a), (b)); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+#define TF_LITE_ENSURE_OK(context, status) \
+ do { \
+ if ((status) != kTfLiteOk) { \
+ return status; \
+ } \
+ } while (0)
+
+// Single-precision complex data type compatible with the C99 definition.
+typedef struct {
+ float re, im; // real and imaginary parts, respectively.
+} TfLiteComplex64;
+
+// Types supported by tensor
+typedef enum {
+ kTfLiteNoType = 0,
+ kTfLiteFloat32 = 1,
+ kTfLiteInt32 = 2,
+ kTfLiteUInt8 = 3,
+ kTfLiteInt64 = 4,
+ kTfLiteString = 5,
+ kTfLiteBool = 6,
+ kTfLiteInt16 = 7,
+ kTfLiteComplex64 = 8,
+} TfLiteType;
+
+// Parameters for asymmetric quantization. Quantized values can be converted
+// back to float using:
+// real_value = scale * (quantized_value - zero_point);
+typedef struct {
+ float scale;
+ int32_t zero_point;
+} TfLiteQuantizationParams;
+
+// A union of pointers that points to memory for a given tensor.
+typedef union {
+ int* i32;
+ int64_t* i64;
+ float* f;
+ char* raw;
+ const char* raw_const;
+ uint8_t* uint8;
+ bool* b;
+ int16_t* i16;
+ TfLiteComplex64* c64;
+} TfLitePtrUnion;
+
+// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
+// data (or data externally allocated). kTfLiteArenaRw is arena allocated
+// data. kTfLiteDynamic is for tensors that are allocated during evaluation.
+typedef enum {
+ kTfLiteMemNone = 0,
+ kTfLiteMmapRo,
+ kTfLiteArenaRw,
+ kTfLiteArenaRwPersistent,
+ kTfLiteDynamic,
+} TfLiteAllocationType;
+
+// The delegates should use zero or positive integers to represent handles.
+// -1 is reserved from unallocated status.
+typedef int TfLiteBufferHandle;
+const TfLiteBufferHandle kTfLiteNullBufferHandle = -1;
+
+// An tensor in the interpreter system which is a wrapper around a buffer of
+// data including a dimensionality (or NULL if not currently defined).
+typedef struct {
+ // The data type specification for data stored in `data`. This affects
+ // what member of `data` union should be used.
+ TfLiteType type;
+ // A union of data pointers. The appropriate type should be used for a typed
+ // tensor based on `type`.
+ TfLitePtrUnion data;
+ // A pointer to a structure representing the dimensionality interpretation
+ // that the buffer should have. NOTE: the product of elements of `dims`
+ // and the element datatype size should be equal to `bytes` below.
+ TfLiteIntArray* dims;
+ // Quantization information.
+ TfLiteQuantizationParams params;
+ // How memory is mapped
+ // kTfLiteMmapRo: Memory mapped read only.
+ // i.e. weights
+ // kTfLiteArenaRw: Arena allocated read write memory
+ // (i.e. temporaries, outputs).
+ TfLiteAllocationType allocation_type;
+ // The number of bytes required to store the data of this Tensor. I.e.
+ // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
+ // type is kTfLiteFloat32 and dims = {3, 2} then
+ // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
+ size_t bytes;
+
+ // An opaque pointer to a tflite::MMapAllocation
+ const void* allocation;
+
+ // Null-terminated name of this tensor.
+ const char* name;
+
+ // The delegate which knows how to handle `buffer_handle`.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteDelegate* delegate;
+
+ // An integer buffer handle that can be handled by `delegate`.
+ // The value is valid only when delegate is not null.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteBufferHandle buffer_handle;
+
+ // If the delegate uses its own buffer (e.g. GPU memory), the delegate is
+ // responsible to set data_is_stale to true.
+ // `delegate->CopyFromBufferHandle` can be called to copy the data from
+ // delegate buffer.
+ // WARNING: This is an // experimental interface that is subject to change.
+ bool data_is_stale;
+
+ // True if the tensor is a variable.
+ bool is_variable;
+} TfLiteTensor;
+
+// Free data memory of tensor `t`;
+void TfLiteTensorDataFree(TfLiteTensor* t);
+
+// Free memory of tensor `t`;
+void TfLiteTensorFree(TfLiteTensor* t);
+
+// Set all of a tensor's fields (and free any previously allocated data).
+void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
+ TfLiteQuantizationParams quantization, char* buffer,
+ size_t size, TfLiteAllocationType allocation_type,
+ const void* allocation, bool is_variable,
+ TfLiteTensor* tensor);
+
+// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
+// types other than kTfLiteDynamic will be ignored.
+void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
+
+// A structure representing an instance of a node.
+// This structure only exhibits the inputs, outputs and user defined data, not
+// other features like the type.
+typedef struct {
+ // Inputs to this node expressed as indices into the simulator's tensors.
+ TfLiteIntArray* inputs;
+
+ // Outputs to this node expressed as indices into the simulator's tensors.
+ TfLiteIntArray* outputs;
+
+ // Temporary tensors uses during the computations. This usually contains no
+ // tensors, but ops are allowed to change that if they need scratch space of
+ // any sort.
+ TfLiteIntArray* temporaries;
+
+ // Opaque data provided by the node implementer through `Registration.init`.
+ void* user_data;
+
+ // Opaque data provided to the node if the node is a builtin. This is usually
+ // a structure defined in builtin_op_data.h
+ void* builtin_data;
+
+ // Custom initial data. This is the opaque data provided in the flatbuffer.
+ // WARNING: This is an experimental interface that is subject to change.
+ const void* custom_initial_data;
+ int custom_initial_data_size;
+
+ // The pointer to the delegate. This is non-null only when the node is
+ // created by calling `interpreter.ModifyGraphWithDelegate`.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteDelegate* delegate;
+} TfLiteNode;
+
+typedef struct TfLiteContext {
+ // Number of tensors in the context.
+ size_t tensors_size;
+
+ // The execution plan contains a list of the node indices in execution
+ // order. execution_plan->size is the current number of nodes. And,
+ // execution_plan->data[0] is the first node that needs to be run.
+ // TfLiteDelegates can traverse the current execution plan by iterating
+ // through each member of this array and using GetNodeAndRegistration() to
+ // access details about a node. i.e.
+ // TfLiteIntArray* execution_plan;
+ // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
+ // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
+ // int node_index = execution_plan->data[exec_index];
+ // TfLiteNode* node;
+ // TfLiteRegistration* reg;
+ // context->GetNodeAndRegistration(context, node_index, &node, &reg);
+ // }
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context,
+ TfLiteIntArray** execution_plan);
+
+ // An array of tensors in the interpreter context (of length `tensors_size`)
+ TfLiteTensor* tensors;
+
+ // opaque full context ptr (an opaque c++ data structure)
+ void* impl_;
+
+ // Request memory pointer be resized. Updates dimensions on the tensor.
+ // NOTE: ResizeTensor takes ownership of newSize.
+ TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor,
+ TfLiteIntArray* new_size);
+ // Request that a error be reported with format string msg.
+ void (*ReportError)(struct TfLiteContext*, const char* msg, ...);
+
+ // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If
+ // non-null, the value pointed to by `first_new_tensor_index` will be set to
+ // the index of the first new tensor.
+ TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add,
+ int* first_new_tensor_index);
+
+ // Get a Tensor node by node_index.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteStatus (*GetNodeAndRegistration)(struct TfLiteContext*, int node_index,
+ TfLiteNode** node,
+ TfLiteRegistration** registration);
+
+ // Replace ops with one or more stub delegate operations. This function
+ // does not take ownership of `nodes_to_replace`.
+ TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)(
+ struct TfLiteContext*, TfLiteRegistration registration,
+ const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate);
+
+ // Number of threads that are recommended to subsystems like gemmlowp and
+ // eigen.
+ int recommended_num_threads;
+
+ // Access external contexts by type.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*,
+ TfLiteExternalContextType);
+ // Set the value of a external context. Does not take ownership of the
+ // pointer.
+ // WARNING: This is an experimental interface that is subject to change.
+ void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType,
+ TfLiteExternalContext*);
+} TfLiteContext;
+
+typedef struct _TfLiteRegistration {
+ // Initializes the op from serialized data.
+ // If a built-in op:
+ // `buffer` is the op's params data (TfLiteLSTMParams*).
+ // `length` is zero.
+ // If custom op:
+ // `buffer` is the op's `custom_options`.
+ // `length` is the size of the buffer.
+ //
+ // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
+ // or an instance of a struct).
+ //
+ // The returned pointer will be stored with the node in the `user_data` field,
+ // accessible within prepare and invoke functions below.
+ // NOTE: if the data is already in the desired format, simply implement this
+ // function to return `nullptr` and implement the free function to be a no-op.
+ void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
+
+ // The pointer `buffer` is the data previously returned by an init invocation.
+ void (*free)(TfLiteContext* context, void* buffer);
+
+ // prepare is called when the inputs this node depends on have been resized.
+ // context->ResizeTensor() can be called to request output tensors to be
+ // resized.
+ //
+ // Returns kTfLiteOk on success.
+ TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
+
+ // Execute the node (should read node->inputs and output to node->outputs).
+ // Returns kTfLiteOk on success.
+ TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
+
+ // profiling_string is called during summarization of profiling information
+ // in order to group executions together. Providing a value here will cause a
+ // given op to appear multiple times is the profiling report. This is
+ // particularly useful for custom ops that can perform significantly
+ // different calculations depending on their `user-data`.
+ const char* (*profiling_string)(const TfLiteContext* context,
+ const TfLiteNode* node);
+
+ // Builtin codes. If this kernel refers to a builtin this is the code
+ // of the builtin. This is so we can do marshaling to other frameworks like
+ // NN API.
+ // Note: It is the responsibility of the registration binder to set this
+ // properly.
+ int32_t builtin_code;
+
+ // Custom op name. If the op is a builtin, this will be null.
+ // Note: It is the responsibility of the registration binder to set this
+ // properly.
+ // WARNING: This is an experimental interface that is subject to change.
+ const char* custom_name;
+
+ // The version of the op.
+ // Note: It is the responsibility of the registration binder to set this
+ // properly.
+ int version;
+} TfLiteRegistration;
+
+// WARNING: This is an experimental interface that is subject to change.
+typedef struct _TfLiteDelegate {
+ // Data that delegate needs to identify itself. This data is owned by the
+ // delegate. The delegate is owned in the user code, so the delegate is
+ // responsible for doing this when it is destroyed.
+ void* data_;
+
+ // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
+ // delegate a view of the current graph through TfLiteContext*. It typically
+ // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels()
+ // to ask the TensorFlow lite runtime to create macro-nodes to represent
+ // delegated subgraphs of the original graph.
+ TfLiteStatus (*Prepare)(TfLiteContext* context, TfLiteDelegate* delegate);
+
+ // Copy the data from delegate buffer handle to raw memory.
+ // This can be null if the delegate doesn't use its own buffer.
+ TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
+ TfLiteDelegate* delegate,
+ TfLiteBufferHandle buffer_handle,
+ void* data, size_t size);
+
+ // Copy the data from raw memory to delegate buffer handle.
+ // This can be null if the delegate doesn't use its own buffer.
+ TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context,
+ TfLiteDelegate* delegate,
+ TfLiteBufferHandle buffer_handle,
+ void* data, size_t size);
+
+ // Free the Delegate Buffer Handle. Note: This only frees the handle, but
+ // this doesn't release the underlying resource (e.g. textures). The
+ // resources are either owned by application layer or the delegate.
+ // This can be null if the delegate doesn't use its own buffer.
+ void (*FreeBufferHandle)(TfLiteContext* context, TfLiteDelegate* delegate,
+ TfLiteBufferHandle* handle);
+} TfLiteDelegate;
+
+// WARNING: This is an experimental interface that is subject to change.
+//
+// Currently, TfLiteDelegateParams has to be allocated in a way that it's
+// trivially destructable. It will be stored as `builtin_data` field in
+// `TfLiteNode` of the delegate node.
+//
+// See also the `CreateDelegateParams` function in `interpreter.cc` details.
+typedef struct {
+ TfLiteDelegate* delegate;
+ TfLiteIntArray* nodes_to_replace;
+ TfLiteIntArray* input_tensors;
+ TfLiteIntArray* output_tensors;
+} TfLiteDelegateParams;
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif // TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_
diff --git a/tensorflow/contrib/lite/context_test.cc b/tensorflow/contrib/lite/c/c_api_internal_test.cc
index 20d6f69a25..af398f3207 100644
--- a/tensorflow/contrib/lite/context_test.cc
+++ b/tensorflow/contrib/lite/c/c_api_internal_test.cc
@@ -13,16 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/testing/util.h"
namespace tflite {
// NOTE: this tests only the TfLiteIntArray part of context.
-// most of context.h is provided in the context of using it with interpreter.h
-// and interpreter.cc, so interpreter_test.cc tests context structures more
-// thoroughly.
+// most of c_api_internal.h is provided in the context of using it with
+// interpreter.h and interpreter.cc, so interpreter_test.cc tests context
+// structures more thoroughly.
TEST(IntArray, TestIntArrayCreate) {
TfLiteIntArray* a = TfLiteIntArrayCreate(0);
@@ -69,7 +68,6 @@ TEST(IntArray, TestIntArrayEqual) {
} // namespace tflite
int main(int argc, char** argv) {
- ::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
index b23183b743..b86c2819b8 100644
--- a/tensorflow/contrib/lite/context.h
+++ b/tensorflow/contrib/lite/context.h
@@ -12,484 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// This file defines a C API for implementing operations in tflite.
-// These operations can be defined using c++ but the interface between
-// the interpreter and the operations are C.
-//
-// Summary of abstractions
-// TF_LITE_ENSURE - Self-sufficient error checking
-// TfLiteStatus - Status reporting
-// TfLiteIntArray - stores tensor shapes (dims),
-// TfLiteContext - allows an op to access the tensors
-// TfLiteTensor - tensor (a multidimensional array)
-// TfLiteNode - a single node or operation
-// TfLiteRegistration - the implementation of a conceptual operation.
-//
-// Some abstractions in this file are created and managed by Interpreter.
+// Compatibility shim for moved header location.
#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
#define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
-#include <stdbool.h>
-#include <stdint.h>
-#include <stdlib.h>
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
-#ifdef __cplusplus
-extern "C" {
-#endif // __cplusplus
-
-typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
-
-// Forward declarations for use with dependent types.
-struct TfLiteContext;
-struct TfLiteNode;
-struct _TfLiteRegistration;
-struct _TfLiteDelegate;
-
-// The list of external context types known to TF Lite. This list exists solely
-// to avoid conflicts and to ensure ops can share the external contexts they
-// need. Access to the external contexts is controled by one of the
-// corresponding support files.
-typedef enum {
- kTfLiteEigenContext = 0, // include eigen_support.h to use.
- kTfLiteGemmLowpContext = 1, // include gemm_support.h to use.
- kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support.
- kTfLiteMaxExternalContexts = 3
-} TfLiteExternalContextType;
-
-// An external context is a collection of information unrelated to the TF Lite
-// framework, but useful to a subset of the ops. TF Lite knows very little
-// about about the actual contexts, but it keeps a list of them, and is able to
-// refresh them if configurations like the number of recommended threads
-// change.
-typedef struct {
- TfLiteExternalContextType type;
- TfLiteStatus (*Refresh)(struct TfLiteContext* context);
-} TfLiteExternalContext;
-
-#define kOptionalTensor (-1)
-
-// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
-// indices
-typedef struct {
- int size;
-// gcc 6.1+ have a bug where flexible members aren't properly handled
-// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
-#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
- __GNUC_MINOR__ >= 1
- int data[0];
-#else
- int data[];
-#endif
-} TfLiteIntArray;
-
-// Given the size (number of elements) in a TfLiteIntArray, calculate its size
-// in bytes.
-int TfLiteIntArrayGetSizeInBytes(int size);
-
-// Create a array of a given `size` (uninitialized entries).
-// This returns a pointer, that you must free using TfLiteIntArrayFree().
-TfLiteIntArray* TfLiteIntArrayCreate(int size);
-
-// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise.
-int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b);
-
-// Create a copy of an array passed as `src`.
-// You are expected to free memory with TfLiteIntArrayFree
-TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src);
-
-// Free memory of array `v`.
-void TfLiteIntArrayFree(TfLiteIntArray* v);
-
-// Since we must not depend on any libraries, define a minimal subset of
-// error macros while avoiding names that have pre-conceived meanings like
-// assert and check.
-
-// Check whether value is true, and if not return kTfLiteError from
-// the current function (and report the error string msg).
-#define TF_LITE_ENSURE_MSG(context, value, msg) \
- do { \
- if (!(value)) { \
- (context)->ReportError((context), __FILE__ " " msg); \
- return kTfLiteError; \
- } \
- } while (0)
-
-// Check whether the value `a` is true, and if not return kTfLiteError from
-// the current function, while also reporting the location of the error.
-#define TF_LITE_ENSURE(context, a) \
- do { \
- if (!(a)) { \
- (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \
- __LINE__, #a); \
- return kTfLiteError; \
- } \
- } while (0)
-
-#define TF_LITE_ENSURE_STATUS(a) \
- do { \
- if ((a) != kTfLiteOk) { \
- return kTfLiteError; \
- } \
- } while (0)
-
-// Check whether the value `a == b` is true, and if not return kTfLiteError from
-// the current function, while also reporting the location of the error.
-// `a` and `b` may be evaluated more than once, so no side effects or
-// extremely expensive computations should be done.
-#define TF_LITE_ENSURE_EQ(context, a, b) \
- do { \
- if ((a) != (b)) { \
- (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \
- __LINE__, #a, #b, (a), (b)); \
- return kTfLiteError; \
- } \
- } while (0)
-
-#define TF_LITE_ENSURE_OK(context, status) \
- do { \
- if ((status) != kTfLiteOk) { \
- return status; \
- } \
- } while (0)
-
-// Single-precision complex data type compatible with the C99 definition.
-typedef struct {
- float re, im; // real and imaginary parts, respectively.
-} TfLiteComplex64;
-
-// Types supported by tensor
-typedef enum {
- kTfLiteNoType = 0,
- kTfLiteFloat32 = 1,
- kTfLiteInt32 = 2,
- kTfLiteUInt8 = 3,
- kTfLiteInt64 = 4,
- kTfLiteString = 5,
- kTfLiteBool = 6,
- kTfLiteInt16 = 7,
- kTfLiteComplex64 = 8,
-} TfLiteType;
-
-// Parameters for asymmetric quantization. Quantized values can be converted
-// back to float using:
-// real_value = scale * (quantized_value - zero_point);
-typedef struct {
- float scale;
- int32_t zero_point;
-} TfLiteQuantizationParams;
-
-// A union of pointers that points to memory for a given tensor.
-typedef union {
- int* i32;
- int64_t* i64;
- float* f;
- char* raw;
- const char* raw_const;
- uint8_t* uint8;
- bool* b;
- int16_t* i16;
- TfLiteComplex64* c64;
-} TfLitePtrUnion;
-
-// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
-// data (or data externally allocated). kTfLiteArenaRw is arena allocated
-// data. kTfLiteDynamic is for tensors that are allocated during evaluation.
-typedef enum {
- kTfLiteMemNone = 0,
- kTfLiteMmapRo,
- kTfLiteArenaRw,
- kTfLiteArenaRwPersistent,
- kTfLiteDynamic,
-} TfLiteAllocationType;
-
-// The delegates should use zero or positive integers to represent handles.
-// -1 is reserved from unallocated status.
-typedef int TfLiteBufferHandle;
-const TfLiteBufferHandle kTfLiteNullBufferHandle = -1;
-
-// An tensor in the interpreter system which is a wrapper around a buffer of
-// data including a dimensionality (or NULL if not currently defined).
-typedef struct {
- // The data type specification for data stored in `data`. This affects
- // what member of `data` union should be used.
- TfLiteType type;
- // A union of data pointers. The appropriate type should be used for a typed
- // tensor based on `type`.
- TfLitePtrUnion data;
- // A pointer to a structure representing the dimensionality interpretation
- // that the buffer should have. NOTE: the product of elements of `dims`
- // and the element datatype size should be equal to `bytes` below.
- TfLiteIntArray* dims;
- // Quantization information.
- TfLiteQuantizationParams params;
- // How memory is mapped
- // kTfLiteMmapRo: Memory mapped read only.
- // i.e. weights
- // kTfLiteArenaRw: Arena allocated read write memory
- // (i.e. temporaries, outputs).
- TfLiteAllocationType allocation_type;
- // The number of bytes required to store the data of this Tensor. I.e.
- // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
- // type is kTfLiteFloat32 and dims = {3, 2} then
- // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
- size_t bytes;
-
- // An opaque pointer to a tflite::MMapAllocation
- const void* allocation;
-
- // Null-terminated name of this tensor.
- const char* name;
-
- // The delegate which knows how to handle `buffer_handle`.
- // WARNING: This is an experimental interface that is subject to change.
- struct _TfLiteDelegate* delegate;
-
- // An integer buffer handle that can be handled by `delegate`.
- // The value is valid only when delegate is not null.
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteBufferHandle buffer_handle;
-
- // If the delegate uses its own buffer (e.g. GPU memory), the delegate is
- // responsible to set data_is_stale to true.
- // `delegate->CopyFromBufferHandle` can be called to copy the data from
- // delegate buffer.
- // WARNING: This is an // experimental interface that is subject to change.
- bool data_is_stale;
-
- // True if the tensor is a variable.
- bool is_variable;
-} TfLiteTensor;
-
-// Free data memory of tensor `t`;
-void TfLiteTensorDataFree(TfLiteTensor* t);
-
-// Free memory of tensor `t`;
-void TfLiteTensorFree(TfLiteTensor* t);
-
-// Set all of a tensor's fields (and free any previously allocated data).
-void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
- TfLiteQuantizationParams quantization, char* buffer,
- size_t size, TfLiteAllocationType allocation_type,
- const void* allocation, bool is_variable,
- TfLiteTensor* tensor);
-
-// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
-// types other than kTfLiteDynamic will be ignored.
-void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
-
-// A structure representing an instance of a node.
-// This structure only exhibits the inputs, outputs and user defined data, not
-// other features like the type.
-typedef struct TfLiteNode {
- // Inputs to this node expressed as indices into the simulator's tensors.
- TfLiteIntArray* inputs;
-
- // Outputs to this node expressed as indices into the simulator's tensors.
- TfLiteIntArray* outputs;
-
- // Temporary tensors uses during the computations. This usually contains no
- // tensors, but ops are allowed to change that if they need scratch space of
- // any sort.
- TfLiteIntArray* temporaries;
-
- // Opaque data provided by the node implementer through `Registration.init`.
- void* user_data;
-
- // Opaque data provided to the node if the node is a builtin. This is usually
- // a structure defined in builtin_op_data.h
- void* builtin_data;
-
- // Custom initial data. This is the opaque data provided in the flatbuffer.
- // WARNING: This is an experimental interface that is subject to change.
- const void* custom_initial_data;
- int custom_initial_data_size;
-
- // The pointer to the delegate. This is non-null only when the node is
- // created by calling `interpreter.ModifyGraphWithDelegate`.
- // WARNING: This is an experimental interface that is subject to change.
- struct _TfLiteDelegate* delegate;
-} TfLiteNode;
-
-typedef struct TfLiteContext {
- // Number of tensors in the context.
- size_t tensors_size;
-
- // The execution plan contains a list of the node indices in execution
- // order. execution_plan->size is the current number of nodes. And,
- // execution_plan->data[0] is the first node that needs to be run.
- // TfLiteDelegates can traverse the current execution plan by iterating
- // through each member of this array and using GetNodeAndRegistration() to
- // access details about a node. i.e.
- // TfLiteIntArray* execution_plan;
- // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
- // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
- // int node_index = execution_plan->data[exec_index];
- // TfLiteNode* node;
- // TfLiteRegistration* reg;
- // context->GetNodeAndRegistration(context, node_index, &node, &reg);
- // }
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context,
- TfLiteIntArray** execution_plan);
-
- // An array of tensors in the interpreter context (of length `tensors_size`)
- TfLiteTensor* tensors;
-
- // opaque full context ptr (an opaque c++ data structure)
- void* impl_;
-
- // Request memory pointer be resized. Updates dimensions on the tensor.
- // NOTE: ResizeTensor takes ownership of newSize.
- TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor,
- TfLiteIntArray* new_size);
- // Request that a error be reported with format string msg.
- void (*ReportError)(struct TfLiteContext*, const char* msg, ...);
-
- // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If
- // non-null, the value pointed to by `first_new_tensor_index` will be set to
- // the index of the first new tensor.
- TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add,
- int* first_new_tensor_index);
-
- // Get a Tensor node by node_index.
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteStatus (*GetNodeAndRegistration)(
- struct TfLiteContext*, int node_index, struct TfLiteNode** node,
- struct _TfLiteRegistration** registration);
-
- // Replace ops with one or more stub delegate operations. This function
- // does not take ownership of `nodes_to_replace`.
- TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)(
- struct TfLiteContext*, struct _TfLiteRegistration registration,
- const TfLiteIntArray* nodes_to_replace, struct _TfLiteDelegate* delegate);
-
- // Number of threads that are recommended to subsystems like gemmlowp and
- // eigen.
- int recommended_num_threads;
-
- // Access external contexts by type.
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*,
- TfLiteExternalContextType);
- // Set the value of a external context. Does not take ownership of the
- // pointer.
- // WARNING: This is an experimental interface that is subject to change.
- void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType,
- TfLiteExternalContext*);
-} TfLiteContext;
-
-typedef struct _TfLiteRegistration {
- // Initializes the op from serialized data.
- // If a built-in op:
- // `buffer` is the op's params data (TfLiteLSTMParams*).
- // `length` is zero.
- // If custom op:
- // `buffer` is the op's `custom_options`.
- // `length` is the size of the buffer.
- //
- // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
- // or an instance of a struct).
- //
- // The returned pointer will be stored with the node in the `user_data` field,
- // accessible within prepare and invoke functions below.
- // NOTE: if the data is already in the desired format, simply implement this
- // function to return `nullptr` and implement the free function to be a no-op.
- void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
-
- // The pointer `buffer` is the data previously returned by an init invocation.
- void (*free)(TfLiteContext* context, void* buffer);
-
- // prepare is called when the inputs this node depends on have been resized.
- // context->ResizeTensor() can be called to request output tensors to be
- // resized.
- //
- // Returns kTfLiteOk on success.
- TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
-
- // Execute the node (should read node->inputs and output to node->outputs).
- // Returns kTfLiteOk on success.
- TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
-
- // profiling_string is called during summarization of profiling information
- // in order to group executions together. Providing a value here will cause a
- // given op to appear multiple times is the profiling report. This is
- // particularly useful for custom ops that can perform significantly
- // different calculations depending on their `user-data`.
- const char* (*profiling_string)(const TfLiteContext* context,
- const TfLiteNode* node);
-
- // Builtin codes. If this kernel refers to a builtin this is the code
- // of the builtin. This is so we can do marshaling to other frameworks like
- // NN API.
- // Note: It is the responsibility of the registration binder to set this
- // properly.
- int32_t builtin_code;
-
- // Custom op name. If the op is a builtin, this will be null.
- // Note: It is the responsibility of the registration binder to set this
- // properly.
- // WARNING: This is an experimental interface that is subject to change.
- const char* custom_name;
-
- // The version of the op.
- // Note: It is the responsibility of the registration binder to set this
- // properly.
- int version;
-} TfLiteRegistration;
-
-// WARNING: This is an experimental interface that is subject to change.
-typedef struct _TfLiteDelegate {
- // Data that delegate needs to identify itself. This data is owned by the
- // delegate. The delegate is owned in the user code, so the delegate is
- // responsible for doing this when it is destroyed.
- void* data_;
-
- // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
- // delegate a view of the current graph through TfLiteContext*. It typically
- // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels()
- // to ask the TensorFlow lite runtime to create macro-nodes to represent
- // delegated subgraphs of the original graph.
- TfLiteStatus (*Prepare)(struct TfLiteContext* context,
- struct _TfLiteDelegate* delegate);
-
- // Copy the data from delegate buffer handle to raw memory.
- // This can be null if the delegate doesn't use its own buffer.
- TfLiteStatus (*CopyFromBufferHandle)(struct TfLiteContext* context,
- struct _TfLiteDelegate* delegate,
- TfLiteBufferHandle buffer_handle,
- void* data, size_t size);
-
- // Copy the data from raw memory to delegate buffer handle.
- // This can be null if the delegate doesn't use its own buffer.
- TfLiteStatus (*CopyToBufferHandle)(struct TfLiteContext* context,
- struct _TfLiteDelegate* delegate,
- TfLiteBufferHandle buffer_handle,
- void* data, size_t size);
-
- // Free the Delegate Buffer Handle. Note: This only frees the handle, but
- // this doesn't release the underlying resource (e.g. textures). The
- // resources are either owned by application layer or the delegate.
- // This can be null if the delegate doesn't use its own buffer.
- void (*FreeBufferHandle)(struct TfLiteContext* context,
- struct _TfLiteDelegate* delegate,
- TfLiteBufferHandle* handle);
-} TfLiteDelegate;
-
-// WARNING: This is an experimental interface that is subject to change.
-//
-// Currently, TfLiteDelegateParams has to be allocated in a way that it's
-// trivially destructable. It will be stored as `builtin_data` field in
-// `TfLiteNode` of the delegate node.
-//
-// See also the `CreateDelegateParams` function in `interpreter.cc` details.
-typedef struct {
- TfLiteDelegate* delegate;
- TfLiteIntArray* nodes_to_replace;
- TfLiteIntArray* input_tensors;
- TfLiteIntArray* output_tensors;
-} TfLiteDelegateParams;
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
#endif // TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
diff --git a/tensorflow/contrib/lite/context_util.h b/tensorflow/contrib/lite/context_util.h
index abe802e342..ccda4c7393 100644
--- a/tensorflow/contrib/lite/context_util.h
+++ b/tensorflow/contrib/lite/context_util.h
@@ -17,7 +17,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_
#define TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/core/api/BUILD b/tensorflow/contrib/lite/core/api/BUILD
new file mode 100644
index 0000000000..e4500534f3
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/BUILD
@@ -0,0 +1,57 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+
+cc_library(
+ name = "api",
+ srcs = [
+ "error_reporter.cc",
+ "flatbuffer_conversions.cc",
+ "op_resolver.cc",
+ ],
+ hdrs = [
+ "error_reporter.h",
+ "flatbuffer_conversions.h",
+ "op_resolver.h",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ ],
+)
+
+cc_test(
+ name = "error_reporter_test",
+ size = "small",
+ srcs = ["error_reporter_test.cc"],
+ deps = [
+ ":api",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_test(
+ name = "op_resolver_test",
+ size = "small",
+ srcs = ["op_resolver_test.cc"],
+ deps = [
+ ":api",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_test(
+ name = "flatbuffer_conversions_test",
+ size = "small",
+ srcs = ["flatbuffer_conversions_test.cc"],
+ deps = [
+ ":api",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/tensorflow/contrib/lite/core/api/error_reporter.cc b/tensorflow/contrib/lite/core/api/error_reporter.cc
new file mode 100644
index 0000000000..423f83b1a9
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/error_reporter.cc
@@ -0,0 +1,38 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include <cstdarg>
+
+namespace tflite {
+
+int ErrorReporter::Report(const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ int code = Report(format, args);
+ va_end(args);
+ return code;
+}
+
+// TODO(aselle): Make the name of ReportError on context the same, so
+// we can use the ensure functions w/o a context and w/ a reporter.
+int ErrorReporter::ReportError(void*, const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ int code = Report(format, args);
+ va_end(args);
+ return code;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/core/api/error_reporter.h b/tensorflow/contrib/lite/core/api/error_reporter.h
new file mode 100644
index 0000000000..a2f780b003
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/error_reporter.h
@@ -0,0 +1,45 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_
+#define TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_
+
+#include <cstdarg>
+
+namespace tflite {
+
+// A functor that reports error to supporting system. Invoked similar to
+// printf.
+//
+// Usage:
+// ErrorReporter foo;
+// foo.Report("test %d", 5);
+// or
+// va_list args;
+// foo.Report("test %d", args); // where args is va_list
+//
+// Subclass ErrorReporter to provide another reporting destination.
+// For example, if you have a GUI program, you might redirect to a buffer
+// that drives a GUI error log box.
+class ErrorReporter {
+ public:
+ virtual ~ErrorReporter() {}
+ virtual int Report(const char* format, va_list args) = 0;
+ int Report(const char* format, ...);
+ int ReportError(void*, const char* format, ...);
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_
diff --git a/tensorflow/contrib/lite/core/api/error_reporter_test.cc b/tensorflow/contrib/lite/core/api/error_reporter_test.cc
new file mode 100644
index 0000000000..0463eee6be
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/error_reporter_test.cc
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+
+#include <cstdio>
+
+#include <gtest/gtest.h>
+
+namespace tflite {
+
+class MockErrorReporter : public ErrorReporter {
+ public:
+ int Report(const char* format, va_list args) override {
+ vsnprintf(buffer_, kBufferSize, format, args);
+ return 0;
+ }
+ char* GetBuffer() { return buffer_; }
+
+ private:
+ static constexpr int kBufferSize = 256;
+ char buffer_[kBufferSize];
+};
+
+TEST(ErrorReporter, TestReport) {
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+ reporter->Report("Error: %d", 23);
+ EXPECT_EQ(0, strcmp(mock_reporter.GetBuffer(), "Error: 23"));
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
new file mode 100644
index 0000000000..1420fbcdc6
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
@@ -0,0 +1,622 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h"
+
+#include <cstdlib>
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+
+namespace tflite {
+
+namespace {
+
+// Copies the contents from the flatbuffer int vector `flatbuffer` into the
+// int array `buffer`. `flat_vector` and `buffer` represent the same
+// configuration operation for a given operation.
+void FlatBufferIntVectorToArray(int max_size_of_buffer,
+ const flatbuffers::Vector<int32_t>* flat_vector,
+ int* buffer, ErrorReporter* error_reporter) {
+ if (!flat_vector) {
+ error_reporter->Report("Input array not provided for operation.\n");
+ } else {
+ int num_dimensions = flat_vector->Length();
+ if (num_dimensions > max_size_of_buffer / sizeof(int)) {
+ error_reporter->Report(
+ "Found too many dimensions in the operation's input array.\n");
+ } else {
+ for (int i = 0; i < num_dimensions; ++i) {
+ buffer[i] = flat_vector->Get(i);
+ }
+ }
+ }
+}
+
+// Allocate a structure using malloc, but make sure the structure is a POD
+// structure that doesn't require constructors to run. The reason we do this,
+// is that Interpreter's C extension part will take ownership so destructors
+// will not be run during deallocation.
+template <class T>
+T* MallocPOD() {
+ static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
+ return static_cast<T*>(malloc(sizeof(T)));
+}
+
+} // namespace
+
+TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
+ ErrorReporter* error_reporter) {
+ switch (tensor_type) {
+ case TensorType_FLOAT32:
+ *type = kTfLiteFloat32;
+ break;
+ case TensorType_INT16:
+ *type = kTfLiteInt16;
+ break;
+ case TensorType_INT32:
+ *type = kTfLiteInt32;
+ break;
+ case TensorType_UINT8:
+ *type = kTfLiteUInt8;
+ break;
+ case TensorType_INT64:
+ *type = kTfLiteInt64;
+ break;
+ case TensorType_STRING:
+ *type = kTfLiteString;
+ break;
+ case TensorType_BOOL:
+ *type = kTfLiteBool;
+ break;
+ case TensorType_COMPLEX64:
+ *type = kTfLiteComplex64;
+ break;
+ default:
+ error_reporter->Report("Unimplemented data type %s (%d) in tensor\n",
+ EnumNameTensorType(tensor_type), tensor_type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+// Parse the appropriate data out of the op.
+//
+// This handles builtin data explicitly as there are flatbuffer schemas.
+// If it returns kTfLiteOk, it passes the data out with `builtin_data`, which
+// need to be released by calling `free`.`
+// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
+TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
+ ErrorReporter* error_reporter, void** builtin_data) {
+ auto parse_padding = [](Padding padding) {
+ switch (padding) {
+ case Padding_SAME:
+ return kTfLitePaddingSame;
+ case Padding_VALID:
+ return kTfLitePaddingValid;
+ }
+ return kTfLitePaddingUnknown;
+ };
+ auto parse_activation = [](ActivationFunctionType activation) {
+ switch (activation) {
+ case ActivationFunctionType_NONE:
+ return kTfLiteActNone;
+ case ActivationFunctionType_RELU:
+ return kTfLiteActRelu;
+ case ActivationFunctionType_RELU_N1_TO_1:
+ return kTfLiteActRelu1;
+ case ActivationFunctionType_RELU6:
+ return kTfLiteActRelu6;
+ case ActivationFunctionType_TANH:
+ return kTfLiteActTanh;
+ case ActivationFunctionType_SIGN_BIT:
+ return kTfLiteActSignBit;
+ }
+ return kTfLiteActNone;
+ };
+ auto parseLSHProjectionType = [](LSHProjectionType type) {
+ switch (type) {
+ case LSHProjectionType_SPARSE:
+ return kTfLiteLshProjectionSparse;
+ case LSHProjectionType_DENSE:
+ return kTfLiteLshProjectionDense;
+ default:
+ return kTfLiteLshProjectionUnknown;
+ }
+ };
+ auto parseCombinerType = [](CombinerType type) {
+ switch (type) {
+ case CombinerType_MEAN:
+ return kTfLiteCombinerTypeMean;
+ case CombinerType_SQRTN:
+ return kTfLiteCombinerTypeSqrtn;
+ case CombinerType_SUM:
+ default:
+ return kTfLiteCombinerTypeSum;
+ }
+ };
+
+ *builtin_data = nullptr;
+ switch (op_type) {
+ case BuiltinOperator_CONV_2D: {
+ TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
+ if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
+ params->padding = parse_padding(conv_params->padding());
+ params->stride_width = conv_params->stride_w();
+ params->stride_height = conv_params->stride_h();
+ params->activation =
+ parse_activation(conv_params->fused_activation_function());
+
+ params->dilation_width_factor = conv_params->dilation_w_factor();
+ params->dilation_height_factor = conv_params->dilation_h_factor();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_CAST: {
+ TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
+ if (auto* schema_params = op->builtin_options_as_CastOptions()) {
+ auto in_status =
+ ConvertTensorType(schema_params->in_data_type(),
+ &params->in_data_type, error_reporter);
+ auto out_status =
+ ConvertTensorType(schema_params->out_data_type(),
+ &params->out_data_type, error_reporter);
+ if (in_status != kTfLiteOk || out_status != kTfLiteOk) {
+ free(params);
+ return kTfLiteError;
+ }
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_LSH_PROJECTION: {
+ TfLiteLSHProjectionParams* params =
+ MallocPOD<TfLiteLSHProjectionParams>();
+ if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
+ params->type = parseLSHProjectionType(lshParams->type());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_AVERAGE_POOL_2D:
+ case BuiltinOperator_MAX_POOL_2D:
+ case BuiltinOperator_L2_POOL_2D: {
+ TfLitePoolParams* params = MallocPOD<TfLitePoolParams>();
+ if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
+ params->padding = parse_padding(pool_params->padding());
+ params->stride_width = pool_params->stride_w();
+ params->stride_height = pool_params->stride_h();
+ params->filter_width = pool_params->filter_width();
+ params->filter_height = pool_params->filter_height();
+ params->activation =
+ parse_activation(pool_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_DEPTHWISE_CONV_2D: {
+ TfLiteDepthwiseConvParams* params =
+ MallocPOD<TfLiteDepthwiseConvParams>();
+ if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) {
+ params->padding = parse_padding(conv_params->padding());
+ params->stride_width = conv_params->stride_w();
+ params->stride_height = conv_params->stride_h();
+ params->depth_multiplier = conv_params->depth_multiplier();
+ params->activation =
+ parse_activation(conv_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SVDF: {
+ TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>();
+ if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
+ params->rank = svdf_params->rank();
+ params->activation =
+ parse_activation(svdf_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
+ case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
+ TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>();
+ if (auto* sequence_rnn_params =
+ op->builtin_options_as_SequenceRNNOptions()) {
+ params->activation =
+ parse_activation(sequence_rnn_params->fused_activation_function());
+ params->time_major = sequence_rnn_params->time_major();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_RNN: {
+ TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
+ if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
+ params->activation =
+ parse_activation(rnn_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
+ TfLiteEmbeddingLookupSparseParams* params =
+ MallocPOD<TfLiteEmbeddingLookupSparseParams>();
+ if (auto* embedding_params =
+ op->builtin_options_as_EmbeddingLookupSparseOptions()) {
+ params->combiner = parseCombinerType(embedding_params->combiner());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_FULLY_CONNECTED: {
+ TfLiteFullyConnectedParams* params =
+ MallocPOD<TfLiteFullyConnectedParams>();
+ if (auto* fully_connected_params =
+ op->builtin_options_as_FullyConnectedOptions()) {
+ params->activation = parse_activation(
+ fully_connected_params->fused_activation_function());
+ switch (fully_connected_params->weights_format()) {
+ case FullyConnectedOptionsWeightsFormat_DEFAULT:
+ params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
+ break;
+ case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
+ params->weights_format =
+ kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8;
+ break;
+ default:
+ error_reporter->Report("Unhandled fully-connected weights format.");
+ return kTfLiteError;
+ }
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_HASHTABLE_LOOKUP:
+ // no-op.
+ break;
+ case BuiltinOperator_SOFTMAX: {
+ TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>();
+ if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
+ params->beta = softmax_params->beta();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_CONCATENATION: {
+ TfLiteConcatenationParams* params =
+ MallocPOD<TfLiteConcatenationParams>();
+ if (auto* concatenation_params =
+ op->builtin_options_as_ConcatenationOptions()) {
+ params->activation =
+ parse_activation(concatenation_params->fused_activation_function());
+ params->axis = concatenation_params->axis();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_MUL: {
+ auto* params = MallocPOD<TfLiteMulParams>();
+ if (auto* schema_params = op->builtin_options_as_MulOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_ADD: {
+ auto* params = MallocPOD<TfLiteAddParams>();
+ if (auto* schema_params = op->builtin_options_as_AddOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_DIV: {
+ auto* params = MallocPOD<TfLiteDivParams>();
+ if (auto* schema_params = op->builtin_options_as_DivOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SUB: {
+ auto* params = MallocPOD<TfLiteSubParams>();
+ if (auto* schema_params = op->builtin_options_as_SubOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_L2_NORMALIZATION: {
+ auto* params = MallocPOD<TfLiteL2NormParams>();
+ if (auto* schema_params = op->builtin_options_as_L2NormOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
+ auto* params = MallocPOD<TfLiteLocalResponseNormParams>();
+ if (auto* schema_params =
+ op->builtin_options_as_LocalResponseNormalizationOptions()) {
+ params->radius = schema_params->radius();
+ params->bias = schema_params->bias();
+ params->alpha = schema_params->alpha();
+ params->beta = schema_params->beta();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
+ case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
+ case BuiltinOperator_LSTM: {
+ TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
+ if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
+ params->activation =
+ parse_activation(lstm_params->fused_activation_function());
+ params->cell_clip = lstm_params->cell_clip();
+ params->proj_clip = lstm_params->proj_clip();
+ switch (lstm_params->kernel_type()) {
+ case LSTMKernelType_FULL:
+ params->kernel_type = kTfLiteLSTMFullKernel;
+ break;
+ case LSTMKernelType_BASIC:
+ params->kernel_type = kTfLiteLSTMBasicKernel;
+ break;
+ }
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_RESIZE_BILINEAR: {
+ auto* params = MallocPOD<TfLiteResizeBilinearParams>();
+ if (auto* schema_params =
+ op->builtin_options_as_ResizeBilinearOptions()) {
+ params->align_corners = schema_params->align_corners();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_RESHAPE: {
+ auto* params = MallocPOD<TfLiteReshapeParams>();
+ if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
+ auto* new_shape = schema_params->new_shape();
+ FlatBufferIntVectorToArray(sizeof(params->shape), new_shape,
+ params->shape, error_reporter);
+ params->num_dimensions = new_shape->Length();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SKIP_GRAM: {
+ TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>();
+ if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) {
+ params->ngram_size = skip_gram_params->ngram_size();
+ params->max_skip_size = skip_gram_params->max_skip_size();
+ params->include_all_ngrams = skip_gram_params->include_all_ngrams();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SPACE_TO_DEPTH: {
+ auto* params = MallocPOD<TfLiteSpaceToDepthParams>();
+ if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
+ params->block_size = schema_params->block_size();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_GATHER: {
+ TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>();
+ params->axis = 0;
+ if (auto* gather_params = op->builtin_options_as_GatherOptions()) {
+ params->axis = gather_params->axis();
+ }
+
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_MEAN:
+ case BuiltinOperator_REDUCE_MAX:
+ case BuiltinOperator_REDUCE_MIN:
+ case BuiltinOperator_REDUCE_PROD:
+ case BuiltinOperator_REDUCE_ANY:
+ case BuiltinOperator_SUM: {
+ auto* params = MallocPOD<TfLiteReducerParams>();
+ if (auto* schema_params = op->builtin_options_as_ReducerOptions()) {
+ params->keep_dims = schema_params->keep_dims();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SPLIT: {
+ auto* params = MallocPOD<TfLiteSplitParams>();
+ if (auto* schema_params = op->builtin_options_as_SplitOptions()) {
+ params->num_splits = schema_params->num_splits();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SQUEEZE: {
+ auto* params = MallocPOD<TfLiteSqueezeParams>();
+ if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) {
+ const auto& squeeze_dims = schema_params->squeeze_dims();
+ FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims,
+ params->squeeze_dims, error_reporter);
+ params->num_squeeze_dims = squeeze_dims->Length();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_STRIDED_SLICE: {
+ auto* params = MallocPOD<TfLiteStridedSliceParams>();
+ if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) {
+ params->begin_mask = schema_params->begin_mask();
+ params->end_mask = schema_params->end_mask();
+ params->ellipsis_mask = schema_params->ellipsis_mask();
+ params->new_axis_mask = schema_params->new_axis_mask();
+ params->shrink_axis_mask = schema_params->shrink_axis_mask();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_ARG_MAX: {
+ auto* params = MallocPOD<TfLiteArgMaxParams>();
+ if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) {
+ ConvertTensorType(schema_params->output_type(), &params->output_type,
+ error_reporter);
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_ARG_MIN: {
+ auto* params = MallocPOD<TfLiteArgMinParams>();
+ if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) {
+ ConvertTensorType(schema_params->output_type(), &params->output_type,
+ error_reporter);
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_TRANSPOSE_CONV: {
+ TfLiteTransposeConvParams* params =
+ MallocPOD<TfLiteTransposeConvParams>();
+ if (auto* transpose_conv_params =
+ op->builtin_options_as_TransposeConvOptions()) {
+ params->padding = parse_padding(transpose_conv_params->padding());
+ params->stride_width = transpose_conv_params->stride_w();
+ params->stride_height = transpose_conv_params->stride_h();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SPARSE_TO_DENSE: {
+ TfLiteSparseToDenseParams* params =
+ MallocPOD<TfLiteSparseToDenseParams>();
+ if (auto* sparse_to_dense_params =
+ op->builtin_options_as_SparseToDenseOptions()) {
+ params->validate_indices = sparse_to_dense_params->validate_indices();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SHAPE: {
+ auto* params = MallocPOD<TfLiteShapeParams>();
+ if (auto* schema_params = op->builtin_options_as_ShapeOptions()) {
+ ConvertTensorType(schema_params->out_type(), &params->out_type,
+ error_reporter);
+ }
+ *builtin_data = static_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_PACK: {
+ TfLitePackParams* params = MallocPOD<TfLitePackParams>();
+ if (auto* pack_params = op->builtin_options_as_PackOptions()) {
+ params->values_count = pack_params->values_count();
+ params->axis = pack_params->axis();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_DELEGATE: {
+ // TODO(ycling): Revisit when supporting saving delegated models.
+ error_reporter->Report("DELEGATE op shouldn't exist in model.");
+ return kTfLiteError;
+ }
+ case BuiltinOperator_FAKE_QUANT: {
+ auto* params = MallocPOD<TfLiteFakeQuantParams>();
+ if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) {
+ params->min = schema_params->min();
+ params->max = schema_params->max();
+ params->num_bits = schema_params->num_bits();
+ params->narrow_range = schema_params->narrow_range();
+ }
+ *builtin_data = static_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_ONE_HOT: {
+ auto* params = MallocPOD<TfLiteOneHotParams>();
+ if (auto* schema_params = op->builtin_options_as_OneHotOptions()) {
+ params->axis = schema_params->axis();
+ }
+ *builtin_data = static_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_UNPACK: {
+ TfLiteUnpackParams* params = MallocPOD<TfLiteUnpackParams>();
+ if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) {
+ params->num = unpack_params->num();
+ params->axis = unpack_params->axis();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+
+ // Below are the ops with no builtin_data strcture.
+ case BuiltinOperator_BATCH_TO_SPACE_ND:
+ // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
+ // ok for now, since there is no call implementation either.
+ case BuiltinOperator_CALL:
+ case BuiltinOperator_CONCAT_EMBEDDINGS:
+ case BuiltinOperator_CUSTOM:
+ case BuiltinOperator_DEQUANTIZE:
+ case BuiltinOperator_EMBEDDING_LOOKUP:
+ case BuiltinOperator_EQUAL:
+ case BuiltinOperator_EXP:
+ case BuiltinOperator_EXPAND_DIMS:
+ case BuiltinOperator_FLOOR:
+ case BuiltinOperator_GREATER:
+ case BuiltinOperator_GREATER_EQUAL:
+ case BuiltinOperator_LESS:
+ case BuiltinOperator_LESS_EQUAL:
+ case BuiltinOperator_LOG:
+ case BuiltinOperator_LOGISTIC:
+ case BuiltinOperator_LOG_SOFTMAX:
+ case BuiltinOperator_MAXIMUM:
+ case BuiltinOperator_MINIMUM:
+ case BuiltinOperator_NEG:
+ case BuiltinOperator_NOT_EQUAL:
+ case BuiltinOperator_PAD:
+ case BuiltinOperator_PADV2:
+ case BuiltinOperator_PRELU:
+ case BuiltinOperator_RELU:
+ case BuiltinOperator_RELU6:
+ case BuiltinOperator_RELU_N1_TO_1:
+ case BuiltinOperator_RSQRT:
+ case BuiltinOperator_SELECT:
+ case BuiltinOperator_SIN:
+ case BuiltinOperator_SLICE:
+ case BuiltinOperator_SPACE_TO_BATCH_ND:
+ case BuiltinOperator_SQRT:
+ case BuiltinOperator_TANH:
+ case BuiltinOperator_TILE:
+ case BuiltinOperator_TOPK_V2:
+ case BuiltinOperator_TRANSPOSE:
+ case BuiltinOperator_POW:
+ case BuiltinOperator_LOGICAL_OR:
+ case BuiltinOperator_LOGICAL_AND:
+ case BuiltinOperator_LOGICAL_NOT:
+ case BuiltinOperator_FLOOR_DIV:
+ break;
+ }
+ return kTfLiteOk;
+} // NOLINT[readability/fn_size]
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h
new file mode 100644
index 0000000000..4dec6f9cfc
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h
@@ -0,0 +1,48 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
+#define TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
+
+// These functions transform codes and data structures that are defined in the
+// flatbuffer serialization format into in-memory values that are used by the
+// runtime API and interpreter.
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+
+// Parse the appropriate data out of the op.
+//
+// This handles builtin data explicitly as there are flatbuffer schemas.
+// If it returns kTfLiteOk, it passes the data out with `builtin_data`. The
+// calling function has to pass in an allocator object, and this allocator
+// will be called to reserve space for the output data. If the calling
+// function's allocator reserves memory on the heap, then it's the calling
+// function's responsibility to free it.
+// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
+TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
+ ErrorReporter* error_reporter, void** builtin_data);
+
+// Converts the tensor data type used in the flat buffer to the representation
+// used by the runtime.
+TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
+ ErrorReporter* error_reporter);
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc
new file mode 100644
index 0000000000..b12bdf43b2
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc
@@ -0,0 +1,104 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h"
+
+#include <cstring>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+
+namespace tflite {
+namespace {
+
+class MockErrorReporter : public ErrorReporter {
+ public:
+ MockErrorReporter() : buffer_size_(0) {}
+ int Report(const char* format, va_list args) override {
+ buffer_size_ = vsnprintf(buffer_, kBufferSize, format, args);
+ return buffer_size_;
+ }
+ char* GetBuffer() { return buffer_; }
+ int GetBufferSize() { return buffer_size_; }
+
+ private:
+ static constexpr int kBufferSize = 256;
+ char buffer_[kBufferSize];
+ int buffer_size_;
+};
+
+} // namespace
+
+TEST(FlatbufferConversions, TestParseOpDataConv) {
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<void> conv_options =
+ CreateConv2DOptions(builder, Padding_SAME, 1, 2,
+ ActivationFunctionType_RELU, 3, 4)
+ .Union();
+ flatbuffers::Offset<Operator> conv_offset = CreateOperatorDirect(
+ builder, 0, nullptr, nullptr, BuiltinOptions_Conv2DOptions, conv_options,
+ nullptr, CustomOptionsFormat_FLEXBUFFERS, nullptr);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const Operator* conv_op = flatbuffers::GetRoot<Operator>(conv_pointer);
+ void* output_data = nullptr;
+ EXPECT_EQ(kTfLiteOk, ParseOpData(conv_op, BuiltinOperator_CONV_2D, reporter,
+ &output_data));
+ EXPECT_NE(nullptr, output_data);
+ TfLiteConvParams* params = reinterpret_cast<TfLiteConvParams*>(output_data);
+ EXPECT_EQ(kTfLitePaddingSame, params->padding);
+ EXPECT_EQ(1, params->stride_width);
+ EXPECT_EQ(2, params->stride_height);
+ EXPECT_EQ(kTfLiteActRelu, params->activation);
+ EXPECT_EQ(3, params->dilation_width_factor);
+ EXPECT_EQ(4, params->dilation_height_factor);
+ free(output_data);
+}
+
+TEST(FlatbufferConversions, TestParseOpDataCustom) {
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<void> null_options;
+ flatbuffers::Offset<Operator> custom_offset = CreateOperatorDirect(
+ builder, 0, nullptr, nullptr, BuiltinOptions_NONE, null_options, nullptr,
+ CustomOptionsFormat_FLEXBUFFERS, nullptr);
+ builder.Finish(custom_offset);
+ void* custom_pointer = builder.GetBufferPointer();
+ const Operator* custom_op = flatbuffers::GetRoot<Operator>(custom_pointer);
+ void* output_data = nullptr;
+ EXPECT_EQ(kTfLiteOk, ParseOpData(custom_op, BuiltinOperator_CUSTOM, reporter,
+ &output_data));
+ EXPECT_EQ(nullptr, output_data);
+}
+
+TEST(FlatbufferConversions, TestConvertTensorType) {
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+ TfLiteType type;
+ EXPECT_EQ(kTfLiteOk, ConvertTensorType(TensorType_FLOAT32, &type, reporter));
+ EXPECT_EQ(kTfLiteFloat32, type);
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/core/api/op_resolver.cc b/tensorflow/contrib/lite/core/api/op_resolver.cc
new file mode 100644
index 0000000000..55ee924843
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/op_resolver.cc
@@ -0,0 +1,60 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+
+namespace tflite {
+
+TfLiteStatus GetRegistrationFromOpCode(
+ const OperatorCode* opcode, const OpResolver& op_resolver,
+ ErrorReporter* error_reporter, const TfLiteRegistration** registration) {
+ TfLiteStatus status = kTfLiteOk;
+ *registration = nullptr;
+ auto builtin_code = opcode->builtin_code();
+ int version = opcode->version();
+
+ if (builtin_code > BuiltinOperator_MAX ||
+ builtin_code < BuiltinOperator_MIN) {
+ error_reporter->Report(
+ "Op builtin_code out of range: %d. Are you using old TFLite binary "
+ "with newer model?",
+ builtin_code);
+ status = kTfLiteError;
+ } else if (builtin_code != BuiltinOperator_CUSTOM) {
+ *registration = op_resolver.FindOp(builtin_code, version);
+ if (*registration == nullptr) {
+ error_reporter->Report(
+ "Didn't find op for builtin opcode '%s' version '%d'\n",
+ EnumNameBuiltinOperator(builtin_code), version);
+ status = kTfLiteError;
+ }
+ } else if (!opcode->custom_code()) {
+ error_reporter->Report(
+ "Operator with CUSTOM builtin_code has no custom_code.\n");
+ status = kTfLiteError;
+ } else {
+ const char* name = opcode->custom_code()->c_str();
+ *registration = op_resolver.FindOp(name, version);
+ if (*registration == nullptr) {
+ error_reporter->Report(
+ "Didn't find custom op for name '%s' with version %d\n", name,
+ version);
+ status = kTfLiteError;
+ }
+ }
+ return status;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/core/api/op_resolver.h b/tensorflow/contrib/lite/core/api/op_resolver.h
new file mode 100644
index 0000000000..5f5e6b2736
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/op_resolver.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_
+#define TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+
+// Abstract interface that returns TfLiteRegistrations given op codes or custom
+// op names. This is the mechanism that ops being referenced in the flatbuffer
+// model are mapped to executable function pointers (TfLiteRegistrations).
+class OpResolver {
+ public:
+ // Finds the op registration for a builtin operator by enum code.
+ virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
+ int version) const = 0;
+ // Finds the op registration of a custom operator by op name.
+ virtual const TfLiteRegistration* FindOp(const char* op,
+ int version) const = 0;
+ virtual ~OpResolver() {}
+};
+
+// Handles the logic for converting between an OperatorCode structure extracted
+// from a flatbuffer and information about a registered operator implementation.
+TfLiteStatus GetRegistrationFromOpCode(const OperatorCode* opcode,
+ const OpResolver& op_resolver,
+ ErrorReporter* error_reporter,
+ const TfLiteRegistration** registration);
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/core/api/op_resolver_test.cc b/tensorflow/contrib/lite/core/api/op_resolver_test.cc
new file mode 100644
index 0000000000..167463110e
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/op_resolver_test.cc
@@ -0,0 +1,197 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+
+#include <cstring>
+
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace {
+void* MockInit(TfLiteContext* context, const char* buffer, size_t length) {
+ // Do nothing.
+ return nullptr;
+}
+
+void MockFree(TfLiteContext* context, void* buffer) {
+ // Do nothing.
+}
+
+TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+TfLiteStatus MockInvoke(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+class MockOpResolver : public OpResolver {
+ public:
+ const TfLiteRegistration* FindOp(BuiltinOperator op,
+ int version) const override {
+ if (op == BuiltinOperator_CONV_2D) {
+ static TfLiteRegistration r = {MockInit, MockFree, MockPrepare,
+ MockInvoke};
+ return &r;
+ } else {
+ return nullptr;
+ }
+ }
+ const TfLiteRegistration* FindOp(const char* op, int version) const override {
+ if (strcmp(op, "mock_custom") == 0) {
+ static TfLiteRegistration r = {MockInit, MockFree, MockPrepare,
+ MockInvoke};
+ return &r;
+ } else {
+ return nullptr;
+ }
+ }
+};
+
+class MockErrorReporter : public ErrorReporter {
+ public:
+ MockErrorReporter() : buffer_size_(0) {}
+ int Report(const char* format, va_list args) override {
+ buffer_size_ = vsnprintf(buffer_, kBufferSize, format, args);
+ return buffer_size_;
+ }
+ char* GetBuffer() { return buffer_; }
+ int GetBufferSize() { return buffer_size_; }
+
+ private:
+ static constexpr int kBufferSize = 256;
+ char buffer_[kBufferSize];
+ int buffer_size_;
+};
+
+} // namespace
+
+TEST(OpResolver, TestResolver) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+
+ const TfLiteRegistration* registration =
+ resolver->FindOp(BuiltinOperator_CONV_2D, 0);
+ EXPECT_NE(nullptr, registration);
+ EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+
+ registration = resolver->FindOp(BuiltinOperator_CAST, 0);
+ EXPECT_EQ(nullptr, registration);
+
+ registration = resolver->FindOp("mock_custom", 0);
+ EXPECT_NE(nullptr, registration);
+ EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+
+ registration = resolver->FindOp("nonexistent_custom", 0);
+ EXPECT_EQ(nullptr, registration);
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeConv) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<OperatorCode> conv_offset =
+ CreateOperatorCodeDirect(builder, BuiltinOperator_CONV_2D, nullptr, 0);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const OperatorCode* conv_code =
+ flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+ const TfLiteRegistration* registration = nullptr;
+ EXPECT_EQ(kTfLiteOk, GetRegistrationFromOpCode(conv_code, *resolver, reporter,
+ &registration));
+ EXPECT_NE(nullptr, registration);
+ EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+ EXPECT_EQ(0, mock_reporter.GetBufferSize());
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeCast) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<OperatorCode> conv_offset =
+ CreateOperatorCodeDirect(builder, BuiltinOperator_CAST, nullptr, 0);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const OperatorCode* conv_code =
+ flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+ const TfLiteRegistration* registration = nullptr;
+ EXPECT_EQ(kTfLiteError, GetRegistrationFromOpCode(conv_code, *resolver,
+ reporter, &registration));
+ EXPECT_EQ(nullptr, registration);
+ EXPECT_NE(0, mock_reporter.GetBufferSize());
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeCustom) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<OperatorCode> conv_offset = CreateOperatorCodeDirect(
+ builder, BuiltinOperator_CUSTOM, "mock_custom", 0);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const OperatorCode* conv_code =
+ flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+ const TfLiteRegistration* registration = nullptr;
+ EXPECT_EQ(kTfLiteOk, GetRegistrationFromOpCode(conv_code, *resolver, reporter,
+ &registration));
+ EXPECT_NE(nullptr, registration);
+ EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+ EXPECT_EQ(0, mock_reporter.GetBufferSize());
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeNonexistentCustom) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<OperatorCode> conv_offset = CreateOperatorCodeDirect(
+ builder, BuiltinOperator_CUSTOM, "nonexistent_custom", 0);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const OperatorCode* conv_code =
+ flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+ const TfLiteRegistration* registration = nullptr;
+ EXPECT_EQ(kTfLiteError, GetRegistrationFromOpCode(conv_code, *resolver,
+ reporter, &registration));
+ EXPECT_EQ(nullptr, registration);
+ EXPECT_NE(0, mock_reporter.GetBufferSize());
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD
index b6b2357873..bf5d91899c 100644
--- a/tensorflow/contrib/lite/delegates/eager/BUILD
+++ b/tensorflow/contrib/lite/delegates/eager/BUILD
@@ -16,6 +16,7 @@ cc_library(
deps = [
":util",
"//tensorflow/c:c_api_internal",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite:kernel_api",
] + select({
"//tensorflow:android": [
@@ -54,6 +55,7 @@ cc_library(
":delegate_data",
":kernel",
":util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite:util",
] + select({
@@ -104,6 +106,7 @@ tf_cc_test(
":delegate_data",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
],
@@ -117,6 +120,7 @@ cc_library(
":delegate_data",
":util",
"@flatbuffers",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite:string",
"//tensorflow/contrib/lite/kernels:kernel_util",
@@ -170,6 +174,7 @@ cc_library(
hdrs = ["util.h"],
deps = [
"//tensorflow/c:c_api_internal",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite:kernel_api",
] + select({
"//tensorflow:android": [
diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.h b/tensorflow/contrib/lite/delegates/eager/buffer_map.h
index a28329ae7d..aaaa045840 100644
--- a/tensorflow/contrib/lite/delegates/eager/buffer_map.h
+++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.h
@@ -17,7 +17,7 @@ limitations under the License.
#include <map>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/core/framework/tensor.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/eager/delegate.h
index 6d15ba47dc..70f3c15af4 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate.h
+++ b/tensorflow/contrib/lite/delegates/eager/delegate.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_
#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
index b3a0ffcec1..def063309f 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/testing/util.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/eager/kernel.cc
index 0ee4db1ffb..274c3c082a 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel.cc
+++ b/tensorflow/contrib/lite/delegates/eager/kernel.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "flatbuffers/flexbuffers.h" // flatbuffers
#include "tensorflow/contrib/lite/builtin_ops.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/context_util.h"
#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
#include "tensorflow/contrib/lite/delegates/eager/util.h"
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.h b/tensorflow/contrib/lite/delegates/eager/kernel.h
index 100672c82d..2478abccaa 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel.h
+++ b/tensorflow/contrib/lite/delegates/eager/kernel.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
namespace eager {
diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h
index ff500d18f3..930cb99cb9 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.h
+++ b/tensorflow/contrib/lite/delegates/eager/util.h
@@ -16,7 +16,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_
#include "tensorflow/c/c_api_internal.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
diff --git a/tensorflow/contrib/lite/delegates/nnapi/BUILD b/tensorflow/contrib/lite/delegates/nnapi/BUILD
index 954955f24b..4e7b2948fb 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/BUILD
+++ b/tensorflow/contrib/lite/delegates/nnapi/BUILD
@@ -13,6 +13,7 @@ cc_library(
deps = [
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:kernel_api",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:kernel_util",
"//tensorflow/contrib/lite/nnapi:nnapi_lib",
],
@@ -29,6 +30,7 @@ tf_cc_test(
deps = [
":nnapi_delegate",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
index 980a1cb4a0..e3eebac4da 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/allocation.h"
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/builtin_ops.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/context_util.h"
#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h
index 44cca2fd28..4852b76974 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_
#define TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h
index 3c5f805f12..5c20eedc25 100644
--- a/tensorflow/contrib/lite/error_reporter.h
+++ b/tensorflow/contrib/lite/error_reporter.h
@@ -12,43 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// Compatibility shim for moved header location.
#ifndef TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
#define TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
-#include <cstdarg>
-#include "tensorflow/contrib/lite/context.h"
-
-namespace tflite {
-
-// A functor that reports error to supporting system. Invoked similar to
-// printf.
-//
-// Usage:
-// ErrorReporter foo;
-// foo.Report("test %d", 5);
-// or
-// va_list args;
-// foo.Report("test %d", args); // where args is va_list
-//
-// Subclass ErrorReporter to provide another reporting destination.
-// For example, if you have a GUI program, you might redirect to a buffer
-// that drives a GUI error log box.
-class ErrorReporter {
- public:
- virtual ~ErrorReporter();
- virtual int Report(const char* format, va_list args) = 0;
- int Report(const char* format, ...);
- int ReportError(void*, const char* format, ...);
-};
-
-// An error reporter that simplify writes the message to stderr.
-struct StderrReporter : public ErrorReporter {
- int Report(const char* format, va_list args) override;
-};
-
-// Return the default error reporter (output to stderr).
-ErrorReporter* DefaultErrorReporter();
-
-} // namespace tflite
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/stderr_reporter.h"
#endif // TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD
index 8fc07e8eb7..ea4a543252 100644
--- a/tensorflow/contrib/lite/experimental/c/BUILD
+++ b/tensorflow/contrib/lite/experimental/c/BUILD
@@ -78,6 +78,7 @@ cc_test(
data = ["//tensorflow/contrib/lite:testdata/add.bin"],
deps = [
":c_api",
+ "//tensorflow/contrib/lite:context",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc
index a4ab0e8c30..c589cf71ea 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/experimental/c/c_api.h"
+#include <memory>
+
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/experimental/c/c_api_internal.h"
#include "tensorflow/contrib/lite/interpreter.h"
@@ -29,12 +31,14 @@ extern "C" {
TFL_Model* TFL_NewModel(const void* model_data, size_t model_size) {
auto model = tflite::FlatBufferModel::BuildFromBuffer(
static_cast<const char*>(model_data), model_size);
- return model ? new TFL_Model{std::move(model)} : nullptr;
+ std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release());
+ return shared_model ? new TFL_Model{std::move(shared_model)} : nullptr;
}
TFL_Model* TFL_NewModelFromFile(const char* model_path) {
auto model = tflite::FlatBufferModel::BuildFromFile(model_path);
- return model ? new TFL_Model{std::move(model)} : nullptr;
+ std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release());
+ return shared_model ? new TFL_Model{std::move(shared_model)} : nullptr;
}
void TFL_DeleteModel(TFL_Model* model) { delete model; }
@@ -72,7 +76,7 @@ TFL_Interpreter* TFL_NewInterpreter(
}
}
- return new TFL_Interpreter{std::move(interpreter)};
+ return new TFL_Interpreter{model->impl, std::move(interpreter)};
}
void TFL_DeleteInterpreter(TFL_Interpreter* interpreter) { delete interpreter; }
@@ -129,6 +133,8 @@ void* TFL_TensorData(const TFL_Tensor* tensor) {
return static_cast<void*>(tensor->data.raw);
}
+const char* TFL_TensorName(const TFL_Tensor* tensor) { return tensor->name; }
+
TFL_Status TFL_TensorCopyFromBuffer(TFL_Tensor* tensor, const void* input_data,
size_t input_data_size) {
if (tensor->bytes != input_data_size) {
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.h b/tensorflow/contrib/lite/experimental/c/c_api.h
index 3757349b55..b429e76870 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api.h
@@ -93,7 +93,8 @@ typedef struct TFL_Interpreter TFL_Interpreter;
// failure.
//
// * `model` must be a valid model instance. The caller retains ownership of the
-// object, and can destroy it immediately after creating the interpreter.
+// object, and can destroy it immediately after creating the interpreter; the
+// interpreter will maintain its own reference to the underlying model data.
// * `optional_options` may be null. The caller retains ownership of the object,
// and can safely destroy it immediately after creating the interpreter.
//
@@ -145,6 +146,11 @@ TFL_CAPI_EXPORT extern int32_t TFL_InterpreterGetOutputTensorCount(
// Returns the tensor associated with the output index.
// REQUIRES: 0 <= input_index < TFL_InterpreterGetOutputTensorCount(tensor)
+//
+// NOTE: The shape and underlying data buffer for output tensors may be not
+// be available until after the output tensor has been both sized and allocated.
+// In general, best practice is to interact with the output tensor *after*
+// calling TFL_InterpreterInvoke().
TFL_CAPI_EXPORT extern const TFL_Tensor* TFL_InterpreterGetOutputTensor(
const TFL_Interpreter* interpreter, int32_t output_index);
@@ -172,12 +178,15 @@ TFL_CAPI_EXPORT extern size_t TFL_TensorByteSize(const TFL_Tensor* tensor);
// Returns a pointer to the underlying data buffer.
//
-// Note: The result may be null if tensors have not yet been allocated, e.g.,
+// NOTE: The result may be null if tensors have not yet been allocated, e.g.,
// if the Tensor has just been created or resized and `TFL_AllocateTensors()`
// has yet to be called, or if the output tensor is dynamically sized and the
// interpreter hasn't been invoked.
TFL_CAPI_EXPORT extern void* TFL_TensorData(const TFL_Tensor* tensor);
+// Returns the (null-terminated) name of the tensor.
+TFL_CAPI_EXPORT extern const char* TFL_TensorName(const TFL_Tensor* tensor);
+
// Copies from the provided input buffer into the tensor's buffer.
// REQUIRES: input_data_size == TFL_TensorByteSize(tensor)
TFL_CAPI_EXPORT extern TFL_Status TFL_TensorCopyFromBuffer(
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_internal.h b/tensorflow/contrib/lite/experimental/c/c_api_internal.h
index c5c612a4c6..60c2e4e2cd 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_internal.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api_internal.h
@@ -24,7 +24,8 @@ limitations under the License.
// not be depended on.
struct TFL_Model {
- std::unique_ptr<tflite::FlatBufferModel> impl;
+ // Sharing is safe as FlatBufferModel is const.
+ std::shared_ptr<const tflite::FlatBufferModel> impl;
};
struct TFL_InterpreterOptions {
@@ -35,6 +36,9 @@ struct TFL_InterpreterOptions {
};
struct TFL_Interpreter {
+ // Taking a reference to the (const) model data avoids lifetime-related issues
+ // and complexity with the TFL_Model's existence.
+ std::shared_ptr<const tflite::FlatBufferModel> model;
std::unique_ptr<tflite::Interpreter> impl;
};
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_test.cc
index a631dae890..649dac8d1a 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_test.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_test.cc
@@ -55,6 +55,8 @@ TEST(CApiSimple, Smoke) {
EXPECT_EQ(TFL_TensorNumDims(input_tensor), 1);
EXPECT_EQ(TFL_TensorDim(input_tensor, 0), 2);
EXPECT_EQ(TFL_TensorByteSize(input_tensor), sizeof(float) * 2);
+ EXPECT_NE(TFL_TensorData(input_tensor), nullptr);
+ EXPECT_STREQ(TFL_TensorName(input_tensor), "input");
std::array<float, 2> input = {1.f, 3.f};
ASSERT_EQ(TFL_TensorCopyFromBuffer(input_tensor, input.data(),
@@ -70,6 +72,8 @@ TEST(CApiSimple, Smoke) {
EXPECT_EQ(TFL_TensorNumDims(output_tensor), 1);
EXPECT_EQ(TFL_TensorDim(output_tensor, 0), 2);
EXPECT_EQ(TFL_TensorByteSize(output_tensor), sizeof(float) * 2);
+ EXPECT_NE(TFL_TensorData(output_tensor), nullptr);
+ EXPECT_STREQ(TFL_TensorName(output_tensor), "output");
std::array<float, 2> output;
ASSERT_EQ(TFL_TensorCopyToBuffer(output_tensor, output.data(),
diff --git a/tensorflow/contrib/lite/experimental/kernels/BUILD b/tensorflow/contrib/lite/experimental/kernels/BUILD
index 9c06c4ebd9..4786cc62f9 100644
--- a/tensorflow/contrib/lite/experimental/kernels/BUILD
+++ b/tensorflow/contrib/lite/experimental/kernels/BUILD
@@ -53,6 +53,7 @@ cc_library(
"//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/kernels:kernel_util",
@@ -61,8 +62,8 @@ cc_library(
"//tensorflow/contrib/lite/kernels/internal:optimized",
"//tensorflow/contrib/lite/kernels/internal:optimized_base",
"//tensorflow/contrib/lite/kernels/internal:quantization_util",
- "//tensorflow/contrib/lite/kernels/internal:reference",
"//tensorflow/contrib/lite/kernels/internal:reference_base",
+ "//tensorflow/contrib/lite/kernels/internal:tensor",
"//tensorflow/contrib/lite/kernels/internal:tensor_utils",
"@flatbuffers",
],
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
index 121997dcb2..8442c4d46c 100644
--- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include <vector>
#include "flatbuffers/flexbuffers.h" // flatbuffers
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/graph_info.h b/tensorflow/contrib/lite/graph_info.h
index 77268d7aeb..8ee83827bb 100644
--- a/tensorflow/contrib/lite/graph_info.h
+++ b/tensorflow/contrib/lite/graph_info.h
@@ -17,7 +17,7 @@ limitations under the License.
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 5ab53f4c1d..3f8f4d198f 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -21,9 +21,9 @@ limitations under the License.
#include <cstring>
#include "tensorflow/contrib/lite/arena_planner.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/context_util.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/graph_info.h"
#include "tensorflow/contrib/lite/memory_planner.h"
#include "tensorflow/contrib/lite/nnapi_delegate.h"
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 2b1f1819b9..f0cd178c19 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -23,10 +23,11 @@ limitations under the License.
#include <vector>
#include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/memory_planner.h"
#include "tensorflow/contrib/lite/profiling/profiler.h"
+#include "tensorflow/contrib/lite/stderr_reporter.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 5bcf0927d8..cdede430e2 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/interpreter.h"
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD
index 06f46fb923..781289ceb2 100644
--- a/tensorflow/contrib/lite/java/ovic/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/BUILD
@@ -35,6 +35,7 @@ java_binary(
"//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt",
],
main_class = "org.tensorflow.ovic.OvicValidator",
+ tags = ["no_oss"],
deps = [
"//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib_java",
],
@@ -47,6 +48,7 @@ android_library(
"src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
],
manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml",
+ tags = ["no_oss"],
deps = [
"//tensorflow/contrib/lite/java:tensorflowlite",
"//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
@@ -61,6 +63,7 @@ java_library(
"src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
],
javacopts = JAVACOPTS,
+ tags = ["no_oss"],
deps = [
"//tensorflow/contrib/lite/java:libtensorflowlite_jni.so",
"//tensorflow/contrib/lite/java:tensorflowlite_java",
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
index 55ca47fed7..06b35d77c8 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
@@ -20,7 +20,7 @@ limitations under the License.
#include <stdio.h>
#include <time.h>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h"
#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h"
@@ -124,9 +124,9 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
*/
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jint num_threads);
+ jclass clazz,
+ jlong handle,
+ jint num_threads);
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
* Method:
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
index c020f13d9c..2f73128bdf 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
@@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_
#include <jni.h>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#ifdef __cplusplus
extern "C" {
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index b7c5cbf207..40f28aeab4 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -66,7 +66,7 @@ cc_library(
deps = [
":op_macros",
"//tensorflow/contrib/lite:arena_planner",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels/internal:optimized",
],
)
@@ -82,7 +82,7 @@ cc_library(
copts = tflite_copts(),
deps = [
":op_macros",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"@gemmlowp",
],
)
@@ -93,7 +93,7 @@ cc_library(
"activation_functor.h",
],
deps = [
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -113,9 +113,9 @@ cc_library(
"kernel_util.h",
],
deps = [
- "//tensorflow/contrib/lite:builtin_op_data",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels/internal:round",
+ "//tensorflow/contrib/lite/kernels/internal:types",
],
)
@@ -147,6 +147,15 @@ tf_cc_test(
)
cc_library(
+ name = "padding",
+ srcs = [],
+ hdrs = ["padding.h"],
+ deps = [
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ ],
+)
+
+cc_library(
name = "builtin_op_kernels",
srcs = [
"activations.cc",
@@ -216,7 +225,6 @@ cc_library(
"unpack.cc",
],
hdrs = [
- "padding.h",
],
copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
visibility = ["//visibility:private"],
@@ -225,18 +233,19 @@ cc_library(
":eigen_support",
":kernel_util",
":op_macros",
- "//tensorflow/contrib/lite:builtin_op_data",
+ ":padding",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
"//tensorflow/contrib/lite:util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/kernels/internal:audio_utils",
"//tensorflow/contrib/lite/kernels/internal:kernel_utils",
"//tensorflow/contrib/lite/kernels/internal:optimized",
"//tensorflow/contrib/lite/kernels/internal:optimized_base",
"//tensorflow/contrib/lite/kernels/internal:quantization_util",
- "//tensorflow/contrib/lite/kernels/internal:reference",
"//tensorflow/contrib/lite/kernels/internal:reference_base",
+ "//tensorflow/contrib/lite/kernels/internal:tensor",
"//tensorflow/contrib/lite/kernels/internal:tensor_utils",
"@farmhash_archive//:farmhash",
"@flatbuffers",
@@ -251,6 +260,7 @@ cc_library(
":builtin_op_kernels",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -757,8 +767,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -774,8 +784,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1044,8 +1054,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1147,8 +1157,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1164,8 +1174,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1181,8 +1191,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1198,8 +1208,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1212,8 +1222,8 @@ tf_cc_test(
tags = ["tflite_not_portable_ios"],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1239,8 +1249,8 @@ tf_cc_test(
tags = ["tflite_not_portable_ios"],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
diff --git a/tensorflow/contrib/lite/kernels/activation_functor.h b/tensorflow/contrib/lite/kernels/activation_functor.h
index 41ec3cca33..e075dc7054 100644
--- a/tensorflow/contrib/lite/kernels/activation_functor.h
+++ b/tensorflow/contrib/lite/kernels/activation_functor.h
@@ -19,7 +19,7 @@ limitations under the License.
#include <cmath>
#include <cstdlib>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index 5cdd9fc94f..b2d9b84979 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc
index af9b5c7013..b4393e8097 100644
--- a/tensorflow/contrib/lite/kernels/add.cc
+++ b/tensorflow/contrib/lite/kernels/add.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/arg_min_max.cc b/tensorflow/contrib/lite/kernels/arg_min_max.cc
index 6e05f5a9b2..b91e348c27 100644
--- a/tensorflow/contrib/lite/kernels/arg_min_max.cc
+++ b/tensorflow/contrib/lite/kernels/arg_min_max.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
index 1170d84553..44ef587244 100644
--- a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
+++ b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/spectrogram.h"
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc
index c5a5c0182f..1aa27602e5 100644
--- a/tensorflow/contrib/lite/kernels/basic_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <stddef.h>
#include <stdint.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
index 4efa9d596d..fe2865dfb9 100644
--- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
+++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index 6b8ecdd5c3..541f320138 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
index d988ef8b33..2f896c5289 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc
index 8dd48af57f..a7972140ac 100644
--- a/tensorflow/contrib/lite/kernels/cast.cc
+++ b/tensorflow/contrib/lite/kernels/cast.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <string.h>
#include <algorithm>
#include <complex>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
index 8b4d778332..4cd96348a2 100644
--- a/tensorflow/contrib/lite/kernels/comparisons.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc
index 605a20ac3e..25ea556d5a 100644
--- a/tensorflow/contrib/lite/kernels/concatenation.cc
+++ b/tensorflow/contrib/lite/kernels/concatenation.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index 3ed0cdb131..ab6bdaecaa 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/eigen_support.h"
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h"
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
index 21518156b8..347515f289 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
diff --git a/tensorflow/contrib/lite/kernels/dequantize.cc b/tensorflow/contrib/lite/kernels/dequantize.cc
index 2b0f04489a..3a08f48b00 100644
--- a/tensorflow/contrib/lite/kernels/dequantize.cc
+++ b/tensorflow/contrib/lite/kernels/dequantize.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
index 136697f945..d2906632d7 100644
--- a/tensorflow/contrib/lite/kernels/detection_postprocess.cc
+++ b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include <numeric>
#include <vector>
#include "flatbuffers/flexbuffers.h" // flatbuffers
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc
index d7420ddd8e..7945c095b1 100644
--- a/tensorflow/contrib/lite/kernels/div.cc
+++ b/tensorflow/contrib/lite/kernels/div.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h
index b235829642..feb1543f7b 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.h
+++ b/tensorflow/contrib/lite/kernels/eigen_support.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace EigenForTFLite {
struct ThreadPoolDevice;
diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc
index e19779ea59..04995d70dd 100644
--- a/tensorflow/contrib/lite/kernels/elementwise.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include <cmath>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
index b2dff87e62..fe33f98eb0 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
@@ -37,8 +37,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
index d3be36993c..aa75b03990 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
@@ -65,8 +65,8 @@ limitations under the License.
#include <algorithm>
#include <cmath>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/exp.cc b/tensorflow/contrib/lite/kernels/exp.cc
index ce03cdfe26..673e7be90a 100644
--- a/tensorflow/contrib/lite/kernels/exp.cc
+++ b/tensorflow/contrib/lite/kernels/exp.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/expand_dims.cc b/tensorflow/contrib/lite/kernels/expand_dims.cc
index ed33012864..fa1140b19c 100644
--- a/tensorflow/contrib/lite/kernels/expand_dims.cc
+++ b/tensorflow/contrib/lite/kernels/expand_dims.cc
@@ -15,8 +15,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/expand_dims_test.cc b/tensorflow/contrib/lite/kernels/expand_dims_test.cc
index 50dc860e5a..a3bc1813db 100644
--- a/tensorflow/contrib/lite/kernels/expand_dims_test.cc
+++ b/tensorflow/contrib/lite/kernels/expand_dims_test.cc
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/fake_quant.cc b/tensorflow/contrib/lite/kernels/fake_quant.cc
index 0ef1a50b30..f9bc3747cb 100644
--- a/tensorflow/contrib/lite/kernels/fake_quant.cc
+++ b/tensorflow/contrib/lite/kernels/fake_quant.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/floor.cc b/tensorflow/contrib/lite/kernels/floor.cc
index f7d5f5146d..59ff77f35b 100644
--- a/tensorflow/contrib/lite/kernels/floor.cc
+++ b/tensorflow/contrib/lite/kernels/floor.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/floor_div.cc b/tensorflow/contrib/lite/kernels/floor_div.cc
index 75cf19a5a7..5d62cd2755 100644
--- a/tensorflow/contrib/lite/kernels/floor_div.cc
+++ b/tensorflow/contrib/lite/kernels/floor_div.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc
index eaf5a67d67..7a71fcc219 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc
index 2b2a9e6620..badd2de11a 100644
--- a/tensorflow/contrib/lite/kernels/gather.cc
+++ b/tensorflow/contrib/lite/kernels/gather.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc
index 1d4292955c..1b48884e09 100644
--- a/tensorflow/contrib/lite/kernels/gather_test.cc
+++ b/tensorflow/contrib/lite/kernels/gather_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h
index 37af772c68..43cd2b3055 100644
--- a/tensorflow/contrib/lite/kernels/gemm_support.h
+++ b/tensorflow/contrib/lite/kernels/gemm_support.h
@@ -16,7 +16,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_
#include "public/gemmlowp.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
namespace gemm_support {
diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
index f37c66acb3..c0b3c3c0c5 100644
--- a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
@@ -39,8 +39,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
#include "tensorflow/contrib/lite/string_util.h"
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index 464163bd78..a6fd4ac2dd 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -163,7 +163,7 @@ cc_library(
":tensor_utils",
"//third_party/eigen3",
"@gemmlowp",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -198,7 +198,7 @@ cc_library(
":round",
"//third_party/eigen3",
"@gemmlowp",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -220,13 +220,15 @@ cc_library(
"optimized/eigen_spatial_convolutions.h",
"optimized/eigen_tensor_reduced_instantiations_oss.h",
"optimized/multithreaded_conv.h",
+ # FIXME(petewarden) - This should be removed, since it's a header from the
+ # :tensor dependency below.
"tensor.h",
],
deps = [
":optimized_base",
+ ":tensor",
":types",
- "//tensorflow/contrib/lite:builtin_op_data",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//third_party/eigen3",
],
)
@@ -236,7 +238,7 @@ cc_test(
srcs = ["tensor_test.cc"],
tags = ["no_oss"],
deps = [
- ":reference",
+ ":tensor",
"@com_google_googletest//:gtest",
],
)
@@ -296,7 +298,7 @@ cc_library(
":strided_slice_logic",
":types",
"@gemmlowp",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -326,7 +328,7 @@ cc_library(
":strided_slice_logic",
":types",
"@gemmlowp",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -341,11 +343,27 @@ cc_library(
)
cc_library(
+ name = "tensor",
+ hdrs = [
+ "tensor.h",
+ "tensor_ctypes.h",
+ ],
+ deps = [
+ ":types",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ ],
+)
+
+# Deprecated version of :tensor, kept for backwards compatibility.
+cc_library(
name = "reference",
- hdrs = ["tensor.h"],
+ hdrs = [
+ "tensor.h",
+ "tensor_ctypes.h",
+ ],
deps = [
":types",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -359,7 +377,7 @@ cc_library(
],
deps = [
":round",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:activation_functor",
"//tensorflow/contrib/lite/kernels:op_macros",
],
@@ -384,7 +402,7 @@ cc_library(
":cpu_check",
":round",
":types",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:activation_functor",
"//tensorflow/contrib/lite/kernels:op_macros",
"@arm_neon_2_x86_sse",
@@ -398,7 +416,7 @@ cc_library(
hdrs = ["kernel_utils.h"],
deps = [
":tensor_utils",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -441,7 +459,7 @@ cc_library(
copts = NEON_FLAGS_IF_APPLICABLE,
deps = [
"//tensorflow/contrib/lite/kernels:activation_functor",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"@arm_neon_2_x86_sse",
"@gemmlowp",
] + select({
@@ -517,7 +535,7 @@ cc_test(
],
deps = [
":tensor_utils",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest_main",
],
diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h
index eb4d0108bd..e67fee11b8 100644
--- a/tensorflow/contrib/lite/kernels/internal/common.h
+++ b/tensorflow/contrib/lite/kernels/internal/common.h
@@ -45,7 +45,7 @@ limitations under the License.
#endif
#endif
-#include "public/gemmlowp.h"
+#include "fixedpoint/fixedpoint.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
index b9dd40ddf9..56e9367878 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -14,8 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
-#include <algorithm>
-
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
index 215ad04add..b5558cce55 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
namespace tflite {
namespace kernel_utils {
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
index 921aae1303..5fb31889fe 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
@@ -26,7 +26,7 @@ limitations under the License.
#include <tuple>
#include <type_traits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 70b6994a2b..27418178fd 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include <stdlib.h>
#include <string.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
index 5ca1b4b76f..630a6bbf29 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -17,7 +17,7 @@ limitations under the License.
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h"
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
index 7e53dc2fa2..f87760a6c3 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
@@ -17,7 +17,7 @@ limitations under the License.
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#if defined(_MSC_VER)
#define __restrict__ __restrict
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
index 2a30910c3f..77e60adc18 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <string.h>
#include <algorithm>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/round.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
index f5b3a84f07..714b1164ee 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -17,7 +17,7 @@ limitations under the License.
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#if defined(_MSC_VER)
#define __restrict__ __restrict
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index a027a47726..0abacf85e1 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -3488,8 +3488,7 @@ inline void Gather(const tflite::GatherParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& coords_shape, const int32* coords_data,
const RuntimeShape& output_shape, T* output_data) {
- // TODO(b/80418076): Enable these checks when moving legacy ops to
- // legacy_reference_ops.
+ // Enable these checks when moving legacy ops to legacy_reference_ops.
//
// TFLITE_DCHECK_EQ(coords_shape.DimensionsCount(), 1);
const int input_rank = op_params.input_rank;
@@ -3808,58 +3807,110 @@ inline void Pad(const tflite::PadParams& op_params,
}
template <typename T>
-inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
- int begin_mask, int end_mask, int shrink_axis_mask,
- const std::vector<int>& start_indices,
- const std::vector<int>& stop_indices,
- const std::vector<int>& strides, T* output_data,
- const Dims<4>& output_dims) {
- // Note that the axis orders are reversed for runtime ops, so the indices,
- // strides and masks must be as well too.
- TFLITE_DCHECK_EQ(start_indices.size(), 4);
- TFLITE_DCHECK_EQ(stop_indices.size(), 4);
- TFLITE_DCHECK_EQ(strides.size(), 4);
- const int start_b = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 3);
+inline void StridedSlice(const tflite::StridedSliceParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ // Note that the output_shape is not used herein.
+ tflite::StridedSliceParams params_copy = op_params;
+
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ // Reverse and pad to 4 dimensions because that is what the runtime code
+ // requires (ie. all shapes must be 4D and are given backwards).
+ strided_slice::StridedSlicePadIndices(&params_copy, 4);
+
+ const int start_b = strided_slice::StartForAxis(params_copy, input_shape, 0);
const int stop_b =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 3, start_b);
- const int start_h = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 2);
+ strided_slice::StopForAxis(params_copy, input_shape, 0, start_b);
+ const int start_h = strided_slice::StartForAxis(params_copy, input_shape, 1);
const int stop_h =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 2, start_h);
- const int start_w = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 1);
+ strided_slice::StopForAxis(params_copy, input_shape, 1, start_h);
+ const int start_w = strided_slice::StartForAxis(params_copy, input_shape, 2);
const int stop_w =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 1, start_w);
- const int start_d = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 0);
+ strided_slice::StopForAxis(params_copy, input_shape, 2, start_w);
+ const int start_d = strided_slice::StartForAxis(params_copy, input_shape, 3);
const int stop_d =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 0, start_d);
+ strided_slice::StopForAxis(params_copy, input_shape, 3, start_d);
T* out_ptr = output_data;
for (int in_b = start_b;
- !strided_slice::LoopCondition(in_b, stop_b, strides[3]);
- in_b += strides[3]) {
+ !strided_slice::LoopCondition(in_b, stop_b, params_copy.strides[0]);
+ in_b += params_copy.strides[0]) {
for (int in_h = start_h;
- !strided_slice::LoopCondition(in_h, stop_h, strides[2]);
- in_h += strides[2]) {
+ !strided_slice::LoopCondition(in_h, stop_h, params_copy.strides[1]);
+ in_h += params_copy.strides[1]) {
for (int in_w = start_w;
- !strided_slice::LoopCondition(in_w, stop_w, strides[1]);
- in_w += strides[1]) {
- for (int in_d = start_d;
- !strided_slice::LoopCondition(in_d, stop_d, strides[0]);
- in_d += strides[0]) {
- *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
+ !strided_slice::LoopCondition(in_w, stop_w, params_copy.strides[2]);
+ in_w += params_copy.strides[2]) {
+ for (int in_d = start_d; !strided_slice::LoopCondition(
+ in_d, stop_d, params_copy.strides[3]);
+ in_d += params_copy.strides[3]) {
+ *out_ptr++ = input_data[Offset(input_shape, in_b, in_h, in_w, in_d)];
}
}
}
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline uint32 LegacyReverseBits32(uint32 n) {
+ n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1);
+ n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2);
+ n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4);
+ return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) |
+ ((n & 0xFF000000) >> 24));
+}
+
+inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) {
+ TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
+ TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
+
+ std::reverse(p->start_indices, p->start_indices + p->start_indices_count);
+ std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count);
+ std::reverse(p->strides, p->strides + p->strides_count);
+
+ p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >>
+ (32 - p->start_indices_count);
+ p->ellipsis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >>
+ (32 - p->start_indices_count);
+ p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >>
+ (32 - p->start_indices_count);
+ p->new_axis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >>
+ (32 - p->start_indices_count);
+ p->shrink_axis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >>
+ (32 - p->start_indices_count);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T>
+inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
+ int begin_mask, int end_mask, int shrink_axis_mask,
+ const std::vector<int>& start_indices,
+ const std::vector<int>& stop_indices,
+ const std::vector<int>& strides, T* output_data,
+ const Dims<4>& output_dims) {
+ TFLITE_DCHECK_EQ(start_indices.size(), 4);
+ auto op_params = strided_slice::BuildStridedSliceParams(
+ begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
+ strides);
+ StridedSliceReverseIndices(&op_params);
+
+ StridedSlice(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
diff --git a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
index 5994fad5c7..af5db1064c 100644
--- a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
+++ b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
@@ -19,9 +19,9 @@ limitations under the License.
#include <limits>
#include <vector>
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
-
namespace strided_slice {
// Use until std::clamp() is available from C++17.
@@ -32,15 +32,51 @@ inline int Clamp(const int v, const int lo, const int hi) {
return v;
}
+inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
+ int dim_count) {
+ // Add indices and mask bits to fully include extra dimensions
+ TFLITE_CHECK_LE(dim_count, 4);
+ TFLITE_CHECK_GE(dim_count, p->start_indices_count);
+ TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
+ TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
+
+ const int pad_count = dim_count - p->start_indices_count;
+
+ // Pad indices at start, so move arrays by pad_count.
+ for (int i = p->start_indices_count - 1; i > 0; --i) {
+ p->strides[i + pad_count] = p->strides[i];
+ p->start_indices[i + pad_count] = p->start_indices[i];
+ p->stop_indices[i + pad_count] = p->stop_indices[i];
+ }
+ for (int i = 0; i < pad_count; ++i) {
+ p->start_indices[i] = 0;
+ p->stop_indices[i] = 0;
+ p->strides[i] = 1;
+ }
+
+ // Pad masks with 0s or 1s as required.
+ p->shrink_axis_mask <<= pad_count;
+ p->ellipsis_mask <<= pad_count;
+ p->new_axis_mask <<= pad_count;
+ p->begin_mask <<= pad_count;
+ p->end_mask <<= pad_count;
+ p->begin_mask |= (1 << pad_count) - 1;
+ p->end_mask |= (1 << pad_count) - 1;
+
+ p->start_indices_count = dim_count;
+ p->stop_indices_count = dim_count;
+ p->strides_count = dim_count;
+}
+
// Return the index for the first element along that axis. This index will be a
// positive integer between [0, axis_size - 1] that can be used to index
// directly into the data.
-template <typename IntType>
-inline int StartForAxis(int begin_mask,
- std::vector<IntType> const& start_indices,
- std::vector<IntType> const& strides,
- int const* input_shape, int axis) {
- // Begin with the specified index
+inline int StartForAxis(const tflite::StridedSliceParams& params,
+ const RuntimeShape& input_shape, int axis) {
+ const auto begin_mask = params.begin_mask;
+ const auto* start_indices = params.start_indices;
+ const auto* strides = params.strides;
+ // Begin with the specified index.
int start = start_indices[axis];
// begin_mask override
@@ -57,7 +93,7 @@ inline int StartForAxis(int begin_mask,
}
// Handle negative indices
- int axis_size = input_shape[axis];
+ int axis_size = input_shape.Dims(axis);
if (start < 0) {
start += axis_size;
}
@@ -73,11 +109,14 @@ inline int StartForAxis(int begin_mask,
// element. ie. So if you were iterating through all elements of a 1D array of
// size 4, this function would return 4 as the stop, because it is one past the
// "real" indices of 0, 1, 2 & 3.
-template <typename IntType>
-inline int StopForAxis(int end_mask, int shrink_axis_mask,
- std::vector<IntType> const& stop_indices,
- std::vector<IntType> const& strides,
- int const* input_shape, int axis, int start_for_axis) {
+inline int StopForAxis(const tflite::StridedSliceParams& params,
+ const RuntimeShape& input_shape, int axis,
+ int start_for_axis) {
+ const auto end_mask = params.end_mask;
+ const auto shrink_axis_mask = params.shrink_axis_mask;
+ const auto* stop_indices = params.stop_indices;
+ const auto* strides = params.strides;
+
// Begin with the specified index
const bool shrink_axis = shrink_axis_mask & (1 << axis);
int stop = stop_indices[axis];
@@ -103,7 +142,7 @@ inline int StopForAxis(int end_mask, int shrink_axis_mask,
}
// Handle negative indices
- const int axis_size = input_shape[axis];
+ const int axis_size = input_shape.Dims(axis);
if (stop < 0) {
stop += axis_size;
}
@@ -127,6 +166,31 @@ inline bool LoopCondition(int index, int stop, int stride) {
return stride > 0 ? index >= stop : index <= stop;
}
+inline tflite::StridedSliceParams BuildStridedSliceParams(
+ int begin_mask, int end_mask, int shrink_axis_mask,
+ const std::vector<int>& start_indices, const std::vector<int>& stop_indices,
+ const std::vector<int>& strides) {
+ tflite::StridedSliceParams op_params;
+ const int dims_count = start_indices.size();
+
+ op_params.start_indices_count = dims_count;
+ op_params.stop_indices_count = dims_count;
+ op_params.strides_count = dims_count;
+ for (int i = 0; i < dims_count; ++i) {
+ op_params.start_indices[i] = start_indices[i];
+ op_params.stop_indices[i] = stop_indices[i];
+ op_params.strides[i] = strides[i];
+ }
+
+ op_params.begin_mask = begin_mask;
+ op_params.ellipsis_mask = 0;
+ op_params.end_mask = end_mask;
+ op_params.new_axis_mask = 0;
+ op_params.shrink_axis_mask = shrink_axis_mask;
+
+ return op_params;
+}
+
} // namespace strided_slice
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h
index ee2af5b460..13106456df 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor.h
@@ -17,44 +17,12 @@ limitations under the License.
#include <complex>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
-template <typename T>
-inline T* GetTensorData(TfLiteTensor* tensor);
-
-template <>
-inline float* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.f : nullptr;
-}
-
-template <>
-inline uint8_t* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.uint8 : nullptr;
-}
-
-template <>
-inline int16_t* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i16 : nullptr;
-}
-
-template <>
-inline int32_t* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i32 : nullptr;
-}
-
-template <>
-inline int64_t* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i64 : nullptr;
-}
-
-template <>
-inline bool* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.b : nullptr;
-}
-
template <>
inline std::complex<float>* GetTensorData(TfLiteTensor* tensor) {
return tensor != nullptr
@@ -62,39 +30,6 @@ inline std::complex<float>* GetTensorData(TfLiteTensor* tensor) {
: nullptr;
}
-template <typename T>
-inline const T* GetTensorData(const TfLiteTensor* tensor);
-
-template <>
-inline const float* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.f : nullptr;
-}
-
-template <>
-inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.uint8 : nullptr;
-}
-
-template <>
-inline const int16_t* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i16 : nullptr;
-}
-
-template <>
-inline const int32_t* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i32 : nullptr;
-}
-
-template <>
-inline const int64_t* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i64 : nullptr;
-}
-
-template <>
-inline const bool* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.b : nullptr;
-}
-
template <>
inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) {
return tensor != nullptr
@@ -102,56 +37,14 @@ inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) {
: nullptr;
}
-inline int RemapDim(int max_dimensions, int d) {
- return max_dimensions - d - 1;
-}
-
-// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object
-// even if the original tensors were not 4D. We should consider rewriting them
-// to take a more generic 'shape' object.
-inline Dims<4> GetTensorDims(const int data[], const int size) {
- Dims<4> d;
- for (int i = 0; i < 4; ++i) {
- int src = size - i - 1;
- if (src >= 0) {
- d.sizes[i] = data[src];
- } else {
- d.sizes[i] = 1;
- }
- }
- d.strides[0] = 1;
- for (int i = 1; i < 4; i++) {
- d.strides[i] = d.strides[i - 1] * d.sizes[i - 1];
- }
- return d;
-}
-
inline Dims<4> GetTensorDims(std::vector<int32_t> data) {
return GetTensorDims(data.data(), data.size());
}
-inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
- if (tensor == nullptr) {
- return Dims<4>();
- }
-
- auto* dims = tensor->dims;
- return GetTensorDims(dims->data, dims->size);
-}
-
inline RuntimeShape GetTensorShape(std::vector<int32_t> data) {
return RuntimeShape(data.size(), data.data());
}
-inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) {
- if (tensor == nullptr) {
- return RuntimeShape();
- }
-
- auto* dims = tensor->dims;
- return RuntimeShape(dims->size, dims->data);
-}
-
// A list of tensors in a format that can be used by kernels like split and
// concatenation.
template <typename T>
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
new file mode 100644
index 0000000000..77e22a08b4
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
@@ -0,0 +1,135 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+template <typename T>
+inline T* GetTensorData(TfLiteTensor* tensor);
+
+template <>
+inline float* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.f : nullptr;
+}
+
+template <>
+inline uint8_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.uint8 : nullptr;
+}
+
+template <>
+inline int16_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i16 : nullptr;
+}
+
+template <>
+inline int32_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i32 : nullptr;
+}
+
+template <>
+inline int64_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i64 : nullptr;
+}
+
+template <>
+inline bool* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.b : nullptr;
+}
+
+template <typename T>
+inline const T* GetTensorData(const TfLiteTensor* tensor);
+
+template <>
+inline const float* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.f : nullptr;
+}
+
+template <>
+inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.uint8 : nullptr;
+}
+
+template <>
+inline const int16_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i16 : nullptr;
+}
+
+template <>
+inline const int32_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i32 : nullptr;
+}
+
+template <>
+inline const int64_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i64 : nullptr;
+}
+
+template <>
+inline const bool* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.b : nullptr;
+}
+
+inline int RemapDim(int max_dimensions, int d) {
+ return max_dimensions - d - 1;
+}
+
+// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object
+// even if the original tensors were not 4D. We should consider rewriting them
+// to take a more generic 'shape' object.
+inline Dims<4> GetTensorDims(const int data[], const int size) {
+ Dims<4> d;
+ for (int i = 0; i < 4; ++i) {
+ int src = size - i - 1;
+ if (src >= 0) {
+ d.sizes[i] = data[src];
+ } else {
+ d.sizes[i] = 1;
+ }
+ }
+ d.strides[0] = 1;
+ for (int i = 1; i < 4; i++) {
+ d.strides[i] = d.strides[i - 1] * d.sizes[i - 1];
+ }
+ return d;
+}
+
+inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
+ if (tensor == nullptr) {
+ return Dims<4>();
+ }
+
+ auto* dims = tensor->dims;
+ return GetTensorDims(dims->data, dims->size);
+}
+
+inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) {
+ if (tensor == nullptr) {
+ return RuntimeShape();
+ }
+
+ TfLiteIntArray* dims = tensor->dims;
+ const int dims_size = dims->size;
+ const int32_t* dims_data = dims->data;
+ return RuntimeShape(dims_size, dims_data);
+}
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
index 1439bf8c37..b0fe5adf65 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#if defined(_MSC_VER)
#define __restrict__ __restrict
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
index dad924fc28..6458af714b 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include <gmock/gmock.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h
index ed46cd984f..e9a5fd7a40 100644
--- a/tensorflow/contrib/lite/kernels/kernel_util.h
+++ b/tensorflow/contrib/lite/kernels/kernel_util.h
@@ -16,9 +16,10 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
#include <algorithm>
+#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc
index 5b3536de0c..e02d7df9ef 100644
--- a/tensorflow/contrib/lite/kernels/l2norm.cc
+++ b/tensorflow/contrib/lite/kernels/l2norm.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/contrib/lite/kernels/local_response_norm.cc
index 799c1528bd..334d2a2788 100644
--- a/tensorflow/contrib/lite/kernels/local_response_norm.cc
+++ b/tensorflow/contrib/lite/kernels/local_response_norm.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/logical.cc b/tensorflow/contrib/lite/kernels/logical.cc
index c71f3b4701..f770cb35d1 100644
--- a/tensorflow/contrib/lite/kernels/logical.cc
+++ b/tensorflow/contrib/lite/kernels/logical.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc
index 69523b02cc..9fa1c5f100 100644
--- a/tensorflow/contrib/lite/kernels/lsh_projection.cc
+++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc
@@ -59,8 +59,8 @@ limitations under the License.
#include <limits>
#include <memory>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
#include <farmhash.h>
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index 74dc3f25f9..aaa3ce966e 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
diff --git a/tensorflow/contrib/lite/kernels/maximum_minimum.cc b/tensorflow/contrib/lite/kernels/maximum_minimum.cc
index 0308a3976a..7cb01465ee 100644
--- a/tensorflow/contrib/lite/kernels/maximum_minimum.cc
+++ b/tensorflow/contrib/lite/kernels/maximum_minimum.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/mfcc.cc b/tensorflow/contrib/lite/kernels/mfcc.cc
index 306f676619..66cf147d75 100644
--- a/tensorflow/contrib/lite/kernels/mfcc.cc
+++ b/tensorflow/contrib/lite/kernels/mfcc.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/internal/mfcc.h"
#include "flatbuffers/flexbuffers.h" // flatbuffers
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/mfcc_dct.h"
#include "tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc
index 92d8bc8b67..e0aac8a842 100644
--- a/tensorflow/contrib/lite/kernels/mul.cc
+++ b/tensorflow/contrib/lite/kernels/mul.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/neg.cc b/tensorflow/contrib/lite/kernels/neg.cc
index 4124c05388..0ddd0644f5 100644
--- a/tensorflow/contrib/lite/kernels/neg.cc
+++ b/tensorflow/contrib/lite/kernels/neg.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/one_hot.cc b/tensorflow/contrib/lite/kernels/one_hot.cc
index 9ff3dca932..910aed6f14 100644
--- a/tensorflow/contrib/lite/kernels/one_hot.cc
+++ b/tensorflow/contrib/lite/kernels/one_hot.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/pack.cc b/tensorflow/contrib/lite/kernels/pack.cc
index cc326a7d51..4cb98fdd19 100644
--- a/tensorflow/contrib/lite/kernels/pack.cc
+++ b/tensorflow/contrib/lite/kernels/pack.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc
index 3bce05353d..0d939405f6 100644
--- a/tensorflow/contrib/lite/kernels/pad.cc
+++ b/tensorflow/contrib/lite/kernels/pad.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/padding.h b/tensorflow/contrib/lite/kernels/padding.h
index 3cb55f19a9..42b6b45d3b 100644
--- a/tensorflow/contrib/lite/kernels/padding.h
+++ b/tensorflow/contrib/lite/kernels/padding.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc
index 29a5be0683..6451142391 100644
--- a/tensorflow/contrib/lite/kernels/pooling.cc
+++ b/tensorflow/contrib/lite/kernels/pooling.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/pow.cc b/tensorflow/contrib/lite/kernels/pow.cc
index d676de5b1d..1e96cc80b1 100644
--- a/tensorflow/contrib/lite/kernels/pow.cc
+++ b/tensorflow/contrib/lite/kernels/pow.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc
index ca83797936..d94d821e87 100644
--- a/tensorflow/contrib/lite/kernels/reduce.cc
+++ b/tensorflow/contrib/lite/kernels/reduce.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <string.h>
#include <limits>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h
index 0296152d68..61856ab9de 100644
--- a/tensorflow/contrib/lite/kernels/register.h
+++ b/tensorflow/contrib/lite/kernels/register.h
@@ -16,8 +16,9 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
#include <unordered_map>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
namespace tflite {
namespace ops {
diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc
index 49ba0571e2..f41147b2d6 100644
--- a/tensorflow/contrib/lite/kernels/reshape.cc
+++ b/tensorflow/contrib/lite/kernels/reshape.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
index dafa3aebab..fb045d15f3 100644
--- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc
index 3cdb5db209..3959502d91 100644
--- a/tensorflow/contrib/lite/kernels/select.cc
+++ b/tensorflow/contrib/lite/kernels/select.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/shape.cc b/tensorflow/contrib/lite/kernels/shape.cc
index dbcd2ef004..66d4c9e5c1 100644
--- a/tensorflow/contrib/lite/kernels/shape.cc
+++ b/tensorflow/contrib/lite/kernels/shape.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/skip_gram.cc b/tensorflow/contrib/lite/kernels/skip_gram.cc
index c90a15b3a2..de80a4016e 100644
--- a/tensorflow/contrib/lite/kernels/skip_gram.cc
+++ b/tensorflow/contrib/lite/kernels/skip_gram.cc
@@ -33,8 +33,8 @@ limitations under the License.
#include <string>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
#include "tensorflow/contrib/lite/string_util.h"
diff --git a/tensorflow/contrib/lite/kernels/slice.cc b/tensorflow/contrib/lite/kernels/slice.cc
index 55e16506df..ccfee41b9c 100644
--- a/tensorflow/contrib/lite/kernels/slice.cc
+++ b/tensorflow/contrib/lite/kernels/slice.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include <string.h>
#include <cmath>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
index 8332ae32cf..3a10d2e60c 100644
--- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/contrib/lite/kernels/space_to_depth.cc
index 9238e879f8..64c56c017b 100644
--- a/tensorflow/contrib/lite/kernels/space_to_depth.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_depth.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
index fec2a6f0d9..178568e07c 100644
--- a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
+++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc
index b144486041..719e2dc606 100644
--- a/tensorflow/contrib/lite/kernels/split.cc
+++ b/tensorflow/contrib/lite/kernels/split.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/squeeze.cc b/tensorflow/contrib/lite/kernels/squeeze.cc
index 09a5662fd9..080c51cd18 100644
--- a/tensorflow/contrib/lite/kernels/squeeze.cc
+++ b/tensorflow/contrib/lite/kernels/squeeze.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
index bed2117f9a..87ffcc4110 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <string.h>
#include <cmath>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc
index 77a1f59689..1be0c83f17 100644
--- a/tensorflow/contrib/lite/kernels/sub.cc
+++ b/tensorflow/contrib/lite/kernels/sub.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc
index 6ba7959752..9903fd5c35 100644
--- a/tensorflow/contrib/lite/kernels/svdf.cc
+++ b/tensorflow/contrib/lite/kernels/svdf.cc
@@ -23,8 +23,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/tile.cc b/tensorflow/contrib/lite/kernels/tile.cc
index 5181a8f89a..49421eb870 100644
--- a/tensorflow/contrib/lite/kernels/tile.cc
+++ b/tensorflow/contrib/lite/kernels/tile.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/tile_test.cc b/tensorflow/contrib/lite/kernels/tile_test.cc
index 4f78c224e5..e73ca7b750 100644
--- a/tensorflow/contrib/lite/kernels/tile_test.cc
+++ b/tensorflow/contrib/lite/kernels/tile_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc
index 2dd760bbfe..6c38b6739e 100644
--- a/tensorflow/contrib/lite/kernels/topk_v2.cc
+++ b/tensorflow/contrib/lite/kernels/topk_v2.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <algorithm>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/topk_v2_test.cc b/tensorflow/contrib/lite/kernels/topk_v2_test.cc
index 2abb89b617..16106fdafe 100644
--- a/tensorflow/contrib/lite/kernels/topk_v2_test.cc
+++ b/tensorflow/contrib/lite/kernels/topk_v2_test.cc
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/transpose.cc b/tensorflow/contrib/lite/kernels/transpose.cc
index 800b0563d7..95359962e0 100644
--- a/tensorflow/contrib/lite/kernels/transpose.cc
+++ b/tensorflow/contrib/lite/kernels/transpose.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc
index a9baa5c698..6f2d98ede8 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index c678f14930..63817bd886 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
index 0180c2c498..744ee7c109 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/unpack.cc b/tensorflow/contrib/lite/kernels/unpack.cc
index 4998f88b41..9ff06f8331 100644
--- a/tensorflow/contrib/lite/kernels/unpack.cc
+++ b/tensorflow/contrib/lite/kernels/unpack.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/memory_planner.h b/tensorflow/contrib/lite/memory_planner.h
index 0294ec815c..2d4707f849 100644
--- a/tensorflow/contrib/lite/memory_planner.h
+++ b/tensorflow/contrib/lite/memory_planner.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_
#define TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/mmap_allocation.cc b/tensorflow/contrib/lite/mmap_allocation.cc
index fa9a3cd1d8..92934d1fd1 100644
--- a/tensorflow/contrib/lite/mmap_allocation.cc
+++ b/tensorflow/contrib/lite/mmap_allocation.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <unistd.h>
#include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index aa410ab002..241865b3d8 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -20,8 +20,9 @@ limitations under the License.
#include <sys/types.h>
#include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h"
#include "tensorflow/contrib/lite/model.h"
#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/nnapi_delegate.h"
@@ -42,41 +43,6 @@ ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
const char* kEmptyTensorName = "";
-TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
- ErrorReporter* error_reporter) {
- switch (tensor_type) {
- case TensorType_FLOAT32:
- *type = kTfLiteFloat32;
- break;
- case TensorType_INT16:
- *type = kTfLiteInt16;
- break;
- case TensorType_INT32:
- *type = kTfLiteInt32;
- break;
- case TensorType_UINT8:
- *type = kTfLiteUInt8;
- break;
- case TensorType_INT64:
- *type = kTfLiteInt64;
- break;
- case TensorType_STRING:
- *type = kTfLiteString;
- break;
- case TensorType_BOOL:
- *type = kTfLiteBool;
- break;
- case TensorType_COMPLEX64:
- *type = kTfLiteComplex64;
- break;
- default:
- error_reporter->Report("Unimplemented data type %s (%d) in tensor\n",
- EnumNameTensorType(tensor_type), tensor_type);
- return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
#ifndef TFLITE_MCU
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
// otherwise make a copy of the model in a buffer.
@@ -198,39 +164,10 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
auto opcodes = model_->operator_codes();
for (const OperatorCode* opcode : *opcodes) {
const TfLiteRegistration* registration = nullptr;
- auto builtin_code = opcode->builtin_code();
- int version = opcode->version();
-
- if (builtin_code > BuiltinOperator_MAX ||
- builtin_code < BuiltinOperator_MIN) {
- error_reporter_->Report(
- "Op builtin_code out or range: %d. Are you using old TFLite binary "
- "with newer model?",
- builtin_code);
- status = kTfLiteError;
- } else if (builtin_code != BuiltinOperator_CUSTOM) {
- registration = op_resolver_.FindOp(builtin_code, version);
- if (registration == nullptr) {
- error_reporter_->Report(
- "Didn't find op for builtin opcode '%s' version '%d'\n",
- EnumNameBuiltinOperator(builtin_code), version);
- status = kTfLiteError;
- }
- } else if (!opcode->custom_code()) {
- error_reporter_->Report(
- "Operator with CUSTOM builtin_code has no custom_code.\n");
- status = kTfLiteError;
- } else {
- const char* name = opcode->custom_code()->c_str();
- registration = op_resolver_.FindOp(name, version);
- flatbuffer_op_index_to_registration_types_.push_back(
- BuiltinOperator_CUSTOM);
- if (registration == nullptr) {
- error_reporter_->Report(
- "Didn't find custom op for name '%s' with version %d\n", name,
- version);
- status = kTfLiteError;
- }
+ status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
+ &registration);
+ if (status != kTfLiteOk) {
+ return status;
}
flatbuffer_op_index_to_registration_.push_back(registration);
}
@@ -247,565 +184,6 @@ std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
return ret;
}
-// Copies the contents from the flatbuffer int vector `flatbuffer` into the
-// int array `buffer`. `flat_vector` and `buffer` represent the same
-// configuration operation for a given operation.
-void FlatBufferIntVectorToArray(int max_size_of_buffer,
- const flatbuffers::Vector<int32_t>* flat_vector,
- int* buffer, ErrorReporter* error_reporter) {
- if (!flat_vector) {
- error_reporter->Report("Input array not provided for operation.\n");
- } else {
- int num_dimensions = flat_vector->Length();
- if (num_dimensions > max_size_of_buffer / sizeof(int)) {
- error_reporter->Report(
- "Found too many dimensions in the operation's input array.\n");
- } else {
- for (int i = 0; i < num_dimensions; ++i) {
- buffer[i] = flat_vector->Get(i);
- }
- }
- }
-}
-
-// Allocate a structure using C malloc, but make sure the structure is a
-// POD structure that doesn't require constructors to run. The reason we do
-// this, is that Interpreter's C extension part will take ownership and wants
-// to use malloc() and free().
-template <class T>
-T* MallocPOD() {
- static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
- return static_cast<T*>(malloc(sizeof(T)));
-}
-
-// Parse the appropriate data out of the op.
-//
-// This handles builtin data explicitly as there are flatbuffer schemas.
-// If it returns kTfLiteOk, it passes the data out with `builtin_data`, which
-// need to be released by calling `free`.`
-// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
-TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
- ErrorReporter* error_reporter, void** builtin_data) {
- auto parse_padding = [](Padding padding) {
- switch (padding) {
- case Padding_SAME:
- return kTfLitePaddingSame;
- case Padding_VALID:
- return kTfLitePaddingValid;
- }
- return kTfLitePaddingUnknown;
- };
- auto parse_activation = [](ActivationFunctionType activation) {
- switch (activation) {
- case ActivationFunctionType_NONE:
- return kTfLiteActNone;
- case ActivationFunctionType_RELU:
- return kTfLiteActRelu;
- case ActivationFunctionType_RELU_N1_TO_1:
- return kTfLiteActRelu1;
- case ActivationFunctionType_RELU6:
- return kTfLiteActRelu6;
- case ActivationFunctionType_TANH:
- return kTfLiteActTanh;
- case ActivationFunctionType_SIGN_BIT:
- return kTfLiteActSignBit;
- }
- return kTfLiteActNone;
- };
- auto parseLSHProjectionType = [](LSHProjectionType type) {
- switch (type) {
- case LSHProjectionType_SPARSE:
- return kTfLiteLshProjectionSparse;
- case LSHProjectionType_DENSE:
- return kTfLiteLshProjectionDense;
- default:
- return kTfLiteLshProjectionUnknown;
- }
- };
- auto parseCombinerType = [](CombinerType type) {
- switch (type) {
- case CombinerType_MEAN:
- return kTfLiteCombinerTypeMean;
- case CombinerType_SQRTN:
- return kTfLiteCombinerTypeSqrtn;
- case CombinerType_SUM:
- default:
- return kTfLiteCombinerTypeSum;
- }
- };
-
- *builtin_data = nullptr;
- switch (op_type) {
- case BuiltinOperator_CONV_2D: {
- TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
- if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
- params->padding = parse_padding(conv_params->padding());
- params->stride_width = conv_params->stride_w();
- params->stride_height = conv_params->stride_h();
- params->activation =
- parse_activation(conv_params->fused_activation_function());
-
- params->dilation_width_factor = conv_params->dilation_w_factor();
- params->dilation_height_factor = conv_params->dilation_h_factor();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_CAST: {
- TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
- if (auto* schema_params = op->builtin_options_as_CastOptions()) {
- auto in_status =
- ConvertTensorType(schema_params->in_data_type(),
- &params->in_data_type, error_reporter);
- auto out_status =
- ConvertTensorType(schema_params->out_data_type(),
- &params->out_data_type, error_reporter);
- if (in_status != kTfLiteOk || out_status != kTfLiteOk) {
- free(params);
- return kTfLiteError;
- }
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_LSH_PROJECTION: {
- TfLiteLSHProjectionParams* params =
- MallocPOD<TfLiteLSHProjectionParams>();
- if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
- params->type = parseLSHProjectionType(lshParams->type());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_AVERAGE_POOL_2D:
- case BuiltinOperator_MAX_POOL_2D:
- case BuiltinOperator_L2_POOL_2D: {
- TfLitePoolParams* params = MallocPOD<TfLitePoolParams>();
- if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
- params->padding = parse_padding(pool_params->padding());
- params->stride_width = pool_params->stride_w();
- params->stride_height = pool_params->stride_h();
- params->filter_width = pool_params->filter_width();
- params->filter_height = pool_params->filter_height();
- params->activation =
- parse_activation(pool_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_DEPTHWISE_CONV_2D: {
- TfLiteDepthwiseConvParams* params =
- MallocPOD<TfLiteDepthwiseConvParams>();
- if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) {
- params->padding = parse_padding(conv_params->padding());
- params->stride_width = conv_params->stride_w();
- params->stride_height = conv_params->stride_h();
- params->depth_multiplier = conv_params->depth_multiplier();
- params->activation =
- parse_activation(conv_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SVDF: {
- TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>();
- if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
- params->rank = svdf_params->rank();
- params->activation =
- parse_activation(svdf_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
- case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
- TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>();
- if (auto* sequence_rnn_params =
- op->builtin_options_as_SequenceRNNOptions()) {
- params->activation =
- parse_activation(sequence_rnn_params->fused_activation_function());
- params->time_major = sequence_rnn_params->time_major();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_RNN: {
- TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
- if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
- params->activation =
- parse_activation(rnn_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
- TfLiteEmbeddingLookupSparseParams* params =
- MallocPOD<TfLiteEmbeddingLookupSparseParams>();
- if (auto* embedding_params =
- op->builtin_options_as_EmbeddingLookupSparseOptions()) {
- params->combiner = parseCombinerType(embedding_params->combiner());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_FULLY_CONNECTED: {
- TfLiteFullyConnectedParams* params =
- MallocPOD<TfLiteFullyConnectedParams>();
- if (auto* fully_connected_params =
- op->builtin_options_as_FullyConnectedOptions()) {
- params->activation = parse_activation(
- fully_connected_params->fused_activation_function());
- switch (fully_connected_params->weights_format()) {
- case FullyConnectedOptionsWeightsFormat_DEFAULT:
- params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
- break;
- case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
- params->weights_format =
- kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8;
- break;
- default:
- error_reporter->Report("Unhandled fully-connected weights format.");
- return kTfLiteError;
- }
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_HASHTABLE_LOOKUP:
- // no-op.
- break;
- case BuiltinOperator_SOFTMAX: {
- TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>();
- if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
- params->beta = softmax_params->beta();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_CONCATENATION: {
- TfLiteConcatenationParams* params =
- MallocPOD<TfLiteConcatenationParams>();
- if (auto* concatenation_params =
- op->builtin_options_as_ConcatenationOptions()) {
- params->activation =
- parse_activation(concatenation_params->fused_activation_function());
- params->axis = concatenation_params->axis();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_MUL: {
- auto* params = MallocPOD<TfLiteMulParams>();
- if (auto* schema_params = op->builtin_options_as_MulOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_ADD: {
- auto* params = MallocPOD<TfLiteAddParams>();
- if (auto* schema_params = op->builtin_options_as_AddOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_DIV: {
- auto* params = MallocPOD<TfLiteDivParams>();
- if (auto* schema_params = op->builtin_options_as_DivOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SUB: {
- auto* params = MallocPOD<TfLiteSubParams>();
- if (auto* schema_params = op->builtin_options_as_SubOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_L2_NORMALIZATION: {
- auto* params = MallocPOD<TfLiteL2NormParams>();
- if (auto* schema_params = op->builtin_options_as_L2NormOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
- auto* params = MallocPOD<TfLiteLocalResponseNormParams>();
- if (auto* schema_params =
- op->builtin_options_as_LocalResponseNormalizationOptions()) {
- params->radius = schema_params->radius();
- params->bias = schema_params->bias();
- params->alpha = schema_params->alpha();
- params->beta = schema_params->beta();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
- case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
- case BuiltinOperator_LSTM: {
- TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
- if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
- params->activation =
- parse_activation(lstm_params->fused_activation_function());
- params->cell_clip = lstm_params->cell_clip();
- params->proj_clip = lstm_params->proj_clip();
- switch (lstm_params->kernel_type()) {
- case LSTMKernelType_FULL:
- params->kernel_type = kTfLiteLSTMFullKernel;
- break;
- case LSTMKernelType_BASIC:
- params->kernel_type = kTfLiteLSTMBasicKernel;
- break;
- }
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_RESIZE_BILINEAR: {
- auto* params = MallocPOD<TfLiteResizeBilinearParams>();
- if (auto* schema_params =
- op->builtin_options_as_ResizeBilinearOptions()) {
- params->align_corners = schema_params->align_corners();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_RESHAPE: {
- auto* params = MallocPOD<TfLiteReshapeParams>();
- if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
- auto* new_shape = schema_params->new_shape();
- FlatBufferIntVectorToArray(sizeof(params->shape), new_shape,
- params->shape, error_reporter);
- params->num_dimensions = new_shape->Length();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SKIP_GRAM: {
- TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>();
- if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) {
- params->ngram_size = skip_gram_params->ngram_size();
- params->max_skip_size = skip_gram_params->max_skip_size();
- params->include_all_ngrams = skip_gram_params->include_all_ngrams();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SPACE_TO_DEPTH: {
- auto* params = MallocPOD<TfLiteSpaceToDepthParams>();
- if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
- params->block_size = schema_params->block_size();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_GATHER: {
- TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>();
- params->axis = 0;
- if (auto* gather_params = op->builtin_options_as_GatherOptions()) {
- params->axis = gather_params->axis();
- }
-
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_MEAN:
- case BuiltinOperator_REDUCE_MAX:
- case BuiltinOperator_REDUCE_MIN:
- case BuiltinOperator_REDUCE_PROD:
- case BuiltinOperator_SUM:
- case BuiltinOperator_REDUCE_ANY: {
- auto* params = MallocPOD<TfLiteReducerParams>();
- if (auto* schema_params = op->builtin_options_as_ReducerOptions()) {
- params->keep_dims = schema_params->keep_dims();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SPLIT: {
- auto* params = MallocPOD<TfLiteSplitParams>();
- if (auto* schema_params = op->builtin_options_as_SplitOptions()) {
- params->num_splits = schema_params->num_splits();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SQUEEZE: {
- auto* params = MallocPOD<TfLiteSqueezeParams>();
- if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) {
- const auto& squeeze_dims = schema_params->squeeze_dims();
- FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims,
- params->squeeze_dims, error_reporter);
- params->num_squeeze_dims = squeeze_dims->Length();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_STRIDED_SLICE: {
- auto* params = MallocPOD<TfLiteStridedSliceParams>();
- if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) {
- params->begin_mask = schema_params->begin_mask();
- params->end_mask = schema_params->end_mask();
- params->ellipsis_mask = schema_params->ellipsis_mask();
- params->new_axis_mask = schema_params->new_axis_mask();
- params->shrink_axis_mask = schema_params->shrink_axis_mask();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_ARG_MAX: {
- auto* params = MallocPOD<TfLiteArgMaxParams>();
- if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) {
- ConvertTensorType(schema_params->output_type(), &params->output_type,
- error_reporter);
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_ARG_MIN: {
- auto* params = MallocPOD<TfLiteArgMinParams>();
- if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) {
- ConvertTensorType(schema_params->output_type(), &params->output_type,
- error_reporter);
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_TRANSPOSE_CONV: {
- TfLiteTransposeConvParams* params =
- MallocPOD<TfLiteTransposeConvParams>();
- if (auto* transpose_conv_params =
- op->builtin_options_as_TransposeConvOptions()) {
- params->padding = parse_padding(transpose_conv_params->padding());
- params->stride_width = transpose_conv_params->stride_w();
- params->stride_height = transpose_conv_params->stride_h();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SPARSE_TO_DENSE: {
- TfLiteSparseToDenseParams* params =
- MallocPOD<TfLiteSparseToDenseParams>();
- if (auto* sparse_to_dense_params =
- op->builtin_options_as_SparseToDenseOptions()) {
- params->validate_indices = sparse_to_dense_params->validate_indices();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SHAPE: {
- auto* params = MallocPOD<TfLiteShapeParams>();
- if (auto* schema_params = op->builtin_options_as_ShapeOptions()) {
- ConvertTensorType(schema_params->out_type(), &params->out_type,
- error_reporter);
- }
- *builtin_data = static_cast<void*>(params);
- break;
- }
- case BuiltinOperator_PACK: {
- TfLitePackParams* params = MallocPOD<TfLitePackParams>();
- if (auto* pack_params = op->builtin_options_as_PackOptions()) {
- params->values_count = pack_params->values_count();
- params->axis = pack_params->axis();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_DELEGATE: {
- // TODO(ycling): Revisit when supporting saving delegated models.
- error_reporter->Report("DELEGATE op shouldn't exist in model.");
- return kTfLiteError;
- }
- case BuiltinOperator_FAKE_QUANT: {
- auto* params = MallocPOD<TfLiteFakeQuantParams>();
- if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) {
- params->min = schema_params->min();
- params->max = schema_params->max();
- params->num_bits = schema_params->num_bits();
- params->narrow_range = schema_params->narrow_range();
- }
- *builtin_data = static_cast<void*>(params);
- break;
- }
- case BuiltinOperator_ONE_HOT: {
- auto* params = MallocPOD<TfLiteOneHotParams>();
- if (auto* schema_params = op->builtin_options_as_OneHotOptions()) {
- params->axis = schema_params->axis();
- }
- *builtin_data = static_cast<void*>(params);
- break;
- }
- case BuiltinOperator_UNPACK: {
- TfLiteUnpackParams* params = MallocPOD<TfLiteUnpackParams>();
- if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) {
- params->num = unpack_params->num();
- params->axis = unpack_params->axis();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
-
- // Below are the ops with no builtin_data strcture.
- case BuiltinOperator_BATCH_TO_SPACE_ND:
- // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
- // ok for now, since there is no call implementation either.
- case BuiltinOperator_CALL:
- case BuiltinOperator_CONCAT_EMBEDDINGS:
- case BuiltinOperator_CUSTOM:
- case BuiltinOperator_DEQUANTIZE:
- case BuiltinOperator_EMBEDDING_LOOKUP:
- case BuiltinOperator_EQUAL:
- case BuiltinOperator_EXP:
- case BuiltinOperator_EXPAND_DIMS:
- case BuiltinOperator_FLOOR:
- case BuiltinOperator_GREATER:
- case BuiltinOperator_GREATER_EQUAL:
- case BuiltinOperator_LESS:
- case BuiltinOperator_LESS_EQUAL:
- case BuiltinOperator_LOG:
- case BuiltinOperator_LOGISTIC:
- case BuiltinOperator_LOG_SOFTMAX:
- case BuiltinOperator_MAXIMUM:
- case BuiltinOperator_MINIMUM:
- case BuiltinOperator_NEG:
- case BuiltinOperator_NOT_EQUAL:
- case BuiltinOperator_PAD:
- case BuiltinOperator_PADV2:
- case BuiltinOperator_PRELU:
- case BuiltinOperator_RELU:
- case BuiltinOperator_RELU6:
- case BuiltinOperator_RELU_N1_TO_1:
- case BuiltinOperator_RSQRT:
- case BuiltinOperator_SELECT:
- case BuiltinOperator_SIN:
- case BuiltinOperator_SLICE:
- case BuiltinOperator_SPACE_TO_BATCH_ND:
- case BuiltinOperator_SQRT:
- case BuiltinOperator_TANH:
- case BuiltinOperator_TILE:
- case BuiltinOperator_TOPK_V2:
- case BuiltinOperator_TRANSPOSE:
- case BuiltinOperator_POW:
- case BuiltinOperator_LOGICAL_OR:
- case BuiltinOperator_LOGICAL_AND:
- case BuiltinOperator_LOGICAL_NOT:
- case BuiltinOperator_FLOOR_DIV:
- break;
- }
- return kTfLiteOk;
-}
-
} // namespace
TfLiteStatus InterpreterBuilder::ParseNodes(
diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h
index 8bc9ecd7ce..6abdfcd079 100644
--- a/tensorflow/contrib/lite/model.h
+++ b/tensorflow/contrib/lite/model.h
@@ -35,9 +35,10 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_MODEL_H_
#include <memory>
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
#include "tensorflow/contrib/lite/interpreter.h"
-#include "tensorflow/contrib/lite/op_resolver.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc
index df4f60d4ad..ec7d46af7c 100644
--- a/tensorflow/contrib/lite/model_test.cc
+++ b/tensorflow/contrib/lite/model_test.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/model.h"
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/testing/util.h"
// Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object,
diff --git a/tensorflow/contrib/lite/op_resolver.cc b/tensorflow/contrib/lite/mutable_op_resolver.cc
index f6e435e982..8ee63d2a02 100644
--- a/tensorflow/contrib/lite/op_resolver.cc
+++ b/tensorflow/contrib/lite/mutable_op_resolver.cc
@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/op_resolver.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/mutable_op_resolver.h b/tensorflow/contrib/lite/mutable_op_resolver.h
new file mode 100644
index 0000000000..c319041e9b
--- /dev/null
+++ b/tensorflow/contrib/lite/mutable_op_resolver.h
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_
+#define TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_
+
+#include <unordered_map>
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/util.h"
+
+namespace tflite {
+
+// Some versions of gcc doesn't support partial specialization in class scope,
+// so these are defined in a namescope.
+namespace op_resolver_hasher {
+template <typename V>
+struct ValueHasher {
+ size_t operator()(const V& v) const { return std::hash<V>()(v); }
+};
+
+template <>
+struct ValueHasher<tflite::BuiltinOperator> {
+ size_t operator()(const tflite::BuiltinOperator& v) const {
+ return std::hash<int>()(static_cast<int>(v));
+ }
+};
+
+template <typename T>
+struct OperatorKeyHasher {
+ size_t operator()(const T& x) const {
+ size_t a = ValueHasher<typename T::first_type>()(x.first);
+ size_t b = ValueHasher<typename T::second_type>()(x.second);
+ return CombineHashes({a, b});
+ }
+};
+} // namespace op_resolver_hasher
+
+// An OpResolver that is mutable, also used as the op in gen_op_registration.
+// A typical usage:
+// MutableOpResolver resolver;
+// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD());
+// resolver.AddCustom("CustomOp", Register_CUSTOM_OP());
+// InterpreterBuilder(model, resolver)(&interpreter);
+class MutableOpResolver : public OpResolver {
+ public:
+ const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
+ int version) const override;
+ const TfLiteRegistration* FindOp(const char* op, int version) const override;
+ void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+ void AddCustom(const char* name, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+
+ private:
+ typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey;
+ typedef std::pair<std::string, int> CustomOperatorKey;
+
+ std::unordered_map<BuiltinOperatorKey, TfLiteRegistration,
+ op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> >
+ builtins_;
+ std::unordered_map<CustomOperatorKey, TfLiteRegistration,
+ op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> >
+ custom_ops_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/op_resolver_test.cc b/tensorflow/contrib/lite/mutable_op_resolver_test.cc
index 10b7e31972..db690eaab9 100644
--- a/tensorflow/contrib/lite/op_resolver_test.cc
+++ b/tensorflow/contrib/lite/mutable_op_resolver_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/op_resolver.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/testing/util.h"
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 484842713d..817486e898 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -18,8 +18,8 @@ limitations under the License.
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h"
diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h
index 2bdb2cc5c8..22359d557e 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.h
+++ b/tensorflow/contrib/lite/nnapi_delegate.h
@@ -16,8 +16,8 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_
#include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/interpreter.h"
class ANeuralNetworksModel;
diff --git a/tensorflow/contrib/lite/op_resolver.h b/tensorflow/contrib/lite/op_resolver.h
index 9d7e3f2085..e93134cbde 100644
--- a/tensorflow/contrib/lite/op_resolver.h
+++ b/tensorflow/contrib/lite/op_resolver.h
@@ -12,83 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// Compatibility shim for moved header location.
#ifndef TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_
#define TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_
-#include <unordered_map>
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/schema/schema_generated.h"
-#include "tensorflow/contrib/lite/util.h"
-
-namespace tflite {
-
-// Abstract interface that returns TfLiteRegistrations given op codes or custom
-// op names. This is the mechanism that ops being referenced in the flatbuffer
-// model are mapped to executable function pointers (TfLiteRegistrations).
-class OpResolver {
- public:
- // Finds the op registration for a builtin operator by enum code.
- virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
- int version) const = 0;
- // Finds the op registration of a custom operator by op name.
- virtual const TfLiteRegistration* FindOp(const char* op,
- int version) const = 0;
- virtual ~OpResolver() {}
-};
-
-// Some versions of gcc doesn't support partial specialization in class scope,
-// so these are defined in a namescope.
-namespace op_resolver_hasher {
-template <typename V>
-struct ValueHasher {
- size_t operator()(const V& v) const { return std::hash<V>()(v); }
-};
-
-template <>
-struct ValueHasher<tflite::BuiltinOperator> {
- size_t operator()(const tflite::BuiltinOperator& v) const {
- return std::hash<int>()(static_cast<int>(v));
- }
-};
-
-template <typename T>
-struct OperatorKeyHasher {
- size_t operator()(const T& x) const {
- size_t a = ValueHasher<typename T::first_type>()(x.first);
- size_t b = ValueHasher<typename T::second_type>()(x.second);
- return CombineHashes({a, b});
- }
-};
-} // namespace op_resolver_hasher
-
-// An OpResolver that is mutable, also used as the op in gen_op_registration.
-// A typical usage:
-// MutableOpResolver resolver;
-// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD());
-// resolver.AddCustom("CustomOp", Register_CUSTOM_OP());
-// InterpreterBuilder(model, resolver)(&interpreter);
-class MutableOpResolver : public OpResolver {
- public:
- const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
- int version) const override;
- const TfLiteRegistration* FindOp(const char* op, int version) const override;
- void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
- int min_version = 1, int max_version = 1);
- void AddCustom(const char* name, TfLiteRegistration* registration,
- int min_version = 1, int max_version = 1);
-
- private:
- typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey;
- typedef std::pair<std::string, int> CustomOperatorKey;
-
- std::unordered_map<BuiltinOperatorKey, TfLiteRegistration,
- op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> >
- builtins_;
- std::unordered_map<CustomOperatorKey, TfLiteRegistration,
- op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> >
- custom_ops_;
-};
-
-} // namespace tflite
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
#endif // TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h
index f738315cf2..45d0d8735e 100644
--- a/tensorflow/contrib/lite/simple_memory_arena.h
+++ b/tensorflow/contrib/lite/simple_memory_arena.h
@@ -17,7 +17,7 @@ limitations under the License.
#include <list>
#include <memory>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/error_reporter.cc b/tensorflow/contrib/lite/stderr_reporter.cc
index 646913c026..e29a6345fd 100644
--- a/tensorflow/contrib/lite/error_reporter.cc
+++ b/tensorflow/contrib/lite/stderr_reporter.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/stderr_reporter.h"
#include <cstdarg>
#include <cstdio>
@@ -22,26 +22,6 @@ limitations under the License.
namespace tflite {
-ErrorReporter::~ErrorReporter() {}
-
-int ErrorReporter::Report(const char* format, ...) {
- va_list args;
- va_start(args, format);
- int code = Report(format, args);
- va_end(args);
- return code;
-}
-
-// TODO(aselle): Make the name of ReportError on context the same, so
-// we can use the ensure functions w/o a context and w/ a reporter.
-int ErrorReporter::ReportError(void*, const char* format, ...) {
- va_list args;
- va_start(args, format);
- int code = Report(format, args);
- va_end(args);
- return code;
-}
-
int StderrReporter::Report(const char* format, va_list args) {
#ifdef __ANDROID__
// On Android stderr is not captured for applications, only for code run from
diff --git a/tensorflow/contrib/lite/stderr_reporter.h b/tensorflow/contrib/lite/stderr_reporter.h
new file mode 100644
index 0000000000..c6f4ffbdff
--- /dev/null
+++ b/tensorflow/contrib/lite/stderr_reporter.h
@@ -0,0 +1,34 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_
+#define TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_
+
+#include <cstdarg>
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+
+namespace tflite {
+
+// An error reporter that simplify writes the message to stderr.
+struct StderrReporter : public ErrorReporter {
+ int Report(const char* format, va_list args) override;
+};
+
+// Return the default error reporter (output to stderr).
+ErrorReporter* DefaultErrorReporter();
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_
diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/contrib/lite/string_util.cc
index a316a40b62..b991e999b6 100644
--- a/tensorflow/contrib/lite/string_util.cc
+++ b/tensorflow/contrib/lite/string_util.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/interpreter.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/string_util.h b/tensorflow/contrib/lite/string_util.h
index 57f129bf5e..d24627b509 100644
--- a/tensorflow/contrib/lite/string_util.h
+++ b/tensorflow/contrib/lite/string_util.h
@@ -42,7 +42,7 @@ limitations under the License.
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/string.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/string_util_test.cc b/tensorflow/contrib/lite/string_util_test.cc
index d53fec7512..a583a9184b 100644
--- a/tensorflow/contrib/lite/string_util_test.cc
+++ b/tensorflow/contrib/lite/string_util_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/string_util.h"
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/testing/util.h"
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index 0b3a97d4f5..aad1ecaeb6 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -173,7 +173,6 @@ tf_cc_test(
srcs = ["tflite_driver_test.cc"],
data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"],
tags = [
- "no_oss", # b/112769036
"tflite_not_portable_android",
"tflite_not_portable_ios",
],
@@ -215,6 +214,7 @@ cc_library(
deps = [
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string",
+ "//tensorflow/contrib/lite/core/api",
],
)
diff --git a/tensorflow/contrib/lite/testing/util.h b/tensorflow/contrib/lite/testing/util.h
index 8aa639157b..925791d390 100644
--- a/tensorflow/contrib/lite/testing/util.h
+++ b/tensorflow/contrib/lite/testing/util.h
@@ -17,7 +17,7 @@ limitations under the License.
#include <cstdio>
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/string.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index a75553db84..bea90f1ce8 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -372,6 +372,7 @@ cc_library(
":toco_graphviz_dump_options",
":toco_port",
":types_proto_cc",
+ "//tensorflow/contrib/lite/kernels/internal:types",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@com_googlesource_code_re2//:re2",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index c25be078ff..f103bb94ae 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1314,12 +1314,16 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
// Compute output shape
for (int axis = 0; axis < num_input_axes; ++axis) {
+ const auto strided_slice_params =
+ tflite::strided_slice::BuildStridedSliceParams(
+ op->begin_mask, op->end_mask, op->shrink_axis_mask,
+ op->start_indices, op->stop_indices, op->strides);
int start_index = tflite::strided_slice::StartForAxis(
- op->begin_mask, op->start_indices, op->strides,
- input_array.shape().dims().data(), axis);
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis);
int stop_index = tflite::strided_slice::StopForAxis(
- op->end_mask, op->shrink_axis_mask, op->stop_indices, op->strides,
- input_array.shape().dims().data(), axis, start_index);
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis,
+ start_index);
+
int dim_size =
ceil(static_cast<float>(stop_index - start_index) / op->strides[axis]);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
index 9d8bd4fc39..8853ed87e6 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
@@ -52,14 +52,18 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
Buffer<Type> const& input_buffer = input_array.GetBuffer<Type>();
std::vector<int> src_coord(num_input_axes);
std::vector<int> stop_for_axis(num_input_axes);
+ const auto strided_slice_params =
+ tflite::strided_slice::BuildStridedSliceParams(
+ op.begin_mask, op.end_mask, op.shrink_axis_mask, op.start_indices,
+ op.stop_indices, op.strides);
+
for (int axis = 0; axis < num_input_axes; axis++) {
- int start = tflite::strided_slice::StartForAxis(
- op.begin_mask, op.start_indices, op.strides, input_shape.dims().data(),
- axis);
- src_coord[axis] = start;
+ int start_index = tflite::strided_slice::StartForAxis(
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis);
+ src_coord[axis] = start_index;
stop_for_axis[axis] = tflite::strided_slice::StopForAxis(
- op.end_mask, op.shrink_axis_mask, op.stop_indices, op.strides,
- input_shape.dims().data(), axis, start);
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis,
+ start_index);
}
// In order to handle any number (N) of dimensions, we copy elements one by
@@ -86,8 +90,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
if (tflite::strided_slice::LoopCondition(src_coord[axis], stop, stride)) {
// Reset axis and set carry
src_coord[axis] = tflite::strided_slice::StartForAxis(
- op.begin_mask, op.start_indices, op.strides,
- input_shape.dims().data(), axis);
+ strided_slice_params, ToRuntimeShape(input_shape), axis);
carry = true;
} else {
carry = false;
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index bdeb203024..5f4b8cb66a 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -28,6 +28,7 @@ limitations under the License.
#if TOCO_SUPPORT_PORTABLE_PROTOS
#include "third_party/protobuf/include/google/protobuf/text_format.h"
#endif // TOCO_SUPPORT_PORTABLE_PROTOS
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/runtime/types.h"
@@ -139,6 +140,10 @@ bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1);
// - For the remaining indices [0..i0), d0[i0] == 1.
bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1);
+inline ::tflite::RuntimeShape ToRuntimeShape(const Shape& shape) {
+ return ::tflite::RuntimeShape(shape.dimensions_count(), shape.dims().data());
+}
+
bool IsArrayFullyConnectedWeights(const Model& model, const string& name);
// If there is a wildcard dimension (-1), this may return a negative value.
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD
index a66812fe87..98e2835b2e 100644
--- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD
@@ -54,6 +54,7 @@ tf_cc_test(
linkopts = common_linkopts,
linkstatic = 1,
tags = [
+ "no_oss", # b/114307765
"tflite_not_portable_android",
"tflite_not_portable_ios",
],
diff --git a/tensorflow/contrib/lite/tools/make/Makefile b/tensorflow/contrib/lite/tools/make/Makefile
index e30cc1d70e..59bdb10811 100644
--- a/tensorflow/contrib/lite/tools/make/Makefile
+++ b/tensorflow/contrib/lite/tools/make/Makefile
@@ -24,6 +24,21 @@ HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32
TARGET := $(HOST_OS)
TARGET_ARCH := $(HOST_ARCH)
+INCLUDES := \
+-I. \
+-I$(MAKEFILE_DIR)/../../../../../ \
+-I$(MAKEFILE_DIR)/../../../../../../ \
+-I$(MAKEFILE_DIR)/downloads/ \
+-I$(MAKEFILE_DIR)/downloads/eigen \
+-I$(MAKEFILE_DIR)/downloads/gemmlowp \
+-I$(MAKEFILE_DIR)/downloads/neon_2_sse \
+-I$(MAKEFILE_DIR)/downloads/farmhash/src \
+-I$(MAKEFILE_DIR)/downloads/flatbuffers/include \
+-I$(OBJDIR)
+# This is at the end so any globally-installed frameworks like protobuf don't
+# override local versions in the source tree.
+INCLUDES += -I/usr/local/include
+
# These are the default libraries needed, but they can be added to or
# overridden by the platform-specific settings in target makefiles.
LIBS := \
@@ -44,55 +59,17 @@ ARFLAGS := -r
TARGET_TOOLCHAIN_PREFIX :=
CC_PREFIX :=
-# These target-specific makefiles should modify or replace options like
-# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic
-# based on platforms or architectures should happen within these files, to
-# keep this main makefile focused on the sources and dependencies.
-include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc)
-
-# Where compiled objects are stored.
-GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/
-OBJDIR := $(GENDIR)obj/
-BINDIR := $(GENDIR)bin/
-LIBDIR := $(GENDIR)lib/
-
-INCLUDES := \
--I. \
--I$(MAKEFILE_DIR)/../../../../../ \
--I$(MAKEFILE_DIR)/../../../../../../ \
--I$(MAKEFILE_DIR)/downloads/ \
--I$(MAKEFILE_DIR)/downloads/eigen \
--I$(MAKEFILE_DIR)/downloads/gemmlowp \
--I$(MAKEFILE_DIR)/downloads/neon_2_sse \
--I$(MAKEFILE_DIR)/downloads/farmhash/src \
--I$(MAKEFILE_DIR)/downloads/flatbuffers/include \
--I$(OBJDIR)
-# This is at the end so any globally-installed frameworks like protobuf don't
-# override local versions in the source tree.
-INCLUDES += -I/usr/local/include
-
-CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++
-CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc
-AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar
-
# This library is the main target for this makefile. It will contain a minimal
# runtime that can be linked in to other programs.
LIB_NAME := libtensorflow-lite.a
-LIB_PATH := $(LIBDIR)$(LIB_NAME)
-
-# A small example program that shows how to link against the library.
-MINIMAL_PATH := $(BINDIR)minimal
# Benchmark static library and binary
BENCHMARK_LIB_NAME := benchmark-lib.a
BENCHMARK_BINARY_NAME := benchmark_model
-BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME)
-BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME)
+# A small example program that shows how to link against the library.
MINIMAL_SRCS := \
tensorflow/contrib/lite/examples/minimal/minimal.cc
-MINIMAL_OBJS := $(addprefix $(OBJDIR), \
-$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS))))
# What sources we want to compile, must be kept in sync with the main Bazel
# build files.
@@ -105,7 +82,9 @@ PROFILE_SUMMARIZER_SRCS := \
CORE_CC_ALL_SRCS := \
$(wildcard tensorflow/contrib/lite/*.cc) \
-$(wildcard tensorflow/contrib/lite/*.c)
+$(wildcard tensorflow/contrib/lite/*.c) \
+$(wildcard tensorflow/contrib/lite/c/*.c) \
+$(wildcard tensorflow/contrib/lite/core/api/*.cc)
ifneq ($(BUILD_TYPE),micro)
CORE_CC_ALL_SRCS += \
$(wildcard tensorflow/contrib/lite/kernels/*.cc) \
@@ -136,10 +115,6 @@ tensorflow/contrib/lite/nnapi_delegate.cc
endif
# Filter out all the excluded files.
TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS))
-# File names of the intermediate files target compilation generates.
-TF_LITE_CC_OBJS := $(addprefix $(OBJDIR), \
-$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS))))
-LIB_OBJS := $(TF_LITE_CC_OBJS)
# Benchmark sources
BENCHMARK_SRCS_DIR := tensorflow/contrib/lite/tools/benchmark
@@ -151,6 +126,40 @@ BENCHMARK_SRCS := $(filter-out \
$(wildcard $(BENCHMARK_SRCS_DIR)/*_test.cc), \
$(BENCHMARK_ALL_SRCS))
+# These target-specific makefiles should modify or replace options like
+# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic
+# based on platforms or architectures should happen within these files, to
+# keep this main makefile focused on the sources and dependencies.
+include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc)
+
+ALL_SRCS := \
+ $(MINIMAL_SRCS) \
+ $(PROFILER_SRCS) \
+ $(PROFILER_SUMMARY_SRCS) \
+ $(TF_LITE_CC_SRCS) \
+ $(BENCHMARK_SRCS)
+
+# Where compiled objects are stored.
+GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/
+OBJDIR := $(GENDIR)obj/
+BINDIR := $(GENDIR)bin/
+LIBDIR := $(GENDIR)lib/
+
+LIB_PATH := $(LIBDIR)$(LIB_NAME)
+BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME)
+BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME)
+MINIMAL_BINARY := $(BINDIR)minimal
+
+CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++
+CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc
+AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar
+
+MINIMAL_OBJS := $(addprefix $(OBJDIR), \
+$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS))))
+
+LIB_OBJS := $(addprefix $(OBJDIR), \
+$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS))))
+
BENCHMARK_OBJS := $(addprefix $(OBJDIR), \
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS))))
@@ -164,7 +173,7 @@ $(OBJDIR)%.o: %.c
$(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@
# The target that's compiled if there's no command-line arguments.
-all: $(LIB_PATH) $(MINIMAL_PATH) $(BENCHMARK_BINARY)
+all: $(LIB_PATH) $(MINIMAL_BINARY) $(BENCHMARK_BINARY)
# The target that's compiled for micro-controllers
micro: $(LIB_PATH)
@@ -178,19 +187,18 @@ $(LIB_PATH): tensorflow/contrib/lite/schema/schema_generated.h $(LIB_OBJS)
@mkdir -p $(dir $@)
$(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS)
-$(MINIMAL_PATH): $(MINIMAL_OBJS) $(LIB_PATH)
+$(MINIMAL_BINARY): $(MINIMAL_OBJS) $(LIB_PATH)
@mkdir -p $(dir $@)
$(CXX) $(CXXFLAGS) $(INCLUDES) \
- -o $(MINIMAL_PATH) $(MINIMAL_OBJS) \
+ -o $(MINIMAL_BINARY) $(MINIMAL_OBJS) \
$(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS)
-
$(BENCHMARK_LIB) : $(LIB_PATH) $(BENCHMARK_OBJS)
@mkdir -p $(dir $@)
$(AR) $(ARFLAGS) $(BENCHMARK_LIB) $(LIB_OBJS) $(BENCHMARK_OBJS)
benchmark_lib: $(BENCHMARK_LIB)
-$(info $(BENCHMARK_BINARY))
+
$(BENCHMARK_BINARY) : $(BENCHMARK_LIB)
@mkdir -p $(dir $@)
$(CXX) $(CXXFLAGS) $(INCLUDES) \
@@ -213,4 +221,4 @@ cleantarget:
$(DEPDIR)/%.d: ;
.PRECIOUS: $(DEPDIR)/%.d
--include $(patsubst %,$(DEPDIR)/%.d,$(basename $(TF_CC_SRCS)))
+-include $(patsubst %,$(DEPDIR)/%.d,$(basename $(ALL_SRCS)))
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
index 692efb9029..b863108aa4 100644
--- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
@@ -141,6 +141,7 @@ bool IsHybridEvaluationOp(const OperatorT* op, const BuiltinOperator& op_code) {
op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF ||
op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
op_code == BuiltinOperator_RNN ||
+ op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM ||
op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
diff --git a/tensorflow/contrib/lite/tutorials/BUILD b/tensorflow/contrib/lite/tutorials/BUILD
new file mode 100644
index 0000000000..67ff1ea124
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/BUILD
@@ -0,0 +1,20 @@
+# Example Estimator model
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_binary(
+ name = "mnist_tflite",
+ srcs = [
+ "dataset.py",
+ "mnist_tflite.py",
+ ],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
diff --git a/tensorflow/contrib/lite/tutorials/dataset.py b/tensorflow/contrib/lite/tutorials/dataset.py
new file mode 100644
index 0000000000..ba49dfcc9b
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/dataset.py
@@ -0,0 +1,122 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""tf.data.Dataset interface to the MNIST dataset.
+
+ This is cloned from
+ https://github.com/tensorflow/models/blob/master/official/mnist/dataset.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import shutil
+import tempfile
+
+import numpy as np
+from six.moves import urllib
+import tensorflow as tf
+
+
+def read32(bytestream):
+ """Read 4 bytes from bytestream as an unsigned 32-bit integer."""
+ dt = np.dtype(np.uint32).newbyteorder('>')
+ return np.frombuffer(bytestream.read(4), dtype=dt)[0]
+
+
+def check_image_file_header(filename):
+ """Validate that filename corresponds to images for the MNIST dataset."""
+ with tf.gfile.Open(filename, 'rb') as f:
+ magic = read32(f)
+ read32(f) # num_images, unused
+ rows = read32(f)
+ cols = read32(f)
+ if magic != 2051:
+ raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
+ f.name))
+ if rows != 28 or cols != 28:
+ raise ValueError(
+ 'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %
+ (f.name, rows, cols))
+
+
+def check_labels_file_header(filename):
+ """Validate that filename corresponds to labels for the MNIST dataset."""
+ with tf.gfile.Open(filename, 'rb') as f:
+ magic = read32(f)
+ read32(f) # num_items, unused
+ if magic != 2049:
+ raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
+ f.name))
+
+
+def download(directory, filename):
+ """Download (and unzip) a file from the MNIST dataset if not already done."""
+ filepath = os.path.join(directory, filename)
+ if tf.gfile.Exists(filepath):
+ return filepath
+ if not tf.gfile.Exists(directory):
+ tf.gfile.MakeDirs(directory)
+ # CVDF mirror of http://yann.lecun.com/exdb/mnist/
+ url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
+ _, zipped_filepath = tempfile.mkstemp(suffix='.gz')
+ print('Downloading %s to %s' % (url, zipped_filepath))
+ urllib.request.urlretrieve(url, zipped_filepath)
+ with gzip.open(zipped_filepath, 'rb') as f_in, \
+ tf.gfile.Open(filepath, 'wb') as f_out:
+ shutil.copyfileobj(f_in, f_out)
+ os.remove(zipped_filepath)
+ return filepath
+
+
+def dataset(directory, images_file, labels_file):
+ """Download and parse MNIST dataset."""
+
+ images_file = download(directory, images_file)
+ labels_file = download(directory, labels_file)
+
+ check_image_file_header(images_file)
+ check_labels_file_header(labels_file)
+
+ def decode_image(image):
+ # Normalize from [0, 255] to [0.0, 1.0]
+ image = tf.decode_raw(image, tf.uint8)
+ image = tf.cast(image, tf.float32)
+ image = tf.reshape(image, [784])
+ return image / 255.0
+
+ def decode_label(label):
+ label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8]
+ label = tf.reshape(label, []) # label is a scalar
+ return tf.to_int32(label)
+
+ images = tf.data.FixedLengthRecordDataset(
+ images_file, 28 * 28, header_bytes=16).map(decode_image)
+ labels = tf.data.FixedLengthRecordDataset(
+ labels_file, 1, header_bytes=8).map(decode_label)
+ return tf.data.Dataset.zip((images, labels))
+
+
+def train(directory):
+ """tf.data.Dataset object for MNIST training data."""
+ return dataset(directory, 'train-images-idx3-ubyte',
+ 'train-labels-idx1-ubyte')
+
+
+def test(directory):
+ """tf.data.Dataset object for MNIST test data."""
+ return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')
diff --git a/tensorflow/contrib/lite/tutorials/mnist_tflite.py b/tensorflow/contrib/lite/tutorials/mnist_tflite.py
new file mode 100644
index 0000000000..7b8bf5b5db
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/mnist_tflite.py
@@ -0,0 +1,87 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to evaluate accuracy of TFLite flatbuffer model on mnist dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+import tensorflow as tf # pylint: disable=g-bad-import-order
+from tensorflow.contrib.lite.tutorials import dataset
+flags = tf.app.flags
+
+flags.DEFINE_string('data_dir', '/tmp/data_dir',
+ 'Directory where data is stored.')
+flags.DEFINE_string('model_file', '',
+ 'The path to the TFLite flatbuffer model file.')
+
+
+flags = flags.FLAGS
+
+
+def test_image_generator():
+ # Generates an iterator over images
+ with tf.Session() as sess:
+ input_data = dataset.test(
+ flags.data_dir).make_one_shot_iterator().get_next()
+ try:
+ while True:
+ yield sess.run(input_data)
+ except tf.errors.OutOfRangeError:
+ pass
+
+
+def run_eval(interpreter, input_image):
+ """Performs evaluation for input image over specified model.
+
+ Args:
+ interpreter: TFLite interpreter initialized with model to execute.
+ input_image: Image input to the model.
+
+ Returns:
+ output: output tensor of model being executed.
+ """
+
+ # Get input and output tensors.
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+
+ # Test model on the input images.
+ input_image = np.reshape(input_image, input_details[0]['shape'])
+ interpreter.set_tensor(input_details[0]['index'], input_image)
+
+ interpreter.invoke()
+ output_data = interpreter.get_tensor(output_details[0]['index'])
+ output = np.squeeze(output_data)
+ return output
+
+
+def main(_):
+ interpreter = tf.contrib.lite.Interpreter(model_path=flags.model_file)
+ interpreter.allocate_tensors()
+ num_correct, total = 0, 0
+ for input_data in test_image_generator():
+ output = run_eval(interpreter, input_data[0])
+ total += 1
+ if output == input_data[1]:
+ num_correct += 1
+ if total % 500 == 0:
+ print('Accuracy after %i images: %f' %
+ (total, float(num_correct) / float(total)))
+
+
+if __name__ == '__main__':
+ tf.logging.set_verbosity(tf.logging.INFO)
+ tf.app.run(main)
diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h
index f5b208afbb..6d81f844f8 100644
--- a/tensorflow/contrib/lite/util.h
+++ b/tensorflow/contrib/lite/util.h
@@ -22,7 +22,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_UTIL_H_
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/util_test.cc b/tensorflow/contrib/lite/util_test.cc
index 32bf917a59..c5c1709f1d 100644
--- a/tensorflow/contrib/lite/util_test.cc
+++ b/tensorflow/contrib/lite/util_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/util.h"
namespace tflite {
diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt
index 22b11f1c57..7d26429f9c 100644
--- a/tensorflow/contrib/makefile/proto_text_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt
@@ -56,6 +56,7 @@ tensorflow/core/lib/hash/hash.cc
tensorflow/core/lib/hash/crc32c.cc
tensorflow/core/lib/hash/crc32c_accelerate.cc
tensorflow/core/lib/core/threadpool.cc
+tensorflow/core/lib/core/stringpiece.cc
tensorflow/core/lib/core/status.cc
tensorflow/core/lib/core/coding.cc
tensorflow/core/lib/core/arena.cc
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index 93e589907e..2e4d61d931 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -159,8 +159,10 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:variables",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
index f026f437dc..f55209ec49 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
@@ -25,7 +25,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
@@ -48,12 +47,7 @@ class LazyAdamOptimizer(adam.AdamOptimizer):
may lead to different empirical results.
"""
- def _apply_sparse_shared(self,
- grad,
- var,
- indices,
- scatter_update,
- scatter_sub):
+ def _apply_sparse(self, grad, var):
beta1_power, beta2_power = self._get_beta_accumulators()
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
@@ -65,51 +59,56 @@ class LazyAdamOptimizer(adam.AdamOptimizer):
# \\(m := beta1 * m + (1 - beta1) * g_t\\)
m = self.get_slot(var, "m")
- m_t = scatter_update(m, indices,
- beta1_t * array_ops.gather(m, indices) +
- (1 - beta1_t) * grad)
+ m_t = state_ops.scatter_update(m, grad.indices,
+ beta1_t * array_ops.gather(m, grad.indices) +
+ (1 - beta1_t) * grad.values,
+ use_locking=self._use_locking)
# \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
v = self.get_slot(var, "v")
- v_t = scatter_update(v, indices,
- beta2_t * array_ops.gather(v, indices) +
- (1 - beta2_t) * math_ops.square(grad))
+ v_t = state_ops.scatter_update(v, grad.indices,
+ beta2_t * array_ops.gather(v, grad.indices) +
+ (1 - beta2_t) * math_ops.square(grad.values),
+ use_locking=self._use_locking)
# \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
- m_t_slice = array_ops.gather(m_t, indices)
- v_t_slice = array_ops.gather(v_t, indices)
+ m_t_slice = array_ops.gather(m_t, grad.indices)
+ v_t_slice = array_ops.gather(v_t, grad.indices)
denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t
- var_update = scatter_sub(var, indices,
- lr * m_t_slice / denominator_slice)
+ var_update = state_ops.scatter_sub(var, grad.indices,
+ lr * m_t_slice / denominator_slice,
+ use_locking=self._use_locking)
return control_flow_ops.group(var_update, m_t, v_t)
- def _apply_sparse(self, grad, var):
- return self._apply_sparse_shared(
- grad.values, var, grad.indices,
- self._scatter_update,
- self._scatter_sub)
-
def _resource_apply_sparse(self, grad, var, indices):
- return self._apply_sparse_shared(
- grad, var, indices,
- self._resource_scatter_update,
- self._resource_scatter_sub)
-
- # Utility functions for updating resource or non-resource variables.
- def _scatter_update(self, x, i, v):
- return state_ops.scatter_update(
- x, i, v, use_locking=self._use_locking)
-
- def _scatter_sub(self, x, i, v):
- return state_ops.scatter_sub(
- x, i, v, use_locking=self._use_locking)
-
- def _resource_scatter_update(self, x, i, v):
- update_op = resource_variable_ops.resource_scatter_update(x.handle, i, v)
- with ops.control_dependencies([update_op]):
- return x.value()
-
- def _resource_scatter_sub(self, x, i, v):
- sub_op = resource_variable_ops.resource_scatter_sub(x.handle, i, v)
- with ops.control_dependencies([sub_op]):
- return x.value()
+ beta1_power, beta2_power = self._get_beta_accumulators()
+ beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
+ beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
+ lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
+ beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
+ beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
+ epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
+ lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
+
+ # \\(m := beta1 * m + (1 - beta1) * g_t\\)
+ m = self.get_slot(var, "m")
+ m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad
+ m_update_op = resource_variable_ops.resource_scatter_update(m.handle,
+ indices,
+ m_t_slice)
+
+ # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
+ v = self.get_slot(var, "v")
+ v_t_slice = (beta2_t * array_ops.gather(v, indices) +
+ (1 - beta2_t) * math_ops.square(grad))
+ v_update_op = resource_variable_ops.resource_scatter_update(v.handle,
+ indices,
+ v_t_slice)
+
+ # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
+ var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t)
+ var_update_op = resource_variable_ops.resource_scatter_sub(var.handle,
+ indices,
+ var_slice)
+
+ return control_flow_ops.group(var_update_op, m_update_op, v_update_op)
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
index d3e9e89502..f08ffaa36f 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
@@ -19,12 +19,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.opt.python.training import lazy_adam_optimizer
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
@@ -50,9 +53,10 @@ def adam_update_numpy(param,
return param_t, m_t, v_t
-class AdamOptimizerTest(test.TestCase):
+class AdamOptimizerTest(test.TestCase, parameterized.TestCase):
- def doTestSparse(self, use_resource=False):
+ @parameterized.parameters([False, True])
+ def testSparse(self, use_resource):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
# Initialize variables for numpy implementation.
@@ -68,6 +72,7 @@ class AdamOptimizerTest(test.TestCase):
else:
var0 = variables.Variable(var0_np)
var1 = variables.Variable(var1_np)
+
grads0_np_indices = np.array([0, 1], dtype=np.int32)
grads0 = ops.IndexedSlices(
constant_op.constant(grads0_np),
@@ -99,18 +104,17 @@ class AdamOptimizerTest(test.TestCase):
self.assertAllCloseAccordingToType(var0_np, var0.eval())
self.assertAllCloseAccordingToType(var1_np, var1.eval())
- def testSparse(self):
- self.doTestSparse(use_resource=False)
-
- def testResourceSparse(self):
- self.doTestSparse(use_resource=True)
-
- def testSparseDevicePlacement(self):
+ @parameterized.parameters([False, True])
+ def testSparseDevicePlacement(self, use_resource):
for index_dtype in [dtypes.int32, dtypes.int64]:
with self.test_session(force_gpu=test.is_gpu_available()):
# If a GPU is available, tests that all optimizer ops can be placed on
# it (i.e. they have GPU kernels).
- var = variables.Variable([[1.0], [2.0]])
+ if use_resource:
+ var = resource_variable_ops.ResourceVariable([[1.0], [2.0]])
+ else:
+ var = variables.Variable([[1.0], [2.0]])
+
indices = constant_op.constant([0, 1], dtype=index_dtype)
gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices))
optimizer = lazy_adam_optimizer.LazyAdamOptimizer(3.0)
@@ -118,13 +122,21 @@ class AdamOptimizerTest(test.TestCase):
variables.global_variables_initializer().run()
minimize_op.run()
- def testSparseRepeatedIndices(self):
+ @parameterized.parameters([False, True])
+ def testSparseRepeatedIndices(self, use_resource):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
- repeated_index_update_var = variables.Variable(
- [[1.0], [2.0]], dtype=dtype)
- aggregated_update_var = variables.Variable(
- [[1.0], [2.0]], dtype=dtype)
+ if use_resource:
+ repeated_index_update_var = resource_variable_ops.ResourceVariable(
+ [[1.0], [2.0]], dtype=dtype)
+ aggregated_update_var = resource_variable_ops.ResourceVariable(
+ [[1.0], [2.0]], dtype=dtype)
+ else:
+ repeated_index_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ aggregated_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+
grad_repeated_index = ops.IndexedSlices(
constant_op.constant(
[0.1, 0.1], shape=[2, 1], dtype=dtype),
@@ -150,6 +162,204 @@ class AdamOptimizerTest(test.TestCase):
self.assertAllClose(aggregated_update_var.eval(),
repeated_index_update_var.eval())
+ def doTestBasic(self, use_resource=False, use_callable_params=False):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ with self.session(graph=ops.Graph()):
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_np, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_np, name="var1_%d" % i)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ learning_rate = lambda: 0.001
+ beta1 = lambda: 0.9
+ beta2 = lambda: 0.999
+ epsilon = lambda: 1e-8
+ if not use_callable_params:
+ learning_rate = learning_rate()
+ beta1 = beta1()
+ beta2 = beta2()
+ epsilon = epsilon()
+
+ opt = lazy_adam_optimizer.LazyAdamOptimizer(learning_rate=learning_rate)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ opt_variables = opt.variables()
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+ self.assertIsNotNone(beta1_power)
+ self.assertIsNotNone(beta2_power is not None)
+ self.assertIn(beta1_power, opt_variables)
+ self.assertIn(beta2_power, opt_variables)
+
+ if not context.executing_eagerly():
+ with ops.Graph().as_default():
+ # Shouldn't return non-slot variables from other graphs.
+ self.assertEqual(0, len(opt.variables()))
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ if not context.executing_eagerly():
+ self.evaluate(update)
+ elif t > 1:
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ self.assertAllCloseAccordingToType(0.9**(t + 1),
+ self.evaluate(beta1_power))
+ self.assertAllCloseAccordingToType(0.999**(t + 1),
+ self.evaluate(beta2_power))
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+ if use_resource:
+ self.assertEqual("var0_%d/Adam:0" % (i,),
+ opt.get_slot(var=var0, name="m").name)
+
+ def testBasic(self):
+ with self.test_session():
+ self.doTestBasic(use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTestBasic(use_resource=True)
+
+ def testBasicCallableParams(self):
+ with context.eager_mode():
+ self.doTestBasic(use_resource=True, use_callable_params=True)
+
+ def testTensorLearningRate(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = lazy_adam_optimizer.LazyAdamOptimizer(constant_op.constant(0.001))
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ update.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testSharing(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = lazy_adam_optimizer.LazyAdamOptimizer()
+ update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 3 steps of intertwined Adam1 and Adam2.
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ if t % 2 == 0:
+ update1.run()
+ else:
+ update2.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testTwoSessions(self):
+ optimizer = lazy_adam_optimizer.LazyAdamOptimizer()
+
+ with context.eager_mode():
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+ optimizer.apply_gradients([(grads0, var0)])
+
+ g = ops.Graph()
+ with g.as_default():
+ with self.session(graph=g):
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+ optimizer.apply_gradients([(grads0, var0)])
+
+ gg = ops.Graph()
+ with gg.as_default():
+ with self.session(graph=gg):
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+
+ # If the optimizer saves any state not keyed by graph the following line
+ # fails.
+ optimizer.apply_gradients([(grads0, var0)])
+
+ def testSlotsUniqueEager(self):
+ with context.eager_mode():
+ v1 = resource_variable_ops.ResourceVariable(1.)
+ v2 = resource_variable_ops.ResourceVariable(1.)
+ opt = lazy_adam_optimizer.LazyAdamOptimizer(1.)
+ opt.minimize(lambda: v1 + v2)
+ # There should be two non-slot variables, and two unique slot variables
+ # for v1 and v2 respectively.
+ self.assertEqual(6, len(set(opt.variables())))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD
index 499fec4ffa..c59f667f6a 100644
--- a/tensorflow/contrib/quantize/BUILD
+++ b/tensorflow/contrib/quantize/BUILD
@@ -22,6 +22,7 @@ py_test(
":common",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
"//tensorflow/python:variable_scope",
@@ -89,7 +90,6 @@ py_library(
":common",
":graph_matcher",
":input_to_ops",
- "//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
@@ -171,7 +171,6 @@ py_library(
":graph_matcher",
":input_to_ops",
":quant_ops",
- "//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py
index bf648e158e..b27117dd48 100644
--- a/tensorflow/contrib/quantize/python/common.py
+++ b/tensorflow/contrib/quantize/python/common.py
@@ -131,3 +131,29 @@ def DropStringPrefix(s, prefix):
return s[len(prefix):]
else:
return s
+
+
+def RerouteTensor(t0, t1, can_modify=None):
+ """Reroute the end of the tensor t0 to the ends of the tensor t1.
+
+ Args:
+ t0: a tf.Tensor.
+ t1: a tf.Tensor.
+ can_modify: iterable of operations which can be modified. Any operation
+ outside within_ops will be left untouched by this function.
+
+ Returns:
+ The number of individual modifications made by the function.
+ """
+ nb_update_inputs = 0
+ consumers = t1.consumers()
+ if can_modify is not None:
+ consumers = [c for c in consumers if c in can_modify]
+ consumers_indices = {}
+ for c in consumers:
+ consumers_indices[c] = [i for i, t in enumerate(c.inputs) if t is t1]
+ for c in consumers:
+ for i in consumers_indices[c]:
+ c._update_input(i, t0) # pylint: disable=protected-access
+ nb_update_inputs += 1
+ return nb_update_inputs
diff --git a/tensorflow/contrib/quantize/python/common_test.py b/tensorflow/contrib/quantize/python/common_test.py
index 06c62f2d26..2b26302f8a 100644
--- a/tensorflow/contrib/quantize/python/common_test.py
+++ b/tensorflow/contrib/quantize/python/common_test.py
@@ -20,8 +20,10 @@ from __future__ import print_function
from tensorflow.contrib.quantize.python import common
from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -62,6 +64,29 @@ class CommonTest(test_util.TensorFlowTestCase):
_, step_val = sess.run([b, quantization_step_tensor])
self.assertEqual(step_val, 2)
+ def testRerouteTensor(self):
+ a = constant_op.constant(1, name='a')
+ b = constant_op.constant(2, name='b')
+ c = constant_op.constant(3, name='c')
+ d = constant_op.constant(4, name='d')
+
+ add_ac = math_ops.add(a, c)
+ add_ad = math_ops.add(a, d)
+
+ # Ensure that before rerouting the inputs are what we think.
+ self._CheckOpHasInputs(add_ac.op, [a, c])
+ self._CheckOpHasInputs(add_ad.op, [a, d])
+
+ # references to tensor a should be replaced with b for all ops in
+ # can_modify. This means add_ac will be changed but add_ad will not.
+ common.RerouteTensor(b, a, can_modify=[add_ac.op])
+ self._CheckOpHasInputs(add_ac.op, [b, c])
+ self._CheckOpHasInputs(add_ad.op, [a, d])
+
+ def _CheckOpHasInputs(self, op, inputs):
+ for i in inputs:
+ self.assertIn(i, op.inputs)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index d9f179bee4..2971b28f45 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import re
-from tensorflow.contrib import graph_editor
from tensorflow.contrib.quantize.python import common
from tensorflow.contrib.quantize.python import graph_matcher
from tensorflow.contrib.quantize.python import input_to_ops
@@ -134,8 +133,8 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
bias_add_tensor = math_ops.add(
new_layer_tensor, bias_tensor, name='add_fold')
- nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor,
- match.output_tensor)
+ nodes_modified_count = common.RerouteTensor(bias_add_tensor,
+ match.output_tensor)
if nodes_modified_count == 0:
raise ValueError('Folding batch norms failed, %s had no outputs.' %
match.output_tensor.name)
@@ -370,8 +369,9 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
lambda: match.bn_decay_mean_tensor,
name='freeze_moving_mean')
- graph_editor.reroute_ts(
- [bn_decay_mean_out], [match.bn_decay_mean_tensor],
+ common.RerouteTensor(
+ bn_decay_mean_out,
+ match.bn_decay_mean_tensor,
can_modify=bn_decay_mean_consumers)
bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
@@ -380,8 +380,9 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
lambda: bn_decay_zero,
lambda: match.bn_decay_var_tensor,
name='freeze_moving_var')
- graph_editor.reroute_ts(
- [bn_decay_var_out], [match.bn_decay_var_tensor],
+ common.RerouteTensor(
+ bn_decay_var_out,
+ match.bn_decay_var_tensor,
can_modify=bn_decay_var_consumers)
correction_recip = utils.smart_cond(
@@ -486,9 +487,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
activation = common.GetEndpointActivationOp(graph, bn)
if activation:
- nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
- [original_op.outputs[0]],
- can_modify=[activation])
+ nodes_modified_count = common.RerouteTensor(
+ folded_op.outputs[0], original_op.outputs[0], can_modify=[activation])
if nodes_modified_count != 1:
raise ValueError('Unexpected inputs to op: %s' % activation.name)
continue
@@ -497,9 +497,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
# operations instead of Relu* above.
add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
- nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
- [original_op.outputs[0]],
- can_modify=[add_bypass])
+ nodes_modified_count = common.RerouteTensor(
+ folded_op.outputs[0], original_op.outputs[0], can_modify=[add_bypass])
if nodes_modified_count != 1:
raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 2ddbd73ea6..e88db0acd5 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import re
-from tensorflow.contrib import graph_editor
from tensorflow.contrib.quantize.python import common
from tensorflow.contrib.quantize.python import graph_matcher
from tensorflow.contrib.quantize.python import input_to_ops
@@ -592,8 +591,8 @@ def _InsertQuantOp(context,
name=name_prefix + '/delayed_quant')
if consumers:
- tensors_modified_count = graph_editor.reroute_ts(
- [quant], [inputs], can_modify=consumers)
+ tensors_modified_count = common.RerouteTensor(
+ quant, inputs, can_modify=consumers)
# Some operations can have multiple output tensors going to the same
# consumer. Since consumers is a set, we need to ensure that
# tensors_modified_count is greater than or equal to the length of the set
diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD
index 5874245d58..4e67d80558 100644
--- a/tensorflow/contrib/rnn/BUILD
+++ b/tensorflow/contrib/rnn/BUILD
@@ -212,6 +212,7 @@ cuda_py_tests(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
+ tags = ["noasan"],
)
tf_custom_op_library(
@@ -279,7 +280,10 @@ cuda_py_tests(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
- tags = ["no_oss"],
+ tags = [
+ "no_oss",
+ "noasan",
+ ],
)
tf_cc_test(
@@ -287,6 +291,7 @@ tf_cc_test(
size = "small",
srcs = ["ops/gru_ops_test.cc"],
data = [":python/ops/_gru_ops.so"],
+ tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
@@ -306,6 +311,7 @@ tf_cc_test(
size = "small",
srcs = ["ops/lstm_ops_test.cc"],
data = [":python/ops/_lstm_ops.so"],
+ tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index f74c95f962..06c481672c 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -97,10 +97,10 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
The default non-peephole implementation is based on:
- http://www.bioinf.jku.at/publications/older/2604.pdf
+ https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
- S. Hochreiter and J. Schmidhuber.
- "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+ Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
+ "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
The peephole implementation is based on:
@@ -2448,10 +2448,10 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
The default non-peephole implementation is based on:
- http://www.bioinf.jku.at/publications/older/2604.pdf
+ https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
- S. Hochreiter and J. Schmidhuber.
- "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+ Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
+ "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
The peephole implementation is based on:
@@ -2802,9 +2802,11 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
Training of Deep Neural Networks
The default LSTM implementation based on:
- http://www.bioinf.jku.at/publications/older/2604.pdf
- S. Hochreiter and J. Schmidhuber.
- "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+
+ https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
+
+ Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
+ "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
The class uses optional peephole connections, optional cell clipping
and an optional projection layer.
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index db970deff5..0042d37acd 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -134,19 +134,19 @@ def _get_default_head(params, weights_name, output_type, name=None):
weight_column=weights_name,
label_dimension=params.num_outputs,
name=name,
- loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
else:
if params.num_classes == 2:
return core_head_lib.binary_classification_head(
weight_column=weights_name,
name=name,
- loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
else:
return core_head_lib.multi_class_head(
n_classes=params.num_classes,
weight_column=weights_name,
name=name,
- loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
def get_model_fn(params,
graph_builder_class,
diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py
index 537d94b797..3c0456dc2f 100644
--- a/tensorflow/contrib/tpu/__init__.py
+++ b/tensorflow/contrib/tpu/__init__.py
@@ -33,6 +33,7 @@
@@shard
@@batch_parallel
@@rewrite
+@@outside_compilation
@@CrossShardOptimizer
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index 08e0465b71..d8c3872363 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -258,6 +258,8 @@ class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
def set_weights(self, weights):
+ # TODO(power): Figure out whether we really need this given there is no
+ # caller for this API yet.
self._opt.set_weights()
def get_weights(self):
@@ -282,9 +284,9 @@ def _valid_name(tensor_name):
def _replicated_optimizer(opt):
"""Wrap the optimizer `opt` with CrossShardOptimizer if applicable."""
- if tpu_function.get_tpu_context().number_of_shards == 1:
- return opt
-
+ # Always wrap `opt` with CrossShardOptimizer, even if we are running on a
+ # single core. This ensures Keras properly tracks and initializes optimizer
+ # variables.
if isinstance(opt, keras_optimizers.TFOptimizer):
return tpu_optimizer.CrossShardOptimizer(opt.optimizer)
else:
@@ -1420,7 +1422,7 @@ class KerasTPUModel(models.Model):
y,
sample_weights,
batch_size)
- self._pipeline_fit_loop(
+ return self._pipeline_fit_loop(
x,
y,
sample_weights=sample_weights,
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 1e21cc5252..c1f90c3963 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -652,13 +652,28 @@ def split_compile_and_replicate(computation,
# TODO(phawkins): consider removing this code. It will
# be less confusing to clients if they knowingly choose to use resource
# variables.
+ # Partitioned variables is not supported (b/112311320).
+ def custom_getter(getter, name, *args, **kwargs):
+ partitioner = kwargs["partitioner"]
+ if partitioner is None:
+ return getter(name, *args, **kwargs)
+ else:
+ raise ValueError(
+ "Partitioned variables are not supported on TPU. Got "
+ "`partitioner` that is {}.".format(partitioner))
+
vscope = variable_scope.get_variable_scope()
+
saved_use_resource = vscope.use_resource
+ saved_custom_getter = vscope.custom_getter
+
vscope.set_use_resource(True)
+ vscope.set_custom_getter(custom_getter)
outputs = computation(*computation_inputs)
vscope.set_use_resource(saved_use_resource)
+ vscope.set_custom_getter(saved_custom_getter)
# If the computation returns `None`, make it an empty tuple.
if outputs is None:
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
index ad3dce1784..d4951b156c 100644
--- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
@@ -63,7 +63,7 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
}
CHECK(dst_name.compare(rdma_mgr_->local_worker()) == 0);
RdmaChannel* rc = rdma_mgr_->FindChannel(src_name);
- string key(std::move(parsed.FullKey().ToString()));
+ string key(parsed.FullKey());
string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_);
Device* dst_dev;