diff options
author | 2018-03-07 23:42:50 -0800 | |
---|---|---|
committer | 2018-03-07 23:42:50 -0800 | |
commit | f8363dc424f78ec06c9fe2faee7623624aa0392e (patch) | |
tree | 9da5395dbd3707994964ac03b37f36b5545ce37a | |
parent | 9d867e0c34ea34ac74ebdab2cdcfc5b8c61fed25 (diff) | |
parent | 9cdfd3878935fb6c3c2a5da7f65ee0db6c751170 (diff) |
Merge pull request #17536 from yifeif/branch_188272354
Branch 188272354
34 files changed, 636 insertions, 912 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py index 819095a060..dabadfc7b6 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py @@ -462,138 +462,6 @@ class HMCTest(test.TestCase): def testKernelLeavesTargetInvariant3(self): self._kernel_leaves_target_invariant_wrapper(3) - def _ais_gets_correct_log_normalizer(self, init, independent_chain_ndims, - sess, feed_dict=None): - counter = collections.Counter() - - def proposal_log_prob(x): - counter["proposal_calls"] += 1 - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi), - axis=event_dims) - - def target_log_prob(x): - counter["target_calls"] += 1 - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - if feed_dict is None: - feed_dict = {} - - num_steps = 200 - - _, ais_weights, _ = hmc.sample_annealed_importance_chain( - proposal_log_prob_fn=proposal_log_prob, - num_steps=num_steps, - target_log_prob_fn=target_log_prob, - step_size=0.5, - current_state=init, - num_leapfrog_steps=2, - seed=45) - - # We have three calls because the calculation of `ais_weights` entails - # another call to the `convex_combined_log_prob_fn`. We could refactor - # things to avoid this, if needed (eg, b/72994218). - self.assertAllEqual(dict(target_calls=3, proposal_calls=3), counter) - - event_shape = array_ops.shape(init)[independent_chain_ndims:] - event_size = math_ops.reduce_prod(event_shape) - - log_true_normalizer = ( - -self._shape_param * math_ops.log(self._rate_param) - + math_ops.lgamma(self._shape_param)) - log_true_normalizer *= math_ops.cast(event_size, log_true_normalizer.dtype) - - log_estimated_normalizer = (math_ops.reduce_logsumexp(ais_weights) - - np.log(num_steps)) - - ratio_estimate_true = math_ops.exp(ais_weights - log_true_normalizer) - ais_weights_size = array_ops.size(ais_weights) - standard_error = math_ops.sqrt( - _reduce_variance(ratio_estimate_true) - / math_ops.cast(ais_weights_size, ratio_estimate_true.dtype)) - - [ - ratio_estimate_true_, - log_true_normalizer_, - log_estimated_normalizer_, - standard_error_, - ais_weights_size_, - event_size_, - ] = sess.run([ - ratio_estimate_true, - log_true_normalizer, - log_estimated_normalizer, - standard_error, - ais_weights_size, - event_size, - ], feed_dict) - - logging_ops.vlog(1, " log_true_normalizer: {}\n" - " log_estimated_normalizer: {}\n" - " ais_weights_size: {}\n" - " event_size: {}\n".format( - log_true_normalizer_, - log_estimated_normalizer_, - ais_weights_size_, - event_size_)) - self.assertNear(ratio_estimate_true_.mean(), 1., 4. * standard_error_) - - def _ais_gets_correct_log_normalizer_wrapper(self, independent_chain_ndims): - """Tests that AIS yields reasonable estimates of normalizers.""" - with self.test_session(graph=ops.Graph()) as sess: - x_ph = array_ops.placeholder(np.float32, name="x_ph") - initial_draws = np.random.normal(size=[30, 2, 1]) - self._ais_gets_correct_log_normalizer( - x_ph, - independent_chain_ndims, - sess, - feed_dict={x_ph: initial_draws}) - - def testAIS1(self): - self._ais_gets_correct_log_normalizer_wrapper(1) - - def testAIS2(self): - self._ais_gets_correct_log_normalizer_wrapper(2) - - def testAIS3(self): - self._ais_gets_correct_log_normalizer_wrapper(3) - - def testSampleAIChainSeedReproducibleWorksCorrectly(self): - with self.test_session(graph=ops.Graph()) as sess: - independent_chain_ndims = 1 - x = np.random.rand(4, 3, 2) - - def proposal_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi), - axis=event_dims) - - def target_log_prob(x): - event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) - return self._log_gamma_log_prob(x, event_dims) - - ais_kwargs = dict( - proposal_log_prob_fn=proposal_log_prob, - num_steps=200, - target_log_prob_fn=target_log_prob, - step_size=0.5, - current_state=x, - num_leapfrog_steps=2, - seed=53) - - _, ais_weights0, _ = hmc.sample_annealed_importance_chain( - **ais_kwargs) - - _, ais_weights1, _ = hmc.sample_annealed_importance_chain( - **ais_kwargs) - - [ais_weights0_, ais_weights1_] = sess.run([ - ais_weights0, ais_weights1]) - - self.assertAllClose(ais_weights0_, ais_weights1_, - atol=1e-5, rtol=1e-5) - def testNanRejection(self): """Tests that an update that yields NaN potentials gets rejected. diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc.py b/tensorflow/contrib/bayesflow/python/ops/hmc.py index 7fd5652c5c..c8a5a195d3 100644 --- a/tensorflow/contrib/bayesflow/python/ops/hmc.py +++ b/tensorflow/contrib/bayesflow/python/ops/hmc.py @@ -24,7 +24,6 @@ from tensorflow.python.util import all_util _allowed_symbols = [ "sample_chain", - "sample_annealed_importance_chain", "kernel", ] diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py index 82693c2b7b..66afcc7497 100644 --- a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py @@ -15,7 +15,6 @@ """Hamiltonian Monte Carlo, a gradient-based MCMC algorithm. @@sample_chain -@@sample_annealed_importance_chain @@kernel """ @@ -38,7 +37,6 @@ from tensorflow.python.ops.distributions import util as distributions_util __all__ = [ "sample_chain", - "sample_annealed_importance_chain", "kernel", ] @@ -330,221 +328,6 @@ def sample_chain( return functional_ops.scan(**scan_kwargs) -def sample_annealed_importance_chain( - proposal_log_prob_fn, - num_steps, - target_log_prob_fn, - current_state, - step_size, - num_leapfrog_steps, - seed=None, - name=None): - """Runs annealed importance sampling (AIS) to estimate normalizing constants. - - This function uses Hamiltonian Monte Carlo to sample from a series of - distributions that slowly interpolates between an initial "proposal" - distribution: - - `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)` - - and the target distribution: - - `exp(target_log_prob_fn(x) - target_log_normalizer)`, - - accumulating importance weights along the way. The product of these - importance weights gives an unbiased estimate of the ratio of the - normalizing constants of the initial distribution and the target - distribution: - - `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`. - - Note: `proposal_log_prob_fn` and `target_log_prob_fn` are called exactly three - times (although this may be reduced to two times, in the future). - - #### Examples: - - ##### Estimate the normalizing constant of a log-gamma distribution. - - ```python - tfd = tf.contrib.distributions - - # Run 100 AIS chains in parallel - num_chains = 100 - dims = 20 - dtype = np.float32 - - proposal = tfd.MultivatiateNormalDiag( - loc=tf.zeros([dims], dtype=dtype)) - - target = tfd.TransformedDistribution( - distribution=tfd.Gamma(concentration=dtype(2), - rate=dtype(3)), - bijector=tfd.bijectors.Invert(tfd.bijectors.Exp()), - event_shape=[dims]) - - chains_state, ais_weights, kernels_results = ( - hmc.sample_annealed_importance_chain( - proposal_log_prob_fn=proposal.log_prob, - num_steps=1000, - target_log_prob_fn=target.log_prob, - step_size=0.2, - current_state=proposal.sample(num_chains), - num_leapfrog_steps=2)) - - log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights) - - np.log(num_chains)) - log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.) - ``` - - ##### Estimate marginal likelihood of a Bayesian regression model. - - ```python - tfd = tf.contrib.distributions - - def make_prior(dims, dtype): - return tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)) - - def make_likelihood(weights, x): - return tfd.MultivariateNormalDiag( - loc=tf.tensordot(weights, x, axes=[[0], [-1]])) - - # Run 100 AIS chains in parallel - num_chains = 100 - dims = 10 - dtype = np.float32 - - # Make training data. - x = np.random.randn(num_chains, dims).astype(dtype) - true_weights = np.random.randn(dims).astype(dtype) - y = np.dot(x, true_weights) + np.random.randn(num_chains) - - # Setup model. - prior = make_prior(dims, dtype) - def target_log_prob_fn(weights): - return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y) - - proposal = tfd.MultivariateNormalDiag( - loc=tf.zeros(dims, dtype)) - - weight_samples, ais_weights, kernel_results = ( - hmc.sample_annealed_importance_chain( - num_steps=1000, - proposal_log_prob_fn=proposal.log_prob, - target_log_prob_fn=target_log_prob_fn - current_state=tf.zeros([num_chains, dims], dtype), - step_size=0.1, - num_leapfrog_steps=2)) - log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights) - - np.log(num_chains)) - ``` - - Args: - proposal_log_prob_fn: Python callable that returns the log density of the - initial distribution. - num_steps: Integer number of Markov chain updates to run. More - iterations means more expense, but smoother annealing between q - and p, which in turn means exponentially lower variance for the - normalizing constant estimator. - target_log_prob_fn: Python callable which takes an argument like - `current_state` (or `*current_state` if it's a list) and returns its - (possibly unnormalized) log-density under the target distribution. - current_state: `Tensor` or Python `list` of `Tensor`s representing the - current state(s) of the Markov chain(s). The first `r` dimensions index - independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. - step_size: `Tensor` or Python `list` of `Tensor`s representing the step size - for the leapfrog integrator. Must broadcast with the shape of - `current_state`. Larger step sizes lead to faster progress, but too-large - step sizes make rejection exponentially more likely. When possible, it's - often helpful to match per-variable step sizes to the standard deviations - of the target distribution in each variable. - num_leapfrog_steps: Integer number of steps to run the leapfrog integrator - for. Total progress per HMC step is roughly proportional to `step_size * - num_leapfrog_steps`. - seed: Python integer to seed the random number generator. - name: Python `str` name prefixed to Ops created by this function. - Default value: `None` (i.e., "hmc_sample_annealed_importance_chain"). - - Returns: - next_state: `Tensor` or Python list of `Tensor`s representing the - state(s) of the Markov chain(s) at the final iteration. Has same shape as - input `current_state`. - ais_weights: Tensor with the estimated weight(s). Has shape matching - `target_log_prob_fn(current_state)`. - kernel_results: `collections.namedtuple` of internal calculations used to - advance the chain. - """ - def make_convex_combined_log_prob_fn(iter_): - def _fn(*args): - p = proposal_log_prob_fn(*args) - t = target_log_prob_fn(*args) - dtype = p.dtype.base_dtype - beta = (math_ops.cast(iter_ + 1, dtype) - / math_ops.cast(num_steps, dtype)) - return (1. - beta) * p + beta * t - return _fn - - with ops.name_scope( - name, "hmc_sample_annealed_importance_chain", - [num_steps, current_state, step_size, num_leapfrog_steps, seed]): - with ops.name_scope("initialize"): - [ - current_state, - step_size, - current_log_prob, - current_grads_log_prob, - ] = _prepare_args( - make_convex_combined_log_prob_fn(iter_=0), - current_state, - step_size, - description="convex_combined_log_prob") - num_steps = ops.convert_to_tensor( - num_steps, - dtype=dtypes.int32, - name="num_steps") - num_leapfrog_steps = ops.convert_to_tensor( - num_leapfrog_steps, - dtype=dtypes.int32, - name="num_leapfrog_steps") - def _loop_body(iter_, ais_weights, current_state, kernel_results): - """Closure which implements `tf.while_loop` body.""" - current_state_parts = (list(current_state) - if _is_list_like(current_state) - else [current_state]) - # TODO(b/72994218): Consider refactoring things to avoid this unecessary - # call. - ais_weights += ((target_log_prob_fn(*current_state_parts) - - proposal_log_prob_fn(*current_state_parts)) - / math_ops.cast(num_steps, ais_weights.dtype)) - return [iter_ + 1, ais_weights] + list(kernel( - make_convex_combined_log_prob_fn(iter_), - current_state, - step_size, - num_leapfrog_steps, - seed, - kernel_results.current_target_log_prob, - kernel_results.current_grads_target_log_prob)) - - while_loop_kwargs = dict( - cond=lambda iter_, *args: iter_ < num_steps, - body=_loop_body, - loop_vars=[ - np.int32(0), # iter_ - array_ops.zeros_like(current_log_prob), # ais_weights - current_state, - _make_dummy_kernel_results(current_state, - current_log_prob, - current_grads_log_prob), - ]) - if seed is not None: - while_loop_kwargs["parallel_iterations"] = 1 - - [ais_weights, current_state, kernel_results] = control_flow_ops.while_loop( - **while_loop_kwargs)[1:] # Lop-off "iter_". - - return [current_state, ais_weights, kernel_results] - - def kernel(target_log_prob_fn, current_state, step_size, diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 91874f9b5c..300b19733e 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -147,7 +147,9 @@ class TPUClusterResolver(ClusterResolver): if service is None and should_resolve: if not _GOOGLE_API_CLIENT_INSTALLED: raise ImportError('googleapiclient must be installed before using the ' - 'TPU cluster resolver') + 'TPU cluster resolver. Execute: `pip install ' + '--upgrade google-api-python-client` to install with ' + 'pip.') self._service = discovery.build( 'tpu', 'v1alpha1', diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 54921aeec6..9212b69700 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -31,6 +31,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@enumerate_dataset @@group_by_window @@ignore_errors +@@make_batched_features_dataset @@make_saveable_from_iterator @@map_and_batch @@padded_batch_and_drop_remainder @@ -66,6 +67,7 @@ from tensorflow.contrib.data.python.ops.grouping import group_by_window from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator +from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset from tensorflow.contrib.data.python.ops.readers import read_batch_features from tensorflow.contrib.data.python.ops.readers import SqlDataset from tensorflow.contrib.data.python.ops.resampling import rejection_resample diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index a157acc020..107625ab75 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -297,6 +297,7 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 6efe97444a..15bd55bf64 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -21,6 +21,8 @@ import gzip import os import zlib +import numpy as np + from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import readers from tensorflow.core.example import example_pb2 @@ -262,12 +264,19 @@ class ReadBatchFeaturesTest(test.TestCase): self._num_records = 7 self.test_filenames = self._createFiles() - def _read_batch_features(self, filenames, num_epochs, batch_size): + def _read_batch_features(self, + filenames, + num_epochs, + batch_size, + reader_num_threads=1, + parser_num_threads=1, + shuffle=False, + shuffle_seed=None): self.filenames = filenames self.num_epochs = num_epochs self.batch_size = batch_size - return readers.read_batch_features( + return readers.make_batched_features_dataset( file_pattern=self.filenames, batch_size=self.batch_size, features={ @@ -276,8 +285,12 @@ class ReadBatchFeaturesTest(test.TestCase): "keywords": parsing_ops.VarLenFeature(dtypes.string) }, reader=core_readers.TFRecordDataset, - randomize_input=False, - num_epochs=self.num_epochs) + num_epochs=self.num_epochs, + shuffle=shuffle, + shuffle_seed=shuffle_seed, + reader_num_threads=reader_num_threads, + parser_num_threads=parser_num_threads).make_one_shot_iterator( + ).get_next() def _record(self, f, r): example = example_pb2.Example(features=feature_pb2.Features( @@ -312,24 +325,35 @@ class ReadBatchFeaturesTest(test.TestCase): writer.close() return filenames - def _next_actual_batch(self, sess): - file_op = self.outputs["file"] - keywords_indices_op = self.outputs["keywords"].indices - keywords_values_op = self.outputs["keywords"].values - keywords_dense_shape_op = self.outputs["keywords"].dense_shape - record_op = self.outputs["record"] + def _run_actual_batch(self, outputs, sess): + file_op = outputs["file"] + keywords_indices_op = outputs["keywords"].indices + keywords_values_op = outputs["keywords"].values + keywords_dense_shape_op = outputs["keywords"].dense_shape + record_op = outputs["record"] return sess.run([ file_op, keywords_indices_op, keywords_values_op, keywords_dense_shape_op, record_op ]) - def _next_expected_batch(self, file_indices, batch_size, num_epochs): + def _next_actual_batch(self, sess): + return self._run_actual_batch(self.outputs, sess) + + def _next_expected_batch(self, + file_indices, + batch_size, + num_epochs, + cycle_length=1): def _next_record(file_indices): for j in file_indices: for i in range(self._num_records): yield j, i + def _next_record_interleaved(file_indices, cycle_length): + return self._interleave([_next_record([i]) for i in file_indices], + cycle_length) + file_batch = [] keywords_batch_indices = [] keywords_batch_values = [] @@ -337,7 +361,11 @@ class ReadBatchFeaturesTest(test.TestCase): record_batch = [] batch_index = 0 for _ in range(num_epochs): - for record in _next_record(file_indices): + if cycle_length == 1: + next_records = _next_record(file_indices) + else: + next_records = _next_record_interleaved(file_indices, cycle_length) + for record in next_records: f = record[0] r = record[1] file_batch.append(f) @@ -365,14 +393,41 @@ class ReadBatchFeaturesTest(test.TestCase): [len(file_batch), keywords_batch_max_len], record_batch ] - def _verify_records(self, sess, batch_size, file_index=None, num_epochs=1): + def _interleave(self, iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 + + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 + + def _verify_records(self, + sess, + batch_size, + file_index=None, + num_epochs=1, + interleave_cycle_length=1): if file_index is not None: file_indices = [file_index] else: file_indices = range(self._num_files) - for expected_batch in self._next_expected_batch(file_indices, batch_size, - num_epochs): + for expected_batch in self._next_expected_batch( + file_indices, batch_size, num_epochs, interleave_cycle_length): actual_batch = self._next_actual_batch(sess) for i in range(len(expected_batch)): self.assertAllEqual(expected_batch[i], actual_batch[i]) @@ -435,6 +490,75 @@ class ReadBatchFeaturesTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testReadWithFusedShuffleRepeatDataset(self): + num_epochs = 5 + total_records = num_epochs * self._num_records + for batch_size in [1, 2]: + # Test that shuffling with same seed produces the same result. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + outputs1 = self._read_batch_features( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5) + outputs2 = self._read_batch_features( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5) + for _ in range(total_records // batch_size): + batch1 = self._run_actual_batch(outputs1, sess) + batch2 = self._run_actual_batch(outputs2, sess) + for i in range(len(batch1)): + self.assertAllEqual(batch1[i], batch2[i]) + + # Test that shuffling with different seeds produces a different order. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + outputs1 = self._read_batch_features( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5) + outputs2 = self._read_batch_features( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=15) + all_equal = True + for _ in range(total_records // batch_size): + batch1 = self._run_actual_batch(outputs1, sess) + batch2 = self._run_actual_batch(outputs2, sess) + for i in range(len(batch1)): + all_equal = all_equal and np.array_equal(batch1[i], batch2[i]) + self.assertFalse(all_equal) + + def testParallelReadersAndParsers(self): + num_epochs = 5 + for batch_size in [1, 2]: + for reader_num_threads in [2, 4]: + for parser_num_threads in [2, 4]: + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + self.outputs = self._read_batch_features( + filenames=self.test_filenames, + num_epochs=num_epochs, + batch_size=batch_size, + reader_num_threads=reader_num_threads, + parser_num_threads=parser_num_threads) + self._verify_records( + sess, + batch_size, + num_epochs=num_epochs, + interleave_cycle_length=reader_num_threads) + with self.assertRaises(errors.OutOfRangeError): + self._next_actual_batch(sess) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 1c26296d62..fd871ef5ce 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -67,6 +67,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":dataset_ops", + ":shuffle_ops", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 57f3010277..b346bed3e6 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -17,7 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -25,12 +28,150 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile +from tensorflow.python.util import deprecation +def make_batched_features_dataset(file_pattern, + batch_size, + features, + reader=core_readers.TFRecordDataset, + reader_args=None, + num_epochs=None, + shuffle=True, + shuffle_buffer_size=10000, + shuffle_seed=None, + prefetch_buffer_size=1, + reader_num_threads=1, + parser_num_threads=2, + sloppy_ordering=False): + """Returns a `Dataset` of feature dictionaries from `Example` protos. + + Example: + + ``` + serialized_examples = [ + features { + feature { key: "age" value { int64_list { value: [ 0 ] } } } + feature { key: "gender" value { bytes_list { value: [ "f" ] } } } + feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } } + }, + features { + feature { key: "age" value { int64_list { value: [] } } } + feature { key: "gender" value { bytes_list { value: [ "f" ] } } } + feature { key: "kws" value { bytes_list { value: [ "sports" ] } } } + } + ] + ``` + + We can use arguments: + + ``` + features: { + "age": FixedLenFeature([], dtype=tf.int64, default_value=-1), + "gender": FixedLenFeature([], dtype=tf.string), + "kws": VarLenFeature(dtype=tf.string), + } + ``` + + And the expected output is: + + ```python + { + "age": [[0], [-1]], + "gender": [["f"], ["f"]], + "kws": SparseTensor( + indices=[[0, 0], [0, 1], [1, 0]], + values=["code", "art", "sports"] + dense_shape=[2, 2]), + } + ``` + + Args: + file_pattern: List of files or patterns of file paths containing + `Example` records. See `tf.gfile.Glob` for pattern rules. + batch_size: An int representing the number of consecutive elements of this + dataset to combine in a single batch. + features: A `dict` mapping feature keys to `FixedLenFeature` or + `VarLenFeature` values. See `tf.parse_example`. + reader: A function or class that can be + called with a `filenames` tensor and (optional) `reader_args` and returns + a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`. + reader_args: Additional arguments to pass to the reader class. + num_epochs: Integer specifying the number of times to read through the + dataset. If None, cycles through the dataset forever. Defaults to `None`. + shuffle: A boolean, indicates whether the input should be shuffled. Defaults + to `True`. + shuffle_buffer_size: Buffer size of the ShuffleDataset. A large capacity + ensures better shuffling but would increase memory usage and startup time. + shuffle_seed: Randomization seed to use for shuffling. + prefetch_buffer_size: Number of feature batches to prefetch in order to + improve performance. Recommended value is the number of batches consumed + per training step (default is 1). + reader_num_threads: Number of threads used to read `Example` records. If >1, + the results will be interleaved. + parser_num_threads: Number of threads to use for parsing `Example` tensors + into a dictionary of `Feature` tensors. + sloppy_ordering: If `True`, reading performance will be improved at + the cost of non-deterministic ordering. If `False`, the order of elements + produced is deterministic prior to shuffling (elements are still + randomized if `shuffle=True`. Note that if the seed is set, then order + of elements after shuffling is deterministic). Defaults to `False`. + + Returns: + A dataset of `dict` elements. Each `dict` maps feature keys to + `Tensor` or `SparseTensor` objects. + """ + # Create dataset of all matching filenames + if shuffle: + dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=True) + else: + # TODO(b/73959787): Use Dataset.list_files() once ordering is deterministic. + filenames = _get_file_names(file_pattern, shuffle) + dataset = dataset_ops.Dataset.from_tensor_slices(filenames) + + # Read `Example` records from files as tensor objects. + if reader_args is None: + reader_args = [] + + # Read files sequentially (if reader_num_threads=1) or in parallel + dataset = dataset.apply( + interleave_ops.parallel_interleave( + lambda filename: reader(filename, *reader_args), + cycle_length=reader_num_threads, + sloppy=sloppy_ordering)) + + # Extract values if the `Example` tensors are stored as key-value tuples. + if dataset.output_types == (dtypes.string, dtypes.string): + dataset = dataset.map(lambda _, v: v) + + # Apply dataset repeat and shuffle transformations. + repeat_dataset = (num_epochs != 1) + if repeat_dataset and shuffle: + # Used fused shuffle_and_repeat operation for better performance + dataset = dataset.apply( + shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, + shuffle_seed)) + elif repeat_dataset: + dataset = dataset.repeat(num_epochs) + elif shuffle: + dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed) + + dataset = dataset.batch(batch_size) + + # Parse `Example` tensors to a dictionary of `Feature` tensors. + dataset = dataset.map( + lambda x: parsing_ops.parse_example(x, features), + num_parallel_calls=parser_num_threads) + dataset = dataset.prefetch(prefetch_buffer_size) + return dataset + + +@deprecation.deprecated(None, + "Use `tf.contrib.data.make_batched_features_dataset`") def read_batch_features(file_pattern, batch_size, features, - reader, + reader=core_readers.TFRecordDataset, reader_args=None, randomize_input=True, num_epochs=None, @@ -84,43 +225,38 @@ def read_batch_features(file_pattern, dataset to combine in a single batch. features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. See `tf.parse_example`. - reader: A function or class that can be called with a `filenames` tensor - and (optional) `reader_args` and returns a `Dataset` of Examples. + reader: A function or class that can be + called with a `filenames` tensor and (optional) `reader_args` and returns + a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`. reader_args: Additional arguments to pass to the reader class. randomize_input: Whether the input should be randomized. num_epochs: Integer specifying the number of times to read through the dataset. If None, cycles through the dataset forever. - capacity: Capacity of the ShuffleDataset. A large capacity ensures better + capacity: Buffer size of the ShuffleDataset. A large capacity ensures better shuffling but would increase memory usage and startup time. - Returns: A dict from keys in features to `Tensor` or `SparseTensor` objects. """ - filenames = _get_file_names(file_pattern, randomize_input) - if reader_args: - dataset = reader(filenames, *reader_args) - else: - dataset = reader(filenames) - if dataset.output_types == (dtypes.string, dtypes.string): - dataset = dataset.map(lambda _, v: v) - if num_epochs != 1: - dataset = dataset.repeat(num_epochs) - if randomize_input: - dataset = dataset.shuffle(capacity) - dataset = dataset.batch(batch_size) - dataset = dataset.map(lambda x: parsing_ops.parse_example(x, features)) - dataset = dataset.prefetch(1) + dataset = make_batched_features_dataset( + file_pattern, + batch_size, + features, + reader=reader, + reader_args=reader_args, + shuffle=randomize_input, + num_epochs=num_epochs, + shuffle_buffer_size=capacity) iterator = dataset.make_one_shot_iterator() outputs = iterator.get_next() return outputs -def _get_file_names(file_pattern, randomize_input): +def _get_file_names(file_pattern, shuffle): """Parse list of file names from pattern, optionally shuffled. Args: file_pattern: File glob pattern, or list of glob patterns. - randomize_input: Whether to shuffle the order of file names. + shuffle: Whether to shuffle the order of file names. Returns: List of file names matching `file_pattern`. @@ -141,7 +277,7 @@ def _get_file_names(file_pattern, randomize_input): raise ValueError("No files match %s." % file_pattern) # Sort files so it will be deterministic for unit tests. - if not randomize_input: + if not shuffle: file_names = sorted(file_names) return file_names diff --git a/tensorflow/contrib/eager/python/checkpointable_utils.py b/tensorflow/contrib/eager/python/checkpointable_utils.py index 1fa150f3c6..d07121df63 100644 --- a/tensorflow/contrib/eager/python/checkpointable_utils.py +++ b/tensorflow/contrib/eager/python/checkpointable_utils.py @@ -493,8 +493,9 @@ class NameBasedSaverStatus(_LoadStatus): """Load the name-based training checkpoint using a new `tf.train.Saver`.""" if session is None and not context.executing_eagerly(): session = ops.get_default_session() - saver_lib.Saver(self._object_saver._global_variable_names()).restore( # pylint: disable=protected-access - sess=session, save_path=self._save_path) + with ops.device("/cpu:0"): + saver_lib.Saver(self._object_saver._global_variable_names()).restore( # pylint: disable=protected-access + sess=session, save_path=self._save_path) def initialize_or_restore(self, session=None): """Alias for `run_restore_ops`.""" diff --git a/tensorflow/contrib/eager/python/checkpointable_utils_test.py b/tensorflow/contrib/eager/python/checkpointable_utils_test.py index fd9fc098b3..2054878bf8 100644 --- a/tensorflow/contrib/eager/python/checkpointable_utils_test.py +++ b/tensorflow/contrib/eager/python/checkpointable_utils_test.py @@ -993,20 +993,21 @@ class CheckpointCompatibilityTests(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testLoadFromNameBasedSaver(self): """Save a name-based checkpoint, load it using the object-based API.""" - save_path = self._write_name_based_checkpoint() - root = self._initialized_model() - self._set_sentinels(root) - with self.assertRaises(AssertionError): + with test_util.device(use_gpu=True): + save_path = self._write_name_based_checkpoint() + root = self._initialized_model() + self._set_sentinels(root) + with self.assertRaises(AssertionError): + self._check_sentinels(root) + object_saver = checkpointable_utils.CheckpointableSaver(root) + status = object_saver.restore(save_path) + with self.assertRaises(AssertionError): + status.assert_consumed() + status.run_restore_ops() + self._check_sentinels(root) + self._set_sentinels(root) + status.initialize_or_restore() self._check_sentinels(root) - object_saver = checkpointable_utils.CheckpointableSaver(root) - status = object_saver.restore(save_path) - with self.assertRaises(AssertionError): - status.assert_consumed() - status.run_restore_ops() - self._check_sentinels(root) - self._set_sentinels(root) - status.initialize_or_restore() - self._check_sentinels(root) # TODO(allenl): Test for the core name-based saver loading object-based # checkpoints once object-based checkpointing is in core. diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index c7290c2aaa..aa3957bee1 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -213,7 +213,10 @@ cc_library( "compatibility.h", "quantization_util.h", ], - deps = [":round"], + deps = [ + ":round", + ":types", + ], ) cc_test( diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h index b84d2f9ee1..f7706c7938 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h @@ -15,10 +15,88 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ +#include <cmath> #include <cstdint> +#include <limits> + +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/internal/round.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" namespace tflite { +// Given the min and max values of a float array, return +// reasonable quantization parameters to use for this array. +template <typename T> +QuantizationParams ChooseQuantizationParams(double rmin, double rmax) { + const T qmin = std::numeric_limits<T>::min(); + const T qmax = std::numeric_limits<T>::max(); + const double qmin_double = qmin; + const double qmax_double = qmax; + // 0 should always be a representable value. Let's assume that the initial + // min,max range contains 0. + TFLITE_CHECK_LE(rmin, 0.); + TFLITE_CHECK_GE(rmax, 0.); + if (rmin == rmax) { + // Special case where the min,max range is a point. Should be {0}. + TFLITE_CHECK_EQ(rmin, 0.); + TFLITE_CHECK_EQ(rmax, 0.); + QuantizationParams quantization_params; + quantization_params.zero_point = 0; + quantization_params.scale = 0.; + return quantization_params; + } + + // General case. + // + // First determine the scale. + const double scale = (rmax - rmin) / (qmax_double - qmin_double); + + // Zero-point computation. + // First the initial floating-point computation. The zero-point can be + // determined from solving an affine equation for any known pair + // (real value, corresponding quantized value). + // We know two such pairs: (rmin, qmin) and (rmax, qmax). + // The arithmetic error on the zero point computed from either pair + // will be roughly machine_epsilon * (sum of absolute values of terms) + // so we want to use the variant that adds the smaller terms. + const double zero_point_from_min = qmin_double - rmin / scale; + const double zero_point_from_max = qmax_double - rmax / scale; + const double zero_point_from_min_error = + std::abs(qmin_double) + std::abs(rmin / scale); + const double zero_point_from_max_error = + std::abs(qmax_double) + std::abs(rmax / scale); + + const double zero_point_double = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Now we need to nudge the zero point to be an integer + // (our zero points are integer, and this is motivated by the requirement + // to be able to represent the real value "0" exactly as a quantized value, + // which is required in multiple places, for example in Im2col with SAME + // padding). + T nudged_zero_point = 0; + if (zero_point_double < qmin_double) { + nudged_zero_point = qmin; + } else if (zero_point_double > qmax_double) { + nudged_zero_point = qmax; + } else { + nudged_zero_point = static_cast<T>(round(zero_point_double)); + } + // The zero point should always be in the range of quantized value, + // [qmin, qmax]. + TFLITE_CHECK_GE(nudged_zero_point, qmin); + TFLITE_CHECK_LE(nudged_zero_point, qmax); + + // Finally, store the result nudged quantization params. + QuantizationParams quantization_params; + quantization_params.zero_point = nudged_zero_point; + quantization_params.scale = scale; + return quantization_params; +} + // Decompose a double multiplier into a Q0.31 int32 representation of its // significand, and shift representation of NEGATIVE its exponent --- // this is intended as a RIGHT-shift. diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc index 19b1b408ec..4ae2085c30 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc @@ -22,6 +22,51 @@ namespace { using ::testing::Pair; +// Example taken from http://www.tensorflow.org/performance/quantization +// +// Quantized | Float +// --------- | ----- +// 0 | -10.0 +// 255 | 30.0 +// 128 | 10.0 +TEST(QuantizationUtilTest, ChooseQuantizationParams) { + QuantizationParams qp = ChooseQuantizationParams<uint8>(-10.0, 30.0); + EXPECT_NEAR(qp.scale, 0.156863, 1e-5); + EXPECT_EQ(qp.zero_point, 64); +} + +TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMinBoundary) { + QuantizationParams qp = ChooseQuantizationParams<uint8>(0.0, 30.0); + EXPECT_NEAR(qp.scale, 0.117647, 1e-5); + EXPECT_EQ(qp.zero_point, 0); +} + +TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroNotInRange) { + // Assumption is that zero is within the range. + EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, 30.0), ""); +} + +TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangePositive) { + // Assumption is that zero is within the range. + EXPECT_DEATH(ChooseQuantizationParams<uint8>(30.0, 30.0), ""); +} + +TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangeZero) { + QuantizationParams qp = ChooseQuantizationParams<uint8>(0.0, 0.0); + EXPECT_NEAR(qp.scale, 0.0, 1e-5); + EXPECT_EQ(qp.zero_point, 0); +} + +TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMaxBoundary) { + QuantizationParams qp = ChooseQuantizationParams<uint8>(-10.0, 0.0); + EXPECT_NEAR(qp.scale, 0.039216, 1e-5); + EXPECT_EQ(qp.zero_point, 255); +} + +TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) { + EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, -30.0), ""); +} + TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOne) { auto quantize = [](double d) { int32_t q; diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index afe131b06e..293538fcbb 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -21,6 +21,22 @@ namespace tflite { enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu }; +// Quantization parameters, determining the mapping of quantized values +// to real values (i.e. determining how quantized values are mathematically +// interpreted). +// +// The correspondence is as follows: +// +// real_value = scale * (quantized_value - zero_point); +// +// In other words, zero_point designates which quantized value corresponds to +// the real 0 value, and scale designates the difference between the real values +// corresponding to consecutive quantized values differing by 1. +struct QuantizationParams { + int32 zero_point = 0; + double scale = 0.0; +}; + template <int N> struct Dims { int sizes[N]; diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 845bc0460f..031db2bd7c 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -329,6 +329,7 @@ cc_library( ":toco_graphviz_dump_options", ":toco_port", ":types_proto_cc", + "//tensorflow/contrib/lite/kernels/internal:quantization_util", "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@protobuf_archive//:protobuf_headers", diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index cd3eb06602..3fa0089cba 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -29,6 +29,8 @@ limitations under the License. namespace toco { +using tflite::QuantizationParams; + enum class OperatorType { kNone, // General-purpose neural network operators. @@ -1463,22 +1465,6 @@ inline bool operator<(const Alloc& a, const Alloc& b) { return a.start < b.start; } -// Quantization parameters, determining the mapping of quantized values -// to real values (i.e. determining how quantized values are mathematically -// interpreted). -// -// The correspondence is as follows: -// -// real_value = scale * (quantized_value - zero_point); -// -// In other words, zero_point designates which quantized value corresponds to -// the real 0 value, and scale designates the difference between the real values -// corresponding to consecutive quantized values differing by 1. -struct QuantizationParams { - int32 zero_point = 0; - double scale = 0.; -}; - class Shape { public: // For Shape, we stick to half-way encapsulation for now: diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index d5796486c5..05360e3b0a 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -28,6 +28,7 @@ limitations under the License. #if TOCO_SUPPORT_PORTABLE_PROTOS #include "third_party/protobuf/src/google/protobuf/text_format.h" #endif // TOCO_SUPPORT_PORTABLE_PROTOS +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/runtime/types.h" @@ -149,71 +150,11 @@ template <ArrayDataType A> void GetQuantizationParamsFromMinMax(const MinMax& minmax, QuantizationParams* quantization_params) { using Integer = DataType<A>; - const Integer qmin = std::numeric_limits<Integer>::min(); - const Integer qmax = std::numeric_limits<Integer>::max(); - const double qmin_double = qmin; - const double qmax_double = qmax; const double rmin = minmax.min; const double rmax = minmax.max; - // 0 should always be a representable value. Let's assume that the initial - // min,max range contains 0. - CHECK_LE(rmin, 0.); - CHECK_GE(rmax, 0.); - if (rmin == rmax) { - // Special case where the min,max range is a point. Should be {0}. - CHECK_EQ(rmin, 0.); - CHECK_EQ(rmax, 0.); - quantization_params->zero_point = 0; - quantization_params->scale = 0.; - return; - } - // General case. - // - // First determine the scale. - const double scale = (rmax - rmin) / (qmax_double - qmin_double); - - // Zero-point computation. - // First the initial floating-point computation. The zero-point can be - // determined from solving an affine equation for any known pair - // (real value, corresponding quantized value). - // We know two such pairs: (rmin, qmin) and (rmax, qmax). - // The arithmetic error on the zero point computed from either pair - // will be roughly machine_epsilon * (sum of absolute values of terms) - // so we want to use the variant that adds the smaller terms. - const double zero_point_from_min = qmin_double - rmin / scale; - const double zero_point_from_max = qmax_double - rmax / scale; - const double zero_point_from_min_error = - std::abs(qmin_double) + std::abs(rmin / scale); - const double zero_point_from_max_error = - std::abs(qmax_double) + std::abs(rmax / scale); - - const double zero_point_double = - zero_point_from_min_error < zero_point_from_max_error - ? zero_point_from_min - : zero_point_from_max; - - // Now we need to nudge the zero point to be an integer - // (our zero points are integer, and this is motivated by the requirement - // to be able to represent the real value "0" exactly as a quantized value, - // which is required in multiple places, for example in Im2col with SAME - // padding). - Integer nudged_zero_point = 0; - if (zero_point_double < qmin_double) { - nudged_zero_point = qmin; - } else if (zero_point_double > qmax_double) { - nudged_zero_point = qmax; - } else { - nudged_zero_point = static_cast<Integer>(std::round(zero_point_double)); - } - // The zero point should always be in the range of quantized value, - // [qmin, qmax]. - CHECK_GE(nudged_zero_point, qmin); - CHECK_LE(nudged_zero_point, qmax); - - // Finally, store the result nudged quantization params. - quantization_params->zero_point = nudged_zero_point; - quantization_params->scale = scale; + *quantization_params = + ::tflite::ChooseQuantizationParams<Integer>(rmin, rmax); } void CheckIsReadyForQuantization(const Model& model); diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py index 51b67bd6fa..465c668fd8 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets.py @@ -117,7 +117,7 @@ def StreamingFilesDataset(files, file_reader_job = file_reader_job or 'coordinator' - worker_job = worker_job or 'tpu_worker' + worker_job = worker_job or 'worker' if filename_shuffle_buffer_size is None: filename_shuffle_buffer_size = 4096 diff --git a/tensorflow/contrib/tpu/python/tpu/datasets_test.py b/tensorflow/contrib/tpu/python/tpu/datasets_test.py index 6e6a7ce809..918cf0ed8e 100644 --- a/tensorflow/contrib/tpu/python/tpu/datasets_test.py +++ b/tensorflow/contrib/tpu/python/tpu/datasets_test.py @@ -44,7 +44,7 @@ class DatasetsTest(test.TestCase): self._cluster_def = cluster_pb2.ClusterDef() worker_job = self._cluster_def.job.add() - worker_job.name = 'tpu_worker' + worker_job.name = 'worker' worker_job.tasks[0] = self._worker.target[len('grpc://'):] coord_job = self._cluster_def.job.add() coord_job.name = 'coordinator' diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 10f4d42147..27a96217fd 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1951,6 +1951,7 @@ tf_kernel_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", ], ) diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index e3c78d6b70..7c302e2fc2 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/graph/gradients.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { @@ -317,6 +318,8 @@ class RemoteCallOp : public AsyncOpKernel { if (cached_entry != handle_cache_.end()) { handle = cached_entry->second; } else { + port::Tracing::TraceMe activity(strings::StrCat( + "RemoteCall: Instantiate: ", func_.name(), " on ", target_device)); OP_REQUIRES_OK_ASYNC( ctx, lib->Instantiate(func_.name(), AttrSlice(&attr_values), @@ -344,21 +347,24 @@ class RemoteCallOp : public AsyncOpKernel { args.push_back(argument); } auto* rets = new std::vector<Tensor>; - lib->Run(opts, handle, args, rets, [rets, done, ctx](const Status& status) { - if (!status.ok()) { - ctx->SetStatus(status); - } else { - for (size_t i = 0; i < rets->size(); ++i) { - ctx->set_output(i, (*rets)[i]); - } - } - delete rets; - done(); - }); + auto* trace = new port::Tracing::TraceMe(strings::StrCat( + "RemoteCall: Run: ", func_.name(), " on ", target_device)); + lib->Run(opts, handle, args, rets, + [rets, trace, done, ctx](const Status& status) { + if (!status.ok()) { + ctx->SetStatus(status); + } else { + for (size_t i = 0; i < rets->size(); ++i) { + ctx->set_output(i, (*rets)[i]); + } + } + delete rets; + delete trace; + done(); + }); } private: - string target_; NameAttrList func_; mutex mu_; diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index d6715fa522..5a9cd7531d 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -139,6 +139,10 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import tensor_array_ops +# Eager execution +from tensorflow.python.eager.context import executing_eagerly +from tensorflow.python.framework.ops import enable_eager_execution + # Symbols whitelisted for export without documentation. # TODO(cwhipkey): review these and move to contrib, expose through # documentation, or remove. @@ -290,6 +294,12 @@ _allowed_symbols.extend([ 'MONOLITHIC_BUILD', ]) +# Eager execution +_allowed_symbols.extend([ + 'enable_eager_execution', + 'executing_eagerly', +]) + # Remove all extra symbols that don't have a docstring or are not explicitly # referenced in the whitelist. remove_undocumented(__name__, _allowed_symbols, [ diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 5d13aada63..87d3ed880a 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import errors from tensorflow.python.util import compat from tensorflow.python.util import is_in_graph_mode from tensorflow.python.util import tf_contextlib +from tensorflow.python.util.tf_export import tf_export GRAPH_MODE = 0 EAGER_MODE = 1 @@ -518,8 +519,14 @@ def internal_operation_seed(): return context()._internal_operation_seed() # pylint: disable=protected-access +@tf_export("executing_eagerly") def executing_eagerly(): - """Returns True if the current thread has eager execution enabled.""" + """Returns True if the current thread has eager execution enabled. + + Eager execution is typically enabled via @{tf.enable_eager_execution}, + but may also be enabled within the context of a Python function via + tf.contrib.eager.py_func. + """ return context().executing_eagerly() diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 8ff247fdb1..f5dde3a358 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -5169,41 +5169,60 @@ def init_scope(): yield +@tf_export("enable_eager_execution") def enable_eager_execution(config=None, device_policy=None): - """Enables, for the rest of the lifetime of this program, eager execution. + """Enables eager execution for the lifetime of this program. - If not called immediately on startup risks creating breakage and bugs. + Eager execution provides an imperative interface to TensorFlow. With eager + execution enabled, TensorFlow functions execute operations immediately (as + opposed to adding to a graph to be executed later in a @{tf.Session}) and + return concrete values (as opposed to symbolic references to a node in a + computational graph). - Example: + For example: ```python - tfe.enable_eager_execution() + tf.enable_eager_execution() # After eager execution is enabled, operations are executed as they are - # defined and `Tensor`s hold concrete values, which can be accessed as - # `numpy.ndarray`s through the `numpy()` method. + # defined and Tensor objects hold concrete values, which can be accessed as + # numpy.ndarray`s through the numpy() method. assert tf.multiply(6, 7).numpy() == 42 ``` + Eager execution cannot be enabled after TensorFlow APIs have been used to + create or execute graphs. It is typically recommended to invoke this function + at program startup and not in a library (as most libraries should be usable + both with and without eager execution). + Args: - config: (Optional.) A `ConfigProto` protocol buffer with configuration - options for the Context. Note that a lot of these options may be - currently unimplemented or irrelevant when eager execution is enabled. - device_policy: (Optional.) What policy to use when trying to run an - operation on a device with inputs which are not on that device. + config: (Optional.) A @{tf.ConfigProto} to use to configure the environment + in which operations are executed. Note that @{tf.ConfigProto} is also + used to configure graph execution (via @{tf.Session}) and many options + within `tf.ConfigProto` are not implemented (or are irrelevant) when + eager execution is enabled. + device_policy: (Optional.) Policy controlling how operations requiring + inputs on a specific device (e.g., a GPU 0) handle inputs on a different + device (e.g. GPU 1 or CPU). Valid values: - tfe.DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not - correct. - tfe.DEVICE_PLACEMENT_WARN: copies the tensors which are not on the - right device but raises a warning. - tfe.DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might - hide performance problems. - tfe.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors, - raising errors on the other ones. + + - tf.contrib.eager.DEVICE_PLACEMENT_EXPLICIT: raises an error if the + placement is not correct. + + - tf.contrib.eager.DEVICE_PLACEMENT_WARN: copies the tensors which are not + on the right device but logs a warning. + + - tf.contrib.eager.DEVICE_PLACEMENT_SILENT: silently copies the tensors. + Note that this may hide performance problems as there is no notification + provided when operations are blocked on the tensor being copied between + devices. + + - tf.contrib.eager.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies + int32 tensors, raising errors on the other ones. Raises: - ValueError: If trying to create a context after using graph operations - or if trying to create a context with nontrivial options which differ - from those of the existing context. + ValueError: If eager execution is enabled after creating/executing a + TensorFlow graph, or if options provided conflict with a previous call + to this function. """ if config is not None and not isinstance(config, config_pb2.ConfigProto): raise TypeError( @@ -5213,7 +5232,7 @@ def enable_eager_execution(config=None, device_policy=None): context.DEVICE_PLACEMENT_SILENT, context.DEVICE_PLACEMENT_SILENT_FOR_INT32): raise ValueError( - "device_policy must be one of None, tfe.DEVICE_PLACEMENT_*" + "device_policy must be one of None, tf.contrib.eager.DEVICE_PLACEMENT_*" ) # pylint: disable=protected-access if context._default_mode == context.GRAPH_MODE: @@ -5222,7 +5241,7 @@ def enable_eager_execution(config=None, device_policy=None): _default_graph_stack._global_default_graph is not None) if graph_mode_has_been_used: raise ValueError( - "tfe.enable_eager_execution has to be called at program startup.") + "tf.enable_eager_execution must be called at program startup.") context._default_mode = context.EAGER_MODE if context._context is None: context._context = context.Context(config=config, @@ -5245,7 +5264,7 @@ def enable_eager_execution(config=None, device_policy=None): context._context._device_policy)) else: raise ValueError( - "tfe.enable_eager_execution has to be called at program startup.") + "tf.enable_eager_execution must be called at program startup.") def eager_run(main=None, argv=None): diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index 63203a0043..36142801d6 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -19,6 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import re + import numpy as np from six.moves import queue from six.moves import xrange # pylint: disable=redefined-builtin @@ -356,12 +358,22 @@ class PyFuncTest(test.TestCase): def _testExceptionHandling(self, py_exp, tf_exp, eager=False): - def raise_exception(): + def inner_exception(): raise py_exp("blah") # pylint: disable=not-callable + def raise_exception(): + inner_exception() + + expected_regexp = r": blah.*" # Error at the top + expected_regexp += r"in raise_exception.*" # Stacktrace outer + expected_regexp += r"in inner_exception.*" # Stacktrace inner + expected_regexp += r": blah" # Stacktrace of raise + def expected_error_check(exception): + return re.search(expected_regexp, str(exception), re.DOTALL) + if eager: if context.executing_eagerly(): - with self.assertRaisesRegexp(tf_exp, "blah"): + with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check): f = script_ops.eager_py_func(raise_exception, [], []) return else: @@ -370,7 +382,7 @@ class PyFuncTest(test.TestCase): f = script_ops.py_func(raise_exception, [], []) with self.test_session(): - with self.assertRaisesRegexp(tf_exp, "blah"): + with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check): self.evaluate(f) def testExceptionHandling(self): diff --git a/tensorflow/python/lib/core/py_util.cc b/tensorflow/python/lib/core/py_util.cc index 2635694e23..00cbf0c532 100644 --- a/tensorflow/python/lib/core/py_util.cc +++ b/tensorflow/python/lib/core/py_util.cc @@ -41,6 +41,55 @@ const char* ClassName(PyObject* py) { } // end namespace +// Returns a PyObject containing a string, or null +void TryAppendTraceback(PyObject* ptype, PyObject* pvalue, PyObject* ptraceback, + string* out) { + // The "traceback" module is assumed to be imported already by script_ops.py. + PyObject* tb_module = PyImport_AddModule("traceback"); + + if (!tb_module) { + return; + } + + PyObject* format_exception = + PyObject_GetAttrString(tb_module, "format_exception"); + + if (!format_exception) { + return; + } + + if (!PyCallable_Check(format_exception)) { + Py_DECREF(format_exception); + return; + } + + PyObject* ret_val = PyObject_CallFunctionObjArgs(format_exception, ptype, + pvalue, ptraceback, nullptr); + Py_DECREF(format_exception); + + if (!ret_val) { + return; + } + + if (!PyList_Check(ret_val)) { + Py_DECREF(ret_val); + return; + } + + Py_ssize_t n = PyList_GET_SIZE(ret_val); + for (Py_ssize_t i = 0; i < n; ++i) { + PyObject* v = PyList_GET_ITEM(ret_val, i); +#if PY_MAJOR_VERSION < 3 + strings::StrAppend(out, PyString_AS_STRING(v), "\n"); +#else + strings::StrAppend(out, PyUnicode_AsUTF8(v), "\n"); +#endif + } + + // Iterate through ret_val. + Py_DECREF(ret_val); +} + string PyExceptionFetch() { CHECK(PyErr_Occurred()) << "Must only call PyExceptionFetch after an exception."; @@ -52,14 +101,20 @@ string PyExceptionFetch() { string err = ClassName(ptype); if (pvalue) { PyObject* str = PyObject_Str(pvalue); + if (str) { #if PY_MAJOR_VERSION < 3 - strings::StrAppend(&err, ": ", PyString_AS_STRING(str)); + strings::StrAppend(&err, ": ", PyString_AS_STRING(str), "\n"); #else - strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str)); + strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str), "\n"); #endif Py_DECREF(str); + } else { + strings::StrAppend(&err, "(unknown error message)\n"); } + + TryAppendTraceback(ptype, pvalue, ptraceback, &err); + Py_DECREF(pvalue); } Py_DECREF(ptype); diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index 529eebe769..fb59bbba5e 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -25,6 +25,9 @@ from __future__ import print_function import threading +# Used by py_util.cc to get tracebacks. +import traceback # pylint: disable=unused-import + import numpy as np import six diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index fb0862c016..123d67fd9b 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -969,6 +969,10 @@ tf_module { argspec: "args=[\'equation\'], varargs=inputs, keywords=kwargs, defaults=None" } member_method { + name: "enable_eager_execution" + argspec: "args=[\'config\', \'device_policy\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { name: "encode_base64" argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " } @@ -985,6 +989,10 @@ tf_module { argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { + name: "executing_eagerly" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { name: "exp" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index 4fe4fc3b13..b7d7fac315 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -91,7 +91,6 @@ cc_library( srcs = [ "add_default_attributes.cc", "backports.cc", - "fake_quantize_training.cc", "flatten_atrous.cc", "fold_batch_norms.cc", "fold_constants_lib.cc", @@ -105,7 +104,6 @@ cc_library( "remove_attribute.cc", "remove_control_dependencies.cc", "remove_device.cc", - "remove_ema.cc", "remove_nodes.cc", "rename_attribute.cc", "rename_op.cc", @@ -148,7 +146,6 @@ tf_cc_test( srcs = [ "add_default_attributes_test.cc", "backports_test.cc", - "fake_quantize_training_test.cc", "flatten_atrous_test.cc", "fold_batch_norms_test.cc", "fold_constants_test.cc", @@ -161,7 +158,6 @@ tf_cc_test( "quantize_weights_test.cc", "remove_attribute_test.cc", "remove_device_test.cc", - "remove_ema_test.cc", "remove_nodes_test.cc", "rename_attribute_test.cc", "rename_op_test.cc", diff --git a/tensorflow/tools/graph_transforms/fake_quantize_training.cc b/tensorflow/tools/graph_transforms/fake_quantize_training.cc deleted file mode 100644 index 61aecc6e16..0000000000 --- a/tensorflow/tools/graph_transforms/fake_quantize_training.cc +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#define EIGEN_USE_THREADS - -#include "tensorflow/core/graph/quantize_training.h" -#include "tensorflow/tools/graph_transforms/transform_utils.h" - -namespace tensorflow { -namespace graph_transforms { - -// EXPERIMENTAL: This can change without warning. -// Rewrites the GraphDef for quantized training. -// Rewrites the forward pass to include the precision loss with quantization so -// the model can learn to deal with such loss and achieve better accuracy when -// it is quantized later for inference. -// Quantization range information is collected in FakeQuantizeWithMinMaxVars -// ops. -// -// TODO(suharshs): Provide instructions on converting the resulting graph for -// inference. -// TODO(suharshs): Implement this using the GTT rather than calling the old -// prototype function. -Status FakeQuantizeTraining(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def) { - // TODO(suharshs): Make num_bits a parameter. - const int32 num_bits = 8; - // TODO(suharshs): Make quantization op a parameter? - const string quant_op_type = "FakeQuantWithMinMaxVars"; - - return DoQuantizeTrainingOnGraphDef(input_graph_def, num_bits, quant_op_type, - output_graph_def); -} - -REGISTER_GRAPH_TRANSFORM("fake_quantize_training", FakeQuantizeTraining); - -} // namespace graph_transforms -} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc b/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc deleted file mode 100644 index 5e4ab209e9..0000000000 --- a/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/cc/ops/const_op.h" -#include "tensorflow/cc/ops/math_ops.h" -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/tools/graph_transforms/transform_utils.h" - -namespace tensorflow { -namespace graph_transforms { - -// Declare here, so we don't need a public header. -Status FakeQuantizeTraining(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); - -class FakeQuantizeTrainingTest : public ::testing::Test {}; - -// For now, since the fake_quantize_training transform just calls the -// quantize_training rewrite from tensorflow/core/graph/quantize_training.h, -// we just test that the graph has been changed by the transform. -// TODO(suharshs): Once we implement the fake_quantize_training transform -// using the GTT, write proper tests of the transform here. -TEST_F(FakeQuantizeTrainingTest, TransformOccurred) { - auto root = tensorflow::Scope::DisabledShapeInferenceScope(); - using namespace ::tensorflow::ops; // NOLINT(build/namespaces) - - Tensor a_data(DT_FLOAT, TensorShape()); - test::FillIota<float>(&a_data, 1.0f); - Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); - - Tensor b_data(DT_FLOAT, TensorShape()); - test::FillIota<float>(&b_data, 1.0f); - Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); - - Output matmul = MatMul(root.WithOpName("matmul"), a_const, b_const); - GraphDef graph_def; - TF_ASSERT_OK(root.ToGraphDef(&graph_def)); - - GraphDef result; - TransformFuncContext context; - TF_ASSERT_OK(FakeQuantizeTraining(graph_def, context, &result)); - - // Test that the transformation resulted in a graph with more nodes. - EXPECT_GT(result.node_size(), graph_def.node_size()); -} - -} // namespace graph_transforms -} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/remove_ema.cc b/tensorflow/tools/graph_transforms/remove_ema.cc deleted file mode 100644 index 22e2626702..0000000000 --- a/tensorflow/tools/graph_transforms/remove_ema.cc +++ /dev/null @@ -1,146 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#define EIGEN_USE_THREADS - -#include "tensorflow/tools/graph_transforms/transform_utils.h" - -namespace tensorflow { -namespace graph_transforms { - -// EXPERIMENTAL: This can change without warning. -// Given a graph that has gone through the FakeQuantizeTraining transform and -// has been frozen afterwards, RemoveEMA simplifies the FakeQuantize estimated -// moving average subgraphs to make it compatible with the QuantizeNodes -// transform. -Status RemoveEMA(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def) { - TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( - input_graph_def, // clang-format off - {"FakeQuantWithMinMaxVars", - { - {"*"}, - {"Assign", - { - {"Const"}, - {"Merge", - { - {"Switch", - { - {"Min", - { - {"*"}, - {"Range", - { - {"*"}, - {"*"}, - {"*"}, - } - } - } - }, - {"IsVariableInitialized"} - } - }, - {"Sub", - { - {"Const"}, - {"Mul", - { - {"Sub"}, - {"Sub", - { - {"Const"}, - {"Const"} - } - } - } - } - } - } - } - } - } - }, - {"Assign", - { - {"Const"}, - {"Merge", - { - {"Switch", - { - {"Max"}, - {"IsVariableInitialized"} - } - }, - {"Sub", - { - {"Const"}, - {"Mul", - { - {"Sub"}, - {"Sub", - { - {"Const"}, - {"Const"} - } - } - } - } - } - } - } - } - } - }, - } - }, // clang-format on - [](const NodeMatch& match, const std::set<string>& input_nodes, - const std::set<string>& output_nodes, - std::vector<NodeDef>* new_nodes) { - const NodeDef& fake_quant_node = match.node; - const NodeDef& input_node = match.inputs[0].node; - const NodeDef& min_var_node = match.inputs[1].inputs[0].node; - const NodeDef& max_var_node = match.inputs[2].inputs[0].node; - - // Make a new FakeQuantizeWithMinMaxVars operation that uses constants - // for its min/max arguments rather than an entire EMA subgraph. - NodeDef new_fake_quant_node; - new_fake_quant_node.set_op(fake_quant_node.op()); - new_fake_quant_node.set_name(fake_quant_node.name()); - AddNodeInput(input_node.name(), &new_fake_quant_node); - AddNodeInput(min_var_node.name(), &new_fake_quant_node); - AddNodeInput(max_var_node.name(), &new_fake_quant_node); - CopyNodeAttr(fake_quant_node, "narrow_range", "narrow_range", - &new_fake_quant_node); - CopyNodeAttr(fake_quant_node, "num_bits", "num_bits", - &new_fake_quant_node); - - new_nodes->push_back(new_fake_quant_node); - new_nodes->push_back(input_node); - new_nodes->push_back(min_var_node); - new_nodes->push_back(max_var_node); - - return Status::OK(); - }, - {}, output_graph_def)); - return Status::OK(); -} - -REGISTER_GRAPH_TRANSFORM("remove_ema", RemoveEMA); - -} // namespace graph_transforms -} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/remove_ema_test.cc b/tensorflow/tools/graph_transforms/remove_ema_test.cc deleted file mode 100644 index 27db90e272..0000000000 --- a/tensorflow/tools/graph_transforms/remove_ema_test.cc +++ /dev/null @@ -1,121 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/cc/ops/const_op.h" -#include "tensorflow/cc/ops/math_ops.h" -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/public/session.h" -#include "tensorflow/tools/graph_transforms/transform_utils.h" - -namespace tensorflow { -namespace graph_transforms { - -// Declare transformations here, so we don't need a public header. -Status FakeQuantizeTraining(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); - -Status RemoveEMA(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); - -Status QuantizeNodes(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); - -class RemoveEMATest : public ::testing::Test {}; - -TEST_F(RemoveEMATest, FakeQuant_RemoveEMA_QuantizeTraining) { - // Build a small graph. - auto root = tensorflow::Scope::NewRootScope(); - using namespace ::tensorflow::ops; // NOLINT(build/namespaces) - - Tensor a_data(DT_FLOAT, TensorShape({1, 1})); - test::FillIota<float>(&a_data, 1.0f); - Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data)); - - Tensor b_data(DT_FLOAT, TensorShape({1, 1})); - test::FillIota<float>(&b_data, 1.0f); - Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data)); - - Output matmul = MatMul(root.WithOpName("matmul"), a_const, b_const); - GraphDef graph_def; - TF_ASSERT_OK(root.ToGraphDef(&graph_def)); - - // (1) FakeQuantize the graph. - GraphDef fake_quantized_graph_def; - TransformFuncContext context; - TF_ASSERT_OK( - FakeQuantizeTraining(graph_def, context, &fake_quantized_graph_def)); - - // Test that the transformation resulted in a graph with more nodes. - EXPECT_GT(fake_quantized_graph_def.node_size(), graph_def.node_size()); - - // (2) Run the graph to initialize the newly added variables. - std::unique_ptr<Session> session(NewSession(SessionOptions())); - TF_ASSERT_OK(session->Create(fake_quantized_graph_def)); - std::vector<Tensor> outputs; - TF_ASSERT_OK(session->Run({}, {"matmul"}, {}, &outputs)); - - // (3) Freeze the graph. Create a "frozen graph" that matches what we would - // expect if we actually froze the above graph. - // TODO(suharshs): Use a c++ freeze graph alternative, when one is available. - GraphDef frozen_graph_def; - for (const NodeDef& node : fake_quantized_graph_def.node()) { - if (node.op() == "Variable" || node.op() == "VariableV2") { - NodeDef const_node; - const_node.set_op("Const"); - const_node.set_name(node.name()); - SetNodeAttr("dtype", DT_FLOAT, &const_node); - Tensor tensor(DT_FLOAT, {}); - tensor.flat<float>()(0) = 1.0f; - SetNodeTensorAttr<float>("value", tensor, &const_node); - *(frozen_graph_def.mutable_node()->Add()) = const_node; - } else { - *(frozen_graph_def.mutable_node()->Add()) = node; - } - } - - // Test that freezing the graph resulted in a graph with the same number of - // nodes. - EXPECT_EQ(frozen_graph_def.node_size(), fake_quantized_graph_def.node_size()); - - // (4) RemoveEMA on the graph to make it compatible with QuantizeNodes. - GraphDef removed_ema_graph_def; - TF_ASSERT_OK(RemoveEMA(frozen_graph_def, context, &removed_ema_graph_def)); - - // Test that the transformation resulted in a graph with less nodes. - EXPECT_LT(removed_ema_graph_def.node_size(), frozen_graph_def.node_size()); - - // (5) QuantizeNodes and inspect the final graph. - // TODO(suharshs): Add a more thorough inspection of the structure of - // the output graph. - GraphDef quantized_graph_def; - TF_ASSERT_OK( - QuantizeNodes(removed_ema_graph_def, context, &quantized_graph_def)); - - // Test that the transformation resulted in a graph with more nodes. - EXPECT_GT(quantized_graph_def.node_size(), removed_ema_graph_def.node_size()); - - // Make sure that the FakeQuantizeWithMinMaxVars op has been removed. - for (const NodeDef& node : quantized_graph_def.node()) { - EXPECT_NE(node.op(), "FakeQuantWithMinMaxVars"); - } -} - -} // namespace graph_transforms -} // namespace tensorflow |