aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yifei Feng <1192265+yifeif@users.noreply.github.com>2018-03-07 23:42:50 -0800
committerGravatar GitHub <noreply@github.com>2018-03-07 23:42:50 -0800
commitf8363dc424f78ec06c9fe2faee7623624aa0392e (patch)
tree9da5395dbd3707994964ac03b37f36b5545ce37a
parent9d867e0c34ea34ac74ebdab2cdcfc5b8c61fed25 (diff)
parent9cdfd3878935fb6c3c2a5da7f65ee0db6c751170 (diff)
Merge pull request #17536 from yifeif/branch_188272354
Branch 188272354
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py132
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/hmc.py1
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/hmc_impl.py217
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py4
-rw-r--r--tensorflow/contrib/data/__init__.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py154
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD1
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py180
-rw-r--r--tensorflow/contrib/eager/python/checkpointable_utils.py5
-rw-r--r--tensorflow/contrib/eager/python/checkpointable_utils_test.py27
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD5
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util.h78
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc45
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h16
-rw-r--r--tensorflow/contrib/lite/toco/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/model.h18
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h65
-rw-r--r--tensorflow/contrib/tpu/python/tpu/datasets.py2
-rw-r--r--tensorflow/contrib/tpu/python/tpu/datasets_test.py2
-rw-r--r--tensorflow/core/kernels/BUILD1
-rw-r--r--tensorflow/core/kernels/function_ops.cc30
-rw-r--r--tensorflow/python/__init__.py10
-rw-r--r--tensorflow/python/eager/context.py9
-rw-r--r--tensorflow/python/framework/ops.py69
-rw-r--r--tensorflow/python/kernel_tests/py_func_test.py18
-rw-r--r--tensorflow/python/lib/core/py_util.cc59
-rw-r--r--tensorflow/python/ops/script_ops.py3
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt8
-rw-r--r--tensorflow/tools/graph_transforms/BUILD4
-rw-r--r--tensorflow/tools/graph_transforms/fake_quantize_training.cc51
-rw-r--r--tensorflow/tools/graph_transforms/fake_quantize_training_test.cc63
-rw-r--r--tensorflow/tools/graph_transforms/remove_ema.cc146
-rw-r--r--tensorflow/tools/graph_transforms/remove_ema_test.cc121
34 files changed, 636 insertions, 912 deletions
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py
index 819095a060..dabadfc7b6 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py
@@ -462,138 +462,6 @@ class HMCTest(test.TestCase):
def testKernelLeavesTargetInvariant3(self):
self._kernel_leaves_target_invariant_wrapper(3)
- def _ais_gets_correct_log_normalizer(self, init, independent_chain_ndims,
- sess, feed_dict=None):
- counter = collections.Counter()
-
- def proposal_log_prob(x):
- counter["proposal_calls"] += 1
- event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
- return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi),
- axis=event_dims)
-
- def target_log_prob(x):
- counter["target_calls"] += 1
- event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
- return self._log_gamma_log_prob(x, event_dims)
-
- if feed_dict is None:
- feed_dict = {}
-
- num_steps = 200
-
- _, ais_weights, _ = hmc.sample_annealed_importance_chain(
- proposal_log_prob_fn=proposal_log_prob,
- num_steps=num_steps,
- target_log_prob_fn=target_log_prob,
- step_size=0.5,
- current_state=init,
- num_leapfrog_steps=2,
- seed=45)
-
- # We have three calls because the calculation of `ais_weights` entails
- # another call to the `convex_combined_log_prob_fn`. We could refactor
- # things to avoid this, if needed (eg, b/72994218).
- self.assertAllEqual(dict(target_calls=3, proposal_calls=3), counter)
-
- event_shape = array_ops.shape(init)[independent_chain_ndims:]
- event_size = math_ops.reduce_prod(event_shape)
-
- log_true_normalizer = (
- -self._shape_param * math_ops.log(self._rate_param)
- + math_ops.lgamma(self._shape_param))
- log_true_normalizer *= math_ops.cast(event_size, log_true_normalizer.dtype)
-
- log_estimated_normalizer = (math_ops.reduce_logsumexp(ais_weights)
- - np.log(num_steps))
-
- ratio_estimate_true = math_ops.exp(ais_weights - log_true_normalizer)
- ais_weights_size = array_ops.size(ais_weights)
- standard_error = math_ops.sqrt(
- _reduce_variance(ratio_estimate_true)
- / math_ops.cast(ais_weights_size, ratio_estimate_true.dtype))
-
- [
- ratio_estimate_true_,
- log_true_normalizer_,
- log_estimated_normalizer_,
- standard_error_,
- ais_weights_size_,
- event_size_,
- ] = sess.run([
- ratio_estimate_true,
- log_true_normalizer,
- log_estimated_normalizer,
- standard_error,
- ais_weights_size,
- event_size,
- ], feed_dict)
-
- logging_ops.vlog(1, " log_true_normalizer: {}\n"
- " log_estimated_normalizer: {}\n"
- " ais_weights_size: {}\n"
- " event_size: {}\n".format(
- log_true_normalizer_,
- log_estimated_normalizer_,
- ais_weights_size_,
- event_size_))
- self.assertNear(ratio_estimate_true_.mean(), 1., 4. * standard_error_)
-
- def _ais_gets_correct_log_normalizer_wrapper(self, independent_chain_ndims):
- """Tests that AIS yields reasonable estimates of normalizers."""
- with self.test_session(graph=ops.Graph()) as sess:
- x_ph = array_ops.placeholder(np.float32, name="x_ph")
- initial_draws = np.random.normal(size=[30, 2, 1])
- self._ais_gets_correct_log_normalizer(
- x_ph,
- independent_chain_ndims,
- sess,
- feed_dict={x_ph: initial_draws})
-
- def testAIS1(self):
- self._ais_gets_correct_log_normalizer_wrapper(1)
-
- def testAIS2(self):
- self._ais_gets_correct_log_normalizer_wrapper(2)
-
- def testAIS3(self):
- self._ais_gets_correct_log_normalizer_wrapper(3)
-
- def testSampleAIChainSeedReproducibleWorksCorrectly(self):
- with self.test_session(graph=ops.Graph()) as sess:
- independent_chain_ndims = 1
- x = np.random.rand(4, 3, 2)
-
- def proposal_log_prob(x):
- event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
- return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi),
- axis=event_dims)
-
- def target_log_prob(x):
- event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x))
- return self._log_gamma_log_prob(x, event_dims)
-
- ais_kwargs = dict(
- proposal_log_prob_fn=proposal_log_prob,
- num_steps=200,
- target_log_prob_fn=target_log_prob,
- step_size=0.5,
- current_state=x,
- num_leapfrog_steps=2,
- seed=53)
-
- _, ais_weights0, _ = hmc.sample_annealed_importance_chain(
- **ais_kwargs)
-
- _, ais_weights1, _ = hmc.sample_annealed_importance_chain(
- **ais_kwargs)
-
- [ais_weights0_, ais_weights1_] = sess.run([
- ais_weights0, ais_weights1])
-
- self.assertAllClose(ais_weights0_, ais_weights1_,
- atol=1e-5, rtol=1e-5)
-
def testNanRejection(self):
"""Tests that an update that yields NaN potentials gets rejected.
diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc.py b/tensorflow/contrib/bayesflow/python/ops/hmc.py
index 7fd5652c5c..c8a5a195d3 100644
--- a/tensorflow/contrib/bayesflow/python/ops/hmc.py
+++ b/tensorflow/contrib/bayesflow/python/ops/hmc.py
@@ -24,7 +24,6 @@ from tensorflow.python.util import all_util
_allowed_symbols = [
"sample_chain",
- "sample_annealed_importance_chain",
"kernel",
]
diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py
index 82693c2b7b..66afcc7497 100644
--- a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py
@@ -15,7 +15,6 @@
"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm.
@@sample_chain
-@@sample_annealed_importance_chain
@@kernel
"""
@@ -38,7 +37,6 @@ from tensorflow.python.ops.distributions import util as distributions_util
__all__ = [
"sample_chain",
- "sample_annealed_importance_chain",
"kernel",
]
@@ -330,221 +328,6 @@ def sample_chain(
return functional_ops.scan(**scan_kwargs)
-def sample_annealed_importance_chain(
- proposal_log_prob_fn,
- num_steps,
- target_log_prob_fn,
- current_state,
- step_size,
- num_leapfrog_steps,
- seed=None,
- name=None):
- """Runs annealed importance sampling (AIS) to estimate normalizing constants.
-
- This function uses Hamiltonian Monte Carlo to sample from a series of
- distributions that slowly interpolates between an initial "proposal"
- distribution:
-
- `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)`
-
- and the target distribution:
-
- `exp(target_log_prob_fn(x) - target_log_normalizer)`,
-
- accumulating importance weights along the way. The product of these
- importance weights gives an unbiased estimate of the ratio of the
- normalizing constants of the initial distribution and the target
- distribution:
-
- `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`.
-
- Note: `proposal_log_prob_fn` and `target_log_prob_fn` are called exactly three
- times (although this may be reduced to two times, in the future).
-
- #### Examples:
-
- ##### Estimate the normalizing constant of a log-gamma distribution.
-
- ```python
- tfd = tf.contrib.distributions
-
- # Run 100 AIS chains in parallel
- num_chains = 100
- dims = 20
- dtype = np.float32
-
- proposal = tfd.MultivatiateNormalDiag(
- loc=tf.zeros([dims], dtype=dtype))
-
- target = tfd.TransformedDistribution(
- distribution=tfd.Gamma(concentration=dtype(2),
- rate=dtype(3)),
- bijector=tfd.bijectors.Invert(tfd.bijectors.Exp()),
- event_shape=[dims])
-
- chains_state, ais_weights, kernels_results = (
- hmc.sample_annealed_importance_chain(
- proposal_log_prob_fn=proposal.log_prob,
- num_steps=1000,
- target_log_prob_fn=target.log_prob,
- step_size=0.2,
- current_state=proposal.sample(num_chains),
- num_leapfrog_steps=2))
-
- log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights)
- - np.log(num_chains))
- log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.)
- ```
-
- ##### Estimate marginal likelihood of a Bayesian regression model.
-
- ```python
- tfd = tf.contrib.distributions
-
- def make_prior(dims, dtype):
- return tfd.MultivariateNormalDiag(
- loc=tf.zeros(dims, dtype))
-
- def make_likelihood(weights, x):
- return tfd.MultivariateNormalDiag(
- loc=tf.tensordot(weights, x, axes=[[0], [-1]]))
-
- # Run 100 AIS chains in parallel
- num_chains = 100
- dims = 10
- dtype = np.float32
-
- # Make training data.
- x = np.random.randn(num_chains, dims).astype(dtype)
- true_weights = np.random.randn(dims).astype(dtype)
- y = np.dot(x, true_weights) + np.random.randn(num_chains)
-
- # Setup model.
- prior = make_prior(dims, dtype)
- def target_log_prob_fn(weights):
- return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y)
-
- proposal = tfd.MultivariateNormalDiag(
- loc=tf.zeros(dims, dtype))
-
- weight_samples, ais_weights, kernel_results = (
- hmc.sample_annealed_importance_chain(
- num_steps=1000,
- proposal_log_prob_fn=proposal.log_prob,
- target_log_prob_fn=target_log_prob_fn
- current_state=tf.zeros([num_chains, dims], dtype),
- step_size=0.1,
- num_leapfrog_steps=2))
- log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights)
- - np.log(num_chains))
- ```
-
- Args:
- proposal_log_prob_fn: Python callable that returns the log density of the
- initial distribution.
- num_steps: Integer number of Markov chain updates to run. More
- iterations means more expense, but smoother annealing between q
- and p, which in turn means exponentially lower variance for the
- normalizing constant estimator.
- target_log_prob_fn: Python callable which takes an argument like
- `current_state` (or `*current_state` if it's a list) and returns its
- (possibly unnormalized) log-density under the target distribution.
- current_state: `Tensor` or Python `list` of `Tensor`s representing the
- current state(s) of the Markov chain(s). The first `r` dimensions index
- independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
- step_size: `Tensor` or Python `list` of `Tensor`s representing the step size
- for the leapfrog integrator. Must broadcast with the shape of
- `current_state`. Larger step sizes lead to faster progress, but too-large
- step sizes make rejection exponentially more likely. When possible, it's
- often helpful to match per-variable step sizes to the standard deviations
- of the target distribution in each variable.
- num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
- for. Total progress per HMC step is roughly proportional to `step_size *
- num_leapfrog_steps`.
- seed: Python integer to seed the random number generator.
- name: Python `str` name prefixed to Ops created by this function.
- Default value: `None` (i.e., "hmc_sample_annealed_importance_chain").
-
- Returns:
- next_state: `Tensor` or Python list of `Tensor`s representing the
- state(s) of the Markov chain(s) at the final iteration. Has same shape as
- input `current_state`.
- ais_weights: Tensor with the estimated weight(s). Has shape matching
- `target_log_prob_fn(current_state)`.
- kernel_results: `collections.namedtuple` of internal calculations used to
- advance the chain.
- """
- def make_convex_combined_log_prob_fn(iter_):
- def _fn(*args):
- p = proposal_log_prob_fn(*args)
- t = target_log_prob_fn(*args)
- dtype = p.dtype.base_dtype
- beta = (math_ops.cast(iter_ + 1, dtype)
- / math_ops.cast(num_steps, dtype))
- return (1. - beta) * p + beta * t
- return _fn
-
- with ops.name_scope(
- name, "hmc_sample_annealed_importance_chain",
- [num_steps, current_state, step_size, num_leapfrog_steps, seed]):
- with ops.name_scope("initialize"):
- [
- current_state,
- step_size,
- current_log_prob,
- current_grads_log_prob,
- ] = _prepare_args(
- make_convex_combined_log_prob_fn(iter_=0),
- current_state,
- step_size,
- description="convex_combined_log_prob")
- num_steps = ops.convert_to_tensor(
- num_steps,
- dtype=dtypes.int32,
- name="num_steps")
- num_leapfrog_steps = ops.convert_to_tensor(
- num_leapfrog_steps,
- dtype=dtypes.int32,
- name="num_leapfrog_steps")
- def _loop_body(iter_, ais_weights, current_state, kernel_results):
- """Closure which implements `tf.while_loop` body."""
- current_state_parts = (list(current_state)
- if _is_list_like(current_state)
- else [current_state])
- # TODO(b/72994218): Consider refactoring things to avoid this unecessary
- # call.
- ais_weights += ((target_log_prob_fn(*current_state_parts)
- - proposal_log_prob_fn(*current_state_parts))
- / math_ops.cast(num_steps, ais_weights.dtype))
- return [iter_ + 1, ais_weights] + list(kernel(
- make_convex_combined_log_prob_fn(iter_),
- current_state,
- step_size,
- num_leapfrog_steps,
- seed,
- kernel_results.current_target_log_prob,
- kernel_results.current_grads_target_log_prob))
-
- while_loop_kwargs = dict(
- cond=lambda iter_, *args: iter_ < num_steps,
- body=_loop_body,
- loop_vars=[
- np.int32(0), # iter_
- array_ops.zeros_like(current_log_prob), # ais_weights
- current_state,
- _make_dummy_kernel_results(current_state,
- current_log_prob,
- current_grads_log_prob),
- ])
- if seed is not None:
- while_loop_kwargs["parallel_iterations"] = 1
-
- [ais_weights, current_state, kernel_results] = control_flow_ops.while_loop(
- **while_loop_kwargs)[1:] # Lop-off "iter_".
-
- return [current_state, ais_weights, kernel_results]
-
-
def kernel(target_log_prob_fn,
current_state,
step_size,
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index 91874f9b5c..300b19733e 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -147,7 +147,9 @@ class TPUClusterResolver(ClusterResolver):
if service is None and should_resolve:
if not _GOOGLE_API_CLIENT_INSTALLED:
raise ImportError('googleapiclient must be installed before using the '
- 'TPU cluster resolver')
+ 'TPU cluster resolver. Execute: `pip install '
+ '--upgrade google-api-python-client` to install with '
+ 'pip.')
self._service = discovery.build(
'tpu', 'v1alpha1',
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 54921aeec6..9212b69700 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -31,6 +31,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview.
@@enumerate_dataset
@@group_by_window
@@ignore_errors
+@@make_batched_features_dataset
@@make_saveable_from_iterator
@@map_and_batch
@@padded_batch_and_drop_remainder
@@ -66,6 +67,7 @@ from tensorflow.contrib.data.python.ops.grouping import group_by_window
from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
+from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset
from tensorflow.contrib.data.python.ops.readers import read_batch_features
from tensorflow.contrib.data.python.ops.readers import SqlDataset
from tensorflow.contrib.data.python.ops.resampling import rejection_resample
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index a157acc020..107625ab75 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -297,6 +297,7 @@ py_test(
"//tensorflow/python:parsing_ops",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:iterator_ops",
+ "//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index 6efe97444a..15bd55bf64 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -21,6 +21,8 @@ import gzip
import os
import zlib
+import numpy as np
+
from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import readers
from tensorflow.core.example import example_pb2
@@ -262,12 +264,19 @@ class ReadBatchFeaturesTest(test.TestCase):
self._num_records = 7
self.test_filenames = self._createFiles()
- def _read_batch_features(self, filenames, num_epochs, batch_size):
+ def _read_batch_features(self,
+ filenames,
+ num_epochs,
+ batch_size,
+ reader_num_threads=1,
+ parser_num_threads=1,
+ shuffle=False,
+ shuffle_seed=None):
self.filenames = filenames
self.num_epochs = num_epochs
self.batch_size = batch_size
- return readers.read_batch_features(
+ return readers.make_batched_features_dataset(
file_pattern=self.filenames,
batch_size=self.batch_size,
features={
@@ -276,8 +285,12 @@ class ReadBatchFeaturesTest(test.TestCase):
"keywords": parsing_ops.VarLenFeature(dtypes.string)
},
reader=core_readers.TFRecordDataset,
- randomize_input=False,
- num_epochs=self.num_epochs)
+ num_epochs=self.num_epochs,
+ shuffle=shuffle,
+ shuffle_seed=shuffle_seed,
+ reader_num_threads=reader_num_threads,
+ parser_num_threads=parser_num_threads).make_one_shot_iterator(
+ ).get_next()
def _record(self, f, r):
example = example_pb2.Example(features=feature_pb2.Features(
@@ -312,24 +325,35 @@ class ReadBatchFeaturesTest(test.TestCase):
writer.close()
return filenames
- def _next_actual_batch(self, sess):
- file_op = self.outputs["file"]
- keywords_indices_op = self.outputs["keywords"].indices
- keywords_values_op = self.outputs["keywords"].values
- keywords_dense_shape_op = self.outputs["keywords"].dense_shape
- record_op = self.outputs["record"]
+ def _run_actual_batch(self, outputs, sess):
+ file_op = outputs["file"]
+ keywords_indices_op = outputs["keywords"].indices
+ keywords_values_op = outputs["keywords"].values
+ keywords_dense_shape_op = outputs["keywords"].dense_shape
+ record_op = outputs["record"]
return sess.run([
file_op, keywords_indices_op, keywords_values_op,
keywords_dense_shape_op, record_op
])
- def _next_expected_batch(self, file_indices, batch_size, num_epochs):
+ def _next_actual_batch(self, sess):
+ return self._run_actual_batch(self.outputs, sess)
+
+ def _next_expected_batch(self,
+ file_indices,
+ batch_size,
+ num_epochs,
+ cycle_length=1):
def _next_record(file_indices):
for j in file_indices:
for i in range(self._num_records):
yield j, i
+ def _next_record_interleaved(file_indices, cycle_length):
+ return self._interleave([_next_record([i]) for i in file_indices],
+ cycle_length)
+
file_batch = []
keywords_batch_indices = []
keywords_batch_values = []
@@ -337,7 +361,11 @@ class ReadBatchFeaturesTest(test.TestCase):
record_batch = []
batch_index = 0
for _ in range(num_epochs):
- for record in _next_record(file_indices):
+ if cycle_length == 1:
+ next_records = _next_record(file_indices)
+ else:
+ next_records = _next_record_interleaved(file_indices, cycle_length)
+ for record in next_records:
f = record[0]
r = record[1]
file_batch.append(f)
@@ -365,14 +393,41 @@ class ReadBatchFeaturesTest(test.TestCase):
[len(file_batch), keywords_batch_max_len], record_batch
]
- def _verify_records(self, sess, batch_size, file_index=None, num_epochs=1):
+ def _interleave(self, iterators, cycle_length):
+ pending_iterators = iterators
+ open_iterators = []
+ num_open = 0
+ for i in range(cycle_length):
+ if pending_iterators:
+ open_iterators.append(pending_iterators.pop(0))
+ num_open += 1
+
+ while num_open:
+ for i in range(min(cycle_length, len(open_iterators))):
+ if open_iterators[i] is None:
+ continue
+ try:
+ yield next(open_iterators[i])
+ except StopIteration:
+ if pending_iterators:
+ open_iterators[i] = pending_iterators.pop(0)
+ else:
+ open_iterators[i] = None
+ num_open -= 1
+
+ def _verify_records(self,
+ sess,
+ batch_size,
+ file_index=None,
+ num_epochs=1,
+ interleave_cycle_length=1):
if file_index is not None:
file_indices = [file_index]
else:
file_indices = range(self._num_files)
- for expected_batch in self._next_expected_batch(file_indices, batch_size,
- num_epochs):
+ for expected_batch in self._next_expected_batch(
+ file_indices, batch_size, num_epochs, interleave_cycle_length):
actual_batch = self._next_actual_batch(sess)
for i in range(len(expected_batch)):
self.assertAllEqual(expected_batch[i], actual_batch[i])
@@ -435,6 +490,75 @@ class ReadBatchFeaturesTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
+ def testReadWithFusedShuffleRepeatDataset(self):
+ num_epochs = 5
+ total_records = num_epochs * self._num_records
+ for batch_size in [1, 2]:
+ # Test that shuffling with same seed produces the same result.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ outputs1 = self._read_batch_features(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5)
+ outputs2 = self._read_batch_features(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5)
+ for _ in range(total_records // batch_size):
+ batch1 = self._run_actual_batch(outputs1, sess)
+ batch2 = self._run_actual_batch(outputs2, sess)
+ for i in range(len(batch1)):
+ self.assertAllEqual(batch1[i], batch2[i])
+
+ # Test that shuffling with different seeds produces a different order.
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ outputs1 = self._read_batch_features(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5)
+ outputs2 = self._read_batch_features(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=15)
+ all_equal = True
+ for _ in range(total_records // batch_size):
+ batch1 = self._run_actual_batch(outputs1, sess)
+ batch2 = self._run_actual_batch(outputs2, sess)
+ for i in range(len(batch1)):
+ all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
+ self.assertFalse(all_equal)
+
+ def testParallelReadersAndParsers(self):
+ num_epochs = 5
+ for batch_size in [1, 2]:
+ for reader_num_threads in [2, 4]:
+ for parser_num_threads in [2, 4]:
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ self.outputs = self._read_batch_features(
+ filenames=self.test_filenames,
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ reader_num_threads=reader_num_threads,
+ parser_num_threads=parser_num_threads)
+ self._verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ interleave_cycle_length=reader_num_threads)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 1c26296d62..fd871ef5ce 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -67,6 +67,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":dataset_ops",
+ ":shuffle_ops",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 57f3010277..b346bed3e6 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -17,7 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.contrib.data.python.ops import shuffle_ops
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -25,12 +28,150 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import gfile
+from tensorflow.python.util import deprecation
+def make_batched_features_dataset(file_pattern,
+ batch_size,
+ features,
+ reader=core_readers.TFRecordDataset,
+ reader_args=None,
+ num_epochs=None,
+ shuffle=True,
+ shuffle_buffer_size=10000,
+ shuffle_seed=None,
+ prefetch_buffer_size=1,
+ reader_num_threads=1,
+ parser_num_threads=2,
+ sloppy_ordering=False):
+ """Returns a `Dataset` of feature dictionaries from `Example` protos.
+
+ Example:
+
+ ```
+ serialized_examples = [
+ features {
+ feature { key: "age" value { int64_list { value: [ 0 ] } } }
+ feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
+ feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } }
+ },
+ features {
+ feature { key: "age" value { int64_list { value: [] } } }
+ feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
+ feature { key: "kws" value { bytes_list { value: [ "sports" ] } } }
+ }
+ ]
+ ```
+
+ We can use arguments:
+
+ ```
+ features: {
+ "age": FixedLenFeature([], dtype=tf.int64, default_value=-1),
+ "gender": FixedLenFeature([], dtype=tf.string),
+ "kws": VarLenFeature(dtype=tf.string),
+ }
+ ```
+
+ And the expected output is:
+
+ ```python
+ {
+ "age": [[0], [-1]],
+ "gender": [["f"], ["f"]],
+ "kws": SparseTensor(
+ indices=[[0, 0], [0, 1], [1, 0]],
+ values=["code", "art", "sports"]
+ dense_shape=[2, 2]),
+ }
+ ```
+
+ Args:
+ file_pattern: List of files or patterns of file paths containing
+ `Example` records. See `tf.gfile.Glob` for pattern rules.
+ batch_size: An int representing the number of consecutive elements of this
+ dataset to combine in a single batch.
+ features: A `dict` mapping feature keys to `FixedLenFeature` or
+ `VarLenFeature` values. See `tf.parse_example`.
+ reader: A function or class that can be
+ called with a `filenames` tensor and (optional) `reader_args` and returns
+ a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`.
+ reader_args: Additional arguments to pass to the reader class.
+ num_epochs: Integer specifying the number of times to read through the
+ dataset. If None, cycles through the dataset forever. Defaults to `None`.
+ shuffle: A boolean, indicates whether the input should be shuffled. Defaults
+ to `True`.
+ shuffle_buffer_size: Buffer size of the ShuffleDataset. A large capacity
+ ensures better shuffling but would increase memory usage and startup time.
+ shuffle_seed: Randomization seed to use for shuffling.
+ prefetch_buffer_size: Number of feature batches to prefetch in order to
+ improve performance. Recommended value is the number of batches consumed
+ per training step (default is 1).
+ reader_num_threads: Number of threads used to read `Example` records. If >1,
+ the results will be interleaved.
+ parser_num_threads: Number of threads to use for parsing `Example` tensors
+ into a dictionary of `Feature` tensors.
+ sloppy_ordering: If `True`, reading performance will be improved at
+ the cost of non-deterministic ordering. If `False`, the order of elements
+ produced is deterministic prior to shuffling (elements are still
+ randomized if `shuffle=True`. Note that if the seed is set, then order
+ of elements after shuffling is deterministic). Defaults to `False`.
+
+ Returns:
+ A dataset of `dict` elements. Each `dict` maps feature keys to
+ `Tensor` or `SparseTensor` objects.
+ """
+ # Create dataset of all matching filenames
+ if shuffle:
+ dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=True)
+ else:
+ # TODO(b/73959787): Use Dataset.list_files() once ordering is deterministic.
+ filenames = _get_file_names(file_pattern, shuffle)
+ dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
+
+ # Read `Example` records from files as tensor objects.
+ if reader_args is None:
+ reader_args = []
+
+ # Read files sequentially (if reader_num_threads=1) or in parallel
+ dataset = dataset.apply(
+ interleave_ops.parallel_interleave(
+ lambda filename: reader(filename, *reader_args),
+ cycle_length=reader_num_threads,
+ sloppy=sloppy_ordering))
+
+ # Extract values if the `Example` tensors are stored as key-value tuples.
+ if dataset.output_types == (dtypes.string, dtypes.string):
+ dataset = dataset.map(lambda _, v: v)
+
+ # Apply dataset repeat and shuffle transformations.
+ repeat_dataset = (num_epochs != 1)
+ if repeat_dataset and shuffle:
+ # Used fused shuffle_and_repeat operation for better performance
+ dataset = dataset.apply(
+ shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs,
+ shuffle_seed))
+ elif repeat_dataset:
+ dataset = dataset.repeat(num_epochs)
+ elif shuffle:
+ dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed)
+
+ dataset = dataset.batch(batch_size)
+
+ # Parse `Example` tensors to a dictionary of `Feature` tensors.
+ dataset = dataset.map(
+ lambda x: parsing_ops.parse_example(x, features),
+ num_parallel_calls=parser_num_threads)
+ dataset = dataset.prefetch(prefetch_buffer_size)
+ return dataset
+
+
+@deprecation.deprecated(None,
+ "Use `tf.contrib.data.make_batched_features_dataset`")
def read_batch_features(file_pattern,
batch_size,
features,
- reader,
+ reader=core_readers.TFRecordDataset,
reader_args=None,
randomize_input=True,
num_epochs=None,
@@ -84,43 +225,38 @@ def read_batch_features(file_pattern,
dataset to combine in a single batch.
features: A `dict` mapping feature keys to `FixedLenFeature` or
`VarLenFeature` values. See `tf.parse_example`.
- reader: A function or class that can be called with a `filenames` tensor
- and (optional) `reader_args` and returns a `Dataset` of Examples.
+ reader: A function or class that can be
+ called with a `filenames` tensor and (optional) `reader_args` and returns
+ a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`.
reader_args: Additional arguments to pass to the reader class.
randomize_input: Whether the input should be randomized.
num_epochs: Integer specifying the number of times to read through the
dataset. If None, cycles through the dataset forever.
- capacity: Capacity of the ShuffleDataset. A large capacity ensures better
+ capacity: Buffer size of the ShuffleDataset. A large capacity ensures better
shuffling but would increase memory usage and startup time.
-
Returns:
A dict from keys in features to `Tensor` or `SparseTensor` objects.
"""
- filenames = _get_file_names(file_pattern, randomize_input)
- if reader_args:
- dataset = reader(filenames, *reader_args)
- else:
- dataset = reader(filenames)
- if dataset.output_types == (dtypes.string, dtypes.string):
- dataset = dataset.map(lambda _, v: v)
- if num_epochs != 1:
- dataset = dataset.repeat(num_epochs)
- if randomize_input:
- dataset = dataset.shuffle(capacity)
- dataset = dataset.batch(batch_size)
- dataset = dataset.map(lambda x: parsing_ops.parse_example(x, features))
- dataset = dataset.prefetch(1)
+ dataset = make_batched_features_dataset(
+ file_pattern,
+ batch_size,
+ features,
+ reader=reader,
+ reader_args=reader_args,
+ shuffle=randomize_input,
+ num_epochs=num_epochs,
+ shuffle_buffer_size=capacity)
iterator = dataset.make_one_shot_iterator()
outputs = iterator.get_next()
return outputs
-def _get_file_names(file_pattern, randomize_input):
+def _get_file_names(file_pattern, shuffle):
"""Parse list of file names from pattern, optionally shuffled.
Args:
file_pattern: File glob pattern, or list of glob patterns.
- randomize_input: Whether to shuffle the order of file names.
+ shuffle: Whether to shuffle the order of file names.
Returns:
List of file names matching `file_pattern`.
@@ -141,7 +277,7 @@ def _get_file_names(file_pattern, randomize_input):
raise ValueError("No files match %s." % file_pattern)
# Sort files so it will be deterministic for unit tests.
- if not randomize_input:
+ if not shuffle:
file_names = sorted(file_names)
return file_names
diff --git a/tensorflow/contrib/eager/python/checkpointable_utils.py b/tensorflow/contrib/eager/python/checkpointable_utils.py
index 1fa150f3c6..d07121df63 100644
--- a/tensorflow/contrib/eager/python/checkpointable_utils.py
+++ b/tensorflow/contrib/eager/python/checkpointable_utils.py
@@ -493,8 +493,9 @@ class NameBasedSaverStatus(_LoadStatus):
"""Load the name-based training checkpoint using a new `tf.train.Saver`."""
if session is None and not context.executing_eagerly():
session = ops.get_default_session()
- saver_lib.Saver(self._object_saver._global_variable_names()).restore( # pylint: disable=protected-access
- sess=session, save_path=self._save_path)
+ with ops.device("/cpu:0"):
+ saver_lib.Saver(self._object_saver._global_variable_names()).restore( # pylint: disable=protected-access
+ sess=session, save_path=self._save_path)
def initialize_or_restore(self, session=None):
"""Alias for `run_restore_ops`."""
diff --git a/tensorflow/contrib/eager/python/checkpointable_utils_test.py b/tensorflow/contrib/eager/python/checkpointable_utils_test.py
index fd9fc098b3..2054878bf8 100644
--- a/tensorflow/contrib/eager/python/checkpointable_utils_test.py
+++ b/tensorflow/contrib/eager/python/checkpointable_utils_test.py
@@ -993,20 +993,21 @@ class CheckpointCompatibilityTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testLoadFromNameBasedSaver(self):
"""Save a name-based checkpoint, load it using the object-based API."""
- save_path = self._write_name_based_checkpoint()
- root = self._initialized_model()
- self._set_sentinels(root)
- with self.assertRaises(AssertionError):
+ with test_util.device(use_gpu=True):
+ save_path = self._write_name_based_checkpoint()
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ with self.assertRaises(AssertionError):
+ self._check_sentinels(root)
+ object_saver = checkpointable_utils.CheckpointableSaver(root)
+ status = object_saver.restore(save_path)
+ with self.assertRaises(AssertionError):
+ status.assert_consumed()
+ status.run_restore_ops()
+ self._check_sentinels(root)
+ self._set_sentinels(root)
+ status.initialize_or_restore()
self._check_sentinels(root)
- object_saver = checkpointable_utils.CheckpointableSaver(root)
- status = object_saver.restore(save_path)
- with self.assertRaises(AssertionError):
- status.assert_consumed()
- status.run_restore_ops()
- self._check_sentinels(root)
- self._set_sentinels(root)
- status.initialize_or_restore()
- self._check_sentinels(root)
# TODO(allenl): Test for the core name-based saver loading object-based
# checkpoints once object-based checkpointing is in core.
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index c7290c2aaa..aa3957bee1 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -213,7 +213,10 @@ cc_library(
"compatibility.h",
"quantization_util.h",
],
- deps = [":round"],
+ deps = [
+ ":round",
+ ":types",
+ ],
)
cc_test(
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
index b84d2f9ee1..f7706c7938 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
@@ -15,10 +15,88 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_
+#include <cmath>
#include <cstdint>
+#include <limits>
+
+#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/kernels/internal/round.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
+// Given the min and max values of a float array, return
+// reasonable quantization parameters to use for this array.
+template <typename T>
+QuantizationParams ChooseQuantizationParams(double rmin, double rmax) {
+ const T qmin = std::numeric_limits<T>::min();
+ const T qmax = std::numeric_limits<T>::max();
+ const double qmin_double = qmin;
+ const double qmax_double = qmax;
+ // 0 should always be a representable value. Let's assume that the initial
+ // min,max range contains 0.
+ TFLITE_CHECK_LE(rmin, 0.);
+ TFLITE_CHECK_GE(rmax, 0.);
+ if (rmin == rmax) {
+ // Special case where the min,max range is a point. Should be {0}.
+ TFLITE_CHECK_EQ(rmin, 0.);
+ TFLITE_CHECK_EQ(rmax, 0.);
+ QuantizationParams quantization_params;
+ quantization_params.zero_point = 0;
+ quantization_params.scale = 0.;
+ return quantization_params;
+ }
+
+ // General case.
+ //
+ // First determine the scale.
+ const double scale = (rmax - rmin) / (qmax_double - qmin_double);
+
+ // Zero-point computation.
+ // First the initial floating-point computation. The zero-point can be
+ // determined from solving an affine equation for any known pair
+ // (real value, corresponding quantized value).
+ // We know two such pairs: (rmin, qmin) and (rmax, qmax).
+ // The arithmetic error on the zero point computed from either pair
+ // will be roughly machine_epsilon * (sum of absolute values of terms)
+ // so we want to use the variant that adds the smaller terms.
+ const double zero_point_from_min = qmin_double - rmin / scale;
+ const double zero_point_from_max = qmax_double - rmax / scale;
+ const double zero_point_from_min_error =
+ std::abs(qmin_double) + std::abs(rmin / scale);
+ const double zero_point_from_max_error =
+ std::abs(qmax_double) + std::abs(rmax / scale);
+
+ const double zero_point_double =
+ zero_point_from_min_error < zero_point_from_max_error
+ ? zero_point_from_min
+ : zero_point_from_max;
+
+ // Now we need to nudge the zero point to be an integer
+ // (our zero points are integer, and this is motivated by the requirement
+ // to be able to represent the real value "0" exactly as a quantized value,
+ // which is required in multiple places, for example in Im2col with SAME
+ // padding).
+ T nudged_zero_point = 0;
+ if (zero_point_double < qmin_double) {
+ nudged_zero_point = qmin;
+ } else if (zero_point_double > qmax_double) {
+ nudged_zero_point = qmax;
+ } else {
+ nudged_zero_point = static_cast<T>(round(zero_point_double));
+ }
+ // The zero point should always be in the range of quantized value,
+ // [qmin, qmax].
+ TFLITE_CHECK_GE(nudged_zero_point, qmin);
+ TFLITE_CHECK_LE(nudged_zero_point, qmax);
+
+ // Finally, store the result nudged quantization params.
+ QuantizationParams quantization_params;
+ quantization_params.zero_point = nudged_zero_point;
+ quantization_params.scale = scale;
+ return quantization_params;
+}
+
// Decompose a double multiplier into a Q0.31 int32 representation of its
// significand, and shift representation of NEGATIVE its exponent ---
// this is intended as a RIGHT-shift.
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
index 19b1b408ec..4ae2085c30 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
@@ -22,6 +22,51 @@ namespace {
using ::testing::Pair;
+// Example taken from http://www.tensorflow.org/performance/quantization
+//
+// Quantized | Float
+// --------- | -----
+// 0 | -10.0
+// 255 | 30.0
+// 128 | 10.0
+TEST(QuantizationUtilTest, ChooseQuantizationParams) {
+ QuantizationParams qp = ChooseQuantizationParams<uint8>(-10.0, 30.0);
+ EXPECT_NEAR(qp.scale, 0.156863, 1e-5);
+ EXPECT_EQ(qp.zero_point, 64);
+}
+
+TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMinBoundary) {
+ QuantizationParams qp = ChooseQuantizationParams<uint8>(0.0, 30.0);
+ EXPECT_NEAR(qp.scale, 0.117647, 1e-5);
+ EXPECT_EQ(qp.zero_point, 0);
+}
+
+TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroNotInRange) {
+ // Assumption is that zero is within the range.
+ EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, 30.0), "");
+}
+
+TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangePositive) {
+ // Assumption is that zero is within the range.
+ EXPECT_DEATH(ChooseQuantizationParams<uint8>(30.0, 30.0), "");
+}
+
+TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangeZero) {
+ QuantizationParams qp = ChooseQuantizationParams<uint8>(0.0, 0.0);
+ EXPECT_NEAR(qp.scale, 0.0, 1e-5);
+ EXPECT_EQ(qp.zero_point, 0);
+}
+
+TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMaxBoundary) {
+ QuantizationParams qp = ChooseQuantizationParams<uint8>(-10.0, 0.0);
+ EXPECT_NEAR(qp.scale, 0.039216, 1e-5);
+ EXPECT_EQ(qp.zero_point, 255);
+}
+
+TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) {
+ EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, -30.0), "");
+}
+
TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOne) {
auto quantize = [](double d) {
int32_t q;
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index afe131b06e..293538fcbb 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -21,6 +21,22 @@ namespace tflite {
enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu };
+// Quantization parameters, determining the mapping of quantized values
+// to real values (i.e. determining how quantized values are mathematically
+// interpreted).
+//
+// The correspondence is as follows:
+//
+// real_value = scale * (quantized_value - zero_point);
+//
+// In other words, zero_point designates which quantized value corresponds to
+// the real 0 value, and scale designates the difference between the real values
+// corresponding to consecutive quantized values differing by 1.
+struct QuantizationParams {
+ int32 zero_point = 0;
+ double scale = 0.0;
+};
+
template <int N>
struct Dims {
int sizes[N];
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 845bc0460f..031db2bd7c 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -329,6 +329,7 @@ cc_library(
":toco_graphviz_dump_options",
":toco_port",
":types_proto_cc",
+ "//tensorflow/contrib/lite/kernels/internal:quantization_util",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@protobuf_archive//:protobuf_headers",
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index cd3eb06602..3fa0089cba 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -29,6 +29,8 @@ limitations under the License.
namespace toco {
+using tflite::QuantizationParams;
+
enum class OperatorType {
kNone,
// General-purpose neural network operators.
@@ -1463,22 +1465,6 @@ inline bool operator<(const Alloc& a, const Alloc& b) {
return a.start < b.start;
}
-// Quantization parameters, determining the mapping of quantized values
-// to real values (i.e. determining how quantized values are mathematically
-// interpreted).
-//
-// The correspondence is as follows:
-//
-// real_value = scale * (quantized_value - zero_point);
-//
-// In other words, zero_point designates which quantized value corresponds to
-// the real 0 value, and scale designates the difference between the real values
-// corresponding to consecutive quantized values differing by 1.
-struct QuantizationParams {
- int32 zero_point = 0;
- double scale = 0.;
-};
-
class Shape {
public:
// For Shape, we stick to half-way encapsulation for now:
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index d5796486c5..05360e3b0a 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -28,6 +28,7 @@ limitations under the License.
#if TOCO_SUPPORT_PORTABLE_PROTOS
#include "third_party/protobuf/src/google/protobuf/text_format.h"
#endif // TOCO_SUPPORT_PORTABLE_PROTOS
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/runtime/types.h"
@@ -149,71 +150,11 @@ template <ArrayDataType A>
void GetQuantizationParamsFromMinMax(const MinMax& minmax,
QuantizationParams* quantization_params) {
using Integer = DataType<A>;
- const Integer qmin = std::numeric_limits<Integer>::min();
- const Integer qmax = std::numeric_limits<Integer>::max();
- const double qmin_double = qmin;
- const double qmax_double = qmax;
const double rmin = minmax.min;
const double rmax = minmax.max;
- // 0 should always be a representable value. Let's assume that the initial
- // min,max range contains 0.
- CHECK_LE(rmin, 0.);
- CHECK_GE(rmax, 0.);
- if (rmin == rmax) {
- // Special case where the min,max range is a point. Should be {0}.
- CHECK_EQ(rmin, 0.);
- CHECK_EQ(rmax, 0.);
- quantization_params->zero_point = 0;
- quantization_params->scale = 0.;
- return;
- }
- // General case.
- //
- // First determine the scale.
- const double scale = (rmax - rmin) / (qmax_double - qmin_double);
-
- // Zero-point computation.
- // First the initial floating-point computation. The zero-point can be
- // determined from solving an affine equation for any known pair
- // (real value, corresponding quantized value).
- // We know two such pairs: (rmin, qmin) and (rmax, qmax).
- // The arithmetic error on the zero point computed from either pair
- // will be roughly machine_epsilon * (sum of absolute values of terms)
- // so we want to use the variant that adds the smaller terms.
- const double zero_point_from_min = qmin_double - rmin / scale;
- const double zero_point_from_max = qmax_double - rmax / scale;
- const double zero_point_from_min_error =
- std::abs(qmin_double) + std::abs(rmin / scale);
- const double zero_point_from_max_error =
- std::abs(qmax_double) + std::abs(rmax / scale);
-
- const double zero_point_double =
- zero_point_from_min_error < zero_point_from_max_error
- ? zero_point_from_min
- : zero_point_from_max;
-
- // Now we need to nudge the zero point to be an integer
- // (our zero points are integer, and this is motivated by the requirement
- // to be able to represent the real value "0" exactly as a quantized value,
- // which is required in multiple places, for example in Im2col with SAME
- // padding).
- Integer nudged_zero_point = 0;
- if (zero_point_double < qmin_double) {
- nudged_zero_point = qmin;
- } else if (zero_point_double > qmax_double) {
- nudged_zero_point = qmax;
- } else {
- nudged_zero_point = static_cast<Integer>(std::round(zero_point_double));
- }
- // The zero point should always be in the range of quantized value,
- // [qmin, qmax].
- CHECK_GE(nudged_zero_point, qmin);
- CHECK_LE(nudged_zero_point, qmax);
-
- // Finally, store the result nudged quantization params.
- quantization_params->zero_point = nudged_zero_point;
- quantization_params->scale = scale;
+ *quantization_params =
+ ::tflite::ChooseQuantizationParams<Integer>(rmin, rmax);
}
void CheckIsReadyForQuantization(const Model& model);
diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py
index 51b67bd6fa..465c668fd8 100644
--- a/tensorflow/contrib/tpu/python/tpu/datasets.py
+++ b/tensorflow/contrib/tpu/python/tpu/datasets.py
@@ -117,7 +117,7 @@ def StreamingFilesDataset(files,
file_reader_job = file_reader_job or 'coordinator'
- worker_job = worker_job or 'tpu_worker'
+ worker_job = worker_job or 'worker'
if filename_shuffle_buffer_size is None:
filename_shuffle_buffer_size = 4096
diff --git a/tensorflow/contrib/tpu/python/tpu/datasets_test.py b/tensorflow/contrib/tpu/python/tpu/datasets_test.py
index 6e6a7ce809..918cf0ed8e 100644
--- a/tensorflow/contrib/tpu/python/tpu/datasets_test.py
+++ b/tensorflow/contrib/tpu/python/tpu/datasets_test.py
@@ -44,7 +44,7 @@ class DatasetsTest(test.TestCase):
self._cluster_def = cluster_pb2.ClusterDef()
worker_job = self._cluster_def.job.add()
- worker_job.name = 'tpu_worker'
+ worker_job.name = 'worker'
worker_job.tasks[0] = self._worker.target[len('grpc://'):]
coord_job = self._cluster_def.job.add()
coord_job.name = 'coordinator'
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 10f4d42147..27a96217fd 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1951,6 +1951,7 @@ tf_kernel_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
],
)
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index e3c78d6b70..7c302e2fc2 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/graph/gradients.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@@ -317,6 +318,8 @@ class RemoteCallOp : public AsyncOpKernel {
if (cached_entry != handle_cache_.end()) {
handle = cached_entry->second;
} else {
+ port::Tracing::TraceMe activity(strings::StrCat(
+ "RemoteCall: Instantiate: ", func_.name(), " on ", target_device));
OP_REQUIRES_OK_ASYNC(
ctx,
lib->Instantiate(func_.name(), AttrSlice(&attr_values),
@@ -344,21 +347,24 @@ class RemoteCallOp : public AsyncOpKernel {
args.push_back(argument);
}
auto* rets = new std::vector<Tensor>;
- lib->Run(opts, handle, args, rets, [rets, done, ctx](const Status& status) {
- if (!status.ok()) {
- ctx->SetStatus(status);
- } else {
- for (size_t i = 0; i < rets->size(); ++i) {
- ctx->set_output(i, (*rets)[i]);
- }
- }
- delete rets;
- done();
- });
+ auto* trace = new port::Tracing::TraceMe(strings::StrCat(
+ "RemoteCall: Run: ", func_.name(), " on ", target_device));
+ lib->Run(opts, handle, args, rets,
+ [rets, trace, done, ctx](const Status& status) {
+ if (!status.ok()) {
+ ctx->SetStatus(status);
+ } else {
+ for (size_t i = 0; i < rets->size(); ++i) {
+ ctx->set_output(i, (*rets)[i]);
+ }
+ }
+ delete rets;
+ delete trace;
+ done();
+ });
}
private:
- string target_;
NameAttrList func_;
mutex mu_;
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index d6715fa522..5a9cd7531d 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -139,6 +139,10 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import tensor_array_ops
+# Eager execution
+from tensorflow.python.eager.context import executing_eagerly
+from tensorflow.python.framework.ops import enable_eager_execution
+
# Symbols whitelisted for export without documentation.
# TODO(cwhipkey): review these and move to contrib, expose through
# documentation, or remove.
@@ -290,6 +294,12 @@ _allowed_symbols.extend([
'MONOLITHIC_BUILD',
])
+# Eager execution
+_allowed_symbols.extend([
+ 'enable_eager_execution',
+ 'executing_eagerly',
+])
+
# Remove all extra symbols that don't have a docstring or are not explicitly
# referenced in the whitelist.
remove_undocumented(__name__, _allowed_symbols, [
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 5d13aada63..87d3ed880a 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -32,6 +32,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.util import compat
from tensorflow.python.util import is_in_graph_mode
from tensorflow.python.util import tf_contextlib
+from tensorflow.python.util.tf_export import tf_export
GRAPH_MODE = 0
EAGER_MODE = 1
@@ -518,8 +519,14 @@ def internal_operation_seed():
return context()._internal_operation_seed() # pylint: disable=protected-access
+@tf_export("executing_eagerly")
def executing_eagerly():
- """Returns True if the current thread has eager execution enabled."""
+ """Returns True if the current thread has eager execution enabled.
+
+ Eager execution is typically enabled via @{tf.enable_eager_execution},
+ but may also be enabled within the context of a Python function via
+ tf.contrib.eager.py_func.
+ """
return context().executing_eagerly()
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 8ff247fdb1..f5dde3a358 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -5169,41 +5169,60 @@ def init_scope():
yield
+@tf_export("enable_eager_execution")
def enable_eager_execution(config=None, device_policy=None):
- """Enables, for the rest of the lifetime of this program, eager execution.
+ """Enables eager execution for the lifetime of this program.
- If not called immediately on startup risks creating breakage and bugs.
+ Eager execution provides an imperative interface to TensorFlow. With eager
+ execution enabled, TensorFlow functions execute operations immediately (as
+ opposed to adding to a graph to be executed later in a @{tf.Session}) and
+ return concrete values (as opposed to symbolic references to a node in a
+ computational graph).
- Example:
+ For example:
```python
- tfe.enable_eager_execution()
+ tf.enable_eager_execution()
# After eager execution is enabled, operations are executed as they are
- # defined and `Tensor`s hold concrete values, which can be accessed as
- # `numpy.ndarray`s through the `numpy()` method.
+ # defined and Tensor objects hold concrete values, which can be accessed as
+ # numpy.ndarray`s through the numpy() method.
assert tf.multiply(6, 7).numpy() == 42
```
+ Eager execution cannot be enabled after TensorFlow APIs have been used to
+ create or execute graphs. It is typically recommended to invoke this function
+ at program startup and not in a library (as most libraries should be usable
+ both with and without eager execution).
+
Args:
- config: (Optional.) A `ConfigProto` protocol buffer with configuration
- options for the Context. Note that a lot of these options may be
- currently unimplemented or irrelevant when eager execution is enabled.
- device_policy: (Optional.) What policy to use when trying to run an
- operation on a device with inputs which are not on that device.
+ config: (Optional.) A @{tf.ConfigProto} to use to configure the environment
+ in which operations are executed. Note that @{tf.ConfigProto} is also
+ used to configure graph execution (via @{tf.Session}) and many options
+ within `tf.ConfigProto` are not implemented (or are irrelevant) when
+ eager execution is enabled.
+ device_policy: (Optional.) Policy controlling how operations requiring
+ inputs on a specific device (e.g., a GPU 0) handle inputs on a different
+ device (e.g. GPU 1 or CPU).
Valid values:
- tfe.DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not
- correct.
- tfe.DEVICE_PLACEMENT_WARN: copies the tensors which are not on the
- right device but raises a warning.
- tfe.DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might
- hide performance problems.
- tfe.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors,
- raising errors on the other ones.
+
+ - tf.contrib.eager.DEVICE_PLACEMENT_EXPLICIT: raises an error if the
+ placement is not correct.
+
+ - tf.contrib.eager.DEVICE_PLACEMENT_WARN: copies the tensors which are not
+ on the right device but logs a warning.
+
+ - tf.contrib.eager.DEVICE_PLACEMENT_SILENT: silently copies the tensors.
+ Note that this may hide performance problems as there is no notification
+ provided when operations are blocked on the tensor being copied between
+ devices.
+
+ - tf.contrib.eager.DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies
+ int32 tensors, raising errors on the other ones.
Raises:
- ValueError: If trying to create a context after using graph operations
- or if trying to create a context with nontrivial options which differ
- from those of the existing context.
+ ValueError: If eager execution is enabled after creating/executing a
+ TensorFlow graph, or if options provided conflict with a previous call
+ to this function.
"""
if config is not None and not isinstance(config, config_pb2.ConfigProto):
raise TypeError(
@@ -5213,7 +5232,7 @@ def enable_eager_execution(config=None, device_policy=None):
context.DEVICE_PLACEMENT_SILENT,
context.DEVICE_PLACEMENT_SILENT_FOR_INT32):
raise ValueError(
- "device_policy must be one of None, tfe.DEVICE_PLACEMENT_*"
+ "device_policy must be one of None, tf.contrib.eager.DEVICE_PLACEMENT_*"
)
# pylint: disable=protected-access
if context._default_mode == context.GRAPH_MODE:
@@ -5222,7 +5241,7 @@ def enable_eager_execution(config=None, device_policy=None):
_default_graph_stack._global_default_graph is not None)
if graph_mode_has_been_used:
raise ValueError(
- "tfe.enable_eager_execution has to be called at program startup.")
+ "tf.enable_eager_execution must be called at program startup.")
context._default_mode = context.EAGER_MODE
if context._context is None:
context._context = context.Context(config=config,
@@ -5245,7 +5264,7 @@ def enable_eager_execution(config=None, device_policy=None):
context._context._device_policy))
else:
raise ValueError(
- "tfe.enable_eager_execution has to be called at program startup.")
+ "tf.enable_eager_execution must be called at program startup.")
def eager_run(main=None, argv=None):
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
index 63203a0043..36142801d6 100644
--- a/tensorflow/python/kernel_tests/py_func_test.py
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import re
+
import numpy as np
from six.moves import queue
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -356,12 +358,22 @@ class PyFuncTest(test.TestCase):
def _testExceptionHandling(self, py_exp, tf_exp, eager=False):
- def raise_exception():
+ def inner_exception():
raise py_exp("blah") # pylint: disable=not-callable
+ def raise_exception():
+ inner_exception()
+
+ expected_regexp = r": blah.*" # Error at the top
+ expected_regexp += r"in raise_exception.*" # Stacktrace outer
+ expected_regexp += r"in inner_exception.*" # Stacktrace inner
+ expected_regexp += r": blah" # Stacktrace of raise
+ def expected_error_check(exception):
+ return re.search(expected_regexp, str(exception), re.DOTALL)
+
if eager:
if context.executing_eagerly():
- with self.assertRaisesRegexp(tf_exp, "blah"):
+ with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
f = script_ops.eager_py_func(raise_exception, [], [])
return
else:
@@ -370,7 +382,7 @@ class PyFuncTest(test.TestCase):
f = script_ops.py_func(raise_exception, [], [])
with self.test_session():
- with self.assertRaisesRegexp(tf_exp, "blah"):
+ with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
self.evaluate(f)
def testExceptionHandling(self):
diff --git a/tensorflow/python/lib/core/py_util.cc b/tensorflow/python/lib/core/py_util.cc
index 2635694e23..00cbf0c532 100644
--- a/tensorflow/python/lib/core/py_util.cc
+++ b/tensorflow/python/lib/core/py_util.cc
@@ -41,6 +41,55 @@ const char* ClassName(PyObject* py) {
} // end namespace
+// Returns a PyObject containing a string, or null
+void TryAppendTraceback(PyObject* ptype, PyObject* pvalue, PyObject* ptraceback,
+ string* out) {
+ // The "traceback" module is assumed to be imported already by script_ops.py.
+ PyObject* tb_module = PyImport_AddModule("traceback");
+
+ if (!tb_module) {
+ return;
+ }
+
+ PyObject* format_exception =
+ PyObject_GetAttrString(tb_module, "format_exception");
+
+ if (!format_exception) {
+ return;
+ }
+
+ if (!PyCallable_Check(format_exception)) {
+ Py_DECREF(format_exception);
+ return;
+ }
+
+ PyObject* ret_val = PyObject_CallFunctionObjArgs(format_exception, ptype,
+ pvalue, ptraceback, nullptr);
+ Py_DECREF(format_exception);
+
+ if (!ret_val) {
+ return;
+ }
+
+ if (!PyList_Check(ret_val)) {
+ Py_DECREF(ret_val);
+ return;
+ }
+
+ Py_ssize_t n = PyList_GET_SIZE(ret_val);
+ for (Py_ssize_t i = 0; i < n; ++i) {
+ PyObject* v = PyList_GET_ITEM(ret_val, i);
+#if PY_MAJOR_VERSION < 3
+ strings::StrAppend(out, PyString_AS_STRING(v), "\n");
+#else
+ strings::StrAppend(out, PyUnicode_AsUTF8(v), "\n");
+#endif
+ }
+
+ // Iterate through ret_val.
+ Py_DECREF(ret_val);
+}
+
string PyExceptionFetch() {
CHECK(PyErr_Occurred())
<< "Must only call PyExceptionFetch after an exception.";
@@ -52,14 +101,20 @@ string PyExceptionFetch() {
string err = ClassName(ptype);
if (pvalue) {
PyObject* str = PyObject_Str(pvalue);
+
if (str) {
#if PY_MAJOR_VERSION < 3
- strings::StrAppend(&err, ": ", PyString_AS_STRING(str));
+ strings::StrAppend(&err, ": ", PyString_AS_STRING(str), "\n");
#else
- strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str));
+ strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str), "\n");
#endif
Py_DECREF(str);
+ } else {
+ strings::StrAppend(&err, "(unknown error message)\n");
}
+
+ TryAppendTraceback(ptype, pvalue, ptraceback, &err);
+
Py_DECREF(pvalue);
}
Py_DECREF(ptype);
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index 529eebe769..fb59bbba5e 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -25,6 +25,9 @@ from __future__ import print_function
import threading
+# Used by py_util.cc to get tracebacks.
+import traceback # pylint: disable=unused-import
+
import numpy as np
import six
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index fb0862c016..123d67fd9b 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -969,6 +969,10 @@ tf_module {
argspec: "args=[\'equation\'], varargs=inputs, keywords=kwargs, defaults=None"
}
member_method {
+ name: "enable_eager_execution"
+ argspec: "args=[\'config\', \'device_policy\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
name: "encode_base64"
argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
@@ -985,6 +989,10 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "executing_eagerly"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "exp"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD
index 4fe4fc3b13..b7d7fac315 100644
--- a/tensorflow/tools/graph_transforms/BUILD
+++ b/tensorflow/tools/graph_transforms/BUILD
@@ -91,7 +91,6 @@ cc_library(
srcs = [
"add_default_attributes.cc",
"backports.cc",
- "fake_quantize_training.cc",
"flatten_atrous.cc",
"fold_batch_norms.cc",
"fold_constants_lib.cc",
@@ -105,7 +104,6 @@ cc_library(
"remove_attribute.cc",
"remove_control_dependencies.cc",
"remove_device.cc",
- "remove_ema.cc",
"remove_nodes.cc",
"rename_attribute.cc",
"rename_op.cc",
@@ -148,7 +146,6 @@ tf_cc_test(
srcs = [
"add_default_attributes_test.cc",
"backports_test.cc",
- "fake_quantize_training_test.cc",
"flatten_atrous_test.cc",
"fold_batch_norms_test.cc",
"fold_constants_test.cc",
@@ -161,7 +158,6 @@ tf_cc_test(
"quantize_weights_test.cc",
"remove_attribute_test.cc",
"remove_device_test.cc",
- "remove_ema_test.cc",
"remove_nodes_test.cc",
"rename_attribute_test.cc",
"rename_op_test.cc",
diff --git a/tensorflow/tools/graph_transforms/fake_quantize_training.cc b/tensorflow/tools/graph_transforms/fake_quantize_training.cc
deleted file mode 100644
index 61aecc6e16..0000000000
--- a/tensorflow/tools/graph_transforms/fake_quantize_training.cc
+++ /dev/null
@@ -1,51 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#define EIGEN_USE_THREADS
-
-#include "tensorflow/core/graph/quantize_training.h"
-#include "tensorflow/tools/graph_transforms/transform_utils.h"
-
-namespace tensorflow {
-namespace graph_transforms {
-
-// EXPERIMENTAL: This can change without warning.
-// Rewrites the GraphDef for quantized training.
-// Rewrites the forward pass to include the precision loss with quantization so
-// the model can learn to deal with such loss and achieve better accuracy when
-// it is quantized later for inference.
-// Quantization range information is collected in FakeQuantizeWithMinMaxVars
-// ops.
-//
-// TODO(suharshs): Provide instructions on converting the resulting graph for
-// inference.
-// TODO(suharshs): Implement this using the GTT rather than calling the old
-// prototype function.
-Status FakeQuantizeTraining(const GraphDef& input_graph_def,
- const TransformFuncContext& context,
- GraphDef* output_graph_def) {
- // TODO(suharshs): Make num_bits a parameter.
- const int32 num_bits = 8;
- // TODO(suharshs): Make quantization op a parameter?
- const string quant_op_type = "FakeQuantWithMinMaxVars";
-
- return DoQuantizeTrainingOnGraphDef(input_graph_def, num_bits, quant_op_type,
- output_graph_def);
-}
-
-REGISTER_GRAPH_TRANSFORM("fake_quantize_training", FakeQuantizeTraining);
-
-} // namespace graph_transforms
-} // namespace tensorflow
diff --git a/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc b/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc
deleted file mode 100644
index 5e4ab209e9..0000000000
--- a/tensorflow/tools/graph_transforms/fake_quantize_training_test.cc
+++ /dev/null
@@ -1,63 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/cc/ops/const_op.h"
-#include "tensorflow/cc/ops/math_ops.h"
-#include "tensorflow/core/framework/tensor_testutil.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/tools/graph_transforms/transform_utils.h"
-
-namespace tensorflow {
-namespace graph_transforms {
-
-// Declare here, so we don't need a public header.
-Status FakeQuantizeTraining(const GraphDef& input_graph_def,
- const TransformFuncContext& context,
- GraphDef* output_graph_def);
-
-class FakeQuantizeTrainingTest : public ::testing::Test {};
-
-// For now, since the fake_quantize_training transform just calls the
-// quantize_training rewrite from tensorflow/core/graph/quantize_training.h,
-// we just test that the graph has been changed by the transform.
-// TODO(suharshs): Once we implement the fake_quantize_training transform
-// using the GTT, write proper tests of the transform here.
-TEST_F(FakeQuantizeTrainingTest, TransformOccurred) {
- auto root = tensorflow::Scope::DisabledShapeInferenceScope();
- using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
-
- Tensor a_data(DT_FLOAT, TensorShape());
- test::FillIota<float>(&a_data, 1.0f);
- Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
-
- Tensor b_data(DT_FLOAT, TensorShape());
- test::FillIota<float>(&b_data, 1.0f);
- Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
-
- Output matmul = MatMul(root.WithOpName("matmul"), a_const, b_const);
- GraphDef graph_def;
- TF_ASSERT_OK(root.ToGraphDef(&graph_def));
-
- GraphDef result;
- TransformFuncContext context;
- TF_ASSERT_OK(FakeQuantizeTraining(graph_def, context, &result));
-
- // Test that the transformation resulted in a graph with more nodes.
- EXPECT_GT(result.node_size(), graph_def.node_size());
-}
-
-} // namespace graph_transforms
-} // namespace tensorflow
diff --git a/tensorflow/tools/graph_transforms/remove_ema.cc b/tensorflow/tools/graph_transforms/remove_ema.cc
deleted file mode 100644
index 22e2626702..0000000000
--- a/tensorflow/tools/graph_transforms/remove_ema.cc
+++ /dev/null
@@ -1,146 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#define EIGEN_USE_THREADS
-
-#include "tensorflow/tools/graph_transforms/transform_utils.h"
-
-namespace tensorflow {
-namespace graph_transforms {
-
-// EXPERIMENTAL: This can change without warning.
-// Given a graph that has gone through the FakeQuantizeTraining transform and
-// has been frozen afterwards, RemoveEMA simplifies the FakeQuantize estimated
-// moving average subgraphs to make it compatible with the QuantizeNodes
-// transform.
-Status RemoveEMA(const GraphDef& input_graph_def,
- const TransformFuncContext& context,
- GraphDef* output_graph_def) {
- TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
- input_graph_def, // clang-format off
- {"FakeQuantWithMinMaxVars",
- {
- {"*"},
- {"Assign",
- {
- {"Const"},
- {"Merge",
- {
- {"Switch",
- {
- {"Min",
- {
- {"*"},
- {"Range",
- {
- {"*"},
- {"*"},
- {"*"},
- }
- }
- }
- },
- {"IsVariableInitialized"}
- }
- },
- {"Sub",
- {
- {"Const"},
- {"Mul",
- {
- {"Sub"},
- {"Sub",
- {
- {"Const"},
- {"Const"}
- }
- }
- }
- }
- }
- }
- }
- }
- }
- },
- {"Assign",
- {
- {"Const"},
- {"Merge",
- {
- {"Switch",
- {
- {"Max"},
- {"IsVariableInitialized"}
- }
- },
- {"Sub",
- {
- {"Const"},
- {"Mul",
- {
- {"Sub"},
- {"Sub",
- {
- {"Const"},
- {"Const"}
- }
- }
- }
- }
- }
- }
- }
- }
- }
- },
- }
- }, // clang-format on
- [](const NodeMatch& match, const std::set<string>& input_nodes,
- const std::set<string>& output_nodes,
- std::vector<NodeDef>* new_nodes) {
- const NodeDef& fake_quant_node = match.node;
- const NodeDef& input_node = match.inputs[0].node;
- const NodeDef& min_var_node = match.inputs[1].inputs[0].node;
- const NodeDef& max_var_node = match.inputs[2].inputs[0].node;
-
- // Make a new FakeQuantizeWithMinMaxVars operation that uses constants
- // for its min/max arguments rather than an entire EMA subgraph.
- NodeDef new_fake_quant_node;
- new_fake_quant_node.set_op(fake_quant_node.op());
- new_fake_quant_node.set_name(fake_quant_node.name());
- AddNodeInput(input_node.name(), &new_fake_quant_node);
- AddNodeInput(min_var_node.name(), &new_fake_quant_node);
- AddNodeInput(max_var_node.name(), &new_fake_quant_node);
- CopyNodeAttr(fake_quant_node, "narrow_range", "narrow_range",
- &new_fake_quant_node);
- CopyNodeAttr(fake_quant_node, "num_bits", "num_bits",
- &new_fake_quant_node);
-
- new_nodes->push_back(new_fake_quant_node);
- new_nodes->push_back(input_node);
- new_nodes->push_back(min_var_node);
- new_nodes->push_back(max_var_node);
-
- return Status::OK();
- },
- {}, output_graph_def));
- return Status::OK();
-}
-
-REGISTER_GRAPH_TRANSFORM("remove_ema", RemoveEMA);
-
-} // namespace graph_transforms
-} // namespace tensorflow
diff --git a/tensorflow/tools/graph_transforms/remove_ema_test.cc b/tensorflow/tools/graph_transforms/remove_ema_test.cc
deleted file mode 100644
index 27db90e272..0000000000
--- a/tensorflow/tools/graph_transforms/remove_ema_test.cc
+++ /dev/null
@@ -1,121 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/cc/ops/const_op.h"
-#include "tensorflow/cc/ops/math_ops.h"
-#include "tensorflow/core/framework/tensor_testutil.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/public/session.h"
-#include "tensorflow/tools/graph_transforms/transform_utils.h"
-
-namespace tensorflow {
-namespace graph_transforms {
-
-// Declare transformations here, so we don't need a public header.
-Status FakeQuantizeTraining(const GraphDef& input_graph_def,
- const TransformFuncContext& context,
- GraphDef* output_graph_def);
-
-Status RemoveEMA(const GraphDef& input_graph_def,
- const TransformFuncContext& context,
- GraphDef* output_graph_def);
-
-Status QuantizeNodes(const GraphDef& input_graph_def,
- const TransformFuncContext& context,
- GraphDef* output_graph_def);
-
-class RemoveEMATest : public ::testing::Test {};
-
-TEST_F(RemoveEMATest, FakeQuant_RemoveEMA_QuantizeTraining) {
- // Build a small graph.
- auto root = tensorflow::Scope::NewRootScope();
- using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
-
- Tensor a_data(DT_FLOAT, TensorShape({1, 1}));
- test::FillIota<float>(&a_data, 1.0f);
- Output a_const = Const(root.WithOpName("a"), Input::Initializer(a_data));
-
- Tensor b_data(DT_FLOAT, TensorShape({1, 1}));
- test::FillIota<float>(&b_data, 1.0f);
- Output b_const = Const(root.WithOpName("b"), Input::Initializer(b_data));
-
- Output matmul = MatMul(root.WithOpName("matmul"), a_const, b_const);
- GraphDef graph_def;
- TF_ASSERT_OK(root.ToGraphDef(&graph_def));
-
- // (1) FakeQuantize the graph.
- GraphDef fake_quantized_graph_def;
- TransformFuncContext context;
- TF_ASSERT_OK(
- FakeQuantizeTraining(graph_def, context, &fake_quantized_graph_def));
-
- // Test that the transformation resulted in a graph with more nodes.
- EXPECT_GT(fake_quantized_graph_def.node_size(), graph_def.node_size());
-
- // (2) Run the graph to initialize the newly added variables.
- std::unique_ptr<Session> session(NewSession(SessionOptions()));
- TF_ASSERT_OK(session->Create(fake_quantized_graph_def));
- std::vector<Tensor> outputs;
- TF_ASSERT_OK(session->Run({}, {"matmul"}, {}, &outputs));
-
- // (3) Freeze the graph. Create a "frozen graph" that matches what we would
- // expect if we actually froze the above graph.
- // TODO(suharshs): Use a c++ freeze graph alternative, when one is available.
- GraphDef frozen_graph_def;
- for (const NodeDef& node : fake_quantized_graph_def.node()) {
- if (node.op() == "Variable" || node.op() == "VariableV2") {
- NodeDef const_node;
- const_node.set_op("Const");
- const_node.set_name(node.name());
- SetNodeAttr("dtype", DT_FLOAT, &const_node);
- Tensor tensor(DT_FLOAT, {});
- tensor.flat<float>()(0) = 1.0f;
- SetNodeTensorAttr<float>("value", tensor, &const_node);
- *(frozen_graph_def.mutable_node()->Add()) = const_node;
- } else {
- *(frozen_graph_def.mutable_node()->Add()) = node;
- }
- }
-
- // Test that freezing the graph resulted in a graph with the same number of
- // nodes.
- EXPECT_EQ(frozen_graph_def.node_size(), fake_quantized_graph_def.node_size());
-
- // (4) RemoveEMA on the graph to make it compatible with QuantizeNodes.
- GraphDef removed_ema_graph_def;
- TF_ASSERT_OK(RemoveEMA(frozen_graph_def, context, &removed_ema_graph_def));
-
- // Test that the transformation resulted in a graph with less nodes.
- EXPECT_LT(removed_ema_graph_def.node_size(), frozen_graph_def.node_size());
-
- // (5) QuantizeNodes and inspect the final graph.
- // TODO(suharshs): Add a more thorough inspection of the structure of
- // the output graph.
- GraphDef quantized_graph_def;
- TF_ASSERT_OK(
- QuantizeNodes(removed_ema_graph_def, context, &quantized_graph_def));
-
- // Test that the transformation resulted in a graph with more nodes.
- EXPECT_GT(quantized_graph_def.node_size(), removed_ema_graph_def.node_size());
-
- // Make sure that the FakeQuantizeWithMinMaxVars op has been removed.
- for (const NodeDef& node : quantized_graph_def.node()) {
- EXPECT_NE(node.op(), "FakeQuantWithMinMaxVars");
- }
-}
-
-} // namespace graph_transforms
-} // namespace tensorflow