diff options
Diffstat (limited to 'tensorflow/contrib')
213 files changed, 4072 insertions, 2323 deletions
diff --git a/tensorflow/contrib/autograph/converters/logical_expressions.py b/tensorflow/contrib/autograph/converters/logical_expressions.py index 16eb1f0e3f..41c3424fa3 100644 --- a/tensorflow/contrib/autograph/converters/logical_expressions.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions.py @@ -57,8 +57,8 @@ class LogicalExpressionTransformer(converter.Base): gast.NotEq: 'tf.not_equal', gast.Or: 'tf.logical_or', gast.USub: 'tf.negative', - gast.Is: 'autograph_utils.dynamic_is', - gast.IsNot: 'autograph_utils.dynamic_is_not' + gast.Is: 'ag__.utils.dynamic_is', + gast.IsNot: 'ag__.utils.dynamic_is_not' } def _expect_simple_symbol(self, operand): diff --git a/tensorflow/contrib/autograph/converters/logical_expressions_test.py b/tensorflow/contrib/autograph/converters/logical_expressions_test.py index 8f9eee7081..409a73afba 100644 --- a/tensorflow/contrib/autograph/converters/logical_expressions_test.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions_test.py @@ -47,6 +47,15 @@ class GradientsFunctionTest(converter_testing.TestCase): with self.cached_session() as sess: self.assertTrue(sess.run(result.test_fn(True, False, True))) + def test_ag_utils_lookup(self): + def test_fn(a, b): + return a is b or a is not b + + with self.converted(test_fn, logical_expressions, {}, math_ops.logical_or + ) as result: + with self.cached_session() as sess: + self.assertTrue(sess.run(result.test_fn(True, False))) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py index 803fde9089..a4c6fed265 100644 --- a/tensorflow/contrib/autograph/impl/api_test.py +++ b/tensorflow/contrib/autograph/impl/api_test.py @@ -38,9 +38,6 @@ class ApiTest(test.TestCase): def setUp(self): config.COMPILED_IMPORT_STATEMENTS = ( 'from __future__ import print_function', - 'from tensorflow.contrib.autograph import utils' - ' as autograph_utils', - 'tf = autograph_utils.fake_tf()', ) def test_decorator_recurses(self): diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py index e42f679cfe..d77c15915b 100644 --- a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py +++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py @@ -394,10 +394,16 @@ class AnfTransformer(transformer.Base): # just recur. def visit_List(self, node): - return self._visit_strict_expression(node) + node = self.generic_visit(node) + if not isinstance(node.ctx, gast.Store): + self._ensure_fields_trivial(node) + return node def visit_Tuple(self, node): - return self._visit_strict_expression(node) + node = self.generic_visit(node) + if not isinstance(node.ctx, gast.Store): + self._ensure_fields_trivial(node) + return node def transform(node, entity_info, gensym_source=None): diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py index 951974820c..1ffd4bbe55 100644 --- a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py +++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py @@ -165,6 +165,46 @@ class AnfTransformerTest(test.TestCase): self.assert_body_anfs_as_expected(expected_result, test_function) + def test_nested_multi_value_assign(self): + + def test_function(a, b, c): + x, y = a, a + b + (z, y), x = (c, y + b), x + a + return z, (y, x) + + def expected_result(a, b, c): + tmp_1001 = a + b + x, y = a, tmp_1001 + tmp_1002 = y + b + tmp_1003 = (c, tmp_1002) + tmp_1004 = x + a + (z, y), x = tmp_1003, tmp_1004 + tmp_1005 = y, x + tmp_1006 = z, tmp_1005 + return tmp_1006 + + self.assert_body_anfs_as_expected(expected_result, test_function) + + def test_deeply_nested_multi_value_assign(self): + + def test_function(a): + [([(b, c), [d, e]], (f, g)), [(h, i, j), k]] = a + return [([(b, c), [d, e]], (f, g)), [(h, i, j), k]] + + def expected_result(a): + [([(b, c), [d, e]], (f, g)), [(h, i, j), k]] = a + tmp_1001 = b, c + tmp_1002 = [d, e] + tmp_1003 = [tmp_1001, tmp_1002] + tmp_1004 = f, g + tmp_1005 = h, i, j + tmp_1006 = tmp_1003, tmp_1004 + tmp_1007 = [tmp_1005, k] + tmp_1008 = [tmp_1006, tmp_1007] + return tmp_1008 + + self.assert_body_anfs_as_expected(expected_result, test_function) + def test_local_definition_and_binary_compare(self): def test_function(): diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py index 2d8f922a45..e7baa244b2 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py @@ -29,6 +29,11 @@ from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno +# TODO(aqj): Do we need this? Do other builtins fail in similar ways +# See b/114389775 for a related bug in pyct +# These symbols are legal in Python, but don't appear in the namespace. +_special_symbols = {'range': range} + class LiveValueResolver(transformer.Base): """Annotates nodes with live values.""" @@ -66,6 +71,8 @@ class LiveValueResolver(transformer.Base): # If the symbol value is for example a primitive, then it will not # have a name. pass + elif node.id in _special_symbols: + anno.setanno(node, 'live_val', _special_symbols[node.id]) else: pass # TODO(mdan): Should we raise an error here? diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 870ce2442b..4c7a538b38 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -52,7 +52,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): center_bias=True, use_core_libs=False, output_leaf_index=False, - override_global_step_value=None): + override_global_step_value=None, + num_quantiles=100): """Initializes a GradientBoostedDecisionTreeClassifier estimator instance. Args: @@ -94,6 +95,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): trees were trained), this parameter can be used to set the global step to a large value, making it look like that number of training steps ran. If None, no override of global step will happen. + num_quantiles: Number of quantiles to build for numeric feature values. Raises: ValueError: If learner_config is not valid. @@ -134,7 +136,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'use_core_libs': use_core_libs, 'output_leaf_index': output_leaf_index, - 'override_global_step_value': override_global_step_value + 'override_global_step_value': override_global_step_value, + 'num_quantiles': num_quantiles, }, model_dir=model_dir, config=config, @@ -159,7 +162,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): center_bias=True, use_core_libs=False, output_leaf_index=False, - override_global_step_value=None): + override_global_step_value=None, + num_quantiles=100): """Initializes a GradientBoostedDecisionTreeRegressor estimator instance. Args: @@ -201,6 +205,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): trees were trained), this parameter can be used to set the global step to a large value, making it look like that number of training steps ran. If None, no override of global step will happen. + num_quantiles: Number of quantiles to build for numeric feature values. """ head = head_lib.regression_head( label_name=label_name, @@ -224,7 +229,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): 'center_bias': center_bias, 'use_core_libs': use_core_libs, 'output_leaf_index': False, - 'override_global_step_value': override_global_step_value + 'override_global_step_value': override_global_step_value, + 'num_quantiles': num_quantiles, }, model_dir=model_dir, config=config, @@ -251,7 +257,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): center_bias=True, use_core_libs=False, output_leaf_index=False, - override_global_step_value=None): + override_global_step_value=None, + num_quantiles=100): """Initializes a GradientBoostedDecisionTreeEstimator estimator instance. Args: @@ -289,6 +296,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): trees were trained), this parameter can be used to set the global step to a large value, making it look like that number of training steps ran. If None, no override of global step will happen. + num_quantiles: Number of quantiles to build for numeric feature values. """ super(GradientBoostedDecisionTreeEstimator, self).__init__( model_fn=model.model_builder, @@ -303,7 +311,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): 'center_bias': center_bias, 'use_core_libs': use_core_libs, 'output_leaf_index': False, - 'override_global_step_value': override_global_step_value + 'override_global_step_value': override_global_step_value, + 'num_quantiles': num_quantiles, }, model_dir=model_dir, config=config, @@ -329,7 +338,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): center_bias=False, use_core_libs=False, output_leaf_index=False, - override_global_step_value=None): + override_global_step_value=None, + num_quantiles=100): """Initializes a GradientBoostedDecisionTreeRanker instance. This is an estimator that can be trained off the pairwise data and can be @@ -377,6 +387,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): trees were trained), this parameter can be used to set the global step to a large value, making it look like that number of training steps ran. If None, no override of global step will happen. + num_quantiles: Number of quantiles to build for numeric feature values. + Raises: ValueError: If learner_config is not valid. """ @@ -395,7 +407,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator): 'use_core_libs': use_core_libs, 'output_leaf_index': output_leaf_index, 'ranking_model_pair_keys': ranking_model_pair_keys, - 'override_global_step_value': override_global_step_value + 'override_global_step_value': override_global_step_value, + 'num_quantiles': num_quantiles, }, model_dir=model_dir, config=config, @@ -444,7 +457,8 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): feature_engineering_fn=None, logits_modifier_function=None, center_bias=True, - output_leaf_index=False): + output_leaf_index=False, + num_quantiles=100): """Initializes a core version of GradientBoostedDecisionTreeEstimator. Args: @@ -474,6 +488,7 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): for example_prediction_result in result_dict: # access leaf index list by example_prediction_result["leaf_index"] # which contains one leaf index per tree + num_quantiles: Number of quantiles to build for numeric feature values. """ def _model_fn(features, labels, mode, config): @@ -493,7 +508,8 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'use_core_libs': True, 'output_leaf_index': output_leaf_index, - 'override_global_step_value': None + 'override_global_step_value': None, + 'num_quantiles': num_quantiles, }, output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC) @@ -517,7 +533,8 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): label_keys=None, logits_modifier_function=None, center_bias=False, - output_leaf_index=False): + output_leaf_index=False, + num_quantiles=100): """Initializes a GradientBoostedDecisionTreeRanker instance. This is an estimator that can be trained off the pairwise data and can be @@ -552,6 +569,7 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): for result_dict in result_iter: # access leaf index list by result_dict["leaf_index"] # which contains one leaf index per tree + num_quantiles: Number of quantiles to build for numeric feature values. Raises: ValueError: If learner_config is not valid. @@ -576,7 +594,8 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator): 'use_core_libs': True, 'output_leaf_index': output_leaf_index, 'ranking_model_pair_keys': ranking_model_pair_keys, - 'override_global_step_value': None + 'override_global_step_value': None, + 'num_quantiles': num_quantiles, }, output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index 04b46c3483..a6e422847d 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -81,6 +81,7 @@ def model_builder(features, logits_modifier_function = params["logits_modifier_function"] output_leaf_index = params["output_leaf_index"] override_global_step_value = params.get("override_global_step_value", None) + num_quantiles = params["num_quantiles"] if features is None: raise ValueError("At least one feature must be specified.") @@ -116,7 +117,8 @@ def model_builder(features, logits_dimension=head.logits_dimension, features=training_features, use_core_columns=use_core_libs, - output_leaf_index=output_leaf_index) + output_leaf_index=output_leaf_index, + num_quantiles=num_quantiles) with ops.name_scope("gbdt", "gbdt_optimizer"): predictions_dict = gbdt_model.predict(mode) logits = predictions_dict["predictions"] @@ -237,6 +239,7 @@ def ranking_model_builder(features, output_leaf_index = params["output_leaf_index"] ranking_model_pair_keys = params["ranking_model_pair_keys"] override_global_step_value = params.get("override_global_step_value", None) + num_quantiles = params["num_quantiles"] if features is None: raise ValueError("At least one feature must be specified.") @@ -299,7 +302,8 @@ def ranking_model_builder(features, logits_dimension=head.logits_dimension, features=main_features, use_core_columns=use_core_libs, - output_leaf_index=output_leaf_index) + output_leaf_index=output_leaf_index, + num_quantiles=num_quantiles) with ops.name_scope("gbdt", "gbdt_optimizer"): # Logits for inference. diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index b008c6e534..c7eb2493a8 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -304,7 +304,8 @@ class GradientBoostedDecisionTreeModel(object): feature_columns=None, use_core_columns=False, output_leaf_index=False, - output_leaf_index_modes=None): + output_leaf_index_modes=None, + num_quantiles=100): """Construct a new GradientBoostedDecisionTreeModel function. Args: @@ -327,6 +328,7 @@ class GradientBoostedDecisionTreeModel(object): output_leaf_index_modes: A list of modes from (TRAIN, EVAL, INFER) which dictates when leaf indices will be outputted. By default, leaf indices are only outputted in INFER mode. + num_quantiles: Number of quantiles to build for numeric feature values. Raises: ValueError: if inputs are not valid. @@ -399,6 +401,7 @@ class GradientBoostedDecisionTreeModel(object): self._learner_config = learner_config self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() + self._num_quantiles = num_quantiles self._max_tree_depth = variables.Variable( initial_value=self._learner_config.constraints.max_tree_depth) self._attempted_trees = variables.Variable( @@ -689,8 +692,8 @@ class GradientBoostedDecisionTreeModel(object): loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction) weak_learner_type = constant_op.constant( self._learner_config.weak_learner_type) - epsilon = 0.01 - num_quantiles = 100 + num_quantiles = self._num_quantiles + epsilon = 1.0 / num_quantiles strategy_tensor = constant_op.constant(strategy) with ops.device(self._get_replica_device_setter(worker_device)): # Create handlers for dense float columns diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 1ab150d74a..1056894f18 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -229,6 +229,10 @@ class TPUClusterResolver(ClusterResolver): def get_master(self): return self.master() + def get_job_name(self): + if self._shouldResolve(): + return self._job_name + def cluster_spec(self): """Returns a ClusterSpec object based on the latest TPU information. diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 34f594f741..b9320e5fef 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -279,7 +279,9 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:function", + "//tensorflow/python:functional_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:session", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 9d8e955245..67242fecfe 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -428,10 +428,10 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) @parameterized.named_parameters( - ("default", None, None), - ("sequential_calls", 1, None), - ("parallel_calls", 2, None), - ("parallel_batches", None, 10), + ("Default", None, None), + ("SequentialCalls", 1, None), + ("ParallelCalls", 2, None), + ("ParallelBatches", None, 10), ) def testMapAndBatch(self, num_parallel_calls, num_parallel_batches): """Test a dataset that maps a TF function across its input elements.""" @@ -505,8 +505,8 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) @parameterized.named_parameters( - ("even", False), - ("uneven", True), + ("Even", False), + ("Uneven", True), ) def testMapAndBatchPartialBatch(self, drop_remainder): iterator = ( @@ -663,7 +663,14 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): for _ in range(3): sess.run(get_next) - @parameterized.parameters(0, 5, 10, 90, 95, 99) + @parameterized.named_parameters( + ("1", 0), + ("2", 5), + ("3", 10), + ("4", 90), + ("5", 95), + ("6", 99), + ) def testMapAndBatchOutOfRangeError(self, threshold): def raising_py_fn(i): @@ -689,18 +696,18 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - @parameterized.parameters( - (False, dtypes.bool), - (-42, dtypes.int8), - (-42, dtypes.int16), - (-42, dtypes.int32), - (-42, dtypes.int64), - (42, dtypes.uint8), - (42, dtypes.uint16), - (42.0, dtypes.float16), - (42.0, dtypes.float32), - (42.0, dtypes.float64), - (b"hello", dtypes.string), + @parameterized.named_parameters( + ("1", False, dtypes.bool), + ("2", -42, dtypes.int8), + ("3", -42, dtypes.int16), + ("4", -42, dtypes.int32), + ("5", -42, dtypes.int64), + ("6", 42, dtypes.uint8), + ("7", 42, dtypes.uint16), + ("8", 42.0, dtypes.float16), + ("9", 42.0, dtypes.float32), + ("10", 42.0, dtypes.float64), + ("11", b"hello", dtypes.string), ) def testMapAndBatchTypes(self, element, dtype): def gen(): diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py index 091eb5ce37..61567bc8d7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py @@ -17,7 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import time + from tensorflow.contrib.data.python.ops import map_defun +from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -25,10 +28,10 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test - class MapDefunTest(test.TestCase): def testMapDefunSimple(self): @@ -146,6 +149,105 @@ class MapDefunTest(test.TestCase): r"indices = 10 is not in \[0, 5\)"): self.evaluate(map_defun_op) + def testMapDefunWithUnspecifiedOutputShape(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + res = x * 2 + 3 + return (res, res + 1, res + 2) + + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], + [dtypes.int32, dtypes.int32, dtypes.int32], + [None, (None,), (2,)]) + expected = elems * 2 + 3 + self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected)) + self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1)) + self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2)) + + def testMapDefunWithDifferentOutputShapeEachRun(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + 3 + + elems = array_ops.placeholder(dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0] + with session.Session() as sess: + self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3]) + self.assertAllEqual( + sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]]) + + def testMapDefunWithWrongOutputShape(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + 3 + + nums = [[1, 2], [3, 4], [5, 6]] + elems = constant_op.constant(nums, dtype=dtypes.int32, name="data") + r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0] + with self.assertRaises(errors.InvalidArgumentError): + self.evaluate(r) + + def testMapDefunWithInvalidInput(self): + + @function.Defun(dtypes.int32) + def simple_fn(x): + return x * 2 + + c = constant_op.constant(2) + with self.assertRaises(ValueError): + # Fails at graph construction time for inputs with known shapes. + r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0] + p = array_ops.placeholder(dtypes.int32) + r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0] + with session.Session() as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(r, feed_dict={p: 0}) + + +class MapDefunBenchmark(test.Benchmark): + + def _run(self, op, name=None, num_iters=3000): + with session.Session() as sess: + # Warm up the session + for _ in range(5): + sess.run(op) + start = time.time() + for _ in range(num_iters): + sess.run(op) + end = time.time() + mean_us = (end - start) * 1e6 / num_iters + self.report_benchmark( + name=name, + iters=num_iters, + wall_time=mean_us, + extras={"examples_per_sec": num_iters / (end - start)}) + + def benchmarkDefunVsMapFn(self): + """Benchmarks to compare the performance of MapDefun vs tf.map_fn.""" + + @function.Defun(dtypes.int32) + def defun(x): + return array_ops.identity(x) + + def map_fn(x): + return array_ops.identity(x) + + base = math_ops.range(100) + for input_size in [10, 100, 1000, 10000]: + num_iters = 100000 // input_size + map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()]) + map_fn_op = functional_ops.map_fn(map_fn, base) + + self._run( + map_defun_op, + "benchmarkMapDefun_size_%d" % input_size, + num_iters=num_iters) + self._run( + map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py index 586b4bee5f..6a7ef877f9 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py @@ -44,22 +44,22 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): for i, fun1 in enumerate(functions): for j, fun2 in enumerate(functions): tests.append(( - "test_{}_{}".format(i, j), + "Test{}{}".format(i, j), [fun1, fun2], )) for k, fun3 in enumerate(functions): tests.append(( - "test_{}_{}_{}".format(i, j, k), + "Test{}{}{}".format(i, j, k), [fun1, fun2, fun3], )) swap = lambda x, n: (n, x) tests.append(( - "swap1", + "Swap1", [lambda x: (x, 42), swap], )) tests.append(( - "swap2", + "Swap2", [lambda x: (x, 42), swap, swap], )) return tuple(tests) @@ -109,13 +109,13 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): for x, fun in enumerate(functions): for y, predicate in enumerate(filters): - tests.append(("mixed_{}_{}".format(x, y), fun, predicate)) + tests.append(("Mixed{}{}".format(x, y), fun, predicate)) # Multi output - tests.append(("multiOne", lambda x: (x, x), + tests.append(("Multi1", lambda x: (x, x), lambda x, y: constant_op.constant(True))) tests.append( - ("multiTwo", lambda x: (x, 2), + ("Multi2", lambda x: (x, 2), lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0))) return tuple(tests) @@ -172,17 +172,17 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): identity = lambda x: x for x, predicate_1 in enumerate(filters): for y, predicate_2 in enumerate(filters): - tests.append(("mixed_{}_{}".format(x, y), identity, + tests.append(("Mixed{}{}".format(x, y), identity, [predicate_1, predicate_2])) for z, predicate_3 in enumerate(filters): - tests.append(("mixed_{}_{}_{}".format(x, y, z), identity, + tests.append(("Mixed{}{}{}".format(x, y, z), identity, [predicate_1, predicate_2, predicate_3])) take_all_multiple = lambda x, y: constant_op.constant(True) # Multi output - tests.append(("multiOne", lambda x: (x, x), + tests.append(("Multi1", lambda x: (x, x), [take_all_multiple, take_all_multiple])) - tests.append(("multiTwo", lambda x: (x, 2), [ + tests.append(("Multi2", lambda x: (x, 2), [ take_all_multiple, lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0) ])) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD index 4881f63ab9..aa89674c6e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD @@ -210,6 +210,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py index ac3892fe81..243f6405a1 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base @@ -27,42 +28,38 @@ from tensorflow.python.platform import test class InterleaveDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): + dataset_serialization_test_base.DatasetSerializationTestBase, + parameterized.TestCase): - def _build_iterator_graph(self, input_values, cycle_length, block_length): + def _build_iterator_graph(self, input_values, cycle_length, block_length, + num_parallel_calls): repeat_count = 2 return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( repeat_count).interleave( lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), - cycle_length, block_length) + cycle_length, block_length, num_parallel_calls) - def testSerializationCore(self): + @parameterized.named_parameters( + ("1", 2, 3, None), + ("2", 2, 3, 1), + ("3", 2, 3, 2), + ("4", 1, 3, None), + ("5", 1, 3, 1), + ("6", 2, 1, None), + ("7", 2, 1, 1), + ("8", 2, 1, 2), + ) + def testSerializationCore(self, cycle_length, block_length, + num_parallel_calls): input_values = np.array([4, 5, 6], dtype=np.int64) num_outputs = np.sum(input_values) * 2 - # cycle_length > 1, block_length > 1 - cycle_length = 2 - block_length = 3 # pylint: disable=g-long-lambda self.run_core_tests( lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), + input_values, cycle_length, block_length, num_parallel_calls), lambda: self._build_iterator_graph( - input_values, cycle_length * 2, block_length * 1), + input_values, cycle_length * 2, block_length, num_parallel_calls), num_outputs) - # cycle_length = 1 - cycle_length = 1 - block_length = 3 - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - None, num_outputs) - # block_length = 1 - cycle_length = 2 - block_length = 1 - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - None, num_outputs) # pylint: enable=g-long-lambda def testSparseCore(self): @@ -82,5 +79,5 @@ class InterleaveDatasetSerializationTest( self.run_core_tests(_build_dataset, None, 20) -if __name__ == '__main__': +if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index 8b2f846494..6b3e8e9f6e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -32,18 +32,18 @@ from tensorflow.python.platform import test class SlideDatasetTest(test.TestCase, parameterized.TestCase): - @parameterized.parameters( - (20, 14, 7, 1), - (20, 17, 9, 1), - (20, 14, 14, 1), - (20, 10, 14, 1), - (20, 14, 19, 1), - (20, 4, 1, 2), - (20, 2, 1, 6), - (20, 4, 7, 2), - (20, 2, 7, 6), - (1, 10, 4, 1), - (0, 10, 4, 1), + @parameterized.named_parameters( + ("1", 20, 14, 7, 1), + ("2", 20, 17, 9, 1), + ("3", 20, 14, 14, 1), + ("4", 20, 10, 14, 1), + ("5", 20, 14, 19, 1), + ("6", 20, 4, 1, 2), + ("7", 20, 2, 1, 6), + ("8", 20, 4, 7, 2), + ("9", 20, 2, 7, 6), + ("10", 1, 10, 4, 1), + ("11", 0, 10, 4, 1), ) def testSlideDataset(self, count, window_size, window_shift, window_stride): """Tests a dataset that slides a window its input elements.""" @@ -96,18 +96,18 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - @parameterized.parameters( - (20, 14, 7, 1), - (20, 17, 9, 1), - (20, 14, 14, 1), - (20, 10, 14, 1), - (20, 14, 19, 1), - (20, 4, 1, 2), - (20, 2, 1, 6), - (20, 4, 7, 2), - (20, 2, 7, 6), - (1, 10, 4, 1), - (0, 10, 4, 1), + @parameterized.named_parameters( + ("1", 20, 14, 7, 1), + ("2", 20, 17, 9, 1), + ("3", 20, 14, 14, 1), + ("4", 20, 10, 14, 1), + ("5", 20, 14, 19, 1), + ("6", 20, 4, 1, 2), + ("7", 20, 2, 1, 6), + ("8", 20, 4, 7, 2), + ("9", 20, 2, 7, 6), + ("10", 1, 10, 4, 1), + ("11", 0, 10, 4, 1), ) def testSlideDatasetDeprecated(self, count, window_size, stride, window_stride): @@ -160,10 +160,10 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - @parameterized.parameters( - (14, 0, 3, 1), - (14, 3, 0, 1), - (14, 3, 3, 0), + @parameterized.named_parameters( + ("1", 14, 0, 3, 1), + ("2", 14, 3, 0, 1), + ("3", 14, 3, 3, 0), ) def testSlideDatasetInvalid(self, count, window_size, window_shift, window_stride): diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py index 0486e2bce2..4b08ec759d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py @@ -33,8 +33,17 @@ from tensorflow.python.platform import test class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase): - @parameterized.parameters((1, None), (2, None), (4, None), (8, None), - (16, None), (4, -1), (4, 0), (4, 1), (4, 4)) + @parameterized.named_parameters( + ("1", 1, None), + ("2", 2, None), + ("3", 4, None), + ("4", 8, None), + ("5", 16, None), + ("6", 4, -1), + ("7", 4, 0), + ("8", 4, 1), + ("9", 4, 4), + ) def testNumThreads(self, num_threads, max_intra_op_parallelism): def get_thread_id(_): diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py index 33d95d6754..ff4d9b3260 100644 --- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py @@ -64,15 +64,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): else: self.assertEqual(xs, ys) - @parameterized.parameters( - (None, np.int32([]), dtypes.bool), - (None, np.int32([]), dtypes.int32), - (None, np.int32([]), dtypes.float32), - (None, np.int32([]), dtypes.string), - (None, np.int32([2]), dtypes.int32), - (None, np.int32([2, 2]), dtypes.int32), - ((None, None, None), np.int32([]), dtypes.int32), - ((None, (None, None)), np.int32([]), dtypes.int32), + @parameterized.named_parameters( + ("1", None, np.int32([]), dtypes.bool), + ("2", None, np.int32([]), dtypes.int32), + ("3", None, np.int32([]), dtypes.float32), + ("4", None, np.int32([]), dtypes.string), + ("5", None, np.int32([2]), dtypes.int32), + ("6", None, np.int32([2, 2]), dtypes.int32), + ("7", (None, None, None), np.int32([]), dtypes.int32), + ("8", (None, (None, None)), np.int32([]), dtypes.int32), ) def testWindowDatasetFlatMap(self, structure, shape, dtype): """Tests windowing by chaining it with flat map. @@ -97,15 +97,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (None, np.int32([]), dtypes.bool), - (None, np.int32([]), dtypes.int32), - (None, np.int32([]), dtypes.float32), - (None, np.int32([]), dtypes.string), - (None, np.int32([2]), dtypes.int32), - (None, np.int32([2, 2]), dtypes.int32), - ((None, None, None), np.int32([]), dtypes.int32), - ((None, (None, None)), np.int32([]), dtypes.int32), + @parameterized.named_parameters( + ("1", None, np.int32([]), dtypes.bool), + ("2", None, np.int32([]), dtypes.int32), + ("3", None, np.int32([]), dtypes.float32), + ("4", None, np.int32([]), dtypes.string), + ("5", None, np.int32([2]), dtypes.int32), + ("6", None, np.int32([2, 2]), dtypes.int32), + ("7", (None, None, None), np.int32([]), dtypes.int32), + ("8", (None, (None, None)), np.int32([]), dtypes.int32), ) def testWindowDatasetBatchDense(self, structure, shape, dtype): """Tests batching of dense tensor windows. @@ -135,10 +135,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (np.int32([]),), - (np.int32([1]),), - (np.int32([1, 2, 3]),), + @parameterized.named_parameters( + ("1", np.int32([])), + ("2", np.int32([1])), + ("3", np.int32([1, 2, 3])), ) def testWindowDatasetBatchDenseDynamicShape(self, shape): """Tests batching of dynamically shaped dense tensor windows. @@ -203,15 +203,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): for substructure in structure ]) - @parameterized.parameters( - (None, np.int32([]), dtypes.bool), - (None, np.int32([]), dtypes.int32), - (None, np.int32([]), dtypes.float32), - (None, np.int32([]), dtypes.string), - (None, np.int32([2]), dtypes.int32), - (None, np.int32([2, 2]), dtypes.int32), - ((None, None, None), np.int32([]), dtypes.int32), - ((None, (None, None)), np.int32([]), dtypes.int32), + @parameterized.named_parameters( + ("1", None, np.int32([]), dtypes.bool), + ("2", None, np.int32([]), dtypes.int32), + ("3", None, np.int32([]), dtypes.float32), + ("4", None, np.int32([]), dtypes.string), + ("5", None, np.int32([2]), dtypes.int32), + ("6", None, np.int32([2, 2]), dtypes.int32), + ("7", (None, None, None), np.int32([]), dtypes.int32), + ("8", (None, (None, None)), np.int32([]), dtypes.int32), ) def testWindowDatasetBatchSparse(self, structure, shape, dtype): """Tests batching of sparse tensor windows. @@ -243,10 +243,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (np.int32([]),), - (np.int32([1]),), - (np.int32([1, 2, 3]),), + @parameterized.named_parameters( + ("1", np.int32([])), + ("2", np.int32([1])), + ("3", np.int32([1, 2, 3])), ) def testWindowDatasetBatchSparseDynamicShape(self, shape): """Tests batching of dynamically shaped sparse tensor windows. @@ -284,17 +284,18 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): for substructure in structure ])) - @parameterized.parameters( - (None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]), - (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - (None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]), - (None, np.int32([[1], [2], [3]]), dtypes.string, [-1]), - (None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), - (None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]), - ((None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - ((None, (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), - (None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])), + @parameterized.named_parameters( + ("1", None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]), + ("2", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + ("3", None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]), + ("4", None, np.int32([[1], [2], [3]]), dtypes.string, [-1]), + ("5", None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), + ("6", None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]), + ("7", (None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + ("8", (None, + (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + ("9", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]), + ("10", None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])), ) def testWindowDatasetPaddedBatchDense(self, structure, shapes, dtype, padded_shape): @@ -329,10 +330,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (np.int32([[1], [2], [3]]), [-1]), - (np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]), - (np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), + @parameterized.named_parameters( + ("1", np.int32([[1], [2], [3]]), [-1]), + ("2", np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]), + ("3", np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), ) def testWindowDatasetPaddedBatchDenseDynamicShape(self, shapes, padded_shape): """Tests padded batching of dynamically shaped dense tensor windows. @@ -361,9 +362,9 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (np.int32([[1]]), np.int32([0])), - (np.int32([[10], [20]]), np.int32([15])), + @parameterized.named_parameters( + ("1", np.int32([[1]]), np.int32([0])), + ("2", np.int32([[10], [20]]), np.int32([15])), ) def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape): """Tests invalid padded batching of dense tensor windows. @@ -420,17 +421,18 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): for substructure in structure ]) - @parameterized.parameters( - (None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]), - (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - (None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]), - (None, np.int64([[1], [2], [3]]), dtypes.string, [-1]), - (None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), - (None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]), - ((None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - ((None, (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), - (None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])), + @parameterized.named_parameters( + ("1", None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]), + ("2", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + ("3", None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]), + ("4", None, np.int64([[1], [2], [3]]), dtypes.string, [-1]), + ("5", None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]), + ("6", None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]), + ("7", (None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + ("8", (None, + (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + ("9", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]), + ("10", None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])), ) def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype, padded_shape): @@ -463,10 +465,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (np.int64([[1], [2], [3]]), [-1]), - (np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]), - (np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), + @parameterized.named_parameters( + ("1", np.int64([[1], [2], [3]]), [-1]), + ("2", np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]), + ("3", np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]), ) def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes, padded_shape): @@ -495,9 +497,9 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase): actual = sess.run(get_next) self._assertEqual(expected, actual) - @parameterized.parameters( - (np.int64([[1]]), [0]), - (np.int64([[10], [20]]), [15]), + @parameterized.named_parameters( + ("1", np.int64([[1]]), [0]), + ("2", np.int64([[10], [20]]), [15]), ) def testWindowDatasetPaddedBatchSparseInvalid(self, shapes, padded_shape): """Tests invalid padded batching of sparse tensor windows. diff --git a/tensorflow/contrib/data/python/ops/map_defun.py b/tensorflow/contrib/data/python/ops/map_defun.py index 54d5cd6da0..3d0d0993c9 100644 --- a/tensorflow/contrib/data/python/ops/map_defun.py +++ b/tensorflow/contrib/data/python/ops/map_defun.py @@ -53,6 +53,4 @@ def map_defun(fn, elems, output_dtypes, output_shapes): elems = [ops.convert_to_tensor(e) for e in elems] output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes] - if not all(s.is_fully_defined() for s in output_shapes): - raise ValueError("All fn output shapes must be fully defined.") return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn) diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index d39fd57294..3cee3e37a7 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -446,8 +446,7 @@ class TestWithDistributionStrategy(test.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) - with self.assertRaisesRegexp(ValueError, - 'expected input to have 2 dimensions'): + with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) # Wrong input shape diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 4fb70ec685..6ba83976fc 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -310,7 +310,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): def get_host_cpu_device(self, host_id): if self._tpu_cluster_resolver.get_master() in ('', 'local'): return '/replica:0/task:0/device:CPU:0' - return '/job:tpu_worker/task:%d/device:CPU:0' % (host_id,) + job_name = self._tpu_cluster_resolver.get_job_name() or 'tpu_worker' + return '/job:%s/task:%d/device:CPU:0' % (job_name, host_id) def configure(self, session_config=None, diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 77f62df99d..437b3d965d 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -446,6 +446,7 @@ py_library( "//tensorflow/python/estimator", "//tensorflow/python/estimator:head", "//tensorflow/python/estimator:optimizers", + "//tensorflow/python/ops/losses", "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py index 7c49cd00d1..98660bb731 100644 --- a/tensorflow/contrib/estimator/python/estimator/rnn.py +++ b/tensorflow/contrib/estimator/python/estimator/rnn.py @@ -37,6 +37,7 @@ from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.training import training_util @@ -405,6 +406,7 @@ class RNNClassifier(estimator.Estimator): weight_column=None, label_vocabulary=None, optimizer='Adagrad', + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, input_layer_partitioner=None, config=None): """Initializes a `RNNClassifier` instance. @@ -454,6 +456,8 @@ class RNNClassifier(estimator.Estimator): string. optimizer: An instance of `tf.Optimizer` or string specifying optimizer type. Defaults to Adagrad optimizer. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how + to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`. input_layer_partitioner: Optional. Partitioner for input layer. Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. config: `RunConfig` object to configure the runtime settings. @@ -467,11 +471,15 @@ class RNNClassifier(estimator.Estimator): if n_classes == 2: head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access weight_column=weight_column, - label_vocabulary=label_vocabulary) + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) else: head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access - n_classes, weight_column=weight_column, - label_vocabulary=label_vocabulary) + n_classes, + weight_column=weight_column, + label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction) + def _model_fn(features, labels, mode, config): return _rnn_model_fn( features=features, diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py index 959b40371a..1aebed348d 100644 --- a/tensorflow/contrib/estimator/python/estimator/rnn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py @@ -713,7 +713,7 @@ class RNNClassifierTrainingTest(test.TestCase): # Uses same checkpoint and examples as testBinaryClassEvaluationMetrics. # See that test for loss calculation. - mock_optimizer = self._mock_optimizer(expected_loss=1.119661) + mock_optimizer = self._mock_optimizer(expected_loss=0.559831) sequence_feature_columns = [ seq_fc.sequence_numeric_column('price', shape=(1,))] @@ -748,7 +748,7 @@ class RNNClassifierTrainingTest(test.TestCase): # Uses same checkpoint and examples as testMultiClassEvaluationMetrics. # See that test for loss calculation. - mock_optimizer = self._mock_optimizer(expected_loss=2.662932) + mock_optimizer = self._mock_optimizer(expected_loss=1.331465) sequence_feature_columns = [ seq_fc.sequence_numeric_column('price', shape=(1,))] @@ -812,20 +812,32 @@ class RNNClassifierEvaluationTest(test.TestCase): # probability = exp(logits) / (1 + exp(logits)) = [[0.353593], [0.504930]] # loss = -label * ln(p) - (1 - label) * ln(1 - p) # = [[0.436326], [0.683335]] + # sum_over_batch_size = (0.436326 + 0.683335)/2 expected_metrics = { - ops.GraphKeys.GLOBAL_STEP: global_step, - metric_keys.MetricKeys.LOSS: 1.119661, - metric_keys.MetricKeys.LOSS_MEAN: 0.559831, - metric_keys.MetricKeys.ACCURACY: 1.0, - metric_keys.MetricKeys.PREDICTION_MEAN: 0.429262, - metric_keys.MetricKeys.LABEL_MEAN: 0.5, - metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5, + ops.GraphKeys.GLOBAL_STEP: + global_step, + metric_keys.MetricKeys.LOSS: + 0.559831, + metric_keys.MetricKeys.LOSS_MEAN: + 0.559831, + metric_keys.MetricKeys.ACCURACY: + 1.0, + metric_keys.MetricKeys.PREDICTION_MEAN: + 0.429262, + metric_keys.MetricKeys.LABEL_MEAN: + 0.5, + metric_keys.MetricKeys.ACCURACY_BASELINE: + 0.5, # With default threshold of 0.5, the model is a perfect classifier. - metric_keys.MetricKeys.RECALL: 1.0, - metric_keys.MetricKeys.PRECISION: 1.0, + metric_keys.MetricKeys.RECALL: + 1.0, + metric_keys.MetricKeys.PRECISION: + 1.0, # Positive example is scored above negative, so AUC = 1.0. - metric_keys.MetricKeys.AUC: 1.0, - metric_keys.MetricKeys.AUC_PR: 1.0, + metric_keys.MetricKeys.AUC: + 1.0, + metric_keys.MetricKeys.AUC_PR: + 1.0, } self.assertAllClose( sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics)) @@ -871,9 +883,10 @@ class RNNClassifierEvaluationTest(test.TestCase): # [0.059494, 0.572639, 0.367866]] # loss = -1. * log(softmax[label]) # = [[2.105432], [0.557500]] + # sum_over_batch_size = (2.105432 + 0.557500)/2 expected_metrics = { ops.GraphKeys.GLOBAL_STEP: global_step, - metric_keys.MetricKeys.LOSS: 2.662932, + metric_keys.MetricKeys.LOSS: 1.331465, metric_keys.MetricKeys.LOSS_MEAN: 1.331466, metric_keys.MetricKeys.ACCURACY: 0.5, } diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 0ccb4583ab..716bb87e38 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -174,7 +174,7 @@ class FusedConv2DBiasActivationOp : public OpKernel { // Input bias is a 1-D tensor, with size matching output depth. const Tensor& bias = context->input(kBias); - OP_REQUIRES_OK(context, CheckShape(bias, "conv_input")); + OP_REQUIRES_OK(context, CheckShape(bias, "bias")); const Tensor& conv_input_scale_tensor = context->input(kConvInputScale); const Tensor& side_input_scale_tensor = context->input(kSideInputScale); diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 418b0cf392..61185f65a9 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -403,6 +403,7 @@ py_test( srcs = ["python/learn/estimators/dnn_test.py"], shard_count = 4, srcs_version = "PY2AND3", + tags = ["notap"], deps = [ ":learn", "//tensorflow/contrib/layers:layers_py", diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 0091587bf7..f320b53d94 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -36,10 +36,10 @@ cc_library( srcs = ["arena_planner.cc"], hdrs = ["arena_planner.h"], deps = [ - ":context", ":graph_info", ":memory_planner", ":simple_memory_arena", + "//tensorflow/contrib/lite/c:c_api_internal", ], ) @@ -54,6 +54,7 @@ cc_test( deps = [ ":arena_planner", "//tensorflow/contrib/lite/testing:util", + "//tensorflow/core:framework", "//tensorflow/core:lib", "@com_google_googletest//:gtest", ], @@ -63,27 +64,27 @@ cc_test( # TODO(aselle): Resolve problems preventing C99 usage. cc_library( name = "context", - srcs = ["context.c"], hdrs = ["context.h"], + deps = ["//tensorflow/contrib/lite/c:c_api_internal"], ) cc_library( name = "graph_info", hdrs = ["graph_info.h"], - deps = [":context"], + deps = ["//tensorflow/contrib/lite/c:c_api_internal"], ) cc_library( name = "memory_planner", hdrs = ["memory_planner.h"], - deps = [":context"], + deps = ["//tensorflow/contrib/lite/c:c_api_internal"], ) cc_library( name = "simple_memory_arena", srcs = ["simple_memory_arena.cc"], hdrs = ["simple_memory_arena.h"], - deps = [":context"], + deps = ["//tensorflow/contrib/lite/c:c_api_internal"], ) cc_library( @@ -91,7 +92,7 @@ cc_library( hdrs = [ "builtin_op_data.h", ], - deps = [":context"], + deps = ["//tensorflow/contrib/lite/c:c_api_internal"], ) cc_library( @@ -121,12 +122,12 @@ cc_library( name = "framework", srcs = [ "allocation.cc", - "error_reporter.cc", "graph_info.cc", "interpreter.cc", "model.cc", - "op_resolver.cc", + "mutable_op_resolver.cc", "optional_debug_tools.cc", + "stderr_reporter.cc", ] + select({ "//tensorflow:android": [ "nnapi_delegate.cc", @@ -149,9 +150,11 @@ cc_library( "graph_info.h", "interpreter.h", "model.h", + "mutable_op_resolver.h", "nnapi_delegate.h", "op_resolver.h", "optional_debug_tools.h", + "stderr_reporter.h", ], copts = tflite_copts(), linkopts = [ @@ -164,14 +167,14 @@ cc_library( }), deps = [ ":arena_planner", - ":builtin_op_data", - ":context", ":graph_info", ":memory_planner", ":schema_fbs_version", ":simple_memory_arena", ":string", ":util", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/core/api", "//tensorflow/contrib/lite/kernels:eigen_support", "//tensorflow/contrib/lite/kernels:gemm_support", "//tensorflow/contrib/lite/nnapi:nnapi_lib", @@ -210,6 +213,8 @@ cc_test( deps = [ ":framework", ":string_util", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/core/api", "//tensorflow/contrib/lite/kernels:builtin_ops", "//tensorflow/contrib/lite/kernels:kernel_util", "//tensorflow/contrib/lite/kernels/internal:tensor_utils", @@ -259,6 +264,8 @@ cc_test( ], deps = [ ":framework", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/core/api", "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], @@ -266,9 +273,9 @@ cc_test( # Test OpResolver. cc_test( - name = "op_resolver_test", + name = "mutable_op_resolver_test", size = "small", - srcs = ["op_resolver_test.cc"], + srcs = ["mutable_op_resolver_test.cc"], tags = ["no_oss"], deps = [ ":framework", @@ -277,24 +284,12 @@ cc_test( ], ) -# Test the C extension API code. -cc_test( - name = "context_test", - size = "small", - srcs = ["context_test.cc"], - deps = [ - ":framework", - "//tensorflow/contrib/lite/testing:util", - "@com_google_googletest//:gtest", - ], -) - cc_library( name = "util", srcs = ["util.cc"], hdrs = ["util.h"], deps = [ - ":context", + "//tensorflow/contrib/lite/c:c_api_internal", ], ) @@ -304,7 +299,6 @@ cc_test( srcs = ["util_test.cc"], tags = ["no_oss"], deps = [ - ":context", ":util", "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc index 8946261814..21cb1832a7 100644 --- a/tensorflow/contrib/lite/allocation.cc +++ b/tensorflow/contrib/lite/allocation.cc @@ -23,8 +23,8 @@ limitations under the License. #include <cstring> #include <utility> -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" namespace tflite { diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h index 121f3d2646..182bc0977f 100644 --- a/tensorflow/contrib/lite/allocation.h +++ b/tensorflow/contrib/lite/allocation.h @@ -20,8 +20,8 @@ limitations under the License. #include <cstdio> #include <cstdlib> #include <vector> -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" #include "tensorflow/contrib/lite/simple_memory_arena.h" #include "tensorflow/contrib/lite/string.h" diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h index 55003cf4e9..382577045b 100644 --- a/tensorflow/contrib/lite/arena_planner.h +++ b/tensorflow/contrib/lite/arena_planner.h @@ -18,7 +18,7 @@ limitations under the License. #include <memory> #include <vector> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/graph_info.h" #include "tensorflow/contrib/lite/memory_planner.h" #include "tensorflow/contrib/lite/simple_memory_arena.h" @@ -37,8 +37,8 @@ struct AllocationInfo; // each tensor needs to be allocated and deallocated, and preallocates all the // necessary memory (the PlanAllocations phase). It then assigns portions of // this memory buffer to each tensor (the ExecuteAllocations phase). Tensors may -// share some of the buffer if a tensor B is to be allocated after another tensor -// A has been deallocated. +// share some of the buffer if a tensor B is to be allocated after another +// tensor A has been deallocated. // // If dynamic tensors are used the planning steps can be repeated during model // execution. Since dynamic tensors don't have sizes until after the diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 0246e7fa30..9317e2bb6e 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -49,6 +49,9 @@ def tflite_linkopts_unstripped(): Returns: a select object with proper linkopts """ + + # In case you wonder why there's no --icf is because the gains were + # negligible, and created potential compatibility problems. return select({ "//tensorflow:android": [ "-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj. @@ -56,13 +59,7 @@ def tflite_linkopts_unstripped(): "-Wl,--gc-sections", # Eliminate unused code and data. "-Wl,--as-needed", # Don't link unused libs. ], - "//tensorflow:darwin": [], - "//tensorflow:ios": [], - "//tensorflow/contrib/lite:mips": [], - "//tensorflow/contrib/lite:mips64": [], - "//conditions:default": [ - "-Wl,--icf=all", # Identical code folding. - ], + "//conditions:default": [], }) def tflite_jni_linkopts_unstripped(): @@ -74,17 +71,15 @@ def tflite_jni_linkopts_unstripped(): Returns: a select object with proper linkopts """ + + # In case you wonder why there's no --icf is because the gains were + # negligible, and created potential compatibility problems. return select({ "//tensorflow:android": [ "-Wl,--gc-sections", # Eliminate unused code and data. "-Wl,--as-needed", # Don't link unused libs. ], - "//tensorflow:darwin": [], - "//tensorflow/contrib/lite:mips": [], - "//tensorflow/contrib/lite:mips64": [], - "//conditions:default": [ - "-Wl,--icf=all", # Identical code folding. - ], + "//conditions:default": [], }) def tflite_linkopts(): diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index aecd71910c..30901bd0fa 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -12,297 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// Compatibility shim for new location of interface definitions. + #ifndef TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ #define TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ -#include <stdint.h> - -#include "tensorflow/contrib/lite/context.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// TODO(aselle): Consider using "if this then that" for testing. - -// Useful placeholder to put in otherwise empty structs to avoid size warnings. -typedef struct { - char dummy_; -} EmptyStructPlaceholder; - -// Possible padding types (for convolutions) -typedef enum { - kTfLitePaddingUnknown = 0, - kTfLitePaddingSame, - kTfLitePaddingValid, -} TfLitePadding; - -typedef struct { - int width; - int height; -} TfLitePaddingValues; - -// Possible fused activation functions. -// TODO(aselle): rename to TfLiteActivation -typedef enum { - kTfLiteActNone = 0, - kTfLiteActRelu, - kTfLiteActRelu1, - kTfLiteActRelu6, - kTfLiteActTanh, - kTfLiteActSignBit, - kTfLiteActSigmoid, -} TfLiteFusedActivation; - -typedef struct { - TfLitePadding padding; - int stride_width; - int stride_height; - int dilation_width_factor; - int dilation_height_factor; - TfLiteFusedActivation activation; -} TfLiteConvParams; - -typedef struct { - TfLitePadding padding; - int stride_width; - int stride_height; - int filter_width; - int filter_height; - TfLiteFusedActivation activation; - struct { - TfLitePaddingValues padding; - } computed; -} TfLitePoolParams; - -typedef struct { - TfLitePadding padding; - int stride_width; - int stride_height; - int depth_multiplier; - TfLiteFusedActivation activation; -} TfLiteDepthwiseConvParams; - -typedef struct { - int rank; - TfLiteFusedActivation activation; -} TfLiteSVDFParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteRNNParams; - -typedef struct { - bool time_major; - TfLiteFusedActivation activation; -} TfLiteSequenceRNNParams; - -typedef enum { - kTfLiteFullyConnectedWeightsFormatDefault = 0, - kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1, -} TfLiteFullyConnectedWeightsFormat; - -typedef struct { - // Parameters for FullyConnected version 1 or above. - TfLiteFusedActivation activation; - - // Parameters for FullyConnected version 2 or above. - TfLiteFullyConnectedWeightsFormat weights_format; -} TfLiteFullyConnectedParams; - -typedef enum { - kTfLiteLshProjectionUnknown = 0, - kTfLiteLshProjectionSparse = 1, - kTfLiteLshProjectionDense = 2, -} TfLiteLSHProjectionType; - -typedef struct { - TfLiteLSHProjectionType type; -} TfLiteLSHProjectionParams; - -typedef struct { - float beta; -} TfLiteSoftmaxParams; - -typedef struct { - int axis; - TfLiteFusedActivation activation; -} TfLiteConcatenationParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteAddParams; - -typedef struct { - EmptyStructPlaceholder placeholder_; -} TfLiteSpaceToBatchNDParams; - -typedef struct { - EmptyStructPlaceholder placeholder_; -} TfLiteBatchToSpaceNDParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteMulParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteSubParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteDivParams; - -typedef struct { - TfLiteFusedActivation activation; -} TfLiteL2NormParams; - -typedef struct { - int radius; - float bias; - float alpha; - float beta; -} TfLiteLocalResponseNormParams; - -typedef enum { - kTfLiteLSTMFullKernel = 0, - kTfLiteLSTMBasicKernel -} TfLiteLSTMKernelType; - -typedef struct { - // Parameters for LSTM version 1. - TfLiteFusedActivation activation; - float cell_clip; - float proj_clip; - - // Parameters for LSTM version 2. - // kTfLiteLSTMBasicKernel is only supported in version 2 or above. - TfLiteLSTMKernelType kernel_type; -} TfLiteLSTMParams; - -typedef struct { - bool align_corners; -} TfLiteResizeBilinearParams; - -typedef struct { - EmptyStructPlaceholder placeholder_; -} TfLitePadParams; - -typedef struct { - EmptyStructPlaceholder placeholder_; -} TfLitePadV2Params; - -typedef struct { - // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. - // For now we will fix the maximum possible number of dimensions. - int shape[8]; - int num_dimensions; -} TfLiteReshapeParams; - -typedef struct { - int ngram_size; - int max_skip_size; - bool include_all_ngrams; -} TfLiteSkipGramParams; - -typedef struct { - int block_size; -} TfLiteSpaceToDepthParams; - -typedef struct { - TfLiteType in_data_type; - TfLiteType out_data_type; -} TfLiteCastParams; - -typedef enum { - kTfLiteCombinerTypeSum = 0, - kTfLiteCombinerTypeMean = 1, - kTfLiteCombinerTypeSqrtn = 2, -} TfLiteCombinerType; - -typedef struct { - TfLiteCombinerType combiner; -} TfLiteEmbeddingLookupSparseParams; - -typedef struct { - int axis; -} TfLiteGatherParams; - -typedef struct { - EmptyStructPlaceholder placeholder_; -} TfLiteTransposeParams; - -typedef struct { - bool keep_dims; -} TfLiteReducerParams; - -typedef struct { - int num_splits; -} TfLiteSplitParams; - -typedef struct { - // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. - // For now we will fix the maximum possible number of dimensions. - int squeeze_dims[8]; - int num_squeeze_dims; -} TfLiteSqueezeParams; - -typedef struct { - int begin_mask; - int end_mask; - int ellipsis_mask; - int new_axis_mask; - int shrink_axis_mask; -} TfLiteStridedSliceParams; - -typedef struct { - TfLiteType output_type; -} TfLiteArgMaxParams; - -typedef struct { - TfLiteType output_type; -} TfLiteArgMinParams; - -typedef struct { - TfLitePadding padding; - int stride_width; - int stride_height; -} TfLiteTransposeConvParams; - -typedef struct { - bool validate_indices; -} TfLiteSparseToDenseParams; - -typedef struct { - TfLiteType out_type; -} TfLiteShapeParams; - -typedef struct { - // Parameters supported by version 1: - float min; - float max; - int num_bits; - - // Parameters supported by version 2: - bool narrow_range; -} TfLiteFakeQuantParams; - -typedef struct { - int values_count; - int axis; -} TfLitePackParams; - -typedef struct { - int axis; -} TfLiteOneHotParams; - -typedef struct { - int num; - int axis; -} TfLiteUnpackParams; - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/contrib/lite/c/BUILD b/tensorflow/contrib/lite/c/BUILD new file mode 100644 index 0000000000..663eb63cad --- /dev/null +++ b/tensorflow/contrib/lite/c/BUILD @@ -0,0 +1,39 @@ +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "c_api_internal", + srcs = ["c_api_internal.c"], + hdrs = [ + "builtin_op_data.h", + "c_api_internal.h", + ], + visibility = [ + "//tensorflow/contrib/lite:__subpackages__", + ], +) + +# Test the C extension API code. +cc_test( + name = "c_api_internal_test", + size = "small", + srcs = ["c_api_internal_test.cc"], + deps = [ + ":c_api_internal", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "builtin_op_data_test", + size = "small", + srcs = ["builtin_op_data_test.cc"], + copts = ["-Wno-unused-variable"], + deps = [ + ":c_api_internal", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h new file mode 100644 index 0000000000..fa43e6a024 --- /dev/null +++ b/tensorflow/contrib/lite/c/builtin_op_data.h @@ -0,0 +1,298 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_ +#define TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_ + +#include <stdint.h> + +#include "tensorflow/contrib/lite/c/c_api_internal.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// TODO(aselle): Consider using "if this then that" for testing. + +// Possible padding types (for convolutions) +typedef enum { + kTfLitePaddingUnknown = 0, + kTfLitePaddingSame, + kTfLitePaddingValid, +} TfLitePadding; + +typedef struct { + int width; + int height; +} TfLitePaddingValues; + +// Possible fused activation functions. +// TODO(aselle): rename to TfLiteActivation +typedef enum { + kTfLiteActNone = 0, + kTfLiteActRelu, + kTfLiteActRelu1, + kTfLiteActRelu6, + kTfLiteActTanh, + kTfLiteActSignBit, + kTfLiteActSigmoid, +} TfLiteFusedActivation; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int dilation_width_factor; + int dilation_height_factor; + TfLiteFusedActivation activation; +} TfLiteConvParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int filter_width; + int filter_height; + TfLiteFusedActivation activation; + struct { + TfLitePaddingValues padding; + } computed; +} TfLitePoolParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int depth_multiplier; + TfLiteFusedActivation activation; +} TfLiteDepthwiseConvParams; + +typedef struct { + int rank; + TfLiteFusedActivation activation; +} TfLiteSVDFParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteRNNParams; + +typedef struct { + bool time_major; + TfLiteFusedActivation activation; +} TfLiteSequenceRNNParams; + +typedef enum { + kTfLiteFullyConnectedWeightsFormatDefault = 0, + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1, +} TfLiteFullyConnectedWeightsFormat; + +typedef struct { + // Parameters for FullyConnected version 1 or above. + TfLiteFusedActivation activation; + + // Parameters for FullyConnected version 2 or above. + TfLiteFullyConnectedWeightsFormat weights_format; +} TfLiteFullyConnectedParams; + +typedef enum { + kTfLiteLshProjectionUnknown = 0, + kTfLiteLshProjectionSparse = 1, + kTfLiteLshProjectionDense = 2, +} TfLiteLSHProjectionType; + +typedef struct { + TfLiteLSHProjectionType type; +} TfLiteLSHProjectionParams; + +typedef struct { + float beta; +} TfLiteSoftmaxParams; + +typedef struct { + int axis; + TfLiteFusedActivation activation; +} TfLiteConcatenationParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteAddParams; + +typedef struct { +} TfLiteSpaceToBatchNDParams; + +typedef struct { +} TfLiteBatchToSpaceNDParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteMulParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteSubParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteDivParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteL2NormParams; + +typedef struct { + int radius; + float bias; + float alpha; + float beta; +} TfLiteLocalResponseNormParams; + +typedef enum { + kTfLiteLSTMFullKernel = 0, + kTfLiteLSTMBasicKernel +} TfLiteLSTMKernelType; + +typedef struct { + // Parameters for LSTM version 1. + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; + + // Parameters for LSTM version 2. + // kTfLiteLSTMBasicKernel is only supported in version 2 or above. + TfLiteLSTMKernelType kernel_type; +} TfLiteLSTMParams; + +typedef struct { + bool align_corners; +} TfLiteResizeBilinearParams; + +typedef struct { +} TfLitePadParams; + +typedef struct { +} TfLitePadV2Params; + +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int shape[8]; + int num_dimensions; +} TfLiteReshapeParams; + +typedef struct { + int ngram_size; + int max_skip_size; + bool include_all_ngrams; +} TfLiteSkipGramParams; + +typedef struct { + int block_size; +} TfLiteSpaceToDepthParams; + +typedef struct { + TfLiteType in_data_type; + TfLiteType out_data_type; +} TfLiteCastParams; + +typedef enum { + kTfLiteCombinerTypeSum = 0, + kTfLiteCombinerTypeMean = 1, + kTfLiteCombinerTypeSqrtn = 2, +} TfLiteCombinerType; + +typedef struct { + TfLiteCombinerType combiner; +} TfLiteEmbeddingLookupSparseParams; + +typedef struct { + int axis; +} TfLiteGatherParams; + +typedef struct { +} TfLiteTransposeParams; + +typedef struct { + bool keep_dims; +} TfLiteReducerParams; + +typedef struct { + int num_splits; +} TfLiteSplitParams; + +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int squeeze_dims[8]; + int num_squeeze_dims; +} TfLiteSqueezeParams; + +typedef struct { + int begin_mask; + int end_mask; + int ellipsis_mask; + int new_axis_mask; + int shrink_axis_mask; +} TfLiteStridedSliceParams; + +typedef struct { + TfLiteType output_type; +} TfLiteArgMaxParams; + +typedef struct { + TfLiteType output_type; +} TfLiteArgMinParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; +} TfLiteTransposeConvParams; + +typedef struct { + bool validate_indices; +} TfLiteSparseToDenseParams; + +typedef struct { + TfLiteType out_type; +} TfLiteShapeParams; + +typedef struct { + // Parameters supported by version 1: + float min; + float max; + int num_bits; + + // Parameters supported by version 2: + bool narrow_range; +} TfLiteFakeQuantParams; + +typedef struct { + int values_count; + int axis; +} TfLitePackParams; + +typedef struct { + int axis; +} TfLiteOneHotParams; + +typedef struct { + int num; + int axis; +} TfLiteUnpackParams; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/contrib/lite/c/builtin_op_data_test.cc b/tensorflow/contrib/lite/c/builtin_op_data_test.cc new file mode 100644 index 0000000000..4d0ba75e68 --- /dev/null +++ b/tensorflow/contrib/lite/c/builtin_op_data_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include <gtest/gtest.h> + +namespace tflite { + +// Builtin op data is just a set of data definitions, so the only meaningful +// test we can run is whether we can create the structs we expect to find. +// Testing each struct's members might be possible, but it seems unnecessary +// until we've locked down the API. The build rule has copts set to ignore the +// unused variable warning, since this is just a compilation test. +TEST(IntArray, CanCompileStructs) { + TfLitePadding padding = kTfLitePaddingSame; + TfLitePaddingValues padding_values; + TfLiteFusedActivation fused_activation = kTfLiteActRelu; + TfLiteConvParams conv_params; + TfLitePoolParams pool_params; + TfLiteDepthwiseConvParams depthwise_conv_params; + TfLiteSVDFParams svdf_params; + TfLiteRNNParams rnn_params; + TfLiteSequenceRNNParams sequence_rnn_params; + TfLiteFullyConnectedWeightsFormat fully_connected_weights_format = + kTfLiteFullyConnectedWeightsFormatDefault; + TfLiteFullyConnectedParams fully_connected_params; + TfLiteLSHProjectionType projection_type = kTfLiteLshProjectionDense; + TfLiteLSHProjectionParams projection_params; + TfLiteSoftmaxParams softmax_params; + TfLiteConcatenationParams concatenation_params; + TfLiteAddParams add_params; + TfLiteSpaceToBatchNDParams space_to_batch_nd_params; + TfLiteBatchToSpaceNDParams batch_to_space_nd_params; + TfLiteMulParams mul_params; + TfLiteSubParams sub_params; + TfLiteDivParams div_params; + TfLiteL2NormParams l2_norm_params; + TfLiteLocalResponseNormParams local_response_norm_params; + TfLiteLSTMKernelType lstm_kernel_type = kTfLiteLSTMBasicKernel; + TfLiteLSTMParams lstm_params; + TfLiteResizeBilinearParams resize_bilinear_params; + TfLitePadParams pad_params; + TfLitePadV2Params pad_v2_params; + TfLiteReshapeParams reshape_params; + TfLiteSkipGramParams skip_gram_params; + TfLiteSpaceToDepthParams space_to_depth_params; + TfLiteCastParams cast_params; + TfLiteCombinerType combiner_type = kTfLiteCombinerTypeSqrtn; + TfLiteEmbeddingLookupSparseParams lookup_sparse_params; + TfLiteGatherParams gather_params; + TfLiteTransposeParams transpose_params; + TfLiteReducerParams reducer_params; + TfLiteSplitParams split_params; + TfLiteSqueezeParams squeeze_params; + TfLiteStridedSliceParams strided_slice_params; + TfLiteArgMaxParams arg_max_params; + TfLiteArgMinParams arg_min_params; + TfLiteTransposeConvParams transpose_conv_params; + TfLiteSparseToDenseParams sparse_to_dense_params; + TfLiteShapeParams shape_params; + TfLiteFakeQuantParams fake_quant_params; + TfLitePackParams pack_params; + TfLiteOneHotParams one_hot_params; +} + +} // namespace tflite + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/c/c_api_internal.c index 7f2aa316f4..1846bad4b7 100644 --- a/tensorflow/contrib/lite/context.c +++ b/tensorflow/contrib/lite/c/c_api_internal.c @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include <stdio.h> +#include <stdlib.h> #include <string.h> int TfLiteIntArrayGetSizeInBytes(int size) { @@ -76,7 +77,8 @@ void TfLiteTensorFree(TfLiteTensor* t) { void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, TfLiteQuantizationParams quantization, char* buffer, size_t size, TfLiteAllocationType allocation_type, - const void* allocation, bool is_variable, TfLiteTensor* tensor) { + const void* allocation, bool is_variable, + TfLiteTensor* tensor) { TfLiteTensorFree(tensor); tensor->type = type; tensor->name = name; diff --git a/tensorflow/contrib/lite/c/c_api_internal.h b/tensorflow/contrib/lite/c/c_api_internal.h new file mode 100644 index 0000000000..48df68a654 --- /dev/null +++ b/tensorflow/contrib/lite/c/c_api_internal.h @@ -0,0 +1,491 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file defines a C API for implementing operations in tflite. +// These operations can be defined using c++ but the interface between +// the interpreter and the operations are C. +// +// Summary of abstractions +// TF_LITE_ENSURE - Self-sufficient error checking +// TfLiteStatus - Status reporting +// TfLiteIntArray - stores tensor shapes (dims), +// TfLiteContext - allows an op to access the tensors +// TfLiteTensor - tensor (a multidimensional array) +// TfLiteNode - a single node or operation +// TfLiteRegistration - the implementation of a conceptual operation. +// +// Some abstractions in this file are created and managed by Interpreter. +#ifndef TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_ +#define TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_ + +#include <stdbool.h> +#include <stddef.h> +#include <stdint.h> + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; + +// The list of external context types known to TF Lite. This list exists solely +// to avoid conflicts and to ensure ops can share the external contexts they +// need. Access to the external contexts is controled by one of the +// corresponding support files. +typedef enum { + kTfLiteEigenContext = 0, // include eigen_support.h to use. + kTfLiteGemmLowpContext = 1, // include gemm_support.h to use. + kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support. + kTfLiteMaxExternalContexts = 3 +} TfLiteExternalContextType; + +// An external context is a collection of information unrelated to the TF Lite +// framework, but useful to a subset of the ops. TF Lite knows very little +// about about the actual contexts, but it keeps a list of them, and is able to +// refresh them if configurations like the number of recommended threads +// change. +typedef struct { + TfLiteExternalContextType type; + TfLiteStatus (*Refresh)(struct TfLiteContext* context); +} TfLiteExternalContext; + +// Forward declare so GetNode can use this is in Context. +typedef struct _TfLiteRegistration TfLiteRegistration; +typedef struct _TfLiteDelegate TfLiteDelegate; + +#define kOptionalTensor (-1) + +// Fixed size list of integers. Used for dimensions and inputs/outputs tensor +// indices +typedef struct { + int size; +// gcc 6.1+ have a bug where flexible members aren't properly handled +// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c +#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \ + __GNUC_MINOR__ >= 1 + int data[0]; +#else + int data[]; +#endif +} TfLiteIntArray; + +// Given the size (number of elements) in a TfLiteIntArray, calculate its size +// in bytes. +int TfLiteIntArrayGetSizeInBytes(int size); + +// Create a array of a given `size` (uninitialized entries). +// This returns a pointer, that you must free using TfLiteIntArrayFree(). +TfLiteIntArray* TfLiteIntArrayCreate(int size); + +// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise. +int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b); + +// Create a copy of an array passed as `src`. +// You are expected to free memory with TfLiteIntArrayFree +TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src); + +// Free memory of array `v`. +void TfLiteIntArrayFree(TfLiteIntArray* v); + +// Since we must not depend on any libraries, define a minimal subset of +// error macros while avoiding names that have pre-conceived meanings like +// assert and check. + +// Check whether value is true, and if not return kTfLiteError from +// the current function (and report the error string msg). +#define TF_LITE_ENSURE_MSG(context, value, msg) \ + do { \ + if (!(value)) { \ + (context)->ReportError((context), __FILE__ " " msg); \ + return kTfLiteError; \ + } \ + } while (0) + +// Check whether the value `a` is true, and if not return kTfLiteError from +// the current function, while also reporting the location of the error. +#define TF_LITE_ENSURE(context, a) \ + do { \ + if (!(a)) { \ + (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \ + __LINE__, #a); \ + return kTfLiteError; \ + } \ + } while (0) + +#define TF_LITE_ENSURE_STATUS(a) \ + do { \ + if ((a) != kTfLiteOk) { \ + return kTfLiteError; \ + } \ + } while (0) + +// Check whether the value `a == b` is true, and if not return kTfLiteError from +// the current function, while also reporting the location of the error. +// `a` and `b` may be evaluated more than once, so no side effects or +// extremely expensive computations should be done. +#define TF_LITE_ENSURE_EQ(context, a, b) \ + do { \ + if ((a) != (b)) { \ + (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \ + __LINE__, #a, #b, (a), (b)); \ + return kTfLiteError; \ + } \ + } while (0) + +#define TF_LITE_ENSURE_OK(context, status) \ + do { \ + if ((status) != kTfLiteOk) { \ + return status; \ + } \ + } while (0) + +// Single-precision complex data type compatible with the C99 definition. +typedef struct { + float re, im; // real and imaginary parts, respectively. +} TfLiteComplex64; + +// Types supported by tensor +typedef enum { + kTfLiteNoType = 0, + kTfLiteFloat32 = 1, + kTfLiteInt32 = 2, + kTfLiteUInt8 = 3, + kTfLiteInt64 = 4, + kTfLiteString = 5, + kTfLiteBool = 6, + kTfLiteInt16 = 7, + kTfLiteComplex64 = 8, +} TfLiteType; + +// Parameters for asymmetric quantization. Quantized values can be converted +// back to float using: +// real_value = scale * (quantized_value - zero_point); +typedef struct { + float scale; + int32_t zero_point; +} TfLiteQuantizationParams; + +// A union of pointers that points to memory for a given tensor. +typedef union { + int* i32; + int64_t* i64; + float* f; + char* raw; + const char* raw_const; + uint8_t* uint8; + bool* b; + int16_t* i16; + TfLiteComplex64* c64; +} TfLitePtrUnion; + +// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped +// data (or data externally allocated). kTfLiteArenaRw is arena allocated +// data. kTfLiteDynamic is for tensors that are allocated during evaluation. +typedef enum { + kTfLiteMemNone = 0, + kTfLiteMmapRo, + kTfLiteArenaRw, + kTfLiteArenaRwPersistent, + kTfLiteDynamic, +} TfLiteAllocationType; + +// The delegates should use zero or positive integers to represent handles. +// -1 is reserved from unallocated status. +typedef int TfLiteBufferHandle; +const TfLiteBufferHandle kTfLiteNullBufferHandle = -1; + +// An tensor in the interpreter system which is a wrapper around a buffer of +// data including a dimensionality (or NULL if not currently defined). +typedef struct { + // The data type specification for data stored in `data`. This affects + // what member of `data` union should be used. + TfLiteType type; + // A union of data pointers. The appropriate type should be used for a typed + // tensor based on `type`. + TfLitePtrUnion data; + // A pointer to a structure representing the dimensionality interpretation + // that the buffer should have. NOTE: the product of elements of `dims` + // and the element datatype size should be equal to `bytes` below. + TfLiteIntArray* dims; + // Quantization information. + TfLiteQuantizationParams params; + // How memory is mapped + // kTfLiteMmapRo: Memory mapped read only. + // i.e. weights + // kTfLiteArenaRw: Arena allocated read write memory + // (i.e. temporaries, outputs). + TfLiteAllocationType allocation_type; + // The number of bytes required to store the data of this Tensor. I.e. + // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if + // type is kTfLiteFloat32 and dims = {3, 2} then + // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24. + size_t bytes; + + // An opaque pointer to a tflite::MMapAllocation + const void* allocation; + + // Null-terminated name of this tensor. + const char* name; + + // The delegate which knows how to handle `buffer_handle`. + // WARNING: This is an experimental interface that is subject to change. + TfLiteDelegate* delegate; + + // An integer buffer handle that can be handled by `delegate`. + // The value is valid only when delegate is not null. + // WARNING: This is an experimental interface that is subject to change. + TfLiteBufferHandle buffer_handle; + + // If the delegate uses its own buffer (e.g. GPU memory), the delegate is + // responsible to set data_is_stale to true. + // `delegate->CopyFromBufferHandle` can be called to copy the data from + // delegate buffer. + // WARNING: This is an // experimental interface that is subject to change. + bool data_is_stale; + + // True if the tensor is a variable. + bool is_variable; +} TfLiteTensor; + +// Free data memory of tensor `t`; +void TfLiteTensorDataFree(TfLiteTensor* t); + +// Free memory of tensor `t`; +void TfLiteTensorFree(TfLiteTensor* t); + +// Set all of a tensor's fields (and free any previously allocated data). +void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, + TfLiteQuantizationParams quantization, char* buffer, + size_t size, TfLiteAllocationType allocation_type, + const void* allocation, bool is_variable, + TfLiteTensor* tensor); + +// Resize the allocated data of a (dynamic) tensor. Tensors with allocation +// types other than kTfLiteDynamic will be ignored. +void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor); + +// A structure representing an instance of a node. +// This structure only exhibits the inputs, outputs and user defined data, not +// other features like the type. +typedef struct { + // Inputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray* inputs; + + // Outputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray* outputs; + + // Temporary tensors uses during the computations. This usually contains no + // tensors, but ops are allowed to change that if they need scratch space of + // any sort. + TfLiteIntArray* temporaries; + + // Opaque data provided by the node implementer through `Registration.init`. + void* user_data; + + // Opaque data provided to the node if the node is a builtin. This is usually + // a structure defined in builtin_op_data.h + void* builtin_data; + + // Custom initial data. This is the opaque data provided in the flatbuffer. + // WARNING: This is an experimental interface that is subject to change. + const void* custom_initial_data; + int custom_initial_data_size; + + // The pointer to the delegate. This is non-null only when the node is + // created by calling `interpreter.ModifyGraphWithDelegate`. + // WARNING: This is an experimental interface that is subject to change. + TfLiteDelegate* delegate; +} TfLiteNode; + +typedef struct TfLiteContext { + // Number of tensors in the context. + size_t tensors_size; + + // The execution plan contains a list of the node indices in execution + // order. execution_plan->size is the current number of nodes. And, + // execution_plan->data[0] is the first node that needs to be run. + // TfLiteDelegates can traverse the current execution plan by iterating + // through each member of this array and using GetNodeAndRegistration() to + // access details about a node. i.e. + // TfLiteIntArray* execution_plan; + // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan)); + // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) { + // int node_index = execution_plan->data[exec_index]; + // TfLiteNode* node; + // TfLiteRegistration* reg; + // context->GetNodeAndRegistration(context, node_index, &node, ®); + // } + // WARNING: This is an experimental interface that is subject to change. + TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context, + TfLiteIntArray** execution_plan); + + // An array of tensors in the interpreter context (of length `tensors_size`) + TfLiteTensor* tensors; + + // opaque full context ptr (an opaque c++ data structure) + void* impl_; + + // Request memory pointer be resized. Updates dimensions on the tensor. + // NOTE: ResizeTensor takes ownership of newSize. + TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor, + TfLiteIntArray* new_size); + // Request that a error be reported with format string msg. + void (*ReportError)(struct TfLiteContext*, const char* msg, ...); + + // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If + // non-null, the value pointed to by `first_new_tensor_index` will be set to + // the index of the first new tensor. + TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add, + int* first_new_tensor_index); + + // Get a Tensor node by node_index. + // WARNING: This is an experimental interface that is subject to change. + TfLiteStatus (*GetNodeAndRegistration)(struct TfLiteContext*, int node_index, + TfLiteNode** node, + TfLiteRegistration** registration); + + // Replace ops with one or more stub delegate operations. This function + // does not take ownership of `nodes_to_replace`. + TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)( + struct TfLiteContext*, TfLiteRegistration registration, + const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate); + + // Number of threads that are recommended to subsystems like gemmlowp and + // eigen. + int recommended_num_threads; + + // Access external contexts by type. + // WARNING: This is an experimental interface that is subject to change. + TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*, + TfLiteExternalContextType); + // Set the value of a external context. Does not take ownership of the + // pointer. + // WARNING: This is an experimental interface that is subject to change. + void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType, + TfLiteExternalContext*); +} TfLiteContext; + +typedef struct _TfLiteRegistration { + // Initializes the op from serialized data. + // If a built-in op: + // `buffer` is the op's params data (TfLiteLSTMParams*). + // `length` is zero. + // If custom op: + // `buffer` is the op's `custom_options`. + // `length` is the size of the buffer. + // + // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer + // or an instance of a struct). + // + // The returned pointer will be stored with the node in the `user_data` field, + // accessible within prepare and invoke functions below. + // NOTE: if the data is already in the desired format, simply implement this + // function to return `nullptr` and implement the free function to be a no-op. + void* (*init)(TfLiteContext* context, const char* buffer, size_t length); + + // The pointer `buffer` is the data previously returned by an init invocation. + void (*free)(TfLiteContext* context, void* buffer); + + // prepare is called when the inputs this node depends on have been resized. + // context->ResizeTensor() can be called to request output tensors to be + // resized. + // + // Returns kTfLiteOk on success. + TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node); + + // Execute the node (should read node->inputs and output to node->outputs). + // Returns kTfLiteOk on success. + TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); + + // profiling_string is called during summarization of profiling information + // in order to group executions together. Providing a value here will cause a + // given op to appear multiple times is the profiling report. This is + // particularly useful for custom ops that can perform significantly + // different calculations depending on their `user-data`. + const char* (*profiling_string)(const TfLiteContext* context, + const TfLiteNode* node); + + // Builtin codes. If this kernel refers to a builtin this is the code + // of the builtin. This is so we can do marshaling to other frameworks like + // NN API. + // Note: It is the responsibility of the registration binder to set this + // properly. + int32_t builtin_code; + + // Custom op name. If the op is a builtin, this will be null. + // Note: It is the responsibility of the registration binder to set this + // properly. + // WARNING: This is an experimental interface that is subject to change. + const char* custom_name; + + // The version of the op. + // Note: It is the responsibility of the registration binder to set this + // properly. + int version; +} TfLiteRegistration; + +// WARNING: This is an experimental interface that is subject to change. +typedef struct _TfLiteDelegate { + // Data that delegate needs to identify itself. This data is owned by the + // delegate. The delegate is owned in the user code, so the delegate is + // responsible for doing this when it is destroyed. + void* data_; + + // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the + // delegate a view of the current graph through TfLiteContext*. It typically + // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels() + // to ask the TensorFlow lite runtime to create macro-nodes to represent + // delegated subgraphs of the original graph. + TfLiteStatus (*Prepare)(TfLiteContext* context, TfLiteDelegate* delegate); + + // Copy the data from delegate buffer handle to raw memory. + // This can be null if the delegate doesn't use its own buffer. + TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + void* data, size_t size); + + // Copy the data from raw memory to delegate buffer handle. + // This can be null if the delegate doesn't use its own buffer. + TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + void* data, size_t size); + + // Free the Delegate Buffer Handle. Note: This only frees the handle, but + // this doesn't release the underlying resource (e.g. textures). The + // resources are either owned by application layer or the delegate. + // This can be null if the delegate doesn't use its own buffer. + void (*FreeBufferHandle)(TfLiteContext* context, TfLiteDelegate* delegate, + TfLiteBufferHandle* handle); +} TfLiteDelegate; + +// WARNING: This is an experimental interface that is subject to change. +// +// Currently, TfLiteDelegateParams has to be allocated in a way that it's +// trivially destructable. It will be stored as `builtin_data` field in +// `TfLiteNode` of the delegate node. +// +// See also the `CreateDelegateParams` function in `interpreter.cc` details. +typedef struct { + TfLiteDelegate* delegate; + TfLiteIntArray* nodes_to_replace; + TfLiteIntArray* input_tensors; + TfLiteIntArray* output_tensors; +} TfLiteDelegateParams; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_ diff --git a/tensorflow/contrib/lite/context_test.cc b/tensorflow/contrib/lite/c/c_api_internal_test.cc index 20d6f69a25..af398f3207 100644 --- a/tensorflow/contrib/lite/context_test.cc +++ b/tensorflow/contrib/lite/c/c_api_internal_test.cc @@ -13,16 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/testing/util.h" namespace tflite { // NOTE: this tests only the TfLiteIntArray part of context. -// most of context.h is provided in the context of using it with interpreter.h -// and interpreter.cc, so interpreter_test.cc tests context structures more -// thoroughly. +// most of c_api_internal.h is provided in the context of using it with +// interpreter.h and interpreter.cc, so interpreter_test.cc tests context +// structures more thoroughly. TEST(IntArray, TestIntArrayCreate) { TfLiteIntArray* a = TfLiteIntArrayCreate(0); @@ -69,7 +68,6 @@ TEST(IntArray, TestIntArrayEqual) { } // namespace tflite int main(int argc, char** argv) { - ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index b23183b743..b86c2819b8 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -12,484 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines a C API for implementing operations in tflite. -// These operations can be defined using c++ but the interface between -// the interpreter and the operations are C. -// -// Summary of abstractions -// TF_LITE_ENSURE - Self-sufficient error checking -// TfLiteStatus - Status reporting -// TfLiteIntArray - stores tensor shapes (dims), -// TfLiteContext - allows an op to access the tensors -// TfLiteTensor - tensor (a multidimensional array) -// TfLiteNode - a single node or operation -// TfLiteRegistration - the implementation of a conceptual operation. -// -// Some abstractions in this file are created and managed by Interpreter. +// Compatibility shim for moved header location. #ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ #define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ -#include <stdbool.h> -#include <stdint.h> -#include <stdlib.h> +#include "tensorflow/contrib/lite/c/c_api_internal.h" -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; - -// Forward declarations for use with dependent types. -struct TfLiteContext; -struct TfLiteNode; -struct _TfLiteRegistration; -struct _TfLiteDelegate; - -// The list of external context types known to TF Lite. This list exists solely -// to avoid conflicts and to ensure ops can share the external contexts they -// need. Access to the external contexts is controled by one of the -// corresponding support files. -typedef enum { - kTfLiteEigenContext = 0, // include eigen_support.h to use. - kTfLiteGemmLowpContext = 1, // include gemm_support.h to use. - kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support. - kTfLiteMaxExternalContexts = 3 -} TfLiteExternalContextType; - -// An external context is a collection of information unrelated to the TF Lite -// framework, but useful to a subset of the ops. TF Lite knows very little -// about about the actual contexts, but it keeps a list of them, and is able to -// refresh them if configurations like the number of recommended threads -// change. -typedef struct { - TfLiteExternalContextType type; - TfLiteStatus (*Refresh)(struct TfLiteContext* context); -} TfLiteExternalContext; - -#define kOptionalTensor (-1) - -// Fixed size list of integers. Used for dimensions and inputs/outputs tensor -// indices -typedef struct { - int size; -// gcc 6.1+ have a bug where flexible members aren't properly handled -// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c -#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \ - __GNUC_MINOR__ >= 1 - int data[0]; -#else - int data[]; -#endif -} TfLiteIntArray; - -// Given the size (number of elements) in a TfLiteIntArray, calculate its size -// in bytes. -int TfLiteIntArrayGetSizeInBytes(int size); - -// Create a array of a given `size` (uninitialized entries). -// This returns a pointer, that you must free using TfLiteIntArrayFree(). -TfLiteIntArray* TfLiteIntArrayCreate(int size); - -// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise. -int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b); - -// Create a copy of an array passed as `src`. -// You are expected to free memory with TfLiteIntArrayFree -TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src); - -// Free memory of array `v`. -void TfLiteIntArrayFree(TfLiteIntArray* v); - -// Since we must not depend on any libraries, define a minimal subset of -// error macros while avoiding names that have pre-conceived meanings like -// assert and check. - -// Check whether value is true, and if not return kTfLiteError from -// the current function (and report the error string msg). -#define TF_LITE_ENSURE_MSG(context, value, msg) \ - do { \ - if (!(value)) { \ - (context)->ReportError((context), __FILE__ " " msg); \ - return kTfLiteError; \ - } \ - } while (0) - -// Check whether the value `a` is true, and if not return kTfLiteError from -// the current function, while also reporting the location of the error. -#define TF_LITE_ENSURE(context, a) \ - do { \ - if (!(a)) { \ - (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \ - __LINE__, #a); \ - return kTfLiteError; \ - } \ - } while (0) - -#define TF_LITE_ENSURE_STATUS(a) \ - do { \ - if ((a) != kTfLiteOk) { \ - return kTfLiteError; \ - } \ - } while (0) - -// Check whether the value `a == b` is true, and if not return kTfLiteError from -// the current function, while also reporting the location of the error. -// `a` and `b` may be evaluated more than once, so no side effects or -// extremely expensive computations should be done. -#define TF_LITE_ENSURE_EQ(context, a, b) \ - do { \ - if ((a) != (b)) { \ - (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \ - __LINE__, #a, #b, (a), (b)); \ - return kTfLiteError; \ - } \ - } while (0) - -#define TF_LITE_ENSURE_OK(context, status) \ - do { \ - if ((status) != kTfLiteOk) { \ - return status; \ - } \ - } while (0) - -// Single-precision complex data type compatible with the C99 definition. -typedef struct { - float re, im; // real and imaginary parts, respectively. -} TfLiteComplex64; - -// Types supported by tensor -typedef enum { - kTfLiteNoType = 0, - kTfLiteFloat32 = 1, - kTfLiteInt32 = 2, - kTfLiteUInt8 = 3, - kTfLiteInt64 = 4, - kTfLiteString = 5, - kTfLiteBool = 6, - kTfLiteInt16 = 7, - kTfLiteComplex64 = 8, -} TfLiteType; - -// Parameters for asymmetric quantization. Quantized values can be converted -// back to float using: -// real_value = scale * (quantized_value - zero_point); -typedef struct { - float scale; - int32_t zero_point; -} TfLiteQuantizationParams; - -// A union of pointers that points to memory for a given tensor. -typedef union { - int* i32; - int64_t* i64; - float* f; - char* raw; - const char* raw_const; - uint8_t* uint8; - bool* b; - int16_t* i16; - TfLiteComplex64* c64; -} TfLitePtrUnion; - -// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped -// data (or data externally allocated). kTfLiteArenaRw is arena allocated -// data. kTfLiteDynamic is for tensors that are allocated during evaluation. -typedef enum { - kTfLiteMemNone = 0, - kTfLiteMmapRo, - kTfLiteArenaRw, - kTfLiteArenaRwPersistent, - kTfLiteDynamic, -} TfLiteAllocationType; - -// The delegates should use zero or positive integers to represent handles. -// -1 is reserved from unallocated status. -typedef int TfLiteBufferHandle; -const TfLiteBufferHandle kTfLiteNullBufferHandle = -1; - -// An tensor in the interpreter system which is a wrapper around a buffer of -// data including a dimensionality (or NULL if not currently defined). -typedef struct { - // The data type specification for data stored in `data`. This affects - // what member of `data` union should be used. - TfLiteType type; - // A union of data pointers. The appropriate type should be used for a typed - // tensor based on `type`. - TfLitePtrUnion data; - // A pointer to a structure representing the dimensionality interpretation - // that the buffer should have. NOTE: the product of elements of `dims` - // and the element datatype size should be equal to `bytes` below. - TfLiteIntArray* dims; - // Quantization information. - TfLiteQuantizationParams params; - // How memory is mapped - // kTfLiteMmapRo: Memory mapped read only. - // i.e. weights - // kTfLiteArenaRw: Arena allocated read write memory - // (i.e. temporaries, outputs). - TfLiteAllocationType allocation_type; - // The number of bytes required to store the data of this Tensor. I.e. - // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if - // type is kTfLiteFloat32 and dims = {3, 2} then - // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24. - size_t bytes; - - // An opaque pointer to a tflite::MMapAllocation - const void* allocation; - - // Null-terminated name of this tensor. - const char* name; - - // The delegate which knows how to handle `buffer_handle`. - // WARNING: This is an experimental interface that is subject to change. - struct _TfLiteDelegate* delegate; - - // An integer buffer handle that can be handled by `delegate`. - // The value is valid only when delegate is not null. - // WARNING: This is an experimental interface that is subject to change. - TfLiteBufferHandle buffer_handle; - - // If the delegate uses its own buffer (e.g. GPU memory), the delegate is - // responsible to set data_is_stale to true. - // `delegate->CopyFromBufferHandle` can be called to copy the data from - // delegate buffer. - // WARNING: This is an // experimental interface that is subject to change. - bool data_is_stale; - - // True if the tensor is a variable. - bool is_variable; -} TfLiteTensor; - -// Free data memory of tensor `t`; -void TfLiteTensorDataFree(TfLiteTensor* t); - -// Free memory of tensor `t`; -void TfLiteTensorFree(TfLiteTensor* t); - -// Set all of a tensor's fields (and free any previously allocated data). -void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, - TfLiteQuantizationParams quantization, char* buffer, - size_t size, TfLiteAllocationType allocation_type, - const void* allocation, bool is_variable, - TfLiteTensor* tensor); - -// Resize the allocated data of a (dynamic) tensor. Tensors with allocation -// types other than kTfLiteDynamic will be ignored. -void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor); - -// A structure representing an instance of a node. -// This structure only exhibits the inputs, outputs and user defined data, not -// other features like the type. -typedef struct TfLiteNode { - // Inputs to this node expressed as indices into the simulator's tensors. - TfLiteIntArray* inputs; - - // Outputs to this node expressed as indices into the simulator's tensors. - TfLiteIntArray* outputs; - - // Temporary tensors uses during the computations. This usually contains no - // tensors, but ops are allowed to change that if they need scratch space of - // any sort. - TfLiteIntArray* temporaries; - - // Opaque data provided by the node implementer through `Registration.init`. - void* user_data; - - // Opaque data provided to the node if the node is a builtin. This is usually - // a structure defined in builtin_op_data.h - void* builtin_data; - - // Custom initial data. This is the opaque data provided in the flatbuffer. - // WARNING: This is an experimental interface that is subject to change. - const void* custom_initial_data; - int custom_initial_data_size; - - // The pointer to the delegate. This is non-null only when the node is - // created by calling `interpreter.ModifyGraphWithDelegate`. - // WARNING: This is an experimental interface that is subject to change. - struct _TfLiteDelegate* delegate; -} TfLiteNode; - -typedef struct TfLiteContext { - // Number of tensors in the context. - size_t tensors_size; - - // The execution plan contains a list of the node indices in execution - // order. execution_plan->size is the current number of nodes. And, - // execution_plan->data[0] is the first node that needs to be run. - // TfLiteDelegates can traverse the current execution plan by iterating - // through each member of this array and using GetNodeAndRegistration() to - // access details about a node. i.e. - // TfLiteIntArray* execution_plan; - // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan)); - // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) { - // int node_index = execution_plan->data[exec_index]; - // TfLiteNode* node; - // TfLiteRegistration* reg; - // context->GetNodeAndRegistration(context, node_index, &node, ®); - // } - // WARNING: This is an experimental interface that is subject to change. - TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context, - TfLiteIntArray** execution_plan); - - // An array of tensors in the interpreter context (of length `tensors_size`) - TfLiteTensor* tensors; - - // opaque full context ptr (an opaque c++ data structure) - void* impl_; - - // Request memory pointer be resized. Updates dimensions on the tensor. - // NOTE: ResizeTensor takes ownership of newSize. - TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor, - TfLiteIntArray* new_size); - // Request that a error be reported with format string msg. - void (*ReportError)(struct TfLiteContext*, const char* msg, ...); - - // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If - // non-null, the value pointed to by `first_new_tensor_index` will be set to - // the index of the first new tensor. - TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add, - int* first_new_tensor_index); - - // Get a Tensor node by node_index. - // WARNING: This is an experimental interface that is subject to change. - TfLiteStatus (*GetNodeAndRegistration)( - struct TfLiteContext*, int node_index, struct TfLiteNode** node, - struct _TfLiteRegistration** registration); - - // Replace ops with one or more stub delegate operations. This function - // does not take ownership of `nodes_to_replace`. - TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)( - struct TfLiteContext*, struct _TfLiteRegistration registration, - const TfLiteIntArray* nodes_to_replace, struct _TfLiteDelegate* delegate); - - // Number of threads that are recommended to subsystems like gemmlowp and - // eigen. - int recommended_num_threads; - - // Access external contexts by type. - // WARNING: This is an experimental interface that is subject to change. - TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*, - TfLiteExternalContextType); - // Set the value of a external context. Does not take ownership of the - // pointer. - // WARNING: This is an experimental interface that is subject to change. - void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType, - TfLiteExternalContext*); -} TfLiteContext; - -typedef struct _TfLiteRegistration { - // Initializes the op from serialized data. - // If a built-in op: - // `buffer` is the op's params data (TfLiteLSTMParams*). - // `length` is zero. - // If custom op: - // `buffer` is the op's `custom_options`. - // `length` is the size of the buffer. - // - // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer - // or an instance of a struct). - // - // The returned pointer will be stored with the node in the `user_data` field, - // accessible within prepare and invoke functions below. - // NOTE: if the data is already in the desired format, simply implement this - // function to return `nullptr` and implement the free function to be a no-op. - void* (*init)(TfLiteContext* context, const char* buffer, size_t length); - - // The pointer `buffer` is the data previously returned by an init invocation. - void (*free)(TfLiteContext* context, void* buffer); - - // prepare is called when the inputs this node depends on have been resized. - // context->ResizeTensor() can be called to request output tensors to be - // resized. - // - // Returns kTfLiteOk on success. - TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node); - - // Execute the node (should read node->inputs and output to node->outputs). - // Returns kTfLiteOk on success. - TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); - - // profiling_string is called during summarization of profiling information - // in order to group executions together. Providing a value here will cause a - // given op to appear multiple times is the profiling report. This is - // particularly useful for custom ops that can perform significantly - // different calculations depending on their `user-data`. - const char* (*profiling_string)(const TfLiteContext* context, - const TfLiteNode* node); - - // Builtin codes. If this kernel refers to a builtin this is the code - // of the builtin. This is so we can do marshaling to other frameworks like - // NN API. - // Note: It is the responsibility of the registration binder to set this - // properly. - int32_t builtin_code; - - // Custom op name. If the op is a builtin, this will be null. - // Note: It is the responsibility of the registration binder to set this - // properly. - // WARNING: This is an experimental interface that is subject to change. - const char* custom_name; - - // The version of the op. - // Note: It is the responsibility of the registration binder to set this - // properly. - int version; -} TfLiteRegistration; - -// WARNING: This is an experimental interface that is subject to change. -typedef struct _TfLiteDelegate { - // Data that delegate needs to identify itself. This data is owned by the - // delegate. The delegate is owned in the user code, so the delegate is - // responsible for doing this when it is destroyed. - void* data_; - - // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the - // delegate a view of the current graph through TfLiteContext*. It typically - // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels() - // to ask the TensorFlow lite runtime to create macro-nodes to represent - // delegated subgraphs of the original graph. - TfLiteStatus (*Prepare)(struct TfLiteContext* context, - struct _TfLiteDelegate* delegate); - - // Copy the data from delegate buffer handle to raw memory. - // This can be null if the delegate doesn't use its own buffer. - TfLiteStatus (*CopyFromBufferHandle)(struct TfLiteContext* context, - struct _TfLiteDelegate* delegate, - TfLiteBufferHandle buffer_handle, - void* data, size_t size); - - // Copy the data from raw memory to delegate buffer handle. - // This can be null if the delegate doesn't use its own buffer. - TfLiteStatus (*CopyToBufferHandle)(struct TfLiteContext* context, - struct _TfLiteDelegate* delegate, - TfLiteBufferHandle buffer_handle, - void* data, size_t size); - - // Free the Delegate Buffer Handle. Note: This only frees the handle, but - // this doesn't release the underlying resource (e.g. textures). The - // resources are either owned by application layer or the delegate. - // This can be null if the delegate doesn't use its own buffer. - void (*FreeBufferHandle)(struct TfLiteContext* context, - struct _TfLiteDelegate* delegate, - TfLiteBufferHandle* handle); -} TfLiteDelegate; - -// WARNING: This is an experimental interface that is subject to change. -// -// Currently, TfLiteDelegateParams has to be allocated in a way that it's -// trivially destructable. It will be stored as `builtin_data` field in -// `TfLiteNode` of the delegate node. -// -// See also the `CreateDelegateParams` function in `interpreter.cc` details. -typedef struct { - TfLiteDelegate* delegate; - TfLiteIntArray* nodes_to_replace; - TfLiteIntArray* input_tensors; - TfLiteIntArray* output_tensors; -} TfLiteDelegateParams; - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus #endif // TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ diff --git a/tensorflow/contrib/lite/context_util.h b/tensorflow/contrib/lite/context_util.h index abe802e342..ccda4c7393 100644 --- a/tensorflow/contrib/lite/context_util.h +++ b/tensorflow/contrib/lite/context_util.h @@ -17,7 +17,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ #define TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace tflite { diff --git a/tensorflow/contrib/lite/core/api/BUILD b/tensorflow/contrib/lite/core/api/BUILD new file mode 100644 index 0000000000..e4500534f3 --- /dev/null +++ b/tensorflow/contrib/lite/core/api/BUILD @@ -0,0 +1,57 @@ +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + +cc_library( + name = "api", + srcs = [ + "error_reporter.cc", + "flatbuffer_conversions.cc", + "op_resolver.cc", + ], + hdrs = [ + "error_reporter.h", + "flatbuffer_conversions.h", + "op_resolver.h", + ], + copts = tflite_copts(), + deps = [ + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/schema:schema_fbs", + ], +) + +cc_test( + name = "error_reporter_test", + size = "small", + srcs = ["error_reporter_test.cc"], + deps = [ + ":api", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "op_resolver_test", + size = "small", + srcs = ["op_resolver_test.cc"], + deps = [ + ":api", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "flatbuffer_conversions_test", + size = "small", + srcs = ["flatbuffer_conversions_test.cc"], + deps = [ + ":api", + "//tensorflow/contrib/lite/c:c_api_internal", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/core/api/error_reporter.cc b/tensorflow/contrib/lite/core/api/error_reporter.cc new file mode 100644 index 0000000000..423f83b1a9 --- /dev/null +++ b/tensorflow/contrib/lite/core/api/error_reporter.cc @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include <cstdarg> + +namespace tflite { + +int ErrorReporter::Report(const char* format, ...) { + va_list args; + va_start(args, format); + int code = Report(format, args); + va_end(args); + return code; +} + +// TODO(aselle): Make the name of ReportError on context the same, so +// we can use the ensure functions w/o a context and w/ a reporter. +int ErrorReporter::ReportError(void*, const char* format, ...) { + va_list args; + va_start(args, format); + int code = Report(format, args); + va_end(args); + return code; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/core/api/error_reporter.h b/tensorflow/contrib/lite/core/api/error_reporter.h new file mode 100644 index 0000000000..a2f780b003 --- /dev/null +++ b/tensorflow/contrib/lite/core/api/error_reporter.h @@ -0,0 +1,45 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_ +#define TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_ + +#include <cstdarg> + +namespace tflite { + +// A functor that reports error to supporting system. Invoked similar to +// printf. +// +// Usage: +// ErrorReporter foo; +// foo.Report("test %d", 5); +// or +// va_list args; +// foo.Report("test %d", args); // where args is va_list +// +// Subclass ErrorReporter to provide another reporting destination. +// For example, if you have a GUI program, you might redirect to a buffer +// that drives a GUI error log box. +class ErrorReporter { + public: + virtual ~ErrorReporter() {} + virtual int Report(const char* format, va_list args) = 0; + int Report(const char* format, ...); + int ReportError(void*, const char* format, ...); +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_ diff --git a/tensorflow/contrib/lite/core/api/error_reporter_test.cc b/tensorflow/contrib/lite/core/api/error_reporter_test.cc new file mode 100644 index 0000000000..0463eee6be --- /dev/null +++ b/tensorflow/contrib/lite/core/api/error_reporter_test.cc @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/core/api/error_reporter.h" + +#include <cstdio> + +#include <gtest/gtest.h> + +namespace tflite { + +class MockErrorReporter : public ErrorReporter { + public: + int Report(const char* format, va_list args) override { + vsnprintf(buffer_, kBufferSize, format, args); + return 0; + } + char* GetBuffer() { return buffer_; } + + private: + static constexpr int kBufferSize = 256; + char buffer_[kBufferSize]; +}; + +TEST(ErrorReporter, TestReport) { + MockErrorReporter mock_reporter; + ErrorReporter* reporter = &mock_reporter; + reporter->Report("Error: %d", 23); + EXPECT_EQ(0, strcmp(mock_reporter.GetBuffer(), "Error: 23")); +} + +} // namespace tflite + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc new file mode 100644 index 0000000000..1420fbcdc6 --- /dev/null +++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc @@ -0,0 +1,622 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h" + +#include <cstdlib> + +#include "tensorflow/contrib/lite/c/builtin_op_data.h" + +namespace tflite { + +namespace { + +// Copies the contents from the flatbuffer int vector `flatbuffer` into the +// int array `buffer`. `flat_vector` and `buffer` represent the same +// configuration operation for a given operation. +void FlatBufferIntVectorToArray(int max_size_of_buffer, + const flatbuffers::Vector<int32_t>* flat_vector, + int* buffer, ErrorReporter* error_reporter) { + if (!flat_vector) { + error_reporter->Report("Input array not provided for operation.\n"); + } else { + int num_dimensions = flat_vector->Length(); + if (num_dimensions > max_size_of_buffer / sizeof(int)) { + error_reporter->Report( + "Found too many dimensions in the operation's input array.\n"); + } else { + for (int i = 0; i < num_dimensions; ++i) { + buffer[i] = flat_vector->Get(i); + } + } + } +} + +// Allocate a structure using malloc, but make sure the structure is a POD +// structure that doesn't require constructors to run. The reason we do this, +// is that Interpreter's C extension part will take ownership so destructors +// will not be run during deallocation. +template <class T> +T* MallocPOD() { + static_assert(std::is_pod<T>::value, "Builtin data structure must be POD."); + return static_cast<T*>(malloc(sizeof(T))); +} + +} // namespace + +TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, + ErrorReporter* error_reporter) { + switch (tensor_type) { + case TensorType_FLOAT32: + *type = kTfLiteFloat32; + break; + case TensorType_INT16: + *type = kTfLiteInt16; + break; + case TensorType_INT32: + *type = kTfLiteInt32; + break; + case TensorType_UINT8: + *type = kTfLiteUInt8; + break; + case TensorType_INT64: + *type = kTfLiteInt64; + break; + case TensorType_STRING: + *type = kTfLiteString; + break; + case TensorType_BOOL: + *type = kTfLiteBool; + break; + case TensorType_COMPLEX64: + *type = kTfLiteComplex64; + break; + default: + error_reporter->Report("Unimplemented data type %s (%d) in tensor\n", + EnumNameTensorType(tensor_type), tensor_type); + return kTfLiteError; + } + return kTfLiteOk; +} + +// Parse the appropriate data out of the op. +// +// This handles builtin data explicitly as there are flatbuffer schemas. +// If it returns kTfLiteOk, it passes the data out with `builtin_data`, which +// need to be released by calling `free`.` +// If it returns kTfLiteError, `builtin_data` will be `nullptr`. +TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, + ErrorReporter* error_reporter, void** builtin_data) { + auto parse_padding = [](Padding padding) { + switch (padding) { + case Padding_SAME: + return kTfLitePaddingSame; + case Padding_VALID: + return kTfLitePaddingValid; + } + return kTfLitePaddingUnknown; + }; + auto parse_activation = [](ActivationFunctionType activation) { + switch (activation) { + case ActivationFunctionType_NONE: + return kTfLiteActNone; + case ActivationFunctionType_RELU: + return kTfLiteActRelu; + case ActivationFunctionType_RELU_N1_TO_1: + return kTfLiteActRelu1; + case ActivationFunctionType_RELU6: + return kTfLiteActRelu6; + case ActivationFunctionType_TANH: + return kTfLiteActTanh; + case ActivationFunctionType_SIGN_BIT: + return kTfLiteActSignBit; + } + return kTfLiteActNone; + }; + auto parseLSHProjectionType = [](LSHProjectionType type) { + switch (type) { + case LSHProjectionType_SPARSE: + return kTfLiteLshProjectionSparse; + case LSHProjectionType_DENSE: + return kTfLiteLshProjectionDense; + default: + return kTfLiteLshProjectionUnknown; + } + }; + auto parseCombinerType = [](CombinerType type) { + switch (type) { + case CombinerType_MEAN: + return kTfLiteCombinerTypeMean; + case CombinerType_SQRTN: + return kTfLiteCombinerTypeSqrtn; + case CombinerType_SUM: + default: + return kTfLiteCombinerTypeSum; + } + }; + + *builtin_data = nullptr; + switch (op_type) { + case BuiltinOperator_CONV_2D: { + TfLiteConvParams* params = MallocPOD<TfLiteConvParams>(); + if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { + params->padding = parse_padding(conv_params->padding()); + params->stride_width = conv_params->stride_w(); + params->stride_height = conv_params->stride_h(); + params->activation = + parse_activation(conv_params->fused_activation_function()); + + params->dilation_width_factor = conv_params->dilation_w_factor(); + params->dilation_height_factor = conv_params->dilation_h_factor(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_CAST: { + TfLiteCastParams* params = MallocPOD<TfLiteCastParams>(); + if (auto* schema_params = op->builtin_options_as_CastOptions()) { + auto in_status = + ConvertTensorType(schema_params->in_data_type(), + ¶ms->in_data_type, error_reporter); + auto out_status = + ConvertTensorType(schema_params->out_data_type(), + ¶ms->out_data_type, error_reporter); + if (in_status != kTfLiteOk || out_status != kTfLiteOk) { + free(params); + return kTfLiteError; + } + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_LSH_PROJECTION: { + TfLiteLSHProjectionParams* params = + MallocPOD<TfLiteLSHProjectionParams>(); + if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) { + params->type = parseLSHProjectionType(lshParams->type()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_AVERAGE_POOL_2D: + case BuiltinOperator_MAX_POOL_2D: + case BuiltinOperator_L2_POOL_2D: { + TfLitePoolParams* params = MallocPOD<TfLitePoolParams>(); + if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) { + params->padding = parse_padding(pool_params->padding()); + params->stride_width = pool_params->stride_w(); + params->stride_height = pool_params->stride_h(); + params->filter_width = pool_params->filter_width(); + params->filter_height = pool_params->filter_height(); + params->activation = + parse_activation(pool_params->fused_activation_function()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_DEPTHWISE_CONV_2D: { + TfLiteDepthwiseConvParams* params = + MallocPOD<TfLiteDepthwiseConvParams>(); + if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) { + params->padding = parse_padding(conv_params->padding()); + params->stride_width = conv_params->stride_w(); + params->stride_height = conv_params->stride_h(); + params->depth_multiplier = conv_params->depth_multiplier(); + params->activation = + parse_activation(conv_params->fused_activation_function()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_SVDF: { + TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>(); + if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) { + params->rank = svdf_params->rank(); + params->activation = + parse_activation(svdf_params->fused_activation_function()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { + TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>(); + if (auto* sequence_rnn_params = + op->builtin_options_as_SequenceRNNOptions()) { + params->activation = + parse_activation(sequence_rnn_params->fused_activation_function()); + params->time_major = sequence_rnn_params->time_major(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_RNN: { + TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>(); + if (auto* rnn_params = op->builtin_options_as_RNNOptions()) { + params->activation = + parse_activation(rnn_params->fused_activation_function()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { + TfLiteEmbeddingLookupSparseParams* params = + MallocPOD<TfLiteEmbeddingLookupSparseParams>(); + if (auto* embedding_params = + op->builtin_options_as_EmbeddingLookupSparseOptions()) { + params->combiner = parseCombinerType(embedding_params->combiner()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_FULLY_CONNECTED: { + TfLiteFullyConnectedParams* params = + MallocPOD<TfLiteFullyConnectedParams>(); + if (auto* fully_connected_params = + op->builtin_options_as_FullyConnectedOptions()) { + params->activation = parse_activation( + fully_connected_params->fused_activation_function()); + switch (fully_connected_params->weights_format()) { + case FullyConnectedOptionsWeightsFormat_DEFAULT: + params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault; + break; + case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: + params->weights_format = + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8; + break; + default: + error_reporter->Report("Unhandled fully-connected weights format."); + return kTfLiteError; + } + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_HASHTABLE_LOOKUP: + // no-op. + break; + case BuiltinOperator_SOFTMAX: { + TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>(); + if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) { + params->beta = softmax_params->beta(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_CONCATENATION: { + TfLiteConcatenationParams* params = + MallocPOD<TfLiteConcatenationParams>(); + if (auto* concatenation_params = + op->builtin_options_as_ConcatenationOptions()) { + params->activation = + parse_activation(concatenation_params->fused_activation_function()); + params->axis = concatenation_params->axis(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_MUL: { + auto* params = MallocPOD<TfLiteMulParams>(); + if (auto* schema_params = op->builtin_options_as_MulOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_ADD: { + auto* params = MallocPOD<TfLiteAddParams>(); + if (auto* schema_params = op->builtin_options_as_AddOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_DIV: { + auto* params = MallocPOD<TfLiteDivParams>(); + if (auto* schema_params = op->builtin_options_as_DivOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_SUB: { + auto* params = MallocPOD<TfLiteSubParams>(); + if (auto* schema_params = op->builtin_options_as_SubOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_L2_NORMALIZATION: { + auto* params = MallocPOD<TfLiteL2NormParams>(); + if (auto* schema_params = op->builtin_options_as_L2NormOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: { + auto* params = MallocPOD<TfLiteLocalResponseNormParams>(); + if (auto* schema_params = + op->builtin_options_as_LocalResponseNormalizationOptions()) { + params->radius = schema_params->radius(); + params->bias = schema_params->bias(); + params->alpha = schema_params->alpha(); + params->beta = schema_params->beta(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: + case BuiltinOperator_LSTM: { + TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>(); + if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { + params->activation = + parse_activation(lstm_params->fused_activation_function()); + params->cell_clip = lstm_params->cell_clip(); + params->proj_clip = lstm_params->proj_clip(); + switch (lstm_params->kernel_type()) { + case LSTMKernelType_FULL: + params->kernel_type = kTfLiteLSTMFullKernel; + break; + case LSTMKernelType_BASIC: + params->kernel_type = kTfLiteLSTMBasicKernel; + break; + } + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_RESIZE_BILINEAR: { + auto* params = MallocPOD<TfLiteResizeBilinearParams>(); + if (auto* schema_params = + op->builtin_options_as_ResizeBilinearOptions()) { + params->align_corners = schema_params->align_corners(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_RESHAPE: { + auto* params = MallocPOD<TfLiteReshapeParams>(); + if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { + auto* new_shape = schema_params->new_shape(); + FlatBufferIntVectorToArray(sizeof(params->shape), new_shape, + params->shape, error_reporter); + params->num_dimensions = new_shape->Length(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_SKIP_GRAM: { + TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>(); + if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) { + params->ngram_size = skip_gram_params->ngram_size(); + params->max_skip_size = skip_gram_params->max_skip_size(); + params->include_all_ngrams = skip_gram_params->include_all_ngrams(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_SPACE_TO_DEPTH: { + auto* params = MallocPOD<TfLiteSpaceToDepthParams>(); + if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) { + params->block_size = schema_params->block_size(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_GATHER: { + TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>(); + params->axis = 0; + if (auto* gather_params = op->builtin_options_as_GatherOptions()) { + params->axis = gather_params->axis(); + } + + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_MEAN: + case BuiltinOperator_REDUCE_MAX: + case BuiltinOperator_REDUCE_MIN: + case BuiltinOperator_REDUCE_PROD: + case BuiltinOperator_REDUCE_ANY: + case BuiltinOperator_SUM: { + auto* params = MallocPOD<TfLiteReducerParams>(); + if (auto* schema_params = op->builtin_options_as_ReducerOptions()) { + params->keep_dims = schema_params->keep_dims(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_SPLIT: { + auto* params = MallocPOD<TfLiteSplitParams>(); + if (auto* schema_params = op->builtin_options_as_SplitOptions()) { + params->num_splits = schema_params->num_splits(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_SQUEEZE: { + auto* params = MallocPOD<TfLiteSqueezeParams>(); + if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) { + const auto& squeeze_dims = schema_params->squeeze_dims(); + FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims, + params->squeeze_dims, error_reporter); + params->num_squeeze_dims = squeeze_dims->Length(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_STRIDED_SLICE: { + auto* params = MallocPOD<TfLiteStridedSliceParams>(); + if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) { + params->begin_mask = schema_params->begin_mask(); + params->end_mask = schema_params->end_mask(); + params->ellipsis_mask = schema_params->ellipsis_mask(); + params->new_axis_mask = schema_params->new_axis_mask(); + params->shrink_axis_mask = schema_params->shrink_axis_mask(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_ARG_MAX: { + auto* params = MallocPOD<TfLiteArgMaxParams>(); + if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) { + ConvertTensorType(schema_params->output_type(), ¶ms->output_type, + error_reporter); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_ARG_MIN: { + auto* params = MallocPOD<TfLiteArgMinParams>(); + if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) { + ConvertTensorType(schema_params->output_type(), ¶ms->output_type, + error_reporter); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_TRANSPOSE_CONV: { + TfLiteTransposeConvParams* params = + MallocPOD<TfLiteTransposeConvParams>(); + if (auto* transpose_conv_params = + op->builtin_options_as_TransposeConvOptions()) { + params->padding = parse_padding(transpose_conv_params->padding()); + params->stride_width = transpose_conv_params->stride_w(); + params->stride_height = transpose_conv_params->stride_h(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_SPARSE_TO_DENSE: { + TfLiteSparseToDenseParams* params = + MallocPOD<TfLiteSparseToDenseParams>(); + if (auto* sparse_to_dense_params = + op->builtin_options_as_SparseToDenseOptions()) { + params->validate_indices = sparse_to_dense_params->validate_indices(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_SHAPE: { + auto* params = MallocPOD<TfLiteShapeParams>(); + if (auto* schema_params = op->builtin_options_as_ShapeOptions()) { + ConvertTensorType(schema_params->out_type(), ¶ms->out_type, + error_reporter); + } + *builtin_data = static_cast<void*>(params); + break; + } + case BuiltinOperator_PACK: { + TfLitePackParams* params = MallocPOD<TfLitePackParams>(); + if (auto* pack_params = op->builtin_options_as_PackOptions()) { + params->values_count = pack_params->values_count(); + params->axis = pack_params->axis(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_DELEGATE: { + // TODO(ycling): Revisit when supporting saving delegated models. + error_reporter->Report("DELEGATE op shouldn't exist in model."); + return kTfLiteError; + } + case BuiltinOperator_FAKE_QUANT: { + auto* params = MallocPOD<TfLiteFakeQuantParams>(); + if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) { + params->min = schema_params->min(); + params->max = schema_params->max(); + params->num_bits = schema_params->num_bits(); + params->narrow_range = schema_params->narrow_range(); + } + *builtin_data = static_cast<void*>(params); + break; + } + case BuiltinOperator_ONE_HOT: { + auto* params = MallocPOD<TfLiteOneHotParams>(); + if (auto* schema_params = op->builtin_options_as_OneHotOptions()) { + params->axis = schema_params->axis(); + } + *builtin_data = static_cast<void*>(params); + break; + } + case BuiltinOperator_UNPACK: { + TfLiteUnpackParams* params = MallocPOD<TfLiteUnpackParams>(); + if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) { + params->num = unpack_params->num(); + params->axis = unpack_params->axis(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + + // Below are the ops with no builtin_data strcture. + case BuiltinOperator_BATCH_TO_SPACE_ND: + // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are + // ok for now, since there is no call implementation either. + case BuiltinOperator_CALL: + case BuiltinOperator_CONCAT_EMBEDDINGS: + case BuiltinOperator_CUSTOM: + case BuiltinOperator_DEQUANTIZE: + case BuiltinOperator_EMBEDDING_LOOKUP: + case BuiltinOperator_EQUAL: + case BuiltinOperator_EXP: + case BuiltinOperator_EXPAND_DIMS: + case BuiltinOperator_FLOOR: + case BuiltinOperator_GREATER: + case BuiltinOperator_GREATER_EQUAL: + case BuiltinOperator_LESS: + case BuiltinOperator_LESS_EQUAL: + case BuiltinOperator_LOG: + case BuiltinOperator_LOGISTIC: + case BuiltinOperator_LOG_SOFTMAX: + case BuiltinOperator_MAXIMUM: + case BuiltinOperator_MINIMUM: + case BuiltinOperator_NEG: + case BuiltinOperator_NOT_EQUAL: + case BuiltinOperator_PAD: + case BuiltinOperator_PADV2: + case BuiltinOperator_PRELU: + case BuiltinOperator_RELU: + case BuiltinOperator_RELU6: + case BuiltinOperator_RELU_N1_TO_1: + case BuiltinOperator_RSQRT: + case BuiltinOperator_SELECT: + case BuiltinOperator_SIN: + case BuiltinOperator_SLICE: + case BuiltinOperator_SPACE_TO_BATCH_ND: + case BuiltinOperator_SQRT: + case BuiltinOperator_TANH: + case BuiltinOperator_TILE: + case BuiltinOperator_TOPK_V2: + case BuiltinOperator_TRANSPOSE: + case BuiltinOperator_POW: + case BuiltinOperator_LOGICAL_OR: + case BuiltinOperator_LOGICAL_AND: + case BuiltinOperator_LOGICAL_NOT: + case BuiltinOperator_FLOOR_DIV: + break; + } + return kTfLiteOk; +} // NOLINT[readability/fn_size] + +} // namespace tflite diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h new file mode 100644 index 0000000000..4dec6f9cfc --- /dev/null +++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ +#define TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ + +// These functions transform codes and data structures that are defined in the +// flatbuffer serialization format into in-memory values that are used by the +// runtime API and interpreter. + +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/op_resolver.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { + +// Parse the appropriate data out of the op. +// +// This handles builtin data explicitly as there are flatbuffer schemas. +// If it returns kTfLiteOk, it passes the data out with `builtin_data`. The +// calling function has to pass in an allocator object, and this allocator +// will be called to reserve space for the output data. If the calling +// function's allocator reserves memory on the heap, then it's the calling +// function's responsibility to free it. +// If it returns kTfLiteError, `builtin_data` will be `nullptr`. +TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, + ErrorReporter* error_reporter, void** builtin_data); + +// Converts the tensor data type used in the flat buffer to the representation +// used by the runtime. +TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, + ErrorReporter* error_reporter); + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc new file mode 100644 index 0000000000..b12bdf43b2 --- /dev/null +++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h" + +#include <cstring> + +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/c/builtin_op_data.h" + +namespace tflite { +namespace { + +class MockErrorReporter : public ErrorReporter { + public: + MockErrorReporter() : buffer_size_(0) {} + int Report(const char* format, va_list args) override { + buffer_size_ = vsnprintf(buffer_, kBufferSize, format, args); + return buffer_size_; + } + char* GetBuffer() { return buffer_; } + int GetBufferSize() { return buffer_size_; } + + private: + static constexpr int kBufferSize = 256; + char buffer_[kBufferSize]; + int buffer_size_; +}; + +} // namespace + +TEST(FlatbufferConversions, TestParseOpDataConv) { + MockErrorReporter mock_reporter; + ErrorReporter* reporter = &mock_reporter; + + flatbuffers::FlatBufferBuilder builder; + flatbuffers::Offset<void> conv_options = + CreateConv2DOptions(builder, Padding_SAME, 1, 2, + ActivationFunctionType_RELU, 3, 4) + .Union(); + flatbuffers::Offset<Operator> conv_offset = CreateOperatorDirect( + builder, 0, nullptr, nullptr, BuiltinOptions_Conv2DOptions, conv_options, + nullptr, CustomOptionsFormat_FLEXBUFFERS, nullptr); + builder.Finish(conv_offset); + void* conv_pointer = builder.GetBufferPointer(); + const Operator* conv_op = flatbuffers::GetRoot<Operator>(conv_pointer); + void* output_data = nullptr; + EXPECT_EQ(kTfLiteOk, ParseOpData(conv_op, BuiltinOperator_CONV_2D, reporter, + &output_data)); + EXPECT_NE(nullptr, output_data); + TfLiteConvParams* params = reinterpret_cast<TfLiteConvParams*>(output_data); + EXPECT_EQ(kTfLitePaddingSame, params->padding); + EXPECT_EQ(1, params->stride_width); + EXPECT_EQ(2, params->stride_height); + EXPECT_EQ(kTfLiteActRelu, params->activation); + EXPECT_EQ(3, params->dilation_width_factor); + EXPECT_EQ(4, params->dilation_height_factor); + free(output_data); +} + +TEST(FlatbufferConversions, TestParseOpDataCustom) { + MockErrorReporter mock_reporter; + ErrorReporter* reporter = &mock_reporter; + + flatbuffers::FlatBufferBuilder builder; + flatbuffers::Offset<void> null_options; + flatbuffers::Offset<Operator> custom_offset = CreateOperatorDirect( + builder, 0, nullptr, nullptr, BuiltinOptions_NONE, null_options, nullptr, + CustomOptionsFormat_FLEXBUFFERS, nullptr); + builder.Finish(custom_offset); + void* custom_pointer = builder.GetBufferPointer(); + const Operator* custom_op = flatbuffers::GetRoot<Operator>(custom_pointer); + void* output_data = nullptr; + EXPECT_EQ(kTfLiteOk, ParseOpData(custom_op, BuiltinOperator_CUSTOM, reporter, + &output_data)); + EXPECT_EQ(nullptr, output_data); +} + +TEST(FlatbufferConversions, TestConvertTensorType) { + MockErrorReporter mock_reporter; + ErrorReporter* reporter = &mock_reporter; + TfLiteType type; + EXPECT_EQ(kTfLiteOk, ConvertTensorType(TensorType_FLOAT32, &type, reporter)); + EXPECT_EQ(kTfLiteFloat32, type); +} + +} // namespace tflite + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/core/api/op_resolver.cc b/tensorflow/contrib/lite/core/api/op_resolver.cc new file mode 100644 index 0000000000..55ee924843 --- /dev/null +++ b/tensorflow/contrib/lite/core/api/op_resolver.cc @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/core/api/op_resolver.h" + +namespace tflite { + +TfLiteStatus GetRegistrationFromOpCode( + const OperatorCode* opcode, const OpResolver& op_resolver, + ErrorReporter* error_reporter, const TfLiteRegistration** registration) { + TfLiteStatus status = kTfLiteOk; + *registration = nullptr; + auto builtin_code = opcode->builtin_code(); + int version = opcode->version(); + + if (builtin_code > BuiltinOperator_MAX || + builtin_code < BuiltinOperator_MIN) { + error_reporter->Report( + "Op builtin_code out of range: %d. Are you using old TFLite binary " + "with newer model?", + builtin_code); + status = kTfLiteError; + } else if (builtin_code != BuiltinOperator_CUSTOM) { + *registration = op_resolver.FindOp(builtin_code, version); + if (*registration == nullptr) { + error_reporter->Report( + "Didn't find op for builtin opcode '%s' version '%d'\n", + EnumNameBuiltinOperator(builtin_code), version); + status = kTfLiteError; + } + } else if (!opcode->custom_code()) { + error_reporter->Report( + "Operator with CUSTOM builtin_code has no custom_code.\n"); + status = kTfLiteError; + } else { + const char* name = opcode->custom_code()->c_str(); + *registration = op_resolver.FindOp(name, version); + if (*registration == nullptr) { + error_reporter->Report( + "Didn't find custom op for name '%s' with version %d\n", name, + version); + status = kTfLiteError; + } + } + return status; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/core/api/op_resolver.h b/tensorflow/contrib/lite/core/api/op_resolver.h new file mode 100644 index 0000000000..5f5e6b2736 --- /dev/null +++ b/tensorflow/contrib/lite/core/api/op_resolver.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_ +#define TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_ + +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { + +// Abstract interface that returns TfLiteRegistrations given op codes or custom +// op names. This is the mechanism that ops being referenced in the flatbuffer +// model are mapped to executable function pointers (TfLiteRegistrations). +class OpResolver { + public: + // Finds the op registration for a builtin operator by enum code. + virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, + int version) const = 0; + // Finds the op registration of a custom operator by op name. + virtual const TfLiteRegistration* FindOp(const char* op, + int version) const = 0; + virtual ~OpResolver() {} +}; + +// Handles the logic for converting between an OperatorCode structure extracted +// from a flatbuffer and information about a registered operator implementation. +TfLiteStatus GetRegistrationFromOpCode(const OperatorCode* opcode, + const OpResolver& op_resolver, + ErrorReporter* error_reporter, + const TfLiteRegistration** registration); + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/core/api/op_resolver_test.cc b/tensorflow/contrib/lite/core/api/op_resolver_test.cc new file mode 100644 index 0000000000..167463110e --- /dev/null +++ b/tensorflow/contrib/lite/core/api/op_resolver_test.cc @@ -0,0 +1,197 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/core/api/op_resolver.h" + +#include <cstring> + +#include <gtest/gtest.h> + +namespace tflite { +namespace { +void* MockInit(TfLiteContext* context, const char* buffer, size_t length) { + // Do nothing. + return nullptr; +} + +void MockFree(TfLiteContext* context, void* buffer) { + // Do nothing. +} + +TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +TfLiteStatus MockInvoke(TfLiteContext* context, TfLiteNode* node) { + return kTfLiteOk; +} + +class MockOpResolver : public OpResolver { + public: + const TfLiteRegistration* FindOp(BuiltinOperator op, + int version) const override { + if (op == BuiltinOperator_CONV_2D) { + static TfLiteRegistration r = {MockInit, MockFree, MockPrepare, + MockInvoke}; + return &r; + } else { + return nullptr; + } + } + const TfLiteRegistration* FindOp(const char* op, int version) const override { + if (strcmp(op, "mock_custom") == 0) { + static TfLiteRegistration r = {MockInit, MockFree, MockPrepare, + MockInvoke}; + return &r; + } else { + return nullptr; + } + } +}; + +class MockErrorReporter : public ErrorReporter { + public: + MockErrorReporter() : buffer_size_(0) {} + int Report(const char* format, va_list args) override { + buffer_size_ = vsnprintf(buffer_, kBufferSize, format, args); + return buffer_size_; + } + char* GetBuffer() { return buffer_; } + int GetBufferSize() { return buffer_size_; } + + private: + static constexpr int kBufferSize = 256; + char buffer_[kBufferSize]; + int buffer_size_; +}; + +} // namespace + +TEST(OpResolver, TestResolver) { + MockOpResolver mock_resolver; + OpResolver* resolver = &mock_resolver; + + const TfLiteRegistration* registration = + resolver->FindOp(BuiltinOperator_CONV_2D, 0); + EXPECT_NE(nullptr, registration); + EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0)); + EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr)); + EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr)); + + registration = resolver->FindOp(BuiltinOperator_CAST, 0); + EXPECT_EQ(nullptr, registration); + + registration = resolver->FindOp("mock_custom", 0); + EXPECT_NE(nullptr, registration); + EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0)); + EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr)); + EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr)); + + registration = resolver->FindOp("nonexistent_custom", 0); + EXPECT_EQ(nullptr, registration); +} + +TEST(OpResolver, TestGetRegistrationFromOpCodeConv) { + MockOpResolver mock_resolver; + OpResolver* resolver = &mock_resolver; + MockErrorReporter mock_reporter; + ErrorReporter* reporter = &mock_reporter; + + flatbuffers::FlatBufferBuilder builder; + flatbuffers::Offset<OperatorCode> conv_offset = + CreateOperatorCodeDirect(builder, BuiltinOperator_CONV_2D, nullptr, 0); + builder.Finish(conv_offset); + void* conv_pointer = builder.GetBufferPointer(); + const OperatorCode* conv_code = + flatbuffers::GetRoot<OperatorCode>(conv_pointer); + const TfLiteRegistration* registration = nullptr; + EXPECT_EQ(kTfLiteOk, GetRegistrationFromOpCode(conv_code, *resolver, reporter, + ®istration)); + EXPECT_NE(nullptr, registration); + EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0)); + EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr)); + EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr)); + EXPECT_EQ(0, mock_reporter.GetBufferSize()); +} + +TEST(OpResolver, TestGetRegistrationFromOpCodeCast) { + MockOpResolver mock_resolver; + OpResolver* resolver = &mock_resolver; + MockErrorReporter mock_reporter; + ErrorReporter* reporter = &mock_reporter; + + flatbuffers::FlatBufferBuilder builder; + flatbuffers::Offset<OperatorCode> conv_offset = + CreateOperatorCodeDirect(builder, BuiltinOperator_CAST, nullptr, 0); + builder.Finish(conv_offset); + void* conv_pointer = builder.GetBufferPointer(); + const OperatorCode* conv_code = + flatbuffers::GetRoot<OperatorCode>(conv_pointer); + const TfLiteRegistration* registration = nullptr; + EXPECT_EQ(kTfLiteError, GetRegistrationFromOpCode(conv_code, *resolver, + reporter, ®istration)); + EXPECT_EQ(nullptr, registration); + EXPECT_NE(0, mock_reporter.GetBufferSize()); +} + +TEST(OpResolver, TestGetRegistrationFromOpCodeCustom) { + MockOpResolver mock_resolver; + OpResolver* resolver = &mock_resolver; + MockErrorReporter mock_reporter; + ErrorReporter* reporter = &mock_reporter; + + flatbuffers::FlatBufferBuilder builder; + flatbuffers::Offset<OperatorCode> conv_offset = CreateOperatorCodeDirect( + builder, BuiltinOperator_CUSTOM, "mock_custom", 0); + builder.Finish(conv_offset); + void* conv_pointer = builder.GetBufferPointer(); + const OperatorCode* conv_code = + flatbuffers::GetRoot<OperatorCode>(conv_pointer); + const TfLiteRegistration* registration = nullptr; + EXPECT_EQ(kTfLiteOk, GetRegistrationFromOpCode(conv_code, *resolver, reporter, + ®istration)); + EXPECT_NE(nullptr, registration); + EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0)); + EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr)); + EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr)); + EXPECT_EQ(0, mock_reporter.GetBufferSize()); +} + +TEST(OpResolver, TestGetRegistrationFromOpCodeNonexistentCustom) { + MockOpResolver mock_resolver; + OpResolver* resolver = &mock_resolver; + MockErrorReporter mock_reporter; + ErrorReporter* reporter = &mock_reporter; + + flatbuffers::FlatBufferBuilder builder; + flatbuffers::Offset<OperatorCode> conv_offset = CreateOperatorCodeDirect( + builder, BuiltinOperator_CUSTOM, "nonexistent_custom", 0); + builder.Finish(conv_offset); + void* conv_pointer = builder.GetBufferPointer(); + const OperatorCode* conv_code = + flatbuffers::GetRoot<OperatorCode>(conv_pointer); + const TfLiteRegistration* registration = nullptr; + EXPECT_EQ(kTfLiteError, GetRegistrationFromOpCode(conv_code, *resolver, + reporter, ®istration)); + EXPECT_EQ(nullptr, registration); + EXPECT_NE(0, mock_reporter.GetBufferSize()); +} + +} // namespace tflite + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD index b6b2357873..bf5d91899c 100644 --- a/tensorflow/contrib/lite/delegates/eager/BUILD +++ b/tensorflow/contrib/lite/delegates/eager/BUILD @@ -16,6 +16,7 @@ cc_library( deps = [ ":util", "//tensorflow/c:c_api_internal", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite:kernel_api", ] + select({ "//tensorflow:android": [ @@ -54,6 +55,7 @@ cc_library( ":delegate_data", ":kernel", ":util", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite:kernel_api", "//tensorflow/contrib/lite:util", ] + select({ @@ -104,6 +106,7 @@ tf_cc_test( ":delegate_data", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:util", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], @@ -117,6 +120,7 @@ cc_library( ":delegate_data", ":util", "@flatbuffers", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite:kernel_api", "//tensorflow/contrib/lite:string", "//tensorflow/contrib/lite/kernels:kernel_util", @@ -170,6 +174,7 @@ cc_library( hdrs = ["util.h"], deps = [ "//tensorflow/c:c_api_internal", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite:kernel_api", ] + select({ "//tensorflow:android": [ diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.h b/tensorflow/contrib/lite/delegates/eager/buffer_map.h index a28329ae7d..aaaa045840 100644 --- a/tensorflow/contrib/lite/delegates/eager/buffer_map.h +++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.h @@ -17,7 +17,7 @@ limitations under the License. #include <map> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/core/framework/tensor.h" namespace tflite { diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/eager/delegate.h index 6d15ba47dc..70f3c15af4 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate.h +++ b/tensorflow/contrib/lite/delegates/eager/delegate.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_ #define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" namespace tflite { diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc index b3a0ffcec1..def063309f 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc +++ b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include <gmock/gmock.h> #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/testing/util.h" namespace tflite { diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/eager/kernel.cc index 0ee4db1ffb..274c3c082a 100644 --- a/tensorflow/contrib/lite/delegates/eager/kernel.cc +++ b/tensorflow/contrib/lite/delegates/eager/kernel.cc @@ -16,7 +16,7 @@ limitations under the License. #include "flatbuffers/flexbuffers.h" // flatbuffers #include "tensorflow/contrib/lite/builtin_ops.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/context_util.h" #include "tensorflow/contrib/lite/delegates/eager/delegate_data.h" #include "tensorflow/contrib/lite/delegates/eager/util.h" diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.h b/tensorflow/contrib/lite/delegates/eager/kernel.h index 100672c82d..2478abccaa 100644 --- a/tensorflow/contrib/lite/delegates/eager/kernel.h +++ b/tensorflow/contrib/lite/delegates/eager/kernel.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_ #define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace tflite { namespace eager { diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h index ff500d18f3..930cb99cb9 100644 --- a/tensorflow/contrib/lite/delegates/eager/util.h +++ b/tensorflow/contrib/lite/delegates/eager/util.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_ #include "tensorflow/c/c_api_internal.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/contrib/lite/delegates/nnapi/BUILD b/tensorflow/contrib/lite/delegates/nnapi/BUILD index 954955f24b..4e7b2948fb 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/BUILD +++ b/tensorflow/contrib/lite/delegates/nnapi/BUILD @@ -13,6 +13,7 @@ cc_library( deps = [ "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:kernel_util", "//tensorflow/contrib/lite/nnapi:nnapi_lib", ], @@ -29,6 +30,7 @@ tf_cc_test( deps = [ ":nnapi_delegate", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc index 980a1cb4a0..e3eebac4da 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/contrib/lite/allocation.h" #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/builtin_ops.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/context_util.h" #include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h index 44cca2fd28..4852b76974 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ #define TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace tflite { diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h index 3c5f805f12..5c20eedc25 100644 --- a/tensorflow/contrib/lite/error_reporter.h +++ b/tensorflow/contrib/lite/error_reporter.h @@ -12,43 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// Compatibility shim for moved header location. #ifndef TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ #define TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ -#include <cstdarg> -#include "tensorflow/contrib/lite/context.h" - -namespace tflite { - -// A functor that reports error to supporting system. Invoked similar to -// printf. -// -// Usage: -// ErrorReporter foo; -// foo.Report("test %d", 5); -// or -// va_list args; -// foo.Report("test %d", args); // where args is va_list -// -// Subclass ErrorReporter to provide another reporting destination. -// For example, if you have a GUI program, you might redirect to a buffer -// that drives a GUI error log box. -class ErrorReporter { - public: - virtual ~ErrorReporter(); - virtual int Report(const char* format, va_list args) = 0; - int Report(const char* format, ...); - int ReportError(void*, const char* format, ...); -}; - -// An error reporter that simplify writes the message to stderr. -struct StderrReporter : public ErrorReporter { - int Report(const char* format, va_list args) override; -}; - -// Return the default error reporter (output to stderr). -ErrorReporter* DefaultErrorReporter(); - -} // namespace tflite +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/contrib/lite/stderr_reporter.h" #endif // TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD index 8fc07e8eb7..ea4a543252 100644 --- a/tensorflow/contrib/lite/experimental/c/BUILD +++ b/tensorflow/contrib/lite/experimental/c/BUILD @@ -78,6 +78,7 @@ cc_test( data = ["//tensorflow/contrib/lite:testdata/add.bin"], deps = [ ":c_api", + "//tensorflow/contrib/lite:context", "//tensorflow/contrib/lite:kernel_api", "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc index a4ab0e8c30..c589cf71ea 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api.cc +++ b/tensorflow/contrib/lite/experimental/c/c_api.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/experimental/c/c_api.h" +#include <memory> + #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/experimental/c/c_api_internal.h" #include "tensorflow/contrib/lite/interpreter.h" @@ -29,12 +31,14 @@ extern "C" { TFL_Model* TFL_NewModel(const void* model_data, size_t model_size) { auto model = tflite::FlatBufferModel::BuildFromBuffer( static_cast<const char*>(model_data), model_size); - return model ? new TFL_Model{std::move(model)} : nullptr; + std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release()); + return shared_model ? new TFL_Model{std::move(shared_model)} : nullptr; } TFL_Model* TFL_NewModelFromFile(const char* model_path) { auto model = tflite::FlatBufferModel::BuildFromFile(model_path); - return model ? new TFL_Model{std::move(model)} : nullptr; + std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release()); + return shared_model ? new TFL_Model{std::move(shared_model)} : nullptr; } void TFL_DeleteModel(TFL_Model* model) { delete model; } @@ -72,7 +76,7 @@ TFL_Interpreter* TFL_NewInterpreter( } } - return new TFL_Interpreter{std::move(interpreter)}; + return new TFL_Interpreter{model->impl, std::move(interpreter)}; } void TFL_DeleteInterpreter(TFL_Interpreter* interpreter) { delete interpreter; } @@ -129,6 +133,8 @@ void* TFL_TensorData(const TFL_Tensor* tensor) { return static_cast<void*>(tensor->data.raw); } +const char* TFL_TensorName(const TFL_Tensor* tensor) { return tensor->name; } + TFL_Status TFL_TensorCopyFromBuffer(TFL_Tensor* tensor, const void* input_data, size_t input_data_size) { if (tensor->bytes != input_data_size) { diff --git a/tensorflow/contrib/lite/experimental/c/c_api.h b/tensorflow/contrib/lite/experimental/c/c_api.h index 3757349b55..b429e76870 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api.h +++ b/tensorflow/contrib/lite/experimental/c/c_api.h @@ -93,7 +93,8 @@ typedef struct TFL_Interpreter TFL_Interpreter; // failure. // // * `model` must be a valid model instance. The caller retains ownership of the -// object, and can destroy it immediately after creating the interpreter. +// object, and can destroy it immediately after creating the interpreter; the +// interpreter will maintain its own reference to the underlying model data. // * `optional_options` may be null. The caller retains ownership of the object, // and can safely destroy it immediately after creating the interpreter. // @@ -145,6 +146,11 @@ TFL_CAPI_EXPORT extern int32_t TFL_InterpreterGetOutputTensorCount( // Returns the tensor associated with the output index. // REQUIRES: 0 <= input_index < TFL_InterpreterGetOutputTensorCount(tensor) +// +// NOTE: The shape and underlying data buffer for output tensors may be not +// be available until after the output tensor has been both sized and allocated. +// In general, best practice is to interact with the output tensor *after* +// calling TFL_InterpreterInvoke(). TFL_CAPI_EXPORT extern const TFL_Tensor* TFL_InterpreterGetOutputTensor( const TFL_Interpreter* interpreter, int32_t output_index); @@ -172,12 +178,15 @@ TFL_CAPI_EXPORT extern size_t TFL_TensorByteSize(const TFL_Tensor* tensor); // Returns a pointer to the underlying data buffer. // -// Note: The result may be null if tensors have not yet been allocated, e.g., +// NOTE: The result may be null if tensors have not yet been allocated, e.g., // if the Tensor has just been created or resized and `TFL_AllocateTensors()` // has yet to be called, or if the output tensor is dynamically sized and the // interpreter hasn't been invoked. TFL_CAPI_EXPORT extern void* TFL_TensorData(const TFL_Tensor* tensor); +// Returns the (null-terminated) name of the tensor. +TFL_CAPI_EXPORT extern const char* TFL_TensorName(const TFL_Tensor* tensor); + // Copies from the provided input buffer into the tensor's buffer. // REQUIRES: input_data_size == TFL_TensorByteSize(tensor) TFL_CAPI_EXPORT extern TFL_Status TFL_TensorCopyFromBuffer( diff --git a/tensorflow/contrib/lite/experimental/c/c_api_internal.h b/tensorflow/contrib/lite/experimental/c/c_api_internal.h index c5c612a4c6..60c2e4e2cd 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_internal.h +++ b/tensorflow/contrib/lite/experimental/c/c_api_internal.h @@ -24,7 +24,8 @@ limitations under the License. // not be depended on. struct TFL_Model { - std::unique_ptr<tflite::FlatBufferModel> impl; + // Sharing is safe as FlatBufferModel is const. + std::shared_ptr<const tflite::FlatBufferModel> impl; }; struct TFL_InterpreterOptions { @@ -35,6 +36,9 @@ struct TFL_InterpreterOptions { }; struct TFL_Interpreter { + // Taking a reference to the (const) model data avoids lifetime-related issues + // and complexity with the TFL_Model's existence. + std::shared_ptr<const tflite::FlatBufferModel> model; std::unique_ptr<tflite::Interpreter> impl; }; diff --git a/tensorflow/contrib/lite/experimental/c/c_api_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_test.cc index a631dae890..649dac8d1a 100644 --- a/tensorflow/contrib/lite/experimental/c/c_api_test.cc +++ b/tensorflow/contrib/lite/experimental/c/c_api_test.cc @@ -55,6 +55,8 @@ TEST(CApiSimple, Smoke) { EXPECT_EQ(TFL_TensorNumDims(input_tensor), 1); EXPECT_EQ(TFL_TensorDim(input_tensor, 0), 2); EXPECT_EQ(TFL_TensorByteSize(input_tensor), sizeof(float) * 2); + EXPECT_NE(TFL_TensorData(input_tensor), nullptr); + EXPECT_STREQ(TFL_TensorName(input_tensor), "input"); std::array<float, 2> input = {1.f, 3.f}; ASSERT_EQ(TFL_TensorCopyFromBuffer(input_tensor, input.data(), @@ -70,6 +72,8 @@ TEST(CApiSimple, Smoke) { EXPECT_EQ(TFL_TensorNumDims(output_tensor), 1); EXPECT_EQ(TFL_TensorDim(output_tensor, 0), 2); EXPECT_EQ(TFL_TensorByteSize(output_tensor), sizeof(float) * 2); + EXPECT_NE(TFL_TensorData(output_tensor), nullptr); + EXPECT_STREQ(TFL_TensorName(output_tensor), "output"); std::array<float, 2> output; ASSERT_EQ(TFL_TensorCopyToBuffer(output_tensor, output.data(), diff --git a/tensorflow/contrib/lite/experimental/kernels/BUILD b/tensorflow/contrib/lite/experimental/kernels/BUILD index 9c06c4ebd9..4786cc62f9 100644 --- a/tensorflow/contrib/lite/experimental/kernels/BUILD +++ b/tensorflow/contrib/lite/experimental/kernels/BUILD @@ -53,6 +53,7 @@ cc_library( "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:builtin_ops", "//tensorflow/contrib/lite/kernels:gemm_support", "//tensorflow/contrib/lite/kernels:kernel_util", @@ -61,8 +62,8 @@ cc_library( "//tensorflow/contrib/lite/kernels/internal:optimized", "//tensorflow/contrib/lite/kernels/internal:optimized_base", "//tensorflow/contrib/lite/kernels/internal:quantization_util", - "//tensorflow/contrib/lite/kernels/internal:reference", "//tensorflow/contrib/lite/kernels/internal:reference_base", + "//tensorflow/contrib/lite/kernels/internal:tensor", "//tensorflow/contrib/lite/kernels/internal:tensor_utils", "@flatbuffers", ], diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc index 121997dcb2..8442c4d46c 100644 --- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include <vector> #include "flatbuffers/flexbuffers.h" // flatbuffers -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/graph_info.h b/tensorflow/contrib/lite/graph_info.h index 77268d7aeb..8ee83827bb 100644 --- a/tensorflow/contrib/lite/graph_info.h +++ b/tensorflow/contrib/lite/graph_info.h @@ -17,7 +17,7 @@ limitations under the License. #include <vector> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace tflite { diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 5ab53f4c1d..3f8f4d198f 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -21,9 +21,9 @@ limitations under the License. #include <cstring> #include "tensorflow/contrib/lite/arena_planner.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/context_util.h" -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" #include "tensorflow/contrib/lite/graph_info.h" #include "tensorflow/contrib/lite/memory_planner.h" #include "tensorflow/contrib/lite/nnapi_delegate.h" diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 2b1f1819b9..f0cd178c19 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -23,10 +23,11 @@ limitations under the License. #include <vector> #include "tensorflow/contrib/lite/allocation.h" -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" #include "tensorflow/contrib/lite/memory_planner.h" #include "tensorflow/contrib/lite/profiling/profiler.h" +#include "tensorflow/contrib/lite/stderr_reporter.h" namespace tflite { diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index 5bcf0927d8..cdede430e2 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/contrib/lite/interpreter.h" #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" #include "tensorflow/contrib/lite/kernels/internal/compatibility.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD index 06f46fb923..781289ceb2 100644 --- a/tensorflow/contrib/lite/java/ovic/BUILD +++ b/tensorflow/contrib/lite/java/ovic/BUILD @@ -35,6 +35,7 @@ java_binary( "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", ], main_class = "org.tensorflow.ovic.OvicValidator", + tags = ["no_oss"], deps = [ "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib_java", ], @@ -47,6 +48,7 @@ android_library( "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", ], manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml", + tags = ["no_oss"], deps = [ "//tensorflow/contrib/lite/java:tensorflowlite", "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", @@ -61,6 +63,7 @@ java_library( "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", ], javacopts = JAVACOPTS, + tags = ["no_oss"], deps = [ "//tensorflow/contrib/lite/java:libtensorflowlite_jni.so", "//tensorflow/contrib/lite/java:tensorflowlite_java", diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h index 55ca47fed7..06b35d77c8 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -20,7 +20,7 @@ limitations under the License. #include <stdio.h> #include <time.h> #include <vector> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h" #include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h" @@ -124,9 +124,9 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, */ JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env, - jclass clazz, - jlong handle, - jint num_threads); + jclass clazz, + jlong handle, + jint num_threads); /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h index c020f13d9c..2f73128bdf 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_ #include <jni.h> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #ifdef __cplusplus extern "C" { diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index b7c5cbf207..40f28aeab4 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -66,7 +66,7 @@ cc_library( deps = [ ":op_macros", "//tensorflow/contrib/lite:arena_planner", - "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels/internal:optimized", ], ) @@ -82,7 +82,7 @@ cc_library( copts = tflite_copts(), deps = [ ":op_macros", - "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite/c:c_api_internal", "@gemmlowp", ], ) @@ -93,7 +93,7 @@ cc_library( "activation_functor.h", ], deps = [ - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", ], ) @@ -113,9 +113,9 @@ cc_library( "kernel_util.h", ], deps = [ - "//tensorflow/contrib/lite:builtin_op_data", - "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels/internal:round", + "//tensorflow/contrib/lite/kernels/internal:types", ], ) @@ -147,6 +147,15 @@ tf_cc_test( ) cc_library( + name = "padding", + srcs = [], + hdrs = ["padding.h"], + deps = [ + "//tensorflow/contrib/lite/c:c_api_internal", + ], +) + +cc_library( name = "builtin_op_kernels", srcs = [ "activations.cc", @@ -216,7 +225,6 @@ cc_library( "unpack.cc", ], hdrs = [ - "padding.h", ], copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS, visibility = ["//visibility:private"], @@ -225,18 +233,19 @@ cc_library( ":eigen_support", ":kernel_util", ":op_macros", - "//tensorflow/contrib/lite:builtin_op_data", + ":padding", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite:util", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:gemm_support", "//tensorflow/contrib/lite/kernels/internal:audio_utils", "//tensorflow/contrib/lite/kernels/internal:kernel_utils", "//tensorflow/contrib/lite/kernels/internal:optimized", "//tensorflow/contrib/lite/kernels/internal:optimized_base", "//tensorflow/contrib/lite/kernels/internal:quantization_util", - "//tensorflow/contrib/lite/kernels/internal:reference", "//tensorflow/contrib/lite/kernels/internal:reference_base", + "//tensorflow/contrib/lite/kernels/internal:tensor", "//tensorflow/contrib/lite/kernels/internal:tensor_utils", "@farmhash_archive//:farmhash", "@flatbuffers", @@ -251,6 +260,7 @@ cc_library( ":builtin_op_kernels", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:util", + "//tensorflow/contrib/lite/c:c_api_internal", ], ) @@ -757,8 +767,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], @@ -774,8 +784,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], @@ -1044,8 +1054,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], @@ -1147,8 +1157,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], @@ -1164,8 +1174,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], @@ -1181,8 +1191,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], @@ -1198,8 +1208,8 @@ tf_cc_test( ], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], @@ -1212,8 +1222,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], @@ -1239,8 +1249,8 @@ tf_cc_test( tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", - "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest", ], diff --git a/tensorflow/contrib/lite/kernels/activation_functor.h b/tensorflow/contrib/lite/kernels/activation_functor.h index 41ec3cca33..e075dc7054 100644 --- a/tensorflow/contrib/lite/kernels/activation_functor.h +++ b/tensorflow/contrib/lite/kernels/activation_functor.h @@ -19,7 +19,7 @@ limitations under the License. #include <cmath> #include <cstdlib> -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index 5cdd9fc94f..b2d9b84979 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -19,8 +19,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc index af9b5c7013..b4393e8097 100644 --- a/tensorflow/contrib/lite/kernels/add.cc +++ b/tensorflow/contrib/lite/kernels/add.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" diff --git a/tensorflow/contrib/lite/kernels/arg_min_max.cc b/tensorflow/contrib/lite/kernels/arg_min_max.cc index 6e05f5a9b2..b91e348c27 100644 --- a/tensorflow/contrib/lite/kernels/arg_min_max.cc +++ b/tensorflow/contrib/lite/kernels/arg_min_max.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc index 1170d84553..44ef587244 100644 --- a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc +++ b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/spectrogram.h" diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc index c5a5c0182f..1aa27602e5 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc @@ -15,8 +15,8 @@ limitations under the License. #include <stddef.h> #include <stdint.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc index 4efa9d596d..fe2865dfb9 100644 --- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc index 6b8ecdd5c3..541f320138 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc @@ -20,8 +20,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc index d988ef8b33..2f896c5289 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc @@ -19,8 +19,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc index 8dd48af57f..a7972140ac 100644 --- a/tensorflow/contrib/lite/kernels/cast.cc +++ b/tensorflow/contrib/lite/kernels/cast.cc @@ -15,8 +15,8 @@ limitations under the License. #include <string.h> #include <algorithm> #include <complex> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc index 8b4d778332..4cd96348a2 100644 --- a/tensorflow/contrib/lite/kernels/comparisons.cc +++ b/tensorflow/contrib/lite/kernels/comparisons.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc index 605a20ac3e..25ea556d5a 100644 --- a/tensorflow/contrib/lite/kernels/concatenation.cc +++ b/tensorflow/contrib/lite/kernels/concatenation.cc @@ -19,8 +19,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index 3ed0cdb131..ab6bdaecaa 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -20,8 +20,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/eigen_support.h" #include "tensorflow/contrib/lite/kernels/gemm_support.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h" diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc index 21518156b8..347515f289 100644 --- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc +++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc @@ -19,8 +19,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" diff --git a/tensorflow/contrib/lite/kernels/dequantize.cc b/tensorflow/contrib/lite/kernels/dequantize.cc index 2b0f04489a..3a08f48b00 100644 --- a/tensorflow/contrib/lite/kernels/dequantize.cc +++ b/tensorflow/contrib/lite/kernels/dequantize.cc @@ -15,8 +15,8 @@ limitations under the License. #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/contrib/lite/kernels/detection_postprocess.cc index 136697f945..d2906632d7 100644 --- a/tensorflow/contrib/lite/kernels/detection_postprocess.cc +++ b/tensorflow/contrib/lite/kernels/detection_postprocess.cc @@ -16,8 +16,8 @@ limitations under the License. #include <numeric> #include <vector> #include "flatbuffers/flexbuffers.h" // flatbuffers -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc index d7420ddd8e..7945c095b1 100644 --- a/tensorflow/contrib/lite/kernels/div.cc +++ b/tensorflow/contrib/lite/kernels/div.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h index b235829642..feb1543f7b 100644 --- a/tensorflow/contrib/lite/kernels/eigen_support.h +++ b/tensorflow/contrib/lite/kernels/eigen_support.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace EigenForTFLite { struct ThreadPoolDevice; diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc index e19779ea59..04995d70dd 100644 --- a/tensorflow/contrib/lite/kernels/elementwise.cc +++ b/tensorflow/contrib/lite/kernels/elementwise.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include <cmath> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc index b2dff87e62..fe33f98eb0 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc @@ -37,8 +37,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc index d3be36993c..aa75b03990 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc @@ -65,8 +65,8 @@ limitations under the License. #include <algorithm> #include <cmath> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" diff --git a/tensorflow/contrib/lite/kernels/exp.cc b/tensorflow/contrib/lite/kernels/exp.cc index ce03cdfe26..673e7be90a 100644 --- a/tensorflow/contrib/lite/kernels/exp.cc +++ b/tensorflow/contrib/lite/kernels/exp.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/expand_dims.cc b/tensorflow/contrib/lite/kernels/expand_dims.cc index ed33012864..fa1140b19c 100644 --- a/tensorflow/contrib/lite/kernels/expand_dims.cc +++ b/tensorflow/contrib/lite/kernels/expand_dims.cc @@ -15,8 +15,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/expand_dims_test.cc b/tensorflow/contrib/lite/kernels/expand_dims_test.cc index 50dc860e5a..a3bc1813db 100644 --- a/tensorflow/contrib/lite/kernels/expand_dims_test.cc +++ b/tensorflow/contrib/lite/kernels/expand_dims_test.cc @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/kernels/test_util.h" diff --git a/tensorflow/contrib/lite/kernels/fake_quant.cc b/tensorflow/contrib/lite/kernels/fake_quant.cc index 0ef1a50b30..f9bc3747cb 100644 --- a/tensorflow/contrib/lite/kernels/fake_quant.cc +++ b/tensorflow/contrib/lite/kernels/fake_quant.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/floor.cc b/tensorflow/contrib/lite/kernels/floor.cc index f7d5f5146d..59ff77f35b 100644 --- a/tensorflow/contrib/lite/kernels/floor.cc +++ b/tensorflow/contrib/lite/kernels/floor.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/floor_div.cc b/tensorflow/contrib/lite/kernels/floor_div.cc index 75cf19a5a7..5d62cd2755 100644 --- a/tensorflow/contrib/lite/kernels/floor_div.cc +++ b/tensorflow/contrib/lite/kernels/floor_div.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc index eaf5a67d67..7a71fcc219 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected.cc @@ -20,8 +20,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/gemm_support.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc index 2b2a9e6620..badd2de11a 100644 --- a/tensorflow/contrib/lite/kernels/gather.cc +++ b/tensorflow/contrib/lite/kernels/gather.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include <string.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc index 1d4292955c..1b48884e09 100644 --- a/tensorflow/contrib/lite/kernels/gather_test.cc +++ b/tensorflow/contrib/lite/kernels/gather_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/kernels/test_util.h" diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h index 37af772c68..43cd2b3055 100644 --- a/tensorflow/contrib/lite/kernels/gemm_support.h +++ b/tensorflow/contrib/lite/kernels/gemm_support.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ #include "public/gemmlowp.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace tflite { namespace gemm_support { diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc index f37c66acb3..c0b3c3c0c5 100644 --- a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc +++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc @@ -39,8 +39,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" #include "tensorflow/contrib/lite/string_util.h" diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index 464163bd78..a6fd4ac2dd 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -163,7 +163,7 @@ cc_library( ":tensor_utils", "//third_party/eigen3", "@gemmlowp", - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", ] + select({ ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, @@ -198,7 +198,7 @@ cc_library( ":round", "//third_party/eigen3", "@gemmlowp", - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", ] + select({ ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, @@ -220,13 +220,15 @@ cc_library( "optimized/eigen_spatial_convolutions.h", "optimized/eigen_tensor_reduced_instantiations_oss.h", "optimized/multithreaded_conv.h", + # FIXME(petewarden) - This should be removed, since it's a header from the + # :tensor dependency below. "tensor.h", ], deps = [ ":optimized_base", + ":tensor", ":types", - "//tensorflow/contrib/lite:builtin_op_data", - "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite/c:c_api_internal", "//third_party/eigen3", ], ) @@ -236,7 +238,7 @@ cc_test( srcs = ["tensor_test.cc"], tags = ["no_oss"], deps = [ - ":reference", + ":tensor", "@com_google_googletest//:gtest", ], ) @@ -296,7 +298,7 @@ cc_library( ":strided_slice_logic", ":types", "@gemmlowp", - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", ] + select({ ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, @@ -326,7 +328,7 @@ cc_library( ":strided_slice_logic", ":types", "@gemmlowp", - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", ] + select({ ":haswell": tflite_deps_intel, ":ios_x86_64": tflite_deps_intel, @@ -341,11 +343,27 @@ cc_library( ) cc_library( + name = "tensor", + hdrs = [ + "tensor.h", + "tensor_ctypes.h", + ], + deps = [ + ":types", + "//tensorflow/contrib/lite/c:c_api_internal", + ], +) + +# Deprecated version of :tensor, kept for backwards compatibility. +cc_library( name = "reference", - hdrs = ["tensor.h"], + hdrs = [ + "tensor.h", + "tensor_ctypes.h", + ], deps = [ ":types", - "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite/c:c_api_internal", ], ) @@ -359,7 +377,7 @@ cc_library( ], deps = [ ":round", - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:activation_functor", "//tensorflow/contrib/lite/kernels:op_macros", ], @@ -384,7 +402,7 @@ cc_library( ":cpu_check", ":round", ":types", - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:activation_functor", "//tensorflow/contrib/lite/kernels:op_macros", "@arm_neon_2_x86_sse", @@ -398,7 +416,7 @@ cc_library( hdrs = ["kernel_utils.h"], deps = [ ":tensor_utils", - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", ], ) @@ -441,7 +459,7 @@ cc_library( copts = NEON_FLAGS_IF_APPLICABLE, deps = [ "//tensorflow/contrib/lite/kernels:activation_functor", - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", "@arm_neon_2_x86_sse", "@gemmlowp", ] + select({ @@ -517,7 +535,7 @@ cc_test( ], deps = [ ":tensor_utils", - "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/c:c_api_internal", "//tensorflow/contrib/lite/kernels:test_util", "@com_google_googletest//:gtest_main", ], diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h index eb4d0108bd..e67fee11b8 100644 --- a/tensorflow/contrib/lite/kernels/internal/common.h +++ b/tensorflow/contrib/lite/kernels/internal/common.h @@ -45,7 +45,7 @@ limitations under the License. #endif #endif -#include "public/gemmlowp.h" +#include "fixedpoint/fixedpoint.h" #include "tensorflow/contrib/lite/kernels/internal/types.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc index b9dd40ddf9..56e9367878 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -14,8 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" -#include <algorithm> - #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h index 215ad04add..b5558cce55 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" namespace tflite { namespace kernel_utils { diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h index 921aae1303..5fb31889fe 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h @@ -26,7 +26,7 @@ limitations under the License. #include <tuple> #include <type_traits> -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #include "tensorflow/contrib/lite/kernels/internal/common.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc index 70b6994a2b..27418178fd 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -15,7 +15,7 @@ limitations under the License. #include <stdlib.h> #include <string.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/common.h" #include "tensorflow/contrib/lite/kernels/internal/compatibility.h" diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h index 5ca1b4b76f..630a6bbf29 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -17,7 +17,7 @@ limitations under the License. // TODO(ghodrat): Remove this header file and the dependency to internal data // structure. -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h" diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h index 7e53dc2fa2..f87760a6c3 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -17,7 +17,7 @@ limitations under the License. // TODO(ghodrat): Remove this header file and the dependency to internal data // structure. -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #if defined(_MSC_VER) #define __restrict__ __restrict diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc index 2a30910c3f..77e60adc18 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -16,7 +16,7 @@ limitations under the License. #include <string.h> #include <algorithm> -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/round.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h index f5b3a84f07..714b1164ee 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -17,7 +17,7 @@ limitations under the License. // TODO(ghodrat): Remove this header file and the dependency to internal data // structure. -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #if defined(_MSC_VER) #define __restrict__ __restrict diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index a027a47726..0abacf85e1 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -3488,8 +3488,7 @@ inline void Gather(const tflite::GatherParams& op_params, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& coords_shape, const int32* coords_data, const RuntimeShape& output_shape, T* output_data) { - // TODO(b/80418076): Enable these checks when moving legacy ops to - // legacy_reference_ops. + // Enable these checks when moving legacy ops to legacy_reference_ops. // // TFLITE_DCHECK_EQ(coords_shape.DimensionsCount(), 1); const int input_rank = op_params.input_rank; @@ -3808,58 +3807,110 @@ inline void Pad(const tflite::PadParams& op_params, } template <typename T> -inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, - int begin_mask, int end_mask, int shrink_axis_mask, - const std::vector<int>& start_indices, - const std::vector<int>& stop_indices, - const std::vector<int>& strides, T* output_data, - const Dims<4>& output_dims) { - // Note that the axis orders are reversed for runtime ops, so the indices, - // strides and masks must be as well too. - TFLITE_DCHECK_EQ(start_indices.size(), 4); - TFLITE_DCHECK_EQ(stop_indices.size(), 4); - TFLITE_DCHECK_EQ(strides.size(), 4); - const int start_b = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 3); +inline void StridedSlice(const tflite::StridedSliceParams& op_params, + const RuntimeShape& unextended_input_shape, + const T* input_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { + // Note that the output_shape is not used herein. + tflite::StridedSliceParams params_copy = op_params; + + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + // Reverse and pad to 4 dimensions because that is what the runtime code + // requires (ie. all shapes must be 4D and are given backwards). + strided_slice::StridedSlicePadIndices(¶ms_copy, 4); + + const int start_b = strided_slice::StartForAxis(params_copy, input_shape, 0); const int stop_b = - strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, - strides, input_dims.sizes, 3, start_b); - const int start_h = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 2); + strided_slice::StopForAxis(params_copy, input_shape, 0, start_b); + const int start_h = strided_slice::StartForAxis(params_copy, input_shape, 1); const int stop_h = - strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, - strides, input_dims.sizes, 2, start_h); - const int start_w = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 1); + strided_slice::StopForAxis(params_copy, input_shape, 1, start_h); + const int start_w = strided_slice::StartForAxis(params_copy, input_shape, 2); const int stop_w = - strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, - strides, input_dims.sizes, 1, start_w); - const int start_d = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 0); + strided_slice::StopForAxis(params_copy, input_shape, 2, start_w); + const int start_d = strided_slice::StartForAxis(params_copy, input_shape, 3); const int stop_d = - strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices, - strides, input_dims.sizes, 0, start_d); + strided_slice::StopForAxis(params_copy, input_shape, 3, start_d); T* out_ptr = output_data; for (int in_b = start_b; - !strided_slice::LoopCondition(in_b, stop_b, strides[3]); - in_b += strides[3]) { + !strided_slice::LoopCondition(in_b, stop_b, params_copy.strides[0]); + in_b += params_copy.strides[0]) { for (int in_h = start_h; - !strided_slice::LoopCondition(in_h, stop_h, strides[2]); - in_h += strides[2]) { + !strided_slice::LoopCondition(in_h, stop_h, params_copy.strides[1]); + in_h += params_copy.strides[1]) { for (int in_w = start_w; - !strided_slice::LoopCondition(in_w, stop_w, strides[1]); - in_w += strides[1]) { - for (int in_d = start_d; - !strided_slice::LoopCondition(in_d, stop_d, strides[0]); - in_d += strides[0]) { - *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; + !strided_slice::LoopCondition(in_w, stop_w, params_copy.strides[2]); + in_w += params_copy.strides[2]) { + for (int in_d = start_d; !strided_slice::LoopCondition( + in_d, stop_d, params_copy.strides[3]); + in_d += params_copy.strides[3]) { + *out_ptr++ = input_data[Offset(input_shape, in_b, in_h, in_w, in_d)]; } } } } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline uint32 LegacyReverseBits32(uint32 n) { + n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1); + n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2); + n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4); + return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) | + ((n & 0xFF000000) >> 24)); +} + +inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) { + TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count); + TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count); + + std::reverse(p->start_indices, p->start_indices + p->start_indices_count); + std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count); + std::reverse(p->strides, p->strides + p->strides_count); + + p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >> + (32 - p->start_indices_count); + p->ellipsis_mask = + LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >> + (32 - p->start_indices_count); + p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >> + (32 - p->start_indices_count); + p->new_axis_mask = + LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >> + (32 - p->start_indices_count); + p->shrink_axis_mask = + LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >> + (32 - p->start_indices_count); +} + +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +template <typename T> +inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, + int begin_mask, int end_mask, int shrink_axis_mask, + const std::vector<int>& start_indices, + const std::vector<int>& stop_indices, + const std::vector<int>& strides, T* output_data, + const Dims<4>& output_dims) { + TFLITE_DCHECK_EQ(start_indices.size(), 4); + auto op_params = strided_slice::BuildStridedSliceParams( + begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices, + strides); + StridedSliceReverseIndices(&op_params); + + StridedSlice(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + template <typename T> inline void Slice(const tflite::SliceParams& op_params, const RuntimeShape& input_shape, const T* input_data, diff --git a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h index 5994fad5c7..af5db1064c 100644 --- a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h +++ b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h @@ -19,9 +19,9 @@ limitations under the License. #include <limits> #include <vector> #include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" namespace tflite { - namespace strided_slice { // Use until std::clamp() is available from C++17. @@ -32,15 +32,51 @@ inline int Clamp(const int v, const int lo, const int hi) { return v; } +inline void StridedSlicePadIndices(tflite::StridedSliceParams* p, + int dim_count) { + // Add indices and mask bits to fully include extra dimensions + TFLITE_CHECK_LE(dim_count, 4); + TFLITE_CHECK_GE(dim_count, p->start_indices_count); + TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count); + TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count); + + const int pad_count = dim_count - p->start_indices_count; + + // Pad indices at start, so move arrays by pad_count. + for (int i = p->start_indices_count - 1; i > 0; --i) { + p->strides[i + pad_count] = p->strides[i]; + p->start_indices[i + pad_count] = p->start_indices[i]; + p->stop_indices[i + pad_count] = p->stop_indices[i]; + } + for (int i = 0; i < pad_count; ++i) { + p->start_indices[i] = 0; + p->stop_indices[i] = 0; + p->strides[i] = 1; + } + + // Pad masks with 0s or 1s as required. + p->shrink_axis_mask <<= pad_count; + p->ellipsis_mask <<= pad_count; + p->new_axis_mask <<= pad_count; + p->begin_mask <<= pad_count; + p->end_mask <<= pad_count; + p->begin_mask |= (1 << pad_count) - 1; + p->end_mask |= (1 << pad_count) - 1; + + p->start_indices_count = dim_count; + p->stop_indices_count = dim_count; + p->strides_count = dim_count; +} + // Return the index for the first element along that axis. This index will be a // positive integer between [0, axis_size - 1] that can be used to index // directly into the data. -template <typename IntType> -inline int StartForAxis(int begin_mask, - std::vector<IntType> const& start_indices, - std::vector<IntType> const& strides, - int const* input_shape, int axis) { - // Begin with the specified index +inline int StartForAxis(const tflite::StridedSliceParams& params, + const RuntimeShape& input_shape, int axis) { + const auto begin_mask = params.begin_mask; + const auto* start_indices = params.start_indices; + const auto* strides = params.strides; + // Begin with the specified index. int start = start_indices[axis]; // begin_mask override @@ -57,7 +93,7 @@ inline int StartForAxis(int begin_mask, } // Handle negative indices - int axis_size = input_shape[axis]; + int axis_size = input_shape.Dims(axis); if (start < 0) { start += axis_size; } @@ -73,11 +109,14 @@ inline int StartForAxis(int begin_mask, // element. ie. So if you were iterating through all elements of a 1D array of // size 4, this function would return 4 as the stop, because it is one past the // "real" indices of 0, 1, 2 & 3. -template <typename IntType> -inline int StopForAxis(int end_mask, int shrink_axis_mask, - std::vector<IntType> const& stop_indices, - std::vector<IntType> const& strides, - int const* input_shape, int axis, int start_for_axis) { +inline int StopForAxis(const tflite::StridedSliceParams& params, + const RuntimeShape& input_shape, int axis, + int start_for_axis) { + const auto end_mask = params.end_mask; + const auto shrink_axis_mask = params.shrink_axis_mask; + const auto* stop_indices = params.stop_indices; + const auto* strides = params.strides; + // Begin with the specified index const bool shrink_axis = shrink_axis_mask & (1 << axis); int stop = stop_indices[axis]; @@ -103,7 +142,7 @@ inline int StopForAxis(int end_mask, int shrink_axis_mask, } // Handle negative indices - const int axis_size = input_shape[axis]; + const int axis_size = input_shape.Dims(axis); if (stop < 0) { stop += axis_size; } @@ -127,6 +166,31 @@ inline bool LoopCondition(int index, int stop, int stride) { return stride > 0 ? index >= stop : index <= stop; } +inline tflite::StridedSliceParams BuildStridedSliceParams( + int begin_mask, int end_mask, int shrink_axis_mask, + const std::vector<int>& start_indices, const std::vector<int>& stop_indices, + const std::vector<int>& strides) { + tflite::StridedSliceParams op_params; + const int dims_count = start_indices.size(); + + op_params.start_indices_count = dims_count; + op_params.stop_indices_count = dims_count; + op_params.strides_count = dims_count; + for (int i = 0; i < dims_count; ++i) { + op_params.start_indices[i] = start_indices[i]; + op_params.stop_indices[i] = stop_indices[i]; + op_params.strides[i] = strides[i]; + } + + op_params.begin_mask = begin_mask; + op_params.ellipsis_mask = 0; + op_params.end_mask = end_mask; + op_params.new_axis_mask = 0; + op_params.shrink_axis_mask = shrink_axis_mask; + + return op_params; +} + } // namespace strided_slice } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h index ee2af5b460..13106456df 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -17,44 +17,12 @@ limitations under the License. #include <complex> #include <vector> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/contrib/lite/kernels/internal/types.h" namespace tflite { -template <typename T> -inline T* GetTensorData(TfLiteTensor* tensor); - -template <> -inline float* GetTensorData(TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.f : nullptr; -} - -template <> -inline uint8_t* GetTensorData(TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.uint8 : nullptr; -} - -template <> -inline int16_t* GetTensorData(TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.i16 : nullptr; -} - -template <> -inline int32_t* GetTensorData(TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.i32 : nullptr; -} - -template <> -inline int64_t* GetTensorData(TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.i64 : nullptr; -} - -template <> -inline bool* GetTensorData(TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.b : nullptr; -} - template <> inline std::complex<float>* GetTensorData(TfLiteTensor* tensor) { return tensor != nullptr @@ -62,39 +30,6 @@ inline std::complex<float>* GetTensorData(TfLiteTensor* tensor) { : nullptr; } -template <typename T> -inline const T* GetTensorData(const TfLiteTensor* tensor); - -template <> -inline const float* GetTensorData(const TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.f : nullptr; -} - -template <> -inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.uint8 : nullptr; -} - -template <> -inline const int16_t* GetTensorData(const TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.i16 : nullptr; -} - -template <> -inline const int32_t* GetTensorData(const TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.i32 : nullptr; -} - -template <> -inline const int64_t* GetTensorData(const TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.i64 : nullptr; -} - -template <> -inline const bool* GetTensorData(const TfLiteTensor* tensor) { - return tensor != nullptr ? tensor->data.b : nullptr; -} - template <> inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) { return tensor != nullptr @@ -102,56 +37,14 @@ inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) { : nullptr; } -inline int RemapDim(int max_dimensions, int d) { - return max_dimensions - d - 1; -} - -// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object -// even if the original tensors were not 4D. We should consider rewriting them -// to take a more generic 'shape' object. -inline Dims<4> GetTensorDims(const int data[], const int size) { - Dims<4> d; - for (int i = 0; i < 4; ++i) { - int src = size - i - 1; - if (src >= 0) { - d.sizes[i] = data[src]; - } else { - d.sizes[i] = 1; - } - } - d.strides[0] = 1; - for (int i = 1; i < 4; i++) { - d.strides[i] = d.strides[i - 1] * d.sizes[i - 1]; - } - return d; -} - inline Dims<4> GetTensorDims(std::vector<int32_t> data) { return GetTensorDims(data.data(), data.size()); } -inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) { - if (tensor == nullptr) { - return Dims<4>(); - } - - auto* dims = tensor->dims; - return GetTensorDims(dims->data, dims->size); -} - inline RuntimeShape GetTensorShape(std::vector<int32_t> data) { return RuntimeShape(data.size(), data.data()); } -inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) { - if (tensor == nullptr) { - return RuntimeShape(); - } - - auto* dims = tensor->dims; - return RuntimeShape(dims->size, dims->data); -} - // A list of tensors in a format that can be used by kernels like split and // concatenation. template <typename T> diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h new file mode 100644 index 0000000000..77e22a08b4 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h @@ -0,0 +1,135 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_ + +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { + +template <typename T> +inline T* GetTensorData(TfLiteTensor* tensor); + +template <> +inline float* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.f : nullptr; +} + +template <> +inline uint8_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.uint8 : nullptr; +} + +template <> +inline int16_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i16 : nullptr; +} + +template <> +inline int32_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i32 : nullptr; +} + +template <> +inline int64_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i64 : nullptr; +} + +template <> +inline bool* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.b : nullptr; +} + +template <typename T> +inline const T* GetTensorData(const TfLiteTensor* tensor); + +template <> +inline const float* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.f : nullptr; +} + +template <> +inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.uint8 : nullptr; +} + +template <> +inline const int16_t* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i16 : nullptr; +} + +template <> +inline const int32_t* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i32 : nullptr; +} + +template <> +inline const int64_t* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i64 : nullptr; +} + +template <> +inline const bool* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.b : nullptr; +} + +inline int RemapDim(int max_dimensions, int d) { + return max_dimensions - d - 1; +} + +// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object +// even if the original tensors were not 4D. We should consider rewriting them +// to take a more generic 'shape' object. +inline Dims<4> GetTensorDims(const int data[], const int size) { + Dims<4> d; + for (int i = 0; i < 4; ++i) { + int src = size - i - 1; + if (src >= 0) { + d.sizes[i] = data[src]; + } else { + d.sizes[i] = 1; + } + } + d.strides[0] = 1; + for (int i = 1; i < 4; i++) { + d.strides[i] = d.strides[i - 1] * d.sizes[i - 1]; + } + return d; +} + +inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) { + if (tensor == nullptr) { + return Dims<4>(); + } + + auto* dims = tensor->dims; + return GetTensorDims(dims->data, dims->size); +} + +inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) { + if (tensor == nullptr) { + return RuntimeShape(); + } + + TfLiteIntArray* dims = tensor->dims; + const int dims_size = dims->size; + const int32_t* dims_data = dims->data; + return RuntimeShape(dims_size, dims_data); +} + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h index 1439bf8c37..b0fe5adf65 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #if defined(_MSC_VER) #define __restrict__ __restrict diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc index dad924fc28..6458af714b 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include <gmock/gmock.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #include "tensorflow/contrib/lite/kernels/test_util.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h index ed46cd984f..e9a5fd7a40 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.h +++ b/tensorflow/contrib/lite/kernels/kernel_util.h @@ -16,9 +16,10 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ #include <algorithm> +#include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc index 5b3536de0c..e02d7df9ef 100644 --- a/tensorflow/contrib/lite/kernels/l2norm.cc +++ b/tensorflow/contrib/lite/kernels/l2norm.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/contrib/lite/kernels/local_response_norm.cc index 799c1528bd..334d2a2788 100644 --- a/tensorflow/contrib/lite/kernels/local_response_norm.cc +++ b/tensorflow/contrib/lite/kernels/local_response_norm.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/logical.cc b/tensorflow/contrib/lite/kernels/logical.cc index c71f3b4701..f770cb35d1 100644 --- a/tensorflow/contrib/lite/kernels/logical.cc +++ b/tensorflow/contrib/lite/kernels/logical.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc index 69523b02cc..9fa1c5f100 100644 --- a/tensorflow/contrib/lite/kernels/lsh_projection.cc +++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc @@ -59,8 +59,8 @@ limitations under the License. #include <limits> #include <memory> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" #include <farmhash.h> diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index 74dc3f25f9..aaa3ce966e 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -20,8 +20,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/gemm_support.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" diff --git a/tensorflow/contrib/lite/kernels/maximum_minimum.cc b/tensorflow/contrib/lite/kernels/maximum_minimum.cc index 0308a3976a..7cb01465ee 100644 --- a/tensorflow/contrib/lite/kernels/maximum_minimum.cc +++ b/tensorflow/contrib/lite/kernels/maximum_minimum.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/mfcc.cc b/tensorflow/contrib/lite/kernels/mfcc.cc index 306f676619..66cf147d75 100644 --- a/tensorflow/contrib/lite/kernels/mfcc.cc +++ b/tensorflow/contrib/lite/kernels/mfcc.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/internal/mfcc.h" #include "flatbuffers/flexbuffers.h" // flatbuffers -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/mfcc_dct.h" #include "tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc index 92d8bc8b67..e0aac8a842 100644 --- a/tensorflow/contrib/lite/kernels/mul.cc +++ b/tensorflow/contrib/lite/kernels/mul.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" diff --git a/tensorflow/contrib/lite/kernels/neg.cc b/tensorflow/contrib/lite/kernels/neg.cc index 4124c05388..0ddd0644f5 100644 --- a/tensorflow/contrib/lite/kernels/neg.cc +++ b/tensorflow/contrib/lite/kernels/neg.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/one_hot.cc b/tensorflow/contrib/lite/kernels/one_hot.cc index 9ff3dca932..910aed6f14 100644 --- a/tensorflow/contrib/lite/kernels/one_hot.cc +++ b/tensorflow/contrib/lite/kernels/one_hot.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" diff --git a/tensorflow/contrib/lite/kernels/pack.cc b/tensorflow/contrib/lite/kernels/pack.cc index cc326a7d51..4cb98fdd19 100644 --- a/tensorflow/contrib/lite/kernels/pack.cc +++ b/tensorflow/contrib/lite/kernels/pack.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc index 3bce05353d..0d939405f6 100644 --- a/tensorflow/contrib/lite/kernels/pad.cc +++ b/tensorflow/contrib/lite/kernels/pad.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/padding.h b/tensorflow/contrib/lite/kernels/padding.h index 3cb55f19a9..42b6b45d3b 100644 --- a/tensorflow/contrib/lite/kernels/padding.h +++ b/tensorflow/contrib/lite/kernels/padding.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc index 29a5be0683..6451142391 100644 --- a/tensorflow/contrib/lite/kernels/pooling.cc +++ b/tensorflow/contrib/lite/kernels/pooling.cc @@ -19,8 +19,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/pow.cc b/tensorflow/contrib/lite/kernels/pow.cc index d676de5b1d..1e96cc80b1 100644 --- a/tensorflow/contrib/lite/kernels/pow.cc +++ b/tensorflow/contrib/lite/kernels/pow.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc index ca83797936..d94d821e87 100644 --- a/tensorflow/contrib/lite/kernels/reduce.cc +++ b/tensorflow/contrib/lite/kernels/reduce.cc @@ -15,8 +15,8 @@ limitations under the License. #include <string.h> #include <limits> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h index 0296152d68..61856ab9de 100644 --- a/tensorflow/contrib/lite/kernels/register.h +++ b/tensorflow/contrib/lite/kernels/register.h @@ -16,8 +16,9 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ #include <unordered_map> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/mutable_op_resolver.h" namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc index 49ba0571e2..f41147b2d6 100644 --- a/tensorflow/contrib/lite/kernels/reshape.cc +++ b/tensorflow/contrib/lite/kernels/reshape.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include <string.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc index dafa3aebab..fb045d15f3 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc index 3cdb5db209..3959502d91 100644 --- a/tensorflow/contrib/lite/kernels/select.cc +++ b/tensorflow/contrib/lite/kernels/select.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/shape.cc b/tensorflow/contrib/lite/kernels/shape.cc index dbcd2ef004..66d4c9e5c1 100644 --- a/tensorflow/contrib/lite/kernels/shape.cc +++ b/tensorflow/contrib/lite/kernels/shape.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" diff --git a/tensorflow/contrib/lite/kernels/skip_gram.cc b/tensorflow/contrib/lite/kernels/skip_gram.cc index c90a15b3a2..de80a4016e 100644 --- a/tensorflow/contrib/lite/kernels/skip_gram.cc +++ b/tensorflow/contrib/lite/kernels/skip_gram.cc @@ -33,8 +33,8 @@ limitations under the License. #include <string> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" #include "tensorflow/contrib/lite/string_util.h" diff --git a/tensorflow/contrib/lite/kernels/slice.cc b/tensorflow/contrib/lite/kernels/slice.cc index 55e16506df..ccfee41b9c 100644 --- a/tensorflow/contrib/lite/kernels/slice.cc +++ b/tensorflow/contrib/lite/kernels/slice.cc @@ -16,8 +16,8 @@ limitations under the License. #include <string.h> #include <cmath> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc index 8332ae32cf..3a10d2e60c 100644 --- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc +++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/contrib/lite/kernels/space_to_depth.cc index 9238e879f8..64c56c017b 100644 --- a/tensorflow/contrib/lite/kernels/space_to_depth.cc +++ b/tensorflow/contrib/lite/kernels/space_to_depth.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc index fec2a6f0d9..178568e07c 100644 --- a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc +++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc @@ -19,8 +19,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc index b144486041..719e2dc606 100644 --- a/tensorflow/contrib/lite/kernels/split.cc +++ b/tensorflow/contrib/lite/kernels/split.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" diff --git a/tensorflow/contrib/lite/kernels/squeeze.cc b/tensorflow/contrib/lite/kernels/squeeze.cc index 09a5662fd9..080c51cd18 100644 --- a/tensorflow/contrib/lite/kernels/squeeze.cc +++ b/tensorflow/contrib/lite/kernels/squeeze.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc index bed2117f9a..87ffcc4110 100644 --- a/tensorflow/contrib/lite/kernels/strided_slice.cc +++ b/tensorflow/contrib/lite/kernels/strided_slice.cc @@ -15,8 +15,8 @@ limitations under the License. #include <string.h> #include <cmath> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc index 77a1f59689..1be0c83f17 100644 --- a/tensorflow/contrib/lite/kernels/sub.cc +++ b/tensorflow/contrib/lite/kernels/sub.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc index 6ba7959752..9903fd5c35 100644 --- a/tensorflow/contrib/lite/kernels/svdf.cc +++ b/tensorflow/contrib/lite/kernels/svdf.cc @@ -23,8 +23,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/tile.cc b/tensorflow/contrib/lite/kernels/tile.cc index 5181a8f89a..49421eb870 100644 --- a/tensorflow/contrib/lite/kernels/tile.cc +++ b/tensorflow/contrib/lite/kernels/tile.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/tile_test.cc b/tensorflow/contrib/lite/kernels/tile_test.cc index 4f78c224e5..e73ca7b750 100644 --- a/tensorflow/contrib/lite/kernels/tile_test.cc +++ b/tensorflow/contrib/lite/kernels/tile_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/kernels/test_util.h" diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc index 2dd760bbfe..6c38b6739e 100644 --- a/tensorflow/contrib/lite/kernels/topk_v2.cc +++ b/tensorflow/contrib/lite/kernels/topk_v2.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <algorithm> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" diff --git a/tensorflow/contrib/lite/kernels/topk_v2_test.cc b/tensorflow/contrib/lite/kernels/topk_v2_test.cc index 2abb89b617..16106fdafe 100644 --- a/tensorflow/contrib/lite/kernels/topk_v2_test.cc +++ b/tensorflow/contrib/lite/kernels/topk_v2_test.cc @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/kernels/test_util.h" diff --git a/tensorflow/contrib/lite/kernels/transpose.cc b/tensorflow/contrib/lite/kernels/transpose.cc index 800b0563d7..95359962e0 100644 --- a/tensorflow/contrib/lite/kernels/transpose.cc +++ b/tensorflow/contrib/lite/kernels/transpose.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc index a9baa5c698..6f2d98ede8 100644 --- a/tensorflow/contrib/lite/kernels/transpose_conv.cc +++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc @@ -19,8 +19,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc index c678f14930..63817bd886 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc @@ -20,8 +20,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc index 0180c2c498..744ee7c109 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc @@ -19,8 +19,8 @@ limitations under the License. #include <iostream> #include <limits> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/kernels/unpack.cc b/tensorflow/contrib/lite/kernels/unpack.cc index 4998f88b41..9ff06f8331 100644 --- a/tensorflow/contrib/lite/kernels/unpack.cc +++ b/tensorflow/contrib/lite/kernels/unpack.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" diff --git a/tensorflow/contrib/lite/memory_planner.h b/tensorflow/contrib/lite/memory_planner.h index 0294ec815c..2d4707f849 100644 --- a/tensorflow/contrib/lite/memory_planner.h +++ b/tensorflow/contrib/lite/memory_planner.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_ #define TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_ -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace tflite { diff --git a/tensorflow/contrib/lite/mmap_allocation.cc b/tensorflow/contrib/lite/mmap_allocation.cc index fa9a3cd1d8..92934d1fd1 100644 --- a/tensorflow/contrib/lite/mmap_allocation.cc +++ b/tensorflow/contrib/lite/mmap_allocation.cc @@ -20,7 +20,7 @@ limitations under the License. #include <unistd.h> #include "tensorflow/contrib/lite/allocation.h" -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" namespace tflite { diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index aa410ab002..241865b3d8 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -20,8 +20,9 @@ limitations under the License. #include <sys/types.h> #include "tensorflow/contrib/lite/allocation.h" -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h" #include "tensorflow/contrib/lite/model.h" #ifndef TFLITE_MCU #include "tensorflow/contrib/lite/nnapi_delegate.h" @@ -42,41 +43,6 @@ ErrorReporter* ValidateErrorReporter(ErrorReporter* e) { const char* kEmptyTensorName = ""; -TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, - ErrorReporter* error_reporter) { - switch (tensor_type) { - case TensorType_FLOAT32: - *type = kTfLiteFloat32; - break; - case TensorType_INT16: - *type = kTfLiteInt16; - break; - case TensorType_INT32: - *type = kTfLiteInt32; - break; - case TensorType_UINT8: - *type = kTfLiteUInt8; - break; - case TensorType_INT64: - *type = kTfLiteInt64; - break; - case TensorType_STRING: - *type = kTfLiteString; - break; - case TensorType_BOOL: - *type = kTfLiteBool; - break; - case TensorType_COMPLEX64: - *type = kTfLiteComplex64; - break; - default: - error_reporter->Report("Unimplemented data type %s (%d) in tensor\n", - EnumNameTensorType(tensor_type), tensor_type); - return kTfLiteError; - } - return kTfLiteOk; -} - #ifndef TFLITE_MCU // Loads a model from `filename`. If `mmap_file` is true then use mmap, // otherwise make a copy of the model in a buffer. @@ -198,39 +164,10 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { auto opcodes = model_->operator_codes(); for (const OperatorCode* opcode : *opcodes) { const TfLiteRegistration* registration = nullptr; - auto builtin_code = opcode->builtin_code(); - int version = opcode->version(); - - if (builtin_code > BuiltinOperator_MAX || - builtin_code < BuiltinOperator_MIN) { - error_reporter_->Report( - "Op builtin_code out or range: %d. Are you using old TFLite binary " - "with newer model?", - builtin_code); - status = kTfLiteError; - } else if (builtin_code != BuiltinOperator_CUSTOM) { - registration = op_resolver_.FindOp(builtin_code, version); - if (registration == nullptr) { - error_reporter_->Report( - "Didn't find op for builtin opcode '%s' version '%d'\n", - EnumNameBuiltinOperator(builtin_code), version); - status = kTfLiteError; - } - } else if (!opcode->custom_code()) { - error_reporter_->Report( - "Operator with CUSTOM builtin_code has no custom_code.\n"); - status = kTfLiteError; - } else { - const char* name = opcode->custom_code()->c_str(); - registration = op_resolver_.FindOp(name, version); - flatbuffer_op_index_to_registration_types_.push_back( - BuiltinOperator_CUSTOM); - if (registration == nullptr) { - error_reporter_->Report( - "Didn't find custom op for name '%s' with version %d\n", name, - version); - status = kTfLiteError; - } + status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_, + ®istration); + if (status != kTfLiteOk) { + return status; } flatbuffer_op_index_to_registration_.push_back(registration); } @@ -247,565 +184,6 @@ std::vector<int> FlatBufferIntArrayToVector(T* flat_array) { return ret; } -// Copies the contents from the flatbuffer int vector `flatbuffer` into the -// int array `buffer`. `flat_vector` and `buffer` represent the same -// configuration operation for a given operation. -void FlatBufferIntVectorToArray(int max_size_of_buffer, - const flatbuffers::Vector<int32_t>* flat_vector, - int* buffer, ErrorReporter* error_reporter) { - if (!flat_vector) { - error_reporter->Report("Input array not provided for operation.\n"); - } else { - int num_dimensions = flat_vector->Length(); - if (num_dimensions > max_size_of_buffer / sizeof(int)) { - error_reporter->Report( - "Found too many dimensions in the operation's input array.\n"); - } else { - for (int i = 0; i < num_dimensions; ++i) { - buffer[i] = flat_vector->Get(i); - } - } - } -} - -// Allocate a structure using C malloc, but make sure the structure is a -// POD structure that doesn't require constructors to run. The reason we do -// this, is that Interpreter's C extension part will take ownership and wants -// to use malloc() and free(). -template <class T> -T* MallocPOD() { - static_assert(std::is_pod<T>::value, "Builtin data structure must be POD."); - return static_cast<T*>(malloc(sizeof(T))); -} - -// Parse the appropriate data out of the op. -// -// This handles builtin data explicitly as there are flatbuffer schemas. -// If it returns kTfLiteOk, it passes the data out with `builtin_data`, which -// need to be released by calling `free`.` -// If it returns kTfLiteError, `builtin_data` will be `nullptr`. -TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, - ErrorReporter* error_reporter, void** builtin_data) { - auto parse_padding = [](Padding padding) { - switch (padding) { - case Padding_SAME: - return kTfLitePaddingSame; - case Padding_VALID: - return kTfLitePaddingValid; - } - return kTfLitePaddingUnknown; - }; - auto parse_activation = [](ActivationFunctionType activation) { - switch (activation) { - case ActivationFunctionType_NONE: - return kTfLiteActNone; - case ActivationFunctionType_RELU: - return kTfLiteActRelu; - case ActivationFunctionType_RELU_N1_TO_1: - return kTfLiteActRelu1; - case ActivationFunctionType_RELU6: - return kTfLiteActRelu6; - case ActivationFunctionType_TANH: - return kTfLiteActTanh; - case ActivationFunctionType_SIGN_BIT: - return kTfLiteActSignBit; - } - return kTfLiteActNone; - }; - auto parseLSHProjectionType = [](LSHProjectionType type) { - switch (type) { - case LSHProjectionType_SPARSE: - return kTfLiteLshProjectionSparse; - case LSHProjectionType_DENSE: - return kTfLiteLshProjectionDense; - default: - return kTfLiteLshProjectionUnknown; - } - }; - auto parseCombinerType = [](CombinerType type) { - switch (type) { - case CombinerType_MEAN: - return kTfLiteCombinerTypeMean; - case CombinerType_SQRTN: - return kTfLiteCombinerTypeSqrtn; - case CombinerType_SUM: - default: - return kTfLiteCombinerTypeSum; - } - }; - - *builtin_data = nullptr; - switch (op_type) { - case BuiltinOperator_CONV_2D: { - TfLiteConvParams* params = MallocPOD<TfLiteConvParams>(); - if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { - params->padding = parse_padding(conv_params->padding()); - params->stride_width = conv_params->stride_w(); - params->stride_height = conv_params->stride_h(); - params->activation = - parse_activation(conv_params->fused_activation_function()); - - params->dilation_width_factor = conv_params->dilation_w_factor(); - params->dilation_height_factor = conv_params->dilation_h_factor(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_CAST: { - TfLiteCastParams* params = MallocPOD<TfLiteCastParams>(); - if (auto* schema_params = op->builtin_options_as_CastOptions()) { - auto in_status = - ConvertTensorType(schema_params->in_data_type(), - ¶ms->in_data_type, error_reporter); - auto out_status = - ConvertTensorType(schema_params->out_data_type(), - ¶ms->out_data_type, error_reporter); - if (in_status != kTfLiteOk || out_status != kTfLiteOk) { - free(params); - return kTfLiteError; - } - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_LSH_PROJECTION: { - TfLiteLSHProjectionParams* params = - MallocPOD<TfLiteLSHProjectionParams>(); - if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) { - params->type = parseLSHProjectionType(lshParams->type()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_AVERAGE_POOL_2D: - case BuiltinOperator_MAX_POOL_2D: - case BuiltinOperator_L2_POOL_2D: { - TfLitePoolParams* params = MallocPOD<TfLitePoolParams>(); - if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) { - params->padding = parse_padding(pool_params->padding()); - params->stride_width = pool_params->stride_w(); - params->stride_height = pool_params->stride_h(); - params->filter_width = pool_params->filter_width(); - params->filter_height = pool_params->filter_height(); - params->activation = - parse_activation(pool_params->fused_activation_function()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_DEPTHWISE_CONV_2D: { - TfLiteDepthwiseConvParams* params = - MallocPOD<TfLiteDepthwiseConvParams>(); - if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) { - params->padding = parse_padding(conv_params->padding()); - params->stride_width = conv_params->stride_w(); - params->stride_height = conv_params->stride_h(); - params->depth_multiplier = conv_params->depth_multiplier(); - params->activation = - parse_activation(conv_params->fused_activation_function()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_SVDF: { - TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>(); - if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) { - params->rank = svdf_params->rank(); - params->activation = - parse_activation(svdf_params->fused_activation_function()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: - case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { - TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>(); - if (auto* sequence_rnn_params = - op->builtin_options_as_SequenceRNNOptions()) { - params->activation = - parse_activation(sequence_rnn_params->fused_activation_function()); - params->time_major = sequence_rnn_params->time_major(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_RNN: { - TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>(); - if (auto* rnn_params = op->builtin_options_as_RNNOptions()) { - params->activation = - parse_activation(rnn_params->fused_activation_function()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { - TfLiteEmbeddingLookupSparseParams* params = - MallocPOD<TfLiteEmbeddingLookupSparseParams>(); - if (auto* embedding_params = - op->builtin_options_as_EmbeddingLookupSparseOptions()) { - params->combiner = parseCombinerType(embedding_params->combiner()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_FULLY_CONNECTED: { - TfLiteFullyConnectedParams* params = - MallocPOD<TfLiteFullyConnectedParams>(); - if (auto* fully_connected_params = - op->builtin_options_as_FullyConnectedOptions()) { - params->activation = parse_activation( - fully_connected_params->fused_activation_function()); - switch (fully_connected_params->weights_format()) { - case FullyConnectedOptionsWeightsFormat_DEFAULT: - params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault; - break; - case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: - params->weights_format = - kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8; - break; - default: - error_reporter->Report("Unhandled fully-connected weights format."); - return kTfLiteError; - } - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_HASHTABLE_LOOKUP: - // no-op. - break; - case BuiltinOperator_SOFTMAX: { - TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>(); - if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) { - params->beta = softmax_params->beta(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_CONCATENATION: { - TfLiteConcatenationParams* params = - MallocPOD<TfLiteConcatenationParams>(); - if (auto* concatenation_params = - op->builtin_options_as_ConcatenationOptions()) { - params->activation = - parse_activation(concatenation_params->fused_activation_function()); - params->axis = concatenation_params->axis(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_MUL: { - auto* params = MallocPOD<TfLiteMulParams>(); - if (auto* schema_params = op->builtin_options_as_MulOptions()) { - params->activation = - parse_activation(schema_params->fused_activation_function()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_ADD: { - auto* params = MallocPOD<TfLiteAddParams>(); - if (auto* schema_params = op->builtin_options_as_AddOptions()) { - params->activation = - parse_activation(schema_params->fused_activation_function()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_DIV: { - auto* params = MallocPOD<TfLiteDivParams>(); - if (auto* schema_params = op->builtin_options_as_DivOptions()) { - params->activation = - parse_activation(schema_params->fused_activation_function()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_SUB: { - auto* params = MallocPOD<TfLiteSubParams>(); - if (auto* schema_params = op->builtin_options_as_SubOptions()) { - params->activation = - parse_activation(schema_params->fused_activation_function()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_L2_NORMALIZATION: { - auto* params = MallocPOD<TfLiteL2NormParams>(); - if (auto* schema_params = op->builtin_options_as_L2NormOptions()) { - params->activation = - parse_activation(schema_params->fused_activation_function()); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: { - auto* params = MallocPOD<TfLiteLocalResponseNormParams>(); - if (auto* schema_params = - op->builtin_options_as_LocalResponseNormalizationOptions()) { - params->radius = schema_params->radius(); - params->bias = schema_params->bias(); - params->alpha = schema_params->alpha(); - params->beta = schema_params->beta(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: - case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: - case BuiltinOperator_LSTM: { - TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>(); - if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { - params->activation = - parse_activation(lstm_params->fused_activation_function()); - params->cell_clip = lstm_params->cell_clip(); - params->proj_clip = lstm_params->proj_clip(); - switch (lstm_params->kernel_type()) { - case LSTMKernelType_FULL: - params->kernel_type = kTfLiteLSTMFullKernel; - break; - case LSTMKernelType_BASIC: - params->kernel_type = kTfLiteLSTMBasicKernel; - break; - } - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_RESIZE_BILINEAR: { - auto* params = MallocPOD<TfLiteResizeBilinearParams>(); - if (auto* schema_params = - op->builtin_options_as_ResizeBilinearOptions()) { - params->align_corners = schema_params->align_corners(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_RESHAPE: { - auto* params = MallocPOD<TfLiteReshapeParams>(); - if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { - auto* new_shape = schema_params->new_shape(); - FlatBufferIntVectorToArray(sizeof(params->shape), new_shape, - params->shape, error_reporter); - params->num_dimensions = new_shape->Length(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_SKIP_GRAM: { - TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>(); - if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) { - params->ngram_size = skip_gram_params->ngram_size(); - params->max_skip_size = skip_gram_params->max_skip_size(); - params->include_all_ngrams = skip_gram_params->include_all_ngrams(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_SPACE_TO_DEPTH: { - auto* params = MallocPOD<TfLiteSpaceToDepthParams>(); - if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) { - params->block_size = schema_params->block_size(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_GATHER: { - TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>(); - params->axis = 0; - if (auto* gather_params = op->builtin_options_as_GatherOptions()) { - params->axis = gather_params->axis(); - } - - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_MEAN: - case BuiltinOperator_REDUCE_MAX: - case BuiltinOperator_REDUCE_MIN: - case BuiltinOperator_REDUCE_PROD: - case BuiltinOperator_SUM: - case BuiltinOperator_REDUCE_ANY: { - auto* params = MallocPOD<TfLiteReducerParams>(); - if (auto* schema_params = op->builtin_options_as_ReducerOptions()) { - params->keep_dims = schema_params->keep_dims(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_SPLIT: { - auto* params = MallocPOD<TfLiteSplitParams>(); - if (auto* schema_params = op->builtin_options_as_SplitOptions()) { - params->num_splits = schema_params->num_splits(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_SQUEEZE: { - auto* params = MallocPOD<TfLiteSqueezeParams>(); - if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) { - const auto& squeeze_dims = schema_params->squeeze_dims(); - FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims, - params->squeeze_dims, error_reporter); - params->num_squeeze_dims = squeeze_dims->Length(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_STRIDED_SLICE: { - auto* params = MallocPOD<TfLiteStridedSliceParams>(); - if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) { - params->begin_mask = schema_params->begin_mask(); - params->end_mask = schema_params->end_mask(); - params->ellipsis_mask = schema_params->ellipsis_mask(); - params->new_axis_mask = schema_params->new_axis_mask(); - params->shrink_axis_mask = schema_params->shrink_axis_mask(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_ARG_MAX: { - auto* params = MallocPOD<TfLiteArgMaxParams>(); - if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) { - ConvertTensorType(schema_params->output_type(), ¶ms->output_type, - error_reporter); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_ARG_MIN: { - auto* params = MallocPOD<TfLiteArgMinParams>(); - if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) { - ConvertTensorType(schema_params->output_type(), ¶ms->output_type, - error_reporter); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_TRANSPOSE_CONV: { - TfLiteTransposeConvParams* params = - MallocPOD<TfLiteTransposeConvParams>(); - if (auto* transpose_conv_params = - op->builtin_options_as_TransposeConvOptions()) { - params->padding = parse_padding(transpose_conv_params->padding()); - params->stride_width = transpose_conv_params->stride_w(); - params->stride_height = transpose_conv_params->stride_h(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_SPARSE_TO_DENSE: { - TfLiteSparseToDenseParams* params = - MallocPOD<TfLiteSparseToDenseParams>(); - if (auto* sparse_to_dense_params = - op->builtin_options_as_SparseToDenseOptions()) { - params->validate_indices = sparse_to_dense_params->validate_indices(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_SHAPE: { - auto* params = MallocPOD<TfLiteShapeParams>(); - if (auto* schema_params = op->builtin_options_as_ShapeOptions()) { - ConvertTensorType(schema_params->out_type(), ¶ms->out_type, - error_reporter); - } - *builtin_data = static_cast<void*>(params); - break; - } - case BuiltinOperator_PACK: { - TfLitePackParams* params = MallocPOD<TfLitePackParams>(); - if (auto* pack_params = op->builtin_options_as_PackOptions()) { - params->values_count = pack_params->values_count(); - params->axis = pack_params->axis(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - case BuiltinOperator_DELEGATE: { - // TODO(ycling): Revisit when supporting saving delegated models. - error_reporter->Report("DELEGATE op shouldn't exist in model."); - return kTfLiteError; - } - case BuiltinOperator_FAKE_QUANT: { - auto* params = MallocPOD<TfLiteFakeQuantParams>(); - if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) { - params->min = schema_params->min(); - params->max = schema_params->max(); - params->num_bits = schema_params->num_bits(); - params->narrow_range = schema_params->narrow_range(); - } - *builtin_data = static_cast<void*>(params); - break; - } - case BuiltinOperator_ONE_HOT: { - auto* params = MallocPOD<TfLiteOneHotParams>(); - if (auto* schema_params = op->builtin_options_as_OneHotOptions()) { - params->axis = schema_params->axis(); - } - *builtin_data = static_cast<void*>(params); - break; - } - case BuiltinOperator_UNPACK: { - TfLiteUnpackParams* params = MallocPOD<TfLiteUnpackParams>(); - if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) { - params->num = unpack_params->num(); - params->axis = unpack_params->axis(); - } - *builtin_data = reinterpret_cast<void*>(params); - break; - } - - // Below are the ops with no builtin_data strcture. - case BuiltinOperator_BATCH_TO_SPACE_ND: - // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are - // ok for now, since there is no call implementation either. - case BuiltinOperator_CALL: - case BuiltinOperator_CONCAT_EMBEDDINGS: - case BuiltinOperator_CUSTOM: - case BuiltinOperator_DEQUANTIZE: - case BuiltinOperator_EMBEDDING_LOOKUP: - case BuiltinOperator_EQUAL: - case BuiltinOperator_EXP: - case BuiltinOperator_EXPAND_DIMS: - case BuiltinOperator_FLOOR: - case BuiltinOperator_GREATER: - case BuiltinOperator_GREATER_EQUAL: - case BuiltinOperator_LESS: - case BuiltinOperator_LESS_EQUAL: - case BuiltinOperator_LOG: - case BuiltinOperator_LOGISTIC: - case BuiltinOperator_LOG_SOFTMAX: - case BuiltinOperator_MAXIMUM: - case BuiltinOperator_MINIMUM: - case BuiltinOperator_NEG: - case BuiltinOperator_NOT_EQUAL: - case BuiltinOperator_PAD: - case BuiltinOperator_PADV2: - case BuiltinOperator_PRELU: - case BuiltinOperator_RELU: - case BuiltinOperator_RELU6: - case BuiltinOperator_RELU_N1_TO_1: - case BuiltinOperator_RSQRT: - case BuiltinOperator_SELECT: - case BuiltinOperator_SIN: - case BuiltinOperator_SLICE: - case BuiltinOperator_SPACE_TO_BATCH_ND: - case BuiltinOperator_SQRT: - case BuiltinOperator_TANH: - case BuiltinOperator_TILE: - case BuiltinOperator_TOPK_V2: - case BuiltinOperator_TRANSPOSE: - case BuiltinOperator_POW: - case BuiltinOperator_LOGICAL_OR: - case BuiltinOperator_LOGICAL_AND: - case BuiltinOperator_LOGICAL_NOT: - case BuiltinOperator_FLOOR_DIV: - break; - } - return kTfLiteOk; -} - } // namespace TfLiteStatus InterpreterBuilder::ParseNodes( diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h index 8bc9ecd7ce..6abdfcd079 100644 --- a/tensorflow/contrib/lite/model.h +++ b/tensorflow/contrib/lite/model.h @@ -35,9 +35,10 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_MODEL_H_ #include <memory> -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/op_resolver.h" #include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/op_resolver.h" +#include "tensorflow/contrib/lite/mutable_op_resolver.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" namespace tflite { diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc index df4f60d4ad..ec7d46af7c 100644 --- a/tensorflow/contrib/lite/model_test.cc +++ b/tensorflow/contrib/lite/model_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/contrib/lite/model.h" #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" #include "tensorflow/contrib/lite/testing/util.h" // Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object, diff --git a/tensorflow/contrib/lite/op_resolver.cc b/tensorflow/contrib/lite/mutable_op_resolver.cc index f6e435e982..8ee63d2a02 100644 --- a/tensorflow/contrib/lite/op_resolver.cc +++ b/tensorflow/contrib/lite/mutable_op_resolver.cc @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/op_resolver.h" -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/mutable_op_resolver.h" namespace tflite { diff --git a/tensorflow/contrib/lite/mutable_op_resolver.h b/tensorflow/contrib/lite/mutable_op_resolver.h new file mode 100644 index 0000000000..c319041e9b --- /dev/null +++ b/tensorflow/contrib/lite/mutable_op_resolver.h @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_ +#define TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_ + +#include <unordered_map> +#include "tensorflow/contrib/lite/core/api/op_resolver.h" +#include "tensorflow/contrib/lite/util.h" + +namespace tflite { + +// Some versions of gcc doesn't support partial specialization in class scope, +// so these are defined in a namescope. +namespace op_resolver_hasher { +template <typename V> +struct ValueHasher { + size_t operator()(const V& v) const { return std::hash<V>()(v); } +}; + +template <> +struct ValueHasher<tflite::BuiltinOperator> { + size_t operator()(const tflite::BuiltinOperator& v) const { + return std::hash<int>()(static_cast<int>(v)); + } +}; + +template <typename T> +struct OperatorKeyHasher { + size_t operator()(const T& x) const { + size_t a = ValueHasher<typename T::first_type>()(x.first); + size_t b = ValueHasher<typename T::second_type>()(x.second); + return CombineHashes({a, b}); + } +}; +} // namespace op_resolver_hasher + +// An OpResolver that is mutable, also used as the op in gen_op_registration. +// A typical usage: +// MutableOpResolver resolver; +// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD()); +// resolver.AddCustom("CustomOp", Register_CUSTOM_OP()); +// InterpreterBuilder(model, resolver)(&interpreter); +class MutableOpResolver : public OpResolver { + public: + const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, + int version) const override; + const TfLiteRegistration* FindOp(const char* op, int version) const override; + void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); + void AddCustom(const char* name, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); + + private: + typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey; + typedef std::pair<std::string, int> CustomOperatorKey; + + std::unordered_map<BuiltinOperatorKey, TfLiteRegistration, + op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> > + builtins_; + std::unordered_map<CustomOperatorKey, TfLiteRegistration, + op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> > + custom_ops_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/op_resolver_test.cc b/tensorflow/contrib/lite/mutable_op_resolver_test.cc index 10b7e31972..db690eaab9 100644 --- a/tensorflow/contrib/lite/op_resolver_test.cc +++ b/tensorflow/contrib/lite/mutable_op_resolver_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/op_resolver.h" +#include "tensorflow/contrib/lite/mutable_op_resolver.h" #include <gtest/gtest.h> #include "tensorflow/contrib/lite/testing/util.h" diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 484842713d..817486e898 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -18,8 +18,8 @@ limitations under the License. #include <sys/mman.h> #include <sys/stat.h> #include <sys/types.h> -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/c/builtin_op_data.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h index 2bdb2cc5c8..22359d557e 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.h +++ b/tensorflow/contrib/lite/nnapi_delegate.h @@ -16,8 +16,8 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_ #include "tensorflow/contrib/lite/allocation.h" -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" #include "tensorflow/contrib/lite/interpreter.h" class ANeuralNetworksModel; diff --git a/tensorflow/contrib/lite/op_resolver.h b/tensorflow/contrib/lite/op_resolver.h index 9d7e3f2085..e93134cbde 100644 --- a/tensorflow/contrib/lite/op_resolver.h +++ b/tensorflow/contrib/lite/op_resolver.h @@ -12,83 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// Compatibility shim for moved header location. #ifndef TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_ #define TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_ -#include <unordered_map> -#include "tensorflow/contrib/lite/context.h" -#include "tensorflow/contrib/lite/schema/schema_generated.h" -#include "tensorflow/contrib/lite/util.h" - -namespace tflite { - -// Abstract interface that returns TfLiteRegistrations given op codes or custom -// op names. This is the mechanism that ops being referenced in the flatbuffer -// model are mapped to executable function pointers (TfLiteRegistrations). -class OpResolver { - public: - // Finds the op registration for a builtin operator by enum code. - virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, - int version) const = 0; - // Finds the op registration of a custom operator by op name. - virtual const TfLiteRegistration* FindOp(const char* op, - int version) const = 0; - virtual ~OpResolver() {} -}; - -// Some versions of gcc doesn't support partial specialization in class scope, -// so these are defined in a namescope. -namespace op_resolver_hasher { -template <typename V> -struct ValueHasher { - size_t operator()(const V& v) const { return std::hash<V>()(v); } -}; - -template <> -struct ValueHasher<tflite::BuiltinOperator> { - size_t operator()(const tflite::BuiltinOperator& v) const { - return std::hash<int>()(static_cast<int>(v)); - } -}; - -template <typename T> -struct OperatorKeyHasher { - size_t operator()(const T& x) const { - size_t a = ValueHasher<typename T::first_type>()(x.first); - size_t b = ValueHasher<typename T::second_type>()(x.second); - return CombineHashes({a, b}); - } -}; -} // namespace op_resolver_hasher - -// An OpResolver that is mutable, also used as the op in gen_op_registration. -// A typical usage: -// MutableOpResolver resolver; -// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD()); -// resolver.AddCustom("CustomOp", Register_CUSTOM_OP()); -// InterpreterBuilder(model, resolver)(&interpreter); -class MutableOpResolver : public OpResolver { - public: - const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, - int version) const override; - const TfLiteRegistration* FindOp(const char* op, int version) const override; - void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration, - int min_version = 1, int max_version = 1); - void AddCustom(const char* name, TfLiteRegistration* registration, - int min_version = 1, int max_version = 1); - - private: - typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey; - typedef std::pair<std::string, int> CustomOperatorKey; - - std::unordered_map<BuiltinOperatorKey, TfLiteRegistration, - op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> > - builtins_; - std::unordered_map<CustomOperatorKey, TfLiteRegistration, - op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> > - custom_ops_; -}; - -} // namespace tflite +#include "tensorflow/contrib/lite/core/api/op_resolver.h" +#include "tensorflow/contrib/lite/mutable_op_resolver.h" #endif // TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_ diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h index f738315cf2..45d0d8735e 100644 --- a/tensorflow/contrib/lite/simple_memory_arena.h +++ b/tensorflow/contrib/lite/simple_memory_arena.h @@ -17,7 +17,7 @@ limitations under the License. #include <list> #include <memory> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace tflite { diff --git a/tensorflow/contrib/lite/error_reporter.cc b/tensorflow/contrib/lite/stderr_reporter.cc index 646913c026..e29a6345fd 100644 --- a/tensorflow/contrib/lite/error_reporter.cc +++ b/tensorflow/contrib/lite/stderr_reporter.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/stderr_reporter.h" #include <cstdarg> #include <cstdio> @@ -22,26 +22,6 @@ limitations under the License. namespace tflite { -ErrorReporter::~ErrorReporter() {} - -int ErrorReporter::Report(const char* format, ...) { - va_list args; - va_start(args, format); - int code = Report(format, args); - va_end(args); - return code; -} - -// TODO(aselle): Make the name of ReportError on context the same, so -// we can use the ensure functions w/o a context and w/ a reporter. -int ErrorReporter::ReportError(void*, const char* format, ...) { - va_list args; - va_start(args, format); - int code = Report(format, args); - va_end(args); - return code; -} - int StderrReporter::Report(const char* format, va_list args) { #ifdef __ANDROID__ // On Android stderr is not captured for applications, only for code run from diff --git a/tensorflow/contrib/lite/stderr_reporter.h b/tensorflow/contrib/lite/stderr_reporter.h new file mode 100644 index 0000000000..c6f4ffbdff --- /dev/null +++ b/tensorflow/contrib/lite/stderr_reporter.h @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_ +#define TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_ + +#include <cstdarg> +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" + +namespace tflite { + +// An error reporter that simplify writes the message to stderr. +struct StderrReporter : public ErrorReporter { + int Report(const char* format, va_list args) override; +}; + +// Return the default error reporter (output to stderr). +ErrorReporter* DefaultErrorReporter(); + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_ diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/contrib/lite/string_util.cc index a316a40b62..b991e999b6 100644 --- a/tensorflow/contrib/lite/string_util.cc +++ b/tensorflow/contrib/lite/string_util.cc @@ -17,7 +17,7 @@ limitations under the License. #include <string.h> #include <vector> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/interpreter.h" namespace tflite { diff --git a/tensorflow/contrib/lite/string_util.h b/tensorflow/contrib/lite/string_util.h index 57f129bf5e..d24627b509 100644 --- a/tensorflow/contrib/lite/string_util.h +++ b/tensorflow/contrib/lite/string_util.h @@ -42,7 +42,7 @@ limitations under the License. #include <vector> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/string.h" namespace tflite { diff --git a/tensorflow/contrib/lite/string_util_test.cc b/tensorflow/contrib/lite/string_util_test.cc index d53fec7512..a583a9184b 100644 --- a/tensorflow/contrib/lite/string_util_test.cc +++ b/tensorflow/contrib/lite/string_util_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/contrib/lite/string_util.h" #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/testing/util.h" diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index 0b3a97d4f5..aad1ecaeb6 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -173,7 +173,6 @@ tf_cc_test( srcs = ["tflite_driver_test.cc"], data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], tags = [ - "no_oss", # b/112769036 "tflite_not_portable_android", "tflite_not_portable_ios", ], @@ -215,6 +214,7 @@ cc_library( deps = [ "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string", + "//tensorflow/contrib/lite/core/api", ], ) diff --git a/tensorflow/contrib/lite/testing/util.h b/tensorflow/contrib/lite/testing/util.h index 8aa639157b..925791d390 100644 --- a/tensorflow/contrib/lite/testing/util.h +++ b/tensorflow/contrib/lite/testing/util.h @@ -17,7 +17,7 @@ limitations under the License. #include <cstdio> -#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/core/api/error_reporter.h" #include "tensorflow/contrib/lite/string.h" namespace tflite { diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index a75553db84..bea90f1ce8 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -372,6 +372,7 @@ cc_library( ":toco_graphviz_dump_options", ":toco_port", ":types_proto_cc", + "//tensorflow/contrib/lite/kernels/internal:types", "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@com_googlesource_code_re2//:re2", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index c25be078ff..f103bb94ae 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1314,12 +1314,16 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { // Compute output shape for (int axis = 0; axis < num_input_axes; ++axis) { + const auto strided_slice_params = + tflite::strided_slice::BuildStridedSliceParams( + op->begin_mask, op->end_mask, op->shrink_axis_mask, + op->start_indices, op->stop_indices, op->strides); int start_index = tflite::strided_slice::StartForAxis( - op->begin_mask, op->start_indices, op->strides, - input_array.shape().dims().data(), axis); + strided_slice_params, ToRuntimeShape(input_array.shape()), axis); int stop_index = tflite::strided_slice::StopForAxis( - op->end_mask, op->shrink_axis_mask, op->stop_indices, op->strides, - input_array.shape().dims().data(), axis, start_index); + strided_slice_params, ToRuntimeShape(input_array.shape()), axis, + start_index); + int dim_size = ceil(static_cast<float>(stop_index - start_index) / op->strides[axis]); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc index 9d8bd4fc39..8853ed87e6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc @@ -52,14 +52,18 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, Buffer<Type> const& input_buffer = input_array.GetBuffer<Type>(); std::vector<int> src_coord(num_input_axes); std::vector<int> stop_for_axis(num_input_axes); + const auto strided_slice_params = + tflite::strided_slice::BuildStridedSliceParams( + op.begin_mask, op.end_mask, op.shrink_axis_mask, op.start_indices, + op.stop_indices, op.strides); + for (int axis = 0; axis < num_input_axes; axis++) { - int start = tflite::strided_slice::StartForAxis( - op.begin_mask, op.start_indices, op.strides, input_shape.dims().data(), - axis); - src_coord[axis] = start; + int start_index = tflite::strided_slice::StartForAxis( + strided_slice_params, ToRuntimeShape(input_array.shape()), axis); + src_coord[axis] = start_index; stop_for_axis[axis] = tflite::strided_slice::StopForAxis( - op.end_mask, op.shrink_axis_mask, op.stop_indices, op.strides, - input_shape.dims().data(), axis, start); + strided_slice_params, ToRuntimeShape(input_array.shape()), axis, + start_index); } // In order to handle any number (N) of dimensions, we copy elements one by @@ -86,8 +90,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, if (tflite::strided_slice::LoopCondition(src_coord[axis], stop, stride)) { // Reset axis and set carry src_coord[axis] = tflite::strided_slice::StartForAxis( - op.begin_mask, op.start_indices, op.strides, - input_shape.dims().data(), axis); + strided_slice_params, ToRuntimeShape(input_shape), axis); carry = true; } else { carry = false; diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index bdeb203024..5f4b8cb66a 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -28,6 +28,7 @@ limitations under the License. #if TOCO_SUPPORT_PORTABLE_PROTOS #include "third_party/protobuf/include/google/protobuf/text_format.h" #endif // TOCO_SUPPORT_PORTABLE_PROTOS +#include "tensorflow/contrib/lite/kernels/internal/types.h" #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/runtime/types.h" @@ -139,6 +140,10 @@ bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1); // - For the remaining indices [0..i0), d0[i0] == 1. bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1); +inline ::tflite::RuntimeShape ToRuntimeShape(const Shape& shape) { + return ::tflite::RuntimeShape(shape.dimensions_count(), shape.dims().data()); +} + bool IsArrayFullyConnectedWeights(const Model& model, const string& name); // If there is a wildcard dimension (-1), this may return a negative value. diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD index a66812fe87..98e2835b2e 100644 --- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD +++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD @@ -54,6 +54,7 @@ tf_cc_test( linkopts = common_linkopts, linkstatic = 1, tags = [ + "no_oss", # b/114307765 "tflite_not_portable_android", "tflite_not_portable_ios", ], diff --git a/tensorflow/contrib/lite/tools/make/Makefile b/tensorflow/contrib/lite/tools/make/Makefile index e30cc1d70e..59bdb10811 100644 --- a/tensorflow/contrib/lite/tools/make/Makefile +++ b/tensorflow/contrib/lite/tools/make/Makefile @@ -24,6 +24,21 @@ HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32 TARGET := $(HOST_OS) TARGET_ARCH := $(HOST_ARCH) +INCLUDES := \ +-I. \ +-I$(MAKEFILE_DIR)/../../../../../ \ +-I$(MAKEFILE_DIR)/../../../../../../ \ +-I$(MAKEFILE_DIR)/downloads/ \ +-I$(MAKEFILE_DIR)/downloads/eigen \ +-I$(MAKEFILE_DIR)/downloads/gemmlowp \ +-I$(MAKEFILE_DIR)/downloads/neon_2_sse \ +-I$(MAKEFILE_DIR)/downloads/farmhash/src \ +-I$(MAKEFILE_DIR)/downloads/flatbuffers/include \ +-I$(OBJDIR) +# This is at the end so any globally-installed frameworks like protobuf don't +# override local versions in the source tree. +INCLUDES += -I/usr/local/include + # These are the default libraries needed, but they can be added to or # overridden by the platform-specific settings in target makefiles. LIBS := \ @@ -44,55 +59,17 @@ ARFLAGS := -r TARGET_TOOLCHAIN_PREFIX := CC_PREFIX := -# These target-specific makefiles should modify or replace options like -# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic -# based on platforms or architectures should happen within these files, to -# keep this main makefile focused on the sources and dependencies. -include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc) - -# Where compiled objects are stored. -GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/ -OBJDIR := $(GENDIR)obj/ -BINDIR := $(GENDIR)bin/ -LIBDIR := $(GENDIR)lib/ - -INCLUDES := \ --I. \ --I$(MAKEFILE_DIR)/../../../../../ \ --I$(MAKEFILE_DIR)/../../../../../../ \ --I$(MAKEFILE_DIR)/downloads/ \ --I$(MAKEFILE_DIR)/downloads/eigen \ --I$(MAKEFILE_DIR)/downloads/gemmlowp \ --I$(MAKEFILE_DIR)/downloads/neon_2_sse \ --I$(MAKEFILE_DIR)/downloads/farmhash/src \ --I$(MAKEFILE_DIR)/downloads/flatbuffers/include \ --I$(OBJDIR) -# This is at the end so any globally-installed frameworks like protobuf don't -# override local versions in the source tree. -INCLUDES += -I/usr/local/include - -CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++ -CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc -AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar - # This library is the main target for this makefile. It will contain a minimal # runtime that can be linked in to other programs. LIB_NAME := libtensorflow-lite.a -LIB_PATH := $(LIBDIR)$(LIB_NAME) - -# A small example program that shows how to link against the library. -MINIMAL_PATH := $(BINDIR)minimal # Benchmark static library and binary BENCHMARK_LIB_NAME := benchmark-lib.a BENCHMARK_BINARY_NAME := benchmark_model -BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME) -BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME) +# A small example program that shows how to link against the library. MINIMAL_SRCS := \ tensorflow/contrib/lite/examples/minimal/minimal.cc -MINIMAL_OBJS := $(addprefix $(OBJDIR), \ -$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS)))) # What sources we want to compile, must be kept in sync with the main Bazel # build files. @@ -105,7 +82,9 @@ PROFILE_SUMMARIZER_SRCS := \ CORE_CC_ALL_SRCS := \ $(wildcard tensorflow/contrib/lite/*.cc) \ -$(wildcard tensorflow/contrib/lite/*.c) +$(wildcard tensorflow/contrib/lite/*.c) \ +$(wildcard tensorflow/contrib/lite/c/*.c) \ +$(wildcard tensorflow/contrib/lite/core/api/*.cc) ifneq ($(BUILD_TYPE),micro) CORE_CC_ALL_SRCS += \ $(wildcard tensorflow/contrib/lite/kernels/*.cc) \ @@ -136,10 +115,6 @@ tensorflow/contrib/lite/nnapi_delegate.cc endif # Filter out all the excluded files. TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) -# File names of the intermediate files target compilation generates. -TF_LITE_CC_OBJS := $(addprefix $(OBJDIR), \ -$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS)))) -LIB_OBJS := $(TF_LITE_CC_OBJS) # Benchmark sources BENCHMARK_SRCS_DIR := tensorflow/contrib/lite/tools/benchmark @@ -151,6 +126,40 @@ BENCHMARK_SRCS := $(filter-out \ $(wildcard $(BENCHMARK_SRCS_DIR)/*_test.cc), \ $(BENCHMARK_ALL_SRCS)) +# These target-specific makefiles should modify or replace options like +# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic +# based on platforms or architectures should happen within these files, to +# keep this main makefile focused on the sources and dependencies. +include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc) + +ALL_SRCS := \ + $(MINIMAL_SRCS) \ + $(PROFILER_SRCS) \ + $(PROFILER_SUMMARY_SRCS) \ + $(TF_LITE_CC_SRCS) \ + $(BENCHMARK_SRCS) + +# Where compiled objects are stored. +GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/ +OBJDIR := $(GENDIR)obj/ +BINDIR := $(GENDIR)bin/ +LIBDIR := $(GENDIR)lib/ + +LIB_PATH := $(LIBDIR)$(LIB_NAME) +BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME) +BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME) +MINIMAL_BINARY := $(BINDIR)minimal + +CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++ +CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc +AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar + +MINIMAL_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS)))) + +LIB_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS)))) + BENCHMARK_OBJS := $(addprefix $(OBJDIR), \ $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS)))) @@ -164,7 +173,7 @@ $(OBJDIR)%.o: %.c $(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@ # The target that's compiled if there's no command-line arguments. -all: $(LIB_PATH) $(MINIMAL_PATH) $(BENCHMARK_BINARY) +all: $(LIB_PATH) $(MINIMAL_BINARY) $(BENCHMARK_BINARY) # The target that's compiled for micro-controllers micro: $(LIB_PATH) @@ -178,19 +187,18 @@ $(LIB_PATH): tensorflow/contrib/lite/schema/schema_generated.h $(LIB_OBJS) @mkdir -p $(dir $@) $(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS) -$(MINIMAL_PATH): $(MINIMAL_OBJS) $(LIB_PATH) +$(MINIMAL_BINARY): $(MINIMAL_OBJS) $(LIB_PATH) @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) $(INCLUDES) \ - -o $(MINIMAL_PATH) $(MINIMAL_OBJS) \ + -o $(MINIMAL_BINARY) $(MINIMAL_OBJS) \ $(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS) - $(BENCHMARK_LIB) : $(LIB_PATH) $(BENCHMARK_OBJS) @mkdir -p $(dir $@) $(AR) $(ARFLAGS) $(BENCHMARK_LIB) $(LIB_OBJS) $(BENCHMARK_OBJS) benchmark_lib: $(BENCHMARK_LIB) -$(info $(BENCHMARK_BINARY)) + $(BENCHMARK_BINARY) : $(BENCHMARK_LIB) @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) $(INCLUDES) \ @@ -213,4 +221,4 @@ cleantarget: $(DEPDIR)/%.d: ; .PRECIOUS: $(DEPDIR)/%.d --include $(patsubst %,$(DEPDIR)/%.d,$(basename $(TF_CC_SRCS))) +-include $(patsubst %,$(DEPDIR)/%.d,$(basename $(ALL_SRCS))) diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc index 692efb9029..b863108aa4 100644 --- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc +++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc @@ -141,6 +141,7 @@ bool IsHybridEvaluationOp(const OperatorT* op, const BuiltinOperator& op_code) { op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF || op_code == BuiltinOperator_EMBEDDING_LOOKUP || op_code == BuiltinOperator_RNN || + op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM || op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN || op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM || op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) { diff --git a/tensorflow/contrib/lite/tutorials/BUILD b/tensorflow/contrib/lite/tutorials/BUILD new file mode 100644 index 0000000000..67ff1ea124 --- /dev/null +++ b/tensorflow/contrib/lite/tutorials/BUILD @@ -0,0 +1,20 @@ +# Example Estimator model + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_binary( + name = "mnist_tflite", + srcs = [ + "dataset.py", + "mnist_tflite.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/lite/tutorials/dataset.py b/tensorflow/contrib/lite/tutorials/dataset.py new file mode 100644 index 0000000000..ba49dfcc9b --- /dev/null +++ b/tensorflow/contrib/lite/tutorials/dataset.py @@ -0,0 +1,122 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""tf.data.Dataset interface to the MNIST dataset. + + This is cloned from + https://github.com/tensorflow/models/blob/master/official/mnist/dataset.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os +import shutil +import tempfile + +import numpy as np +from six.moves import urllib +import tensorflow as tf + + +def read32(bytestream): + """Read 4 bytes from bytestream as an unsigned 32-bit integer.""" + dt = np.dtype(np.uint32).newbyteorder('>') + return np.frombuffer(bytestream.read(4), dtype=dt)[0] + + +def check_image_file_header(filename): + """Validate that filename corresponds to images for the MNIST dataset.""" + with tf.gfile.Open(filename, 'rb') as f: + magic = read32(f) + read32(f) # num_images, unused + rows = read32(f) + cols = read32(f) + if magic != 2051: + raise ValueError('Invalid magic number %d in MNIST file %s' % (magic, + f.name)) + if rows != 28 or cols != 28: + raise ValueError( + 'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' % + (f.name, rows, cols)) + + +def check_labels_file_header(filename): + """Validate that filename corresponds to labels for the MNIST dataset.""" + with tf.gfile.Open(filename, 'rb') as f: + magic = read32(f) + read32(f) # num_items, unused + if magic != 2049: + raise ValueError('Invalid magic number %d in MNIST file %s' % (magic, + f.name)) + + +def download(directory, filename): + """Download (and unzip) a file from the MNIST dataset if not already done.""" + filepath = os.path.join(directory, filename) + if tf.gfile.Exists(filepath): + return filepath + if not tf.gfile.Exists(directory): + tf.gfile.MakeDirs(directory) + # CVDF mirror of http://yann.lecun.com/exdb/mnist/ + url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz' + _, zipped_filepath = tempfile.mkstemp(suffix='.gz') + print('Downloading %s to %s' % (url, zipped_filepath)) + urllib.request.urlretrieve(url, zipped_filepath) + with gzip.open(zipped_filepath, 'rb') as f_in, \ + tf.gfile.Open(filepath, 'wb') as f_out: + shutil.copyfileobj(f_in, f_out) + os.remove(zipped_filepath) + return filepath + + +def dataset(directory, images_file, labels_file): + """Download and parse MNIST dataset.""" + + images_file = download(directory, images_file) + labels_file = download(directory, labels_file) + + check_image_file_header(images_file) + check_labels_file_header(labels_file) + + def decode_image(image): + # Normalize from [0, 255] to [0.0, 1.0] + image = tf.decode_raw(image, tf.uint8) + image = tf.cast(image, tf.float32) + image = tf.reshape(image, [784]) + return image / 255.0 + + def decode_label(label): + label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8] + label = tf.reshape(label, []) # label is a scalar + return tf.to_int32(label) + + images = tf.data.FixedLengthRecordDataset( + images_file, 28 * 28, header_bytes=16).map(decode_image) + labels = tf.data.FixedLengthRecordDataset( + labels_file, 1, header_bytes=8).map(decode_label) + return tf.data.Dataset.zip((images, labels)) + + +def train(directory): + """tf.data.Dataset object for MNIST training data.""" + return dataset(directory, 'train-images-idx3-ubyte', + 'train-labels-idx1-ubyte') + + +def test(directory): + """tf.data.Dataset object for MNIST test data.""" + return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte') diff --git a/tensorflow/contrib/lite/tutorials/mnist_tflite.py b/tensorflow/contrib/lite/tutorials/mnist_tflite.py new file mode 100644 index 0000000000..7b8bf5b5db --- /dev/null +++ b/tensorflow/contrib/lite/tutorials/mnist_tflite.py @@ -0,0 +1,87 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script to evaluate accuracy of TFLite flatbuffer model on mnist dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +import tensorflow as tf # pylint: disable=g-bad-import-order +from tensorflow.contrib.lite.tutorials import dataset +flags = tf.app.flags + +flags.DEFINE_string('data_dir', '/tmp/data_dir', + 'Directory where data is stored.') +flags.DEFINE_string('model_file', '', + 'The path to the TFLite flatbuffer model file.') + + +flags = flags.FLAGS + + +def test_image_generator(): + # Generates an iterator over images + with tf.Session() as sess: + input_data = dataset.test( + flags.data_dir).make_one_shot_iterator().get_next() + try: + while True: + yield sess.run(input_data) + except tf.errors.OutOfRangeError: + pass + + +def run_eval(interpreter, input_image): + """Performs evaluation for input image over specified model. + + Args: + interpreter: TFLite interpreter initialized with model to execute. + input_image: Image input to the model. + + Returns: + output: output tensor of model being executed. + """ + + # Get input and output tensors. + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + # Test model on the input images. + input_image = np.reshape(input_image, input_details[0]['shape']) + interpreter.set_tensor(input_details[0]['index'], input_image) + + interpreter.invoke() + output_data = interpreter.get_tensor(output_details[0]['index']) + output = np.squeeze(output_data) + return output + + +def main(_): + interpreter = tf.contrib.lite.Interpreter(model_path=flags.model_file) + interpreter.allocate_tensors() + num_correct, total = 0, 0 + for input_data in test_image_generator(): + output = run_eval(interpreter, input_data[0]) + total += 1 + if output == input_data[1]: + num_correct += 1 + if total % 500 == 0: + print('Accuracy after %i images: %f' % + (total, float(num_correct) / float(total))) + + +if __name__ == '__main__': + tf.logging.set_verbosity(tf.logging.INFO) + tf.app.run(main) diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h index f5b208afbb..6d81f844f8 100644 --- a/tensorflow/contrib/lite/util.h +++ b/tensorflow/contrib/lite/util.h @@ -22,7 +22,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_UTIL_H_ #include <vector> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" namespace tflite { diff --git a/tensorflow/contrib/lite/util_test.cc b/tensorflow/contrib/lite/util_test.cc index 32bf917a59..c5c1709f1d 100644 --- a/tensorflow/contrib/lite/util_test.cc +++ b/tensorflow/contrib/lite/util_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include <gmock/gmock.h> #include <gtest/gtest.h> -#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" #include "tensorflow/contrib/lite/util.h" namespace tflite { diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt index 22b11f1c57..7d26429f9c 100644 --- a/tensorflow/contrib/makefile/proto_text_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt @@ -56,6 +56,7 @@ tensorflow/core/lib/hash/hash.cc tensorflow/core/lib/hash/crc32c.cc tensorflow/core/lib/hash/crc32c_accelerate.cc tensorflow/core/lib/core/threadpool.cc +tensorflow/core/lib/core/stringpiece.cc tensorflow/core/lib/core/status.cc tensorflow/core/lib/core/coding.cc tensorflow/core/lib/core/arena.cc diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 93e589907e..2e4d61d931 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -159,8 +159,10 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:variables", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py index f026f437dc..f55209ec49 100644 --- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py +++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py @@ -25,7 +25,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -48,12 +47,7 @@ class LazyAdamOptimizer(adam.AdamOptimizer): may lead to different empirical results. """ - def _apply_sparse_shared(self, - grad, - var, - indices, - scatter_update, - scatter_sub): + def _apply_sparse(self, grad, var): beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) @@ -65,51 +59,56 @@ class LazyAdamOptimizer(adam.AdamOptimizer): # \\(m := beta1 * m + (1 - beta1) * g_t\\) m = self.get_slot(var, "m") - m_t = scatter_update(m, indices, - beta1_t * array_ops.gather(m, indices) + - (1 - beta1_t) * grad) + m_t = state_ops.scatter_update(m, grad.indices, + beta1_t * array_ops.gather(m, grad.indices) + + (1 - beta1_t) * grad.values, + use_locking=self._use_locking) # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) v = self.get_slot(var, "v") - v_t = scatter_update(v, indices, - beta2_t * array_ops.gather(v, indices) + - (1 - beta2_t) * math_ops.square(grad)) + v_t = state_ops.scatter_update(v, grad.indices, + beta2_t * array_ops.gather(v, grad.indices) + + (1 - beta2_t) * math_ops.square(grad.values), + use_locking=self._use_locking) # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) - m_t_slice = array_ops.gather(m_t, indices) - v_t_slice = array_ops.gather(v_t, indices) + m_t_slice = array_ops.gather(m_t, grad.indices) + v_t_slice = array_ops.gather(v_t, grad.indices) denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t - var_update = scatter_sub(var, indices, - lr * m_t_slice / denominator_slice) + var_update = state_ops.scatter_sub(var, grad.indices, + lr * m_t_slice / denominator_slice, + use_locking=self._use_locking) return control_flow_ops.group(var_update, m_t, v_t) - def _apply_sparse(self, grad, var): - return self._apply_sparse_shared( - grad.values, var, grad.indices, - self._scatter_update, - self._scatter_sub) - def _resource_apply_sparse(self, grad, var, indices): - return self._apply_sparse_shared( - grad, var, indices, - self._resource_scatter_update, - self._resource_scatter_sub) - - # Utility functions for updating resource or non-resource variables. - def _scatter_update(self, x, i, v): - return state_ops.scatter_update( - x, i, v, use_locking=self._use_locking) - - def _scatter_sub(self, x, i, v): - return state_ops.scatter_sub( - x, i, v, use_locking=self._use_locking) - - def _resource_scatter_update(self, x, i, v): - update_op = resource_variable_ops.resource_scatter_update(x.handle, i, v) - with ops.control_dependencies([update_op]): - return x.value() - - def _resource_scatter_sub(self, x, i, v): - sub_op = resource_variable_ops.resource_scatter_sub(x.handle, i, v) - with ops.control_dependencies([sub_op]): - return x.value() + beta1_power, beta2_power = self._get_beta_accumulators() + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) + + # \\(m := beta1 * m + (1 - beta1) * g_t\\) + m = self.get_slot(var, "m") + m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad + m_update_op = resource_variable_ops.resource_scatter_update(m.handle, + indices, + m_t_slice) + + # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) + v = self.get_slot(var, "v") + v_t_slice = (beta2_t * array_ops.gather(v, indices) + + (1 - beta2_t) * math_ops.square(grad)) + v_update_op = resource_variable_ops.resource_scatter_update(v.handle, + indices, + v_t_slice) + + # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) + var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t) + var_update_op = resource_variable_ops.resource_scatter_sub(var.handle, + indices, + var_slice) + + return control_flow_ops.group(var_update_op, m_update_op, v_update_op) diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py index d3e9e89502..f08ffaa36f 100644 --- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py @@ -19,12 +19,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.contrib.opt.python.training import lazy_adam_optimizer +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -50,9 +53,10 @@ def adam_update_numpy(param, return param_t, m_t, v_t -class AdamOptimizerTest(test.TestCase): +class AdamOptimizerTest(test.TestCase, parameterized.TestCase): - def doTestSparse(self, use_resource=False): + @parameterized.parameters([False, True]) + def testSparse(self, use_resource): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): # Initialize variables for numpy implementation. @@ -68,6 +72,7 @@ class AdamOptimizerTest(test.TestCase): else: var0 = variables.Variable(var0_np) var1 = variables.Variable(var1_np) + grads0_np_indices = np.array([0, 1], dtype=np.int32) grads0 = ops.IndexedSlices( constant_op.constant(grads0_np), @@ -99,18 +104,17 @@ class AdamOptimizerTest(test.TestCase): self.assertAllCloseAccordingToType(var0_np, var0.eval()) self.assertAllCloseAccordingToType(var1_np, var1.eval()) - def testSparse(self): - self.doTestSparse(use_resource=False) - - def testResourceSparse(self): - self.doTestSparse(use_resource=True) - - def testSparseDevicePlacement(self): + @parameterized.parameters([False, True]) + def testSparseDevicePlacement(self, use_resource): for index_dtype in [dtypes.int32, dtypes.int64]: with self.test_session(force_gpu=test.is_gpu_available()): # If a GPU is available, tests that all optimizer ops can be placed on # it (i.e. they have GPU kernels). - var = variables.Variable([[1.0], [2.0]]) + if use_resource: + var = resource_variable_ops.ResourceVariable([[1.0], [2.0]]) + else: + var = variables.Variable([[1.0], [2.0]]) + indices = constant_op.constant([0, 1], dtype=index_dtype) gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices)) optimizer = lazy_adam_optimizer.LazyAdamOptimizer(3.0) @@ -118,13 +122,21 @@ class AdamOptimizerTest(test.TestCase): variables.global_variables_initializer().run() minimize_op.run() - def testSparseRepeatedIndices(self): + @parameterized.parameters([False, True]) + def testSparseRepeatedIndices(self, use_resource): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): - repeated_index_update_var = variables.Variable( - [[1.0], [2.0]], dtype=dtype) - aggregated_update_var = variables.Variable( - [[1.0], [2.0]], dtype=dtype) + if use_resource: + repeated_index_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = resource_variable_ops.ResourceVariable( + [[1.0], [2.0]], dtype=dtype) + else: + repeated_index_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + aggregated_update_var = variables.Variable( + [[1.0], [2.0]], dtype=dtype) + grad_repeated_index = ops.IndexedSlices( constant_op.constant( [0.1, 0.1], shape=[2, 1], dtype=dtype), @@ -150,6 +162,204 @@ class AdamOptimizerTest(test.TestCase): self.assertAllClose(aggregated_update_var.eval(), repeated_index_update_var.eval()) + def doTestBasic(self, use_resource=False, use_callable_params=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = lambda: 0.001 + beta1 = lambda: 0.9 + beta2 = lambda: 0.999 + epsilon = lambda: 1e-8 + if not use_callable_params: + learning_rate = learning_rate() + beta1 = beta1() + beta2 = beta2() + epsilon = epsilon() + + opt = lazy_adam_optimizer.LazyAdamOptimizer(learning_rate=learning_rate) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + opt_variables = opt.variables() + beta1_power, beta2_power = opt._get_beta_accumulators() + self.assertIsNotNone(beta1_power) + self.assertIsNotNone(beta2_power is not None) + self.assertIn(beta1_power, opt_variables) + self.assertIn(beta2_power, opt_variables) + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + elif t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + self.assertAllCloseAccordingToType(0.9**(t + 1), + self.evaluate(beta1_power)) + self.assertAllCloseAccordingToType(0.999**(t + 1), + self.evaluate(beta2_power)) + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/Adam:0" % (i,), + opt.get_slot(var=var0, name="m").name) + + def testBasic(self): + with self.test_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + def testBasicCallableParams(self): + with context.eager_mode(): + self.doTestBasic(use_resource=True, use_callable_params=True) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_optimizer.LazyAdamOptimizer(constant_op.constant(0.001)) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Run 3 steps of Adam + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + update.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testSharing(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + opt = lazy_adam_optimizer.LazyAdamOptimizer() + update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + + beta1_power, beta2_power = opt._get_beta_accumulators() + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 3 steps of intertwined Adam1 and Adam2. + for t in range(1, 4): + self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval()) + self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval()) + if t % 2 == 0: + update1.run() + else: + update2.run() + + var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0) + var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testTwoSessions(self): + optimizer = lazy_adam_optimizer.LazyAdamOptimizer() + + with context.eager_mode(): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + g = ops.Graph() + with g.as_default(): + with self.session(graph=g): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + optimizer.apply_gradients([(grads0, var0)]) + + gg = ops.Graph() + with gg.as_default(): + with self.session(graph=gg): + var0 = variables.Variable(np.array([1.0, 2.0]), name="v0") + grads0 = constant_op.constant(np.array([0.1, 0.1])) + + # If the optimizer saves any state not keyed by graph the following line + # fails. + optimizer.apply_gradients([(grads0, var0)]) + + def testSlotsUniqueEager(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.) + v2 = resource_variable_ops.ResourceVariable(1.) + opt = lazy_adam_optimizer.LazyAdamOptimizer(1.) + opt.minimize(lambda: v1 + v2) + # There should be two non-slot variables, and two unique slot variables + # for v1 and v2 respectively. + self.assertEqual(6, len(set(opt.variables()))) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index 499fec4ffa..c59f667f6a 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -22,6 +22,7 @@ py_test( ":common", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:session", "//tensorflow/python:variable_scope", @@ -89,7 +90,6 @@ py_library( ":common", ":graph_matcher", ":input_to_ops", - "//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", @@ -171,7 +171,6 @@ py_library( ":graph_matcher", ":input_to_ops", ":quant_ops", - "//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py index bf648e158e..b27117dd48 100644 --- a/tensorflow/contrib/quantize/python/common.py +++ b/tensorflow/contrib/quantize/python/common.py @@ -131,3 +131,29 @@ def DropStringPrefix(s, prefix): return s[len(prefix):] else: return s + + +def RerouteTensor(t0, t1, can_modify=None): + """Reroute the end of the tensor t0 to the ends of the tensor t1. + + Args: + t0: a tf.Tensor. + t1: a tf.Tensor. + can_modify: iterable of operations which can be modified. Any operation + outside within_ops will be left untouched by this function. + + Returns: + The number of individual modifications made by the function. + """ + nb_update_inputs = 0 + consumers = t1.consumers() + if can_modify is not None: + consumers = [c for c in consumers if c in can_modify] + consumers_indices = {} + for c in consumers: + consumers_indices[c] = [i for i, t in enumerate(c.inputs) if t is t1] + for c in consumers: + for i in consumers_indices[c]: + c._update_input(i, t0) # pylint: disable=protected-access + nb_update_inputs += 1 + return nb_update_inputs diff --git a/tensorflow/contrib/quantize/python/common_test.py b/tensorflow/contrib/quantize/python/common_test.py index 06c62f2d26..2b26302f8a 100644 --- a/tensorflow/contrib/quantize/python/common_test.py +++ b/tensorflow/contrib/quantize/python/common_test.py @@ -20,8 +20,10 @@ from __future__ import print_function from tensorflow.contrib.quantize.python import common from tensorflow.python.client import session +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -62,6 +64,29 @@ class CommonTest(test_util.TensorFlowTestCase): _, step_val = sess.run([b, quantization_step_tensor]) self.assertEqual(step_val, 2) + def testRerouteTensor(self): + a = constant_op.constant(1, name='a') + b = constant_op.constant(2, name='b') + c = constant_op.constant(3, name='c') + d = constant_op.constant(4, name='d') + + add_ac = math_ops.add(a, c) + add_ad = math_ops.add(a, d) + + # Ensure that before rerouting the inputs are what we think. + self._CheckOpHasInputs(add_ac.op, [a, c]) + self._CheckOpHasInputs(add_ad.op, [a, d]) + + # references to tensor a should be replaced with b for all ops in + # can_modify. This means add_ac will be changed but add_ad will not. + common.RerouteTensor(b, a, can_modify=[add_ac.op]) + self._CheckOpHasInputs(add_ac.op, [b, c]) + self._CheckOpHasInputs(add_ad.op, [a, d]) + + def _CheckOpHasInputs(self, op, inputs): + for i in inputs: + self.assertIn(i, op.inputs) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index d9f179bee4..2971b28f45 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import re -from tensorflow.contrib import graph_editor from tensorflow.contrib.quantize.python import common from tensorflow.contrib.quantize.python import graph_matcher from tensorflow.contrib.quantize.python import input_to_ops @@ -134,8 +133,8 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay): bias_add_tensor = math_ops.add( new_layer_tensor, bias_tensor, name='add_fold') - nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor, - match.output_tensor) + nodes_modified_count = common.RerouteTensor(bias_add_tensor, + match.output_tensor) if nodes_modified_count == 0: raise ValueError('Folding batch norms failed, %s had no outputs.' % match.output_tensor.name) @@ -370,8 +369,9 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, lambda: match.bn_decay_mean_tensor, name='freeze_moving_mean') - graph_editor.reroute_ts( - [bn_decay_mean_out], [match.bn_decay_mean_tensor], + common.RerouteTensor( + bn_decay_mean_out, + match.bn_decay_mean_tensor, can_modify=bn_decay_mean_consumers) bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) @@ -380,8 +380,9 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, lambda: bn_decay_zero, lambda: match.bn_decay_var_tensor, name='freeze_moving_var') - graph_editor.reroute_ts( - [bn_decay_var_out], [match.bn_decay_var_tensor], + common.RerouteTensor( + bn_decay_var_out, + match.bn_decay_var_tensor, can_modify=bn_decay_var_consumers) correction_recip = utils.smart_cond( @@ -486,9 +487,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay): activation = common.GetEndpointActivationOp(graph, bn) if activation: - nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]], - [original_op.outputs[0]], - can_modify=[activation]) + nodes_modified_count = common.RerouteTensor( + folded_op.outputs[0], original_op.outputs[0], can_modify=[activation]) if nodes_modified_count != 1: raise ValueError('Unexpected inputs to op: %s' % activation.name) continue @@ -497,9 +497,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay): # operations instead of Relu* above. add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1) add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add') - nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]], - [original_op.outputs[0]], - can_modify=[add_bypass]) + nodes_modified_count = common.RerouteTensor( + folded_op.outputs[0], original_op.outputs[0], can_modify=[add_bypass]) if nodes_modified_count != 1: raise ValueError('Unexpected inputs to op: %s' % add_bypass.name) diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 2ddbd73ea6..e88db0acd5 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import re -from tensorflow.contrib import graph_editor from tensorflow.contrib.quantize.python import common from tensorflow.contrib.quantize.python import graph_matcher from tensorflow.contrib.quantize.python import input_to_ops @@ -592,8 +591,8 @@ def _InsertQuantOp(context, name=name_prefix + '/delayed_quant') if consumers: - tensors_modified_count = graph_editor.reroute_ts( - [quant], [inputs], can_modify=consumers) + tensors_modified_count = common.RerouteTensor( + quant, inputs, can_modify=consumers) # Some operations can have multiple output tensors going to the same # consumer. Since consumers is a set, we need to ensure that # tensors_modified_count is greater than or equal to the length of the set diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 5874245d58..4e67d80558 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -212,6 +212,7 @@ cuda_py_tests( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], + tags = ["noasan"], ) tf_custom_op_library( @@ -279,7 +280,10 @@ cuda_py_tests( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], - tags = ["no_oss"], + tags = [ + "no_oss", + "noasan", + ], ) tf_cc_test( @@ -287,6 +291,7 @@ tf_cc_test( size = "small", srcs = ["ops/gru_ops_test.cc"], data = [":python/ops/_gru_ops.so"], + tags = ["noasan"], # We must ensure that the dependencies can be dynamically linked since # the shared library must be able to use core:framework. # linkstatic = tf_kernel_tests_linkstatic(), @@ -306,6 +311,7 @@ tf_cc_test( size = "small", srcs = ["ops/lstm_ops_test.cc"], data = [":python/ops/_lstm_ops.so"], + tags = ["noasan"], # We must ensure that the dependencies can be dynamically linked since # the shared library must be able to use core:framework. # linkstatic = tf_kernel_tests_linkstatic(), diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index f74c95f962..06c481672c 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -97,10 +97,10 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): The default non-peephole implementation is based on: - http://www.bioinf.jku.at/publications/older/2604.pdf + https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf - S. Hochreiter and J. Schmidhuber. - "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. + Felix Gers, Jurgen Schmidhuber, and Fred Cummins. + "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999. The peephole implementation is based on: @@ -2448,10 +2448,10 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell): The default non-peephole implementation is based on: - http://www.bioinf.jku.at/publications/older/2604.pdf + https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf - S. Hochreiter and J. Schmidhuber. - "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. + Felix Gers, Jurgen Schmidhuber, and Fred Cummins. + "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999. The peephole implementation is based on: @@ -2802,9 +2802,11 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell): Training of Deep Neural Networks The default LSTM implementation based on: - http://www.bioinf.jku.at/publications/older/2604.pdf - S. Hochreiter and J. Schmidhuber. - "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. + + https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf + + Felix Gers, Jurgen Schmidhuber, and Fred Cummins. + "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999. The class uses optional peephole connections, optional cell clipping and an optional projection layer. diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index db970deff5..0042d37acd 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -134,19 +134,19 @@ def _get_default_head(params, weights_name, output_type, name=None): weight_column=weights_name, label_dimension=params.num_outputs, name=name, - loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) else: if params.num_classes == 2: return core_head_lib.binary_classification_head( weight_column=weights_name, name=name, - loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) else: return core_head_lib.multi_class_head( n_classes=params.num_classes, weight_column=weights_name, name=name, - loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) def get_model_fn(params, graph_builder_class, diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py index 537d94b797..3c0456dc2f 100644 --- a/tensorflow/contrib/tpu/__init__.py +++ b/tensorflow/contrib/tpu/__init__.py @@ -33,6 +33,7 @@ @@shard @@batch_parallel @@rewrite +@@outside_compilation @@CrossShardOptimizer diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 08e0465b71..d8c3872363 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -258,6 +258,8 @@ class KerasCrossShardOptimizer(keras_optimizers.Optimizer): return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads] def set_weights(self, weights): + # TODO(power): Figure out whether we really need this given there is no + # caller for this API yet. self._opt.set_weights() def get_weights(self): @@ -282,9 +284,9 @@ def _valid_name(tensor_name): def _replicated_optimizer(opt): """Wrap the optimizer `opt` with CrossShardOptimizer if applicable.""" - if tpu_function.get_tpu_context().number_of_shards == 1: - return opt - + # Always wrap `opt` with CrossShardOptimizer, even if we are running on a + # single core. This ensures Keras properly tracks and initializes optimizer + # variables. if isinstance(opt, keras_optimizers.TFOptimizer): return tpu_optimizer.CrossShardOptimizer(opt.optimizer) else: @@ -1420,7 +1422,7 @@ class KerasTPUModel(models.Model): y, sample_weights, batch_size) - self._pipeline_fit_loop( + return self._pipeline_fit_loop( x, y, sample_weights=sample_weights, diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 1e21cc5252..c1f90c3963 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -652,13 +652,28 @@ def split_compile_and_replicate(computation, # TODO(phawkins): consider removing this code. It will # be less confusing to clients if they knowingly choose to use resource # variables. + # Partitioned variables is not supported (b/112311320). + def custom_getter(getter, name, *args, **kwargs): + partitioner = kwargs["partitioner"] + if partitioner is None: + return getter(name, *args, **kwargs) + else: + raise ValueError( + "Partitioned variables are not supported on TPU. Got " + "`partitioner` that is {}.".format(partitioner)) + vscope = variable_scope.get_variable_scope() + saved_use_resource = vscope.use_resource + saved_custom_getter = vscope.custom_getter + vscope.set_use_resource(True) + vscope.set_custom_getter(custom_getter) outputs = computation(*computation_inputs) vscope.set_use_resource(saved_use_resource) + vscope.set_custom_getter(saved_custom_getter) # If the computation returns `None`, make it an empty tuple. if outputs is None: diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc index ad3dce1784..d4951b156c 100644 --- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc @@ -63,7 +63,7 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( } CHECK(dst_name.compare(rdma_mgr_->local_worker()) == 0); RdmaChannel* rc = rdma_mgr_->FindChannel(src_name); - string key(std::move(parsed.FullKey().ToString())); + string key(parsed.FullKey()); string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_); Device* dst_dev; |