aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/dnn.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/dnn.h')
-rw-r--r--tensorflow/stream_executor/dnn.h15
1 files changed, 12 insertions, 3 deletions
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index a7449c2df4..9abfa1db6a 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -713,15 +713,23 @@ class PoolingDescriptor {
class AlgorithmDesc {
public:
typedef int64 Index;
- AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true) {}
+ AlgorithmDesc()
+ : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true), scratch_size_(0) {}
AlgorithmDesc(Index a, bool use_tensor_ops)
- : algo_(a), tensor_ops_enabled_(use_tensor_ops) {}
+ : algo_(a), tensor_ops_enabled_(use_tensor_ops), scratch_size_(0) {}
+ AlgorithmDesc(Index a, bool use_tensor_ops, size_t scratch_size)
+ : algo_(a),
+ tensor_ops_enabled_(use_tensor_ops),
+ scratch_size_(scratch_size) {}
bool is_default() const { return algo_ == kDefaultAlgorithm; }
bool tensor_ops_enabled() const { return tensor_ops_enabled_; }
Index algo_id() const { return algo_; }
+ size_t scratch_size() const { return scratch_size_; }
+ void set_scratch_size(size_t val) { scratch_size_ = val; }
bool operator==(const AlgorithmDesc& other) const {
return this->algo_ == other.algo_ &&
- this->tensor_ops_enabled_ == other.tensor_ops_enabled_;
+ this->tensor_ops_enabled_ == other.tensor_ops_enabled_ &&
+ this->scratch_size_ == other.scratch_size_;
}
uint64 hash() const;
@@ -729,6 +737,7 @@ class AlgorithmDesc {
enum { kDefaultAlgorithm = -1 };
Index algo_;
bool tensor_ops_enabled_;
+ size_t scratch_size_;
};
// Describes the result from a perf experiment.