diff options
author | 2018-03-19 20:06:26 -0700 | |
---|---|---|
committer | 2018-03-19 20:10:51 -0700 | |
commit | 56555d0604c029e8b92fcd354de3bf32b63b62d8 (patch) | |
tree | 372fae38a1fa92204fb3e0190083f7c5ccb79419 | |
parent | 79d06a6261a523866ace67f7b831d7f617d550e6 (diff) |
Adds final partial batch support for TPUEstimator.predict.
PiperOrigin-RevId: 189683528
-rw-r--r-- | tensorflow/contrib/tpu/BUILD | 11 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 212 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py | 291 |
3 files changed, 458 insertions, 56 deletions
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index ed930e44e8..eea19e9465 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -271,6 +271,17 @@ tf_py_test( ], ) +tf_py_test( + name = "tpu_estimator_signals_test", + size = "small", + srcs = ["python/tpu/tpu_estimator_signals_test.py"], + additional_deps = [ + ":tpu_estimator", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 32f15e60cd..5a8fa04e7c 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -49,6 +49,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -62,6 +63,7 @@ from tensorflow.python.training import evaluation from tensorflow.python.training import session_run_hook from tensorflow.python.training import training from tensorflow.python.training import training_util +from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect _INITIAL_LOSS = 1e7 @@ -678,8 +680,11 @@ def generate_per_host_enqueue_ops_fn_for_host( raise TypeError( 'For mode PREDICT, `input_fn` must return `Dataset` instead of ' '`features` and `labels`.') + if batch_axis is not None: + raise TypeError('For mode PREDICT, batch_axis is not supported yet.') inputs = _InputsWithStoppingSignals( - dataset=inputs.dataset, batch_size=ctx.batch_size_for_input_fn) + dataset=inputs.dataset, batch_size=ctx.batch_size_for_input_fn, + add_padding=True) if is_dataset: hooks.append(inputs.dataset_initializer_hook()) @@ -1620,11 +1625,6 @@ class TPUEstimator(estimator_lib.Estimator): 2. `input_fn` must return a `Dataset` instance rather than `features`. In fact, .train() and .evaluate() also support Dataset as return value. - 3. Each batch returned by `Dataset`'s iterator must have the *same static* - shape. This means two things: - - batch_size cannot be `None` - - the final batch must be padded by user to a full batch. - Example (MNIST): ---------------- ``` @@ -1639,41 +1639,9 @@ class TPUEstimator(estimator_lib.Estimator): [total_examples, height, width, 3], minval=-1, maxval=1) dataset = tf.data.Dataset.from_tensor_slices(images) - dataset = dataset.batch(batch_size) dataset = dataset.map(lambda images: {'image': images}) - def pad(tensor, missing_count): - # Pads out the batch dimension to the complete batch_size. - rank = len(tensor.shape) - assert rank > 0 - padding = tf.stack([[0, missing_count]] + [[0, 0]] * (rank - 1)) - padded_shape = (batch_size,) + tuple(tensor.shape[1:]) - padded_tensor = tf.pad(tensor, padding) - padded_tensor.set_shape(padded_shape) - return padded_tensor - - def pad_batch_if_incomplete(batch_features): - # Pads out the batch dimension for all features. - real_batch_size = tf.shape(batch_features["image"])[0] - - missing_count = tf.constant(batch_size, tf.int32) - real_batch_size - - padded_features = { - key: pad(tensor, missing_count) - for key, tensor in batch_features.iteritems() - } - padding_mask = tf.concat( - [ - tf.zeros((real_batch_size, 1), dtype=tf.int32), - tf.ones((missing_count, 1), dtype=tf.int32) - ], - axis=0) - padding_mask.set_shape((batch_size, 1)) - padded_features["is_padding"] = padding_mask - return padded_features - - dataset = dataset.map(pad_batch_if_incomplete) - + dataset = dataset.batch(batch_size) return dataset def model_fn(features, labels, params, mode): @@ -2089,12 +2057,14 @@ class TPUEstimator(estimator_lib.Estimator): predictions, message=( 'The estimated size for TPUEstimatorSpec.predictions is too ' 'large.')) - stopping_signals = host_call_ret['signals'] + signals = host_call_ret['signals'] with ops.control_dependencies(host_ops): host_ops = [] # Empty, we do do not need it anymore. scalar_stopping_signal = _StopSignals.as_scalar_stopping_signal( - stopping_signals) + signals) + predictions = _PaddingSignals.slice_tensor_or_dict( + predictions, signals) hooks = [ _StoppingPredictHook(scalar_stopping_signal), @@ -2389,20 +2359,19 @@ class _Inputs(object): return self._dataset -# TODO(xiejw): Extend this to support final partial batch. class _InputsWithStoppingSignals(_Inputs): """Inputs with `_StopSignals` inserted into the dataset.""" - def __init__(self, dataset, batch_size): + def __init__(self, dataset, batch_size, add_padding=False): assert dataset is not None user_provided_dataset = dataset.map( _InputsWithStoppingSignals.insert_stopping_signal( - stop=False, batch_size=batch_size)) + stop=False, batch_size=batch_size, add_padding=add_padding)) final_batch_dataset = dataset.take(1).map( _InputsWithStoppingSignals.insert_stopping_signal( - stop=True, batch_size=batch_size)) + stop=True, batch_size=batch_size, add_padding=add_padding)) dataset = user_provided_dataset.concatenate(final_batch_dataset).prefetch(2) super(_InputsWithStoppingSignals, self).__init__(dataset=dataset) @@ -2432,7 +2401,7 @@ class _InputsWithStoppingSignals(_Inputs): return signals @staticmethod - def insert_stopping_signal(stop, batch_size): + def insert_stopping_signal(stop, batch_size, add_padding=False): """Inserts stopping_signal into dataset via _map_fn. Here we change the data structure in the dataset, such that the return value @@ -2443,6 +2412,7 @@ class _InputsWithStoppingSignals(_Inputs): Args: stop: bool, state of current stopping signals. batch_size: int, batch size. + add_padding: bool, whether to pad the tensor to full batch size. Returns: A map_fn passed to dataset.map API. @@ -2456,11 +2426,25 @@ class _InputsWithStoppingSignals(_Inputs): args = args[0] features, labels = _Inputs._parse_inputs(args) new_input_dict = {} - new_input_dict['features'] = features - if labels is not None: - new_input_dict['labels'] = labels + + if add_padding: + padding_mask, features, labels = ( + _PaddingSignals.pad_features_and_labels( + features, labels, batch_size)) + + new_input_dict['features'] = features + if labels is not None: + new_input_dict['labels'] = labels + + else: + new_input_dict['features'] = features + if labels is not None: + new_input_dict['labels'] = labels + padding_mask = None + new_input_dict['signals'] = _StopSignals( - stop=stop, batch_size=batch_size).as_dict() + stop=stop, batch_size=batch_size, padding_mask=padding_mask).as_dict() + return new_input_dict return _map_fn @@ -2469,23 +2453,28 @@ class _InputsWithStoppingSignals(_Inputs): class _StopSignals(object): """Signals class holding all logic to handle TPU stopping condition.""" - NON_STOPPING_SIGNAL = 0.0 - STOPPING_SIGNAL = 1.0 + NON_STOPPING_SIGNAL = False + STOPPING_SIGNAL = True - def __init__(self, stop, batch_size): + def __init__(self, stop, batch_size, padding_mask=None): self._stop = stop self._batch_size = batch_size + self._padding_mask = padding_mask def as_dict(self): + """Returns the signals as Python dict.""" shape = [self._batch_size, 1] - dtype = dtypes.float32 + dtype = dtypes.bool if self._stop: stopping = array_ops.ones(shape=shape, dtype=dtype) else: stopping = array_ops.zeros(shape=shape, dtype=dtype) - return {'stopping': stopping} + signals = {'stopping': stopping} + if self._padding_mask is not None: + signals['padding_mask'] = self._padding_mask + return signals @staticmethod def as_scalar_stopping_signal(signals): @@ -2493,7 +2482,118 @@ class _StopSignals(object): @staticmethod def should_stop(scalar_stopping_signal): - return scalar_stopping_signal >= _StopSignals.STOPPING_SIGNAL + if isinstance(scalar_stopping_signal, ops.Tensor): + # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF + # way to express the bool check whether scalar_stopping_signal is True. + return math_ops.logical_and( + scalar_stopping_signal, _StopSignals.STOPPING_SIGNAL) + else: + # For non Tensor case, it is used in SessionRunHook. So, we cannot modify + # the graph anymore. Here, we use pure Python. + return bool(scalar_stopping_signal) + + +class _PaddingSignals(object): + """Signals class holding all logic to handle padding.""" + + @staticmethod + def pad_features_and_labels(features, labels, batch_size): + """Pads out the batch dimension of features and labels.""" + real_batch_size = array_ops.shape( + _PaddingSignals._find_any_tensor(features))[0] + + batch_size_tensor = constant_op.constant(batch_size, dtypes.int32) + + check_greater = check_ops.assert_greater_equal( + batch_size_tensor, real_batch_size, + data=(batch_size_tensor, real_batch_size), + message='The real batch size should not be greater than batch_size.') + + with ops.control_dependencies([check_greater]): + missing_count = batch_size_tensor - real_batch_size + + def pad_single_tensor(tensor): + """Pads out the batch dimension of a tensor to the complete batch_size.""" + rank = len(tensor.shape) + assert rank > 0 + padding = array_ops.stack([[0, missing_count]] + [[0, 0]] * (rank - 1)) + padded_shape = (batch_size,) + tuple(tensor.shape[1:]) + padded_tensor = array_ops.pad(tensor, padding) + padded_tensor.set_shape(padded_shape) + return padded_tensor + + def nest_pad(tensor_or_dict): + return nest.map_structure(pad_single_tensor, tensor_or_dict) + + features = nest_pad(features) + if labels is not None: + labels = nest_pad(labels) + + padding_mask = _PaddingSignals._padding_mask( + real_batch_size, missing_count, batch_size) + + return padding_mask, features, labels + + @staticmethod + def slice_tensor_or_dict(tensor_or_dict, signals): + """Slice the real Tensors according to padding mask in signals.""" + + padding_mask = signals['padding_mask'] + batch_size = array_ops.shape(padding_mask)[0] + + def verify_batch_size(tensor): + check_batch_size = math_ops.equal(batch_size, tensor.shape[0]) + with ops.control_dependencies([check_batch_size]): + return array_ops.identity(tensor) + + def slice_single_tensor(tensor): + rank = len(tensor.shape) + assert rank > 0 + real_batch_size = batch_size - math_ops.reduce_sum(padding_mask) + return verify_batch_size(tensor)[0:real_batch_size] + + # As we split the Tensors to all TPU cores and concat them back, it is + # important to ensure the real data is placed before padded ones, i.e., + # order is preserved. By that, the sliced padding mask should have all 0's. + # If this assertion failed, # the slice logic here would not hold. + sliced_padding_mask = slice_single_tensor(padding_mask) + assert_padding_mask = math_ops.equal( + math_ops.reduce_sum(sliced_padding_mask), 0) + + with ops.control_dependencies([assert_padding_mask]): + should_stop = _StopSignals.should_stop( + _StopSignals.as_scalar_stopping_signal(signals)) + + is_full_batch = math_ops.equal(math_ops.reduce_sum(padding_mask), 0) + + def slice_fn(tensor): + # If the current batch is full batch or part of stopping signals, we do + # not need to slice to save performance. + return control_flow_ops.cond( + math_ops.logical_or(should_stop, is_full_batch), + (lambda: verify_batch_size(tensor)), + (lambda: slice_single_tensor(tensor))) + + return nest.map_structure(slice_fn, tensor_or_dict) + + @staticmethod + def _find_any_tensor(batch_features): + tensors = [x for x in nest.flatten(batch_features) + if isinstance(x, ops.Tensor)] + if not tensors: + raise ValueError('Cannot find any Tensor in features dict.') + return tensors[0] + + @staticmethod + def _padding_mask(real_batch_size, missing_count, batch_size): + padding_mask = array_ops.concat( + [ + array_ops.zeros((real_batch_size,), dtype=dtypes.int32), + array_ops.ones((missing_count,), dtype=dtypes.int32) + ], + axis=0) + padding_mask.set_shape((batch_size,)) + return padding_mask class _SignalsHelper(object): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py new file mode 100644 index 0000000000..3e90957e6d --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator_signals_test.py @@ -0,0 +1,291 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TPU Estimator Signalling Tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tpu.python.tpu import tpu_estimator +from tensorflow.python import data as dataset_lib +from tensorflow.python.client import session +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.platform import test + + +def make_input_fn(num_samples): + a = np.linspace(0, 100.0, num=num_samples) + b = np.reshape(np.array(a, dtype=np.float32), (len(a), 1)) + + def input_fn(params): + batch_size = params['batch_size'] + da1 = dataset_lib.Dataset.from_tensor_slices(a) + da2 = dataset_lib.Dataset.from_tensor_slices(b) + + dataset = dataset_lib.Dataset.zip((da1, da2)) + dataset = dataset.map(lambda fa, fb: {'a': fa, 'b': fb}) + dataset = dataset.batch(batch_size) + return dataset + return input_fn, (a, b) + + +def make_input_fn_with_labels(num_samples): + a = np.linspace(0, 100.0, num=num_samples) + b = np.reshape(np.array(a, dtype=np.float32), (len(a), 1)) + + def input_fn(params): + batch_size = params['batch_size'] + da1 = dataset_lib.Dataset.from_tensor_slices(a) + da2 = dataset_lib.Dataset.from_tensor_slices(b) + + dataset = dataset_lib.Dataset.zip((da1, da2)) + dataset = dataset.map(lambda fa, fb: ({'a': fa}, fb)) + dataset = dataset.batch(batch_size) + return dataset + return input_fn, (a, b) + + +class TPUEstimatorStoppingSignalsTest(test.TestCase): + + def test_normal_output_without_signals(self): + num_samples = 4 + batch_size = 2 + + params = {'batch_size': batch_size} + input_fn, (a, b) = make_input_fn(num_samples=num_samples) + + with ops.Graph().as_default(): + dataset = input_fn(params) + features = dataset.make_one_shot_iterator().get_next() + + # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape. + self.assertIsNone(features['a'].shape.as_list()[0]) + + with session.Session() as sess: + result = sess.run(features) + self.assertAllEqual(a[:batch_size], result['a']) + self.assertAllEqual(b[:batch_size], result['b']) + + # This run should work as num_samples / batch_size = 2. + result = sess.run(features) + self.assertAllEqual(a[batch_size:num_samples], result['a']) + self.assertAllEqual(b[batch_size:num_samples], result['b']) + + with self.assertRaises(errors.OutOfRangeError): + # Given num_samples and batch_size, this run should fail. + sess.run(features) + + def test_output_with_stopping_signals(self): + num_samples = 4 + batch_size = 2 + + params = {'batch_size': batch_size} + input_fn, (a, b) = make_input_fn(num_samples=num_samples) + + with ops.Graph().as_default(): + dataset = input_fn(params) + inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size) + hook = inputs.dataset_initializer_hook() + features, _ = inputs.features_and_labels() + signals = inputs.signals() + + # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape. + self.assertIsNone(features['a'].shape.as_list()[0]) + + with session.Session() as sess: + hook.begin() + hook.after_create_session(sess, coord=None) + + result, evaluated_signals = sess.run([features, signals]) + self.assertAllEqual(a[:batch_size], result['a']) + self.assertAllEqual(b[:batch_size], result['b']) + self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) + + # This run should work as num_samples / batch_size = 2. + result, evaluated_signals = sess.run([features, signals]) + self.assertAllEqual(a[batch_size:num_samples], result['a']) + self.assertAllEqual(b[batch_size:num_samples], result['b']) + self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) + + # This run should work, *but* see STOP ('1') as signals + _, evaluated_signals = sess.run([features, signals]) + self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping']) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(features) + + +class TPUEstimatorStoppingSignalsWithPaddingTest(test.TestCase): + + def test_num_samples_divisible_by_batch_size(self): + num_samples = 4 + batch_size = 2 + + params = {'batch_size': batch_size} + input_fn, (a, b) = make_input_fn(num_samples=num_samples) + + with ops.Graph().as_default(): + dataset = input_fn(params) + inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size, + add_padding=True) + hook = inputs.dataset_initializer_hook() + features, _ = inputs.features_and_labels() + signals = inputs.signals() + + # With padding, all shapes are static now. + self.assertEqual(batch_size, features['a'].shape.as_list()[0]) + + with session.Session() as sess: + hook.begin() + hook.after_create_session(sess, coord=None) + + result, evaluated_signals = sess.run([features, signals]) + self.assertAllEqual(a[:batch_size], result['a']) + self.assertAllEqual(b[:batch_size], result['b']) + self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) + self.assertAllEqual([0.] * batch_size, + evaluated_signals['padding_mask']) + + # This run should work as num_samples / batch_size = 2. + result, evaluated_signals = sess.run([features, signals]) + self.assertAllEqual(a[batch_size:num_samples], result['a']) + self.assertAllEqual(b[batch_size:num_samples], result['b']) + self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) + self.assertAllEqual([0.] * batch_size, + evaluated_signals['padding_mask']) + + # This run should work, *but* see STOP ('1') as signals + _, evaluated_signals = sess.run([features, signals]) + self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping']) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(features) + + def test_num_samples_not_divisible_by_batch_size(self): + num_samples = 5 + batch_size = 2 + + params = {'batch_size': batch_size} + input_fn, (a, b) = make_input_fn_with_labels(num_samples=num_samples) + + with ops.Graph().as_default(): + dataset = input_fn(params) + inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size, + add_padding=True) + hook = inputs.dataset_initializer_hook() + features, labels = inputs.features_and_labels() + signals = inputs.signals() + + # With padding, all shapes are static. + self.assertEqual(batch_size, features['a'].shape.as_list()[0]) + + with session.Session() as sess: + hook.begin() + hook.after_create_session(sess, coord=None) + + evaluated_features, evaluated_labels, evaluated_signals = ( + sess.run([features, labels, signals])) + self.assertAllEqual(a[:batch_size], evaluated_features['a']) + self.assertAllEqual(b[:batch_size], evaluated_labels) + self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) + self.assertAllEqual([0.] * batch_size, + evaluated_signals['padding_mask']) + + # This run should work as num_samples / batch_size >= 2. + evaluated_features, evaluated_labels, evaluated_signals = ( + sess.run([features, labels, signals])) + self.assertAllEqual(a[batch_size:2*batch_size], evaluated_features['a']) + self.assertAllEqual(b[batch_size:2*batch_size], evaluated_labels) + self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) + self.assertAllEqual([0.] * batch_size, + evaluated_signals['padding_mask']) + + # This is the final partial batch. + evaluated_features, evaluated_labels, evaluated_signals = ( + sess.run([features, labels, signals])) + real_batch_size = num_samples % batch_size + + # Assert the real part. + self.assertAllEqual(a[2*batch_size:num_samples], + evaluated_features['a'][:real_batch_size]) + self.assertAllEqual(b[2*batch_size:num_samples], + evaluated_labels[:real_batch_size]) + # Assert the padded part. + self.assertAllEqual([0.0] * (batch_size - real_batch_size), + evaluated_features['a'][real_batch_size:]) + self.assertAllEqual([[0.0]] * (batch_size - real_batch_size), + evaluated_labels[real_batch_size:]) + + self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) + + padding = ([.0] * real_batch_size + + [1.] * (batch_size - real_batch_size)) + self.assertAllEqual(padding, evaluated_signals['padding_mask']) + + # This run should work, *but* see STOP ('1') as signals + _, evaluated_signals = sess.run([features, signals]) + self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping']) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(features) + + def test_slice(self): + num_samples = 3 + batch_size = 2 + + params = {'batch_size': batch_size} + input_fn, (a, b) = make_input_fn(num_samples=num_samples) + + with ops.Graph().as_default(): + dataset = input_fn(params) + inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size, + add_padding=True) + hook = inputs.dataset_initializer_hook() + features, _ = inputs.features_and_labels() + signals = inputs.signals() + + sliced_features = ( + tpu_estimator._PaddingSignals.slice_tensor_or_dict( + features, signals)) + + with session.Session() as sess: + hook.begin() + hook.after_create_session(sess, coord=None) + + result, evaluated_signals = sess.run([sliced_features, signals]) + self.assertAllEqual(a[:batch_size], result['a']) + self.assertAllEqual(b[:batch_size], result['b']) + self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) + + # This is the final partial batch. + result, evaluated_signals = sess.run([sliced_features, signals]) + self.assertEqual(1, len(result['a'])) + self.assertAllEqual(a[batch_size:num_samples], result['a']) + self.assertAllEqual(b[batch_size:num_samples], result['b']) + self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping']) + + # This run should work, *but* see STOP ('1') as signals + _, evaluated_signals = sess.run([sliced_features, signals]) + self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping']) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(sliced_features) + + +if __name__ == '__main__': + test.main() |