aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/dnn.h
diff options
context:
space:
mode:
authorGravatar Yangzihao Wang <yangzihao@google.com>2017-08-01 22:10:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-01 22:14:12 -0700
commitdb596594b5653b43fcb558a4753b39904bb62cbd (patch)
treed62d0611711a8f97fb942ae3986f8b7cf571d0f7 /tensorflow/stream_executor/dnn.h
parentb9ac2d7eb17022a677597a1f88a65d8d26278088 (diff)
Allows cudnn convolution algorithms to proceed when only find one type of algorithm (between with scratch or without scratch).
Remove valid_ and set_is_valid() in dnn::ProfileResult, instead use (algorithm_ != kDefaultAlgorithm) to see if the result is valid. PiperOrigin-RevId: 163931258
Diffstat (limited to 'tensorflow/stream_executor/dnn.h')
-rw-r--r--tensorflow/stream_executor/dnn.h8
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index f97deb7222..0a0ad7d9fb 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -673,20 +673,20 @@ constexpr AlgorithmType kDefaultAlgorithm = -1;
// Describes the result from a perf experiment.
//
// Arguments:
-// is_valid: indicates whether a valid measurement was obtained.
// algorithm: returns the exact algorithm that was used.
// elapsed_time_in_ms: returns the measured elapsed time in milliseconds.
class ProfileResult {
public:
- bool is_valid() const { return is_valid_; }
- void set_is_valid(bool val) { is_valid_ = val; }
+ bool is_valid() const {
+ return (algorithm_ != kDefaultAlgorithm &&
+ elapsed_time_in_ms_ != std::numeric_limits<float>::max());
+ }
AlgorithmType algorithm() const { return algorithm_; }
void set_algorithm(AlgorithmType val) { algorithm_ = val; }
float elapsed_time_in_ms() const { return elapsed_time_in_ms_; }
void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; }
private:
- bool is_valid_ = false;
AlgorithmType algorithm_ = kDefaultAlgorithm;
float elapsed_time_in_ms_ = std::numeric_limits<float>::max();
};