aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD13
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc36
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc125
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h55
-rw-r--r--tensorflow/contrib/timeseries/examples/lstm.py17
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/ar_model.py44
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/math_utils.py3
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/model.py63
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/level_trend.py4
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py56
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py1
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma.py3
-rw-r--r--tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc17
-rw-r--r--tensorflow/python/estimator/training.py8
-rw-r--r--tensorflow/python/estimator/training_test.py18
16 files changed, 346 insertions, 123 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index a2969d23d6..fa6e5b2313 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -543,6 +543,7 @@ cc_library(
],
deps = [
":ir_emission_utils",
+ ":parallel_task_assignment",
":shape_partition",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
@@ -653,6 +654,18 @@ tf_cc_test(
)
cc_library(
+ name = "parallel_task_assignment",
+ srcs = ["parallel_task_assignment.cc"],
+ hdrs = ["parallel_task_assignment.h"],
+ deps = [
+ ":ir_emission_utils",
+ ":shape_partition",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_cost_analysis",
+ ],
+)
+
+cc_library(
name = "cpu_options",
srcs = ["cpu_options.cc"],
hdrs = ["cpu_options.h"],
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc
index 8c827efefc..2cd0aa7880 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -109,10 +110,11 @@ StatusOr<bool> ParallelizationPreparation::RunParallelTaskAssignment(
HloModule* module) {
VLOG(1) << "RunParallelTaskAssignment max_parallelism_: " << max_parallelism_;
bool changed = false;
- // Run cost analysis on entry computation.
- HloCostAnalysis cost_analysis(shape_size_);
+ // Initialize ParallelTaskAssignment.
+ ParallelTaskAssignment parallel_task_assignment(max_parallelism_, shape_size_,
+ module);
+ // Assign parallel tasks to HLOs in entry computation.
HloComputation* computation = module->entry_computation();
- Status cost_status = computation->root_instruction()->Accept(&cost_analysis);
for (auto* instruction : computation->instructions()) {
// Currently, we do not assign parallel tasks to instructions with at least
// one of the following properties:
@@ -135,8 +137,8 @@ StatusOr<bool> ParallelizationPreparation::RunParallelTaskAssignment(
}
// Calculate target parallel task count in [1, max_parallelism_].
- const int64 target_parallel_task_count = GetTargetParallelTaskCount(
- cost_status.ok() ? &cost_analysis : nullptr, instruction);
+ const int64 target_parallel_task_count =
+ parallel_task_assignment.GetTargetParallelTaskCount(instruction);
if (target_parallel_task_count == 1) {
continue;
}
@@ -159,30 +161,6 @@ StatusOr<bool> ParallelizationPreparation::RunParallelTaskAssignment(
return changed;
}
-int64 ParallelizationPreparation::GetTargetParallelTaskCount(
- const HloCostAnalysis* cost_analysis, HloInstruction* instruction) {
- // Default to a simple cost model based on hlo size and typical L2 cache size.
- // Note that 'cost_analysis' can be 'nullptr' if HloCostAnalysis returns an
- // error status (likely because HLOs like CustomCall are not yet implemented
- // in the HloCostAnalysis).
- int64 instruction_cost = shape_size_(instruction->shape());
- int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size.
- if (cost_analysis != nullptr) {
- // Calculate the instruction cost in cycles.
- // TODO(29630486) Improve on this linear cost model.
- // Consider making 'min_cost_per_thread' be a function of the target
- // bandwidth limit for instructions with low arithmetic complexity.
- instruction_cost = 1 * cost_analysis->flop_count(*instruction) +
- 2 * cost_analysis->transcendental_count(*instruction) +
- 10 * cost_analysis->bytes_accessed(*instruction);
- // Minimum per-thread cost is 100us of work on a 2GHz core.
- min_cost_per_thread = 100000;
- }
- // Return target parallel task count in [1, max_parallelism_].
- return std::min(max_parallelism_,
- std::max(1LL, instruction_cost / min_cost_per_thread));
-}
-
bool ParallelizationPreparation::OutlineParallelizableInstruction(
HloInstruction* instruction) {
if (instruction->outer_dimension_partitions().empty()) {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h
index d53fc46150..87be758ef5 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h
@@ -55,12 +55,6 @@ class ParallelizationPreparation : public HloPassInterface {
// Returns true on success or error status otherwise.
StatusOr<bool> RunParallelTaskAssignment(HloModule* module);
- // Returns the target parallel task count for 'instruction'.
- // Utilizes 'cost_analysis' if non-null.
- // Otherwise defaults to a simple HLO output size-based cost model.
- int64 GetTargetParallelTaskCount(const HloCostAnalysis* cost_analysis,
- HloInstruction* instruction);
-
// Outlines 'instruction' from entry computation, if it had
// been assigned parallel tasks in an earlier pass through the computation.
// Returns true if 'instruction' was successfully outlined, false otherwise.
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
new file mode 100644
index 0000000000..d4b5e41f50
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
@@ -0,0 +1,125 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
+
+#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+
+namespace xla {
+namespace cpu {
+
+class SimpleCostModel : public ParallelCostModel {
+ public:
+ SimpleCostModel(const int64 max_parallelism,
+ const HloCostAnalysis::ShapeSizeFunction& shape_size)
+ : max_parallelism_(max_parallelism), shape_size_(shape_size) {}
+ ~SimpleCostModel() override {}
+
+ int64 GetParallelTaskCount(HloInstruction* instruction) override {
+ // Simple cost model based on hlo size and typical L2 cache size.
+ const int64 instruction_cost = shape_size_(instruction->shape());
+ const int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size.
+ // Return target parallel task count in [1, max_parallelism_].
+ return std::min(max_parallelism_,
+ std::max(1LL, instruction_cost / min_cost_per_thread));
+ }
+
+ private:
+ const int64 max_parallelism_;
+ const HloCostAnalysis::ShapeSizeFunction shape_size_;
+};
+
+class DefaultCostModel : public ParallelCostModel {
+ public:
+ DefaultCostModel(const int64 max_parallelism,
+ std::unique_ptr<HloCostAnalysis> cost_analysis)
+ : max_parallelism_(max_parallelism),
+ cost_analysis_(std::move(cost_analysis)) {}
+ ~DefaultCostModel() override {}
+
+ int64 GetParallelTaskCount(HloInstruction* instruction) override {
+ // Calculate the instruction cost in cycles.
+ // TODO(29630486) Improve on this linear cost model.
+ // Consider making 'min_cost_per_thread' be a function of the target
+ // bandwidth limit for instructions with low arithmetic complexity.
+ const int64 instruction_cost =
+ 1 * cost_analysis_->flop_count(*instruction) +
+ 2 * cost_analysis_->transcendental_count(*instruction) +
+ 10 * cost_analysis_->bytes_accessed(*instruction);
+ // Minimum per-thread cost is 100us of work on a 2GHz core.
+ const int64 min_cost_per_thread = 100000;
+ // Return target parallel task count in [1, max_parallelism_].
+ return std::min(max_parallelism_,
+ std::max(1LL, instruction_cost / min_cost_per_thread));
+ }
+
+ private:
+ const int64 max_parallelism_;
+ const std::unique_ptr<HloCostAnalysis> cost_analysis_;
+};
+
+
+ParallelTaskAssignment::ParallelTaskAssignment(
+ const int64 max_parallelism,
+ const HloCostAnalysis::ShapeSizeFunction& shape_size,
+ HloModule* module) {
+ VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism;
+ // Run cost analysis on 'module'.
+ auto cost_analysis = MakeUnique<HloCostAnalysis>(shape_size);
+ HloComputation* computation = module->entry_computation();
+ Status status = computation->root_instruction()->Accept(cost_analysis.get());
+ if (status.ok()) {
+ // Set default cost model based on 'cost_analysis'.
+ cost_model_.reset(new DefaultCostModel(max_parallelism,
+ std::move(cost_analysis)));
+ } else {
+ // Fall back to a simple cost model based on hlo size and L2 cache size.
+ // Note that HloCostAnalysis can returns an error status (likely because
+ // HLOs like CustomCall are not yet implemented in the HloCostAnalysis).
+ cost_model_.reset(new SimpleCostModel(max_parallelism, shape_size));
+ }
+}
+
+int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
+ HloInstruction* instruction) {
+ // Currently, we do not assign parallel tasks to instructions with at least
+ // one of the following properties:
+ // *) Internal threading (library calls to kConv, kDot, and kCustomCall).
+ // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot).
+ // *) Tuple-shaped.
+ // TODO(b/27458679) Parallelize instructions which are skipped here.
+ if (instruction->opcode() == HloOpcode::kParameter ||
+ instruction->opcode() == HloOpcode::kConstant ||
+ instruction->opcode() == HloOpcode::kCall ||
+ instruction->opcode() == HloOpcode::kCustomCall ||
+ instruction->opcode() == HloOpcode::kSelectAndScatter ||
+ (instruction->opcode() == HloOpcode::kConvolution &&
+ PotentiallyImplementedAsEigenConvolution(*instruction)) ||
+ PotentiallyImplementedAsEigenDot(*instruction) ||
+ (instruction->opcode() == HloOpcode::kFusion &&
+ instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) ||
+ ShapeUtil::IsTuple(instruction->shape())) {
+ return 1;
+ }
+ // Consult 'cost_model_' to compute target parallel task count.
+ return cost_model_->GetParallelTaskCount(instruction);
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
new file mode 100644
index 0000000000..15f065a3ad
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
+#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
+
+#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+
+namespace xla {
+namespace cpu {
+
+// Simple interface for different parallel cost model implementations.
+class ParallelCostModel {
+ public:
+ virtual ~ParallelCostModel() = default;
+ virtual int64 GetParallelTaskCount(HloInstruction* instruction) = 0;
+};
+
+// ParallelTaskAssignment computes parallel task counts for HLOs in 'module'.
+class ParallelTaskAssignment {
+ public:
+ // 'max_parallelism': the maximum parallel task count per instruction.
+ // 'shape_size': shape size function used by HloCostAnalysis during parallel
+ // task assignment.
+ // 'module': the containing HloModule.
+ ParallelTaskAssignment(
+ const int64 max_parallelism,
+ const HloCostAnalysis::ShapeSizeFunction& shape_size,
+ HloModule* module);
+ ~ParallelTaskAssignment() {}
+
+ // Computes and returns the target parallel task count for 'instruction'.
+ int64 GetTargetParallelTaskCount(HloInstruction* instruction);
+
+ private:
+ std::unique_ptr<ParallelCostModel> cost_model_;
+};
+
+} // namespace cpu
+} // namespace xla
+
+#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
diff --git a/tensorflow/contrib/timeseries/examples/lstm.py b/tensorflow/contrib/timeseries/examples/lstm.py
index 6bab06f56c..3ba823f638 100644
--- a/tensorflow/contrib/timeseries/examples/lstm.py
+++ b/tensorflow/contrib/timeseries/examples/lstm.py
@@ -106,16 +106,6 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
for state_element
in self._lstm_cell.zero_state(batch_size=1, dtype=self.dtype)])
- def _transform(self, data):
- """Normalize data based on input statistics to encourage stable training."""
- mean, variance = self._input_statistics.overall_feature_moments
- return (data - mean) / variance
-
- def _de_transform(self, data):
- """Transform data back to the input scale."""
- mean, variance = self._input_statistics.overall_feature_moments
- return data * variance + mean
-
def _filtering_step(self, current_times, current_values, state, predictions):
"""Update model state based on observations.
@@ -140,7 +130,10 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
state_from_time, prediction, lstm_state = state
with tf.control_dependencies(
[tf.assert_equal(current_times, state_from_time)]):
- transformed_values = self._transform(current_values)
+ # Subtract the mean and divide by the variance of the series. Slightly
+ # more efficient if done for a whole window (using the normalize_features
+ # argument to SequentialTimeSeriesModel).
+ transformed_values = self._scale_data(current_values)
# Use mean squared error across features for the loss.
predictions["loss"] = tf.reduce_mean(
(prediction - transformed_values) ** 2, axis=-1)
@@ -156,7 +149,7 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
inputs=previous_observation_or_prediction, state=lstm_state)
next_prediction = self._predict_from_lstm_output(lstm_output)
new_state_tuple = (current_times, next_prediction, new_lstm_state)
- return new_state_tuple, {"mean": self._de_transform(next_prediction)}
+ return new_state_tuple, {"mean": self._scale_back_data(next_prediction)}
def _imputation_step(self, current_times, state):
"""Advance model state across a gap."""
diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
index 7452dc7dc3..267a5f88da 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
@@ -89,8 +89,6 @@ class ARModel(model.TimeSeriesModel):
self.hidden_layer_sizes = hidden_layer_sizes
self.window_size = self.input_window_size + self.output_window_size
self.loss = loss
- self.stats_means = None
- self.stats_sigmas = None
super(ARModel, self).__init__(
num_features=num_features)
assert num_time_buckets > 0
@@ -106,32 +104,6 @@ class ARModel(model.TimeSeriesModel):
assert len(self._periods) or self.input_window_size
assert output_window_size > 0
- def scale_data(self, data):
- """Scale data according to stats."""
- if self._input_statistics is not None:
- return (data - self.stats_means) / self.stats_sigmas
- else:
- return data
-
- def scale_back_data(self, data):
- if self._input_statistics is not None:
- return (data * self.stats_sigmas) + self.stats_means
- else:
- return data
-
- def scale_back_variance(self, var):
- if self._input_statistics is not None:
- return var * self.stats_sigmas * self.stats_sigmas
- else:
- return var
-
- def initialize_graph(self, input_statistics=None):
- super(ARModel, self).initialize_graph(input_statistics=input_statistics)
- if self._input_statistics:
- self.stats_means, variances = (
- self._input_statistics.overall_feature_moments)
- self.stats_sigmas = math_ops.sqrt(variances)
-
def get_start_state(self):
# State which matches the format we'll return later. Typically this will not
# be used by the model directly, but the shapes and dtypes should match so
@@ -388,8 +360,8 @@ class ARModel(model.TimeSeriesModel):
predicted_covariance = array_ops.ones_like(predicted_mean)
# Transform and scale the mean and covariance appropriately.
- predicted_mean = self.scale_back_data(predicted_mean)
- predicted_covariance = self.scale_back_variance(predicted_covariance)
+ predicted_mean = self._scale_back_data(predicted_mean)
+ predicted_covariance = self._scale_back_variance(predicted_covariance)
return {"mean": predicted_mean,
"covariance": predicted_covariance}
@@ -418,7 +390,7 @@ class ARModel(model.TimeSeriesModel):
times_feature=TrainEvalFeatures.TIMES,
window_size=self.window_size,
times_shape=times.get_shape()))
- values = self.scale_data(values)
+ values = self._scale_data(values)
if self.input_window_size > 0:
input_values = values[:, :self.input_window_size, :]
else:
@@ -435,14 +407,14 @@ class ARModel(model.TimeSeriesModel):
# (observed - predicted) ** 2.
# Note that this affects only evaluation; the training loss is unaffected.
loss = self.loss_op(
- self.scale_back_data(targets),
- {"mean": self.scale_back_data(prediction_ops["mean"])})
+ self._scale_back_data(targets),
+ {"mean": self._scale_back_data(prediction_ops["mean"])})
else:
loss = self.loss_op(targets, prediction_ops)
# Scale back the prediction.
- prediction = self.scale_back_data(prediction)
- covariance = self.scale_back_variance(covariance)
+ prediction = self._scale_back_data(prediction)
+ covariance = self._scale_back_variance(covariance)
return model.ModelOutputs(
loss=loss,
@@ -565,7 +537,7 @@ class ARModel(model.TimeSeriesModel):
new_state_times.set_shape((None, self.input_window_size))
new_state_values = array_ops.concat(
[previous_state_values,
- self.scale_data(values)], axis=1)[:, -self.input_window_size:, :]
+ self._scale_data(values)], axis=1)[:, -self.input_window_size:, :]
new_state_values.set_shape((None, self.input_window_size,
self.num_features))
else:
diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
index c70da3e082..23452a81c3 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
@@ -936,8 +936,7 @@ class InputStatisticsFromMiniBatch(object):
start_time = variable_scope.get_variable(
name="start_time",
dtype=dtypes.int64,
- initializer=init_ops.zeros_initializer(),
- shape=[],
+ initializer=dtypes.int64.max,
trainable=False)
total_observation_count = variable_scope.get_variable(
name="total_observation_count",
diff --git a/tensorflow/contrib/timeseries/python/timeseries/model.py b/tensorflow/contrib/timeseries/python/timeseries/model.py
index f2ef8d2211..b32b5c5494 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/model.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/model.py
@@ -80,6 +80,8 @@ class TimeSeriesModel(object):
self.dtype = dtype
self._input_statistics = None
self._graph_initialized = False
+ self._stats_means = None
+ self._stats_sigmas = None
# TODO(allenl): Move more of the generic machinery for generating and
# predicting into TimeSeriesModel, and possibly share it between generate()
@@ -120,6 +122,38 @@ class TimeSeriesModel(object):
"""
self._graph_initialized = True
self._input_statistics = input_statistics
+ if self._input_statistics:
+ self._stats_means, variances = (
+ self._input_statistics.overall_feature_moments)
+ self._stats_sigmas = math_ops.sqrt(variances)
+
+ def _scale_data(self, data):
+ """Scale data according to stats (input scale -> model scale)."""
+ if self._input_statistics is not None:
+ return (data - self._stats_means) / self._stats_sigmas
+ else:
+ return data
+
+ def _scale_variance(self, variance):
+ """Scale variances according to stats (input scale -> model scale)."""
+ if self._input_statistics is not None:
+ return variance / self._input_statistics.overall_feature_moments.variance
+ else:
+ return variance
+
+ def _scale_back_data(self, data):
+ """Scale back data according to stats (model scale -> input scale)."""
+ if self._input_statistics is not None:
+ return (data * self._stats_sigmas) + self._stats_means
+ else:
+ return data
+
+ def _scale_back_variance(self, variance):
+ """Scale back variances according to stats (model scale -> input scale)."""
+ if self._input_statistics is not None:
+ return variance * self._input_statistics.overall_feature_moments.variance
+ else:
+ return variance
def _check_graph_initialized(self):
if not self._graph_initialized:
@@ -304,6 +338,7 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
train_output_names,
predict_output_names,
num_features,
+ normalize_features=False,
dtype=dtypes.float32,
exogenous_feature_columns=None,
exogenous_update_condition=None,
@@ -316,6 +351,12 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
predict_output_names: A list of products/predictions returned from
_prediction_step.
num_features: Number of features for the time series
+ normalize_features: Boolean. If True, `values` are passed normalized to
+ the model (via self._scale_data). Scaling is done for the whole window
+ as a batch, which is slightly more efficient than scaling inside the
+ window loop. The model must then define _scale_back_predictions, which
+ may use _scale_back_data or _scale_back_variance to return predictions
+ to the input scale.
dtype: The floating point datatype to use.
exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn
objects. See `TimeSeriesModel`.
@@ -344,9 +385,25 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
self._exogenous_update_condition = exogenous_update_condition
self._train_output_names = train_output_names
self._predict_output_names = predict_output_names
+ self._normalize_features = normalize_features
self._static_unrolling_window_size_threshold = (
static_unrolling_window_size_threshold)
+ def _scale_back_predictions(self, predictions):
+ """Return a window of predictions to input scale.
+
+ Args:
+ predictions: A dictionary mapping from prediction names to Tensors.
+ Returns:
+ A dictionary with values corrected for input normalization (e.g. with
+ self._scale_back_mean and possibly self._scale_back_variance). May be a
+ mutated version of the argument.
+ """
+ raise NotImplementedError(
+ "SequentialTimeSeriesModel normalized input data"
+ " (normalize_features=True), but no method was provided to transform "
+ "the predictions back to the input scale.")
+
@abc.abstractmethod
def _filtering_step(self, current_times, current_values, state, predictions):
"""Compute a single-step loss for a batch of data.
@@ -524,6 +581,8 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
self._check_graph_initialized()
times = math_ops.cast(features[TrainEvalFeatures.TIMES], dtype=dtypes.int64)
values = math_ops.cast(features[TrainEvalFeatures.VALUES], dtype=self.dtype)
+ if self._normalize_features:
+ values = self._scale_data(values)
exogenous_regressors = self._process_exogenous_features(
times=times,
features={key: value for key, value in features.items()
@@ -556,6 +615,8 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
# Since we have window-level additions to the loss, its per-step value is
# misleading, so we avoid returning it.
del outputs["loss"]
+ if self._normalize_features:
+ outputs = self._scale_back_predictions(outputs)
return per_observation_loss, state, outputs
def predict(self, features):
@@ -583,6 +644,8 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
times=predict_times, state=start_state,
state_update_fn=_call_prediction_step,
outputs=self._predict_output_names)
+ if self._normalize_features:
+ predictions = self._scale_back_predictions(predictions)
return predictions
class _FakeTensorArray(object):
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/level_trend.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/level_trend.py
index b9d3f55c39..56167c4f01 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/level_trend.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/level_trend.py
@@ -57,7 +57,9 @@ class AdderStateSpaceModel(state_space_model.StateSpaceModel):
# TODO(allenl): Better support for multivariate series here.
initial_value = array_ops.stack([
math_ops.reduce_mean(
- self._input_statistics.series_start_moments.mean), 0.
+ self._scale_data(
+ self._input_statistics.series_start_moments.mean)),
+ 0.
])
return initial_value + variable_scope.get_variable(
name="prior_state_mean",
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py
index 6a9660b400..6257002647 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py
@@ -232,6 +232,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
+ filtering_postprocessor_names),
predict_output_names=["mean", "covariance"],
num_features=configuration.num_features,
+ normalize_features=True,
dtype=configuration.dtype,
exogenous_feature_columns=configuration.exogenous_feature_columns,
exogenous_update_condition=configuration.exogenous_update_condition,
@@ -309,15 +310,10 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
_, _, priors_from_time = state
times = ops.convert_to_tensor(times)
priors_from_time = ops.convert_to_tensor(priors_from_time)
- with ops.control_dependencies([
- control_flow_ops.Assert(
- math_ops.reduce_all(priors_from_time <= times[:, 0]),
- [priors_from_time, times[:, 0]],
- summarize=100)
- ]):
- times = array_ops.identity(times)
intra_batch_gaps = array_ops.reshape(times[:, 1:] - times[:, :-1], [-1])
- starting_gaps = times[:, 0] - priors_from_time
+ # Ignore negative starting gaps, since there will be transient start times
+ # as inputs statistics are computed.
+ starting_gaps = math_ops.maximum(times[:, 0] - priors_from_time, 0)
# Pre-define transition matrices raised to powers (and their sums) for every
# gap in this window. This avoids duplicate computation (for example many
# steps will use the transition matrix raised to the first power) and
@@ -369,20 +365,15 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
Imputed model state corresponding to the `state` argument.
"""
estimated_state, estimated_state_var, previous_times = state
- catchup_times = current_times - previous_times
- non_negative_assertion = control_flow_ops.Assert(
- math_ops.reduce_all(catchup_times >= 0), [
- "Negative imputation interval", catchup_times, current_times,
- previous_times
- ],
- summarize=100)
- with ops.control_dependencies([non_negative_assertion]):
- transition_matrices, transition_noise_sums = ( # pylint: disable=unbalanced-tuple-unpacking
- self._cached_transition_powers_and_sums(catchup_times))
- estimated_state = self._kalman_filter.predict_state_mean(
- estimated_state, transition_matrices)
- estimated_state_var = self._kalman_filter.predict_state_var(
- estimated_state_var, transition_matrices, transition_noise_sums)
+ # Ignore negative imputation intervals due to transient start time
+ # estimates.
+ catchup_times = math_ops.maximum(current_times - previous_times, 0)
+ transition_matrices, transition_noise_sums = ( # pylint: disable=unbalanced-tuple-unpacking
+ self._cached_transition_powers_and_sums(catchup_times))
+ estimated_state = self._kalman_filter.predict_state_mean(
+ estimated_state, transition_matrices)
+ estimated_state_var = self._kalman_filter.predict_state_var(
+ estimated_state_var, transition_matrices, transition_noise_sums)
return (estimated_state, estimated_state_var,
previous_times + catchup_times)
@@ -437,6 +428,13 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
outputs=predictions)
return (filtered_state, predictions)
+ def _scale_back_predictions(self, predictions):
+ """Return a window of predictions to input scale."""
+ predictions["mean"] = self._scale_back_data(predictions["mean"])
+ predictions["covariance"] = self._scale_back_variance(
+ predictions["covariance"])
+ return predictions
+
def _prediction_step(self, current_times, state):
"""Make a prediction based on `state`.
@@ -458,7 +456,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
"""
estimated_state, estimated_state_var, previous_times = state
advanced_to_current_assert = control_flow_ops.Assert(
- math_ops.reduce_all(math_ops.equal(current_times, previous_times)),
+ math_ops.reduce_all(math_ops.less_equal(current_times, previous_times)),
["Attempted to predict without imputation"])
with ops.control_dependencies([advanced_to_current_assert]):
observation_model = self.get_broadcasted_observation_model(current_times)
@@ -475,6 +473,9 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
(self.num_features,)))
predicted_obs_var.set_shape(current_times.get_shape().concatenate(
(self.num_features, self.num_features)))
+ # Not scaled back to input-scale, since this also feeds into the
+ # loss. Instead, predictions are scaled back before being returned to the
+ # user in _scale_back_predictions.
predictions = {
"mean": predicted_obs,
"covariance": predicted_obs_var}
@@ -722,7 +723,8 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
# Make sure initial latent value uncertainty is at least on the same
# scale as noise in the data.
covariance_multiplier = math_ops.reduce_max(
- self._input_statistics.series_start_moments.variance)
+ self._scale_variance(
+ self._input_statistics.series_start_moments.variance))
return base_covariance * gen_math_ops.maximum(
covariance_multiplier, 1.0)
else:
@@ -920,7 +922,8 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
self.get_noise_transform(), dtype=self.dtype)
state_noise_dimension = state_noise_transform.get_shape()[1].value
if self._input_statistics is not None:
- feature_variance = self._input_statistics.series_start_moments.variance
+ feature_variance = self._scale_variance(
+ self._input_statistics.series_start_moments.variance)
initial_transition_noise_scale = math_ops.log(
gen_math_ops.maximum(
math_ops.reduce_mean(feature_variance) / math_ops.cast(
@@ -945,7 +948,8 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
if self._input_statistics is not None:
# Get variance across the first few values in each batch for each
# feature, for an initial observation noise (over-)estimate.
- feature_variance = self._input_statistics.series_start_moments.variance
+ feature_variance = self._scale_variance(
+ self._input_statistics.series_start_moments.variance)
else:
feature_variance = None
if feature_variance is not None:
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
index 7c8f81ec51..ca57715e2b 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
@@ -605,6 +605,7 @@ class TimeDependentStateSpaceModel(state_space_model.StateSpaceModel):
super(TimeDependentStateSpaceModel, self).__init__(
configuration=state_space_model.StateSpaceModelConfiguration(
use_observation_noise=False,
+ transition_covariance_initial_log_scale_bias=5.,
static_unrolling_window_size_threshold=
static_unrolling_window_size_threshold))
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma.py
index 110ba9738f..1afc58cfb2 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma.py
@@ -182,7 +182,8 @@ class VARMA(state_space_model.StateSpaceModel):
# modeled as transition noise in VARMA, we set its initial value based on a
# slight over-estimate empirical observation noise.
if self._input_statistics is not None:
- feature_variance = self._input_statistics.series_start_moments.variance
+ feature_variance = self._scale_variance(
+ self._input_statistics.series_start_moments.variance)
initial_transition_noise_scale = math_ops.log(
math_ops.maximum(
math_ops.reduce_mean(feature_variance), minimum_initial_variance))
diff --git a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc
index b1ec35e268..6d25556770 100644
--- a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc
+++ b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc
@@ -39,8 +39,8 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
// x is from the feed.
const int batch_size = tensor_size < 0 ? 1 : tensor_size;
- Output x =
- RandomNormal(s.WithOpName("x"), {batch_size, 1}, DataType::DT_FLOAT);
+ Output x = RandomNormal(s.WithOpName("x").WithDevice("/CPU:0"),
+ {batch_size, 1}, DataType::DT_FLOAT);
// Create stages.
std::vector<Output> last_stage;
@@ -64,16 +64,19 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
}
if (insert_queue) {
- FIFOQueue queue(s.WithOpName("queue"), {DataType::DT_FLOAT});
- QueueEnqueue enqueue(s.WithOpName("enqueue"), queue, last_stage);
- QueueDequeue dequeue(s.WithOpName("dequeue"), queue, {DataType::DT_FLOAT});
- QueueClose cancel(s.WithOpName("cancel"), queue,
+ FIFOQueue queue(s.WithOpName("queue").WithDevice("/CPU:0"),
+ {DataType::DT_FLOAT});
+ QueueEnqueue enqueue(s.WithOpName("enqueue").WithDevice("/CPU:0"), queue,
+ last_stage);
+ QueueDequeue dequeue(s.WithOpName("dequeue").WithDevice("/CPU:0"), queue,
+ {DataType::DT_FLOAT});
+ QueueClose cancel(s.WithOpName("cancel").WithDevice("/CPU:0"), queue,
QueueClose::CancelPendingEnqueues(true));
last_stage = {dequeue[0]};
}
// Create output.
- AddN output(s.WithOpName("y"), last_stage);
+ AddN output(s.WithOpName("y").WithDevice("/CPU:0"), last_stage);
GraphDef def;
TF_CHECK_OK(s.ToGraphDef(&def));
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index 166b7b20ed..953e970eea 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -438,14 +438,18 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
'`estimator.config` must have task_type set. This usually means '
'TF_CONFIG environment is not set correctly.')
- # TODO(xiejw): error out if evaluator index is more than 0.
-
if config.task_type == 'local':
raise ValueError(
'`task.type` in TF_CONFIG cannot be `local`. Leaving `cluster` and '
'`task` properties in TF_CONFIG absent triggers train and evaluate '
'`Estimator` locally (non-distributed).')
+ if (config.task_type == run_config_lib.TaskType.EVALUATOR and
+ config.task_id > 0):
+ raise ValueError(
+ 'For distributed training, there can only be one `evaluator` task '
+ '(with task id 0). Given task id {}'.format(config.task_id))
+
# For task type foo, call executor.run_foo.
available_tasks = [x for x in dir(executor) if x.startswith('run_')
and x != 'run_local'
diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py
index c474004dab..e4c400ca7f 100644
--- a/tensorflow/python/estimator/training_test.py
+++ b/tensorflow/python/estimator/training_test.py
@@ -71,6 +71,8 @@ _INVALID_EMPTY_EVAL_RESULT_ERR = (
_INVALID_EVAL_RESULT_TYPE_ERR = '`Estimator.evaluate` should return dict.'
_MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR = (
'Internal error: `Estimator.evaluate` result should have `global_step`')
+_INVALID_EVAL_TASK_ID_ERR = (
+ 'there can only be one `evaluator` task .*with task id 0')
_TF_CONFIG_FOR_CHIEF = {
'cluster': {
@@ -128,7 +130,7 @@ _TF_CONFIG_FOR_EVALUATOR = {
},
'task': {
'type': run_config_lib.TaskType.EVALUATOR,
- 'index': 1
+ 'index': 0
}
}
@@ -351,6 +353,20 @@ class TrainAndEvaluteTest(test.TestCase):
_TF_CONFIG_FOR_EVALUATOR))
self.assertEqual(1, mock_executor.call_task['evaluator'])
+ def test_error_out_if_evaluator_task_id_is_non_zero(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.CHIEF: ['host0:0'],
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.EVALUATOR,
+ 'index': 1
+ }
+ }
+ with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_TASK_ID_ERR):
+ self._test_run_task_in_distributed_training(
+ run_config=_create_run_config_with_cluster_spec(tf_config))
+
def test_run_local(self):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.config = run_config_lib.RunConfig()