aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/dnn.h
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-07 16:39:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-07 16:39:53 -0700
commit5fe19ce7e879b1d5eb99d8c3b36e0b185a6ac5e6 (patch)
tree6e919930b0ced69a726873826ac7ad2f404ff003 /tensorflow/stream_executor/dnn.h
parent0cae77919613b15ec5ba4db167966ba21e969fd8 (diff)
parent6cc83f55cd6fbc5af0fd6f1e8220bf9dd392306c (diff)
Merge pull request #20708 from ROCmSoftwarePlatform:upstream-staging-stream-executor-algorithmconfig-profileresult
PiperOrigin-RevId: 207801599
Diffstat (limited to 'tensorflow/stream_executor/dnn.h')
-rw-r--r--tensorflow/stream_executor/dnn.h15
1 files changed, 12 insertions, 3 deletions
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index a7449c2df4..9abfa1db6a 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -713,15 +713,23 @@ class PoolingDescriptor {
class AlgorithmDesc {
public:
typedef int64 Index;
- AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true) {}
+ AlgorithmDesc()
+ : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true), scratch_size_(0) {}
AlgorithmDesc(Index a, bool use_tensor_ops)
- : algo_(a), tensor_ops_enabled_(use_tensor_ops) {}
+ : algo_(a), tensor_ops_enabled_(use_tensor_ops), scratch_size_(0) {}
+ AlgorithmDesc(Index a, bool use_tensor_ops, size_t scratch_size)
+ : algo_(a),
+ tensor_ops_enabled_(use_tensor_ops),
+ scratch_size_(scratch_size) {}
bool is_default() const { return algo_ == kDefaultAlgorithm; }
bool tensor_ops_enabled() const { return tensor_ops_enabled_; }
Index algo_id() const { return algo_; }
+ size_t scratch_size() const { return scratch_size_; }
+ void set_scratch_size(size_t val) { scratch_size_ = val; }
bool operator==(const AlgorithmDesc& other) const {
return this->algo_ == other.algo_ &&
- this->tensor_ops_enabled_ == other.tensor_ops_enabled_;
+ this->tensor_ops_enabled_ == other.tensor_ops_enabled_ &&
+ this->scratch_size_ == other.scratch_size_;
}
uint64 hash() const;
@@ -729,6 +737,7 @@ class AlgorithmDesc {
enum { kDefaultAlgorithm = -1 };
Index algo_;
bool tensor_ops_enabled_;
+ size_t scratch_size_;
};
// Describes the result from a perf experiment.