diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-07 16:39:19 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-07 16:39:53 -0700 |
commit | 5fe19ce7e879b1d5eb99d8c3b36e0b185a6ac5e6 (patch) | |
tree | 6e919930b0ced69a726873826ac7ad2f404ff003 /tensorflow/stream_executor/dnn.h | |
parent | 0cae77919613b15ec5ba4db167966ba21e969fd8 (diff) | |
parent | 6cc83f55cd6fbc5af0fd6f1e8220bf9dd392306c (diff) |
Merge pull request #20708 from ROCmSoftwarePlatform:upstream-staging-stream-executor-algorithmconfig-profileresult
PiperOrigin-RevId: 207801599
Diffstat (limited to 'tensorflow/stream_executor/dnn.h')
-rw-r--r-- | tensorflow/stream_executor/dnn.h | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index a7449c2df4..9abfa1db6a 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -713,15 +713,23 @@ class PoolingDescriptor { class AlgorithmDesc { public: typedef int64 Index; - AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true) {} + AlgorithmDesc() + : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true), scratch_size_(0) {} AlgorithmDesc(Index a, bool use_tensor_ops) - : algo_(a), tensor_ops_enabled_(use_tensor_ops) {} + : algo_(a), tensor_ops_enabled_(use_tensor_ops), scratch_size_(0) {} + AlgorithmDesc(Index a, bool use_tensor_ops, size_t scratch_size) + : algo_(a), + tensor_ops_enabled_(use_tensor_ops), + scratch_size_(scratch_size) {} bool is_default() const { return algo_ == kDefaultAlgorithm; } bool tensor_ops_enabled() const { return tensor_ops_enabled_; } Index algo_id() const { return algo_; } + size_t scratch_size() const { return scratch_size_; } + void set_scratch_size(size_t val) { scratch_size_ = val; } bool operator==(const AlgorithmDesc& other) const { return this->algo_ == other.algo_ && - this->tensor_ops_enabled_ == other.tensor_ops_enabled_; + this->tensor_ops_enabled_ == other.tensor_ops_enabled_ && + this->scratch_size_ == other.scratch_size_; } uint64 hash() const; @@ -729,6 +737,7 @@ class AlgorithmDesc { enum { kDefaultAlgorithm = -1 }; Index algo_; bool tensor_ops_enabled_; + size_t scratch_size_; }; // Describes the result from a perf experiment. |