aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/dnn.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-01 00:18:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-01 00:20:43 -0700
commit961a39346d8be33cff473f1e81498b887c155070 (patch)
treed1175e89f82bd60137cf9fb2ecbee64d4ac5e59c /tensorflow/stream_executor/dnn.h
parent54b20c4be0372fb14ec9a289e4d7de7f67c03ff6 (diff)
Unify error handling in CudnnSupport.
PiperOrigin-RevId: 198836479
Diffstat (limited to 'tensorflow/stream_executor/dnn.h')
-rw-r--r--tensorflow/stream_executor/dnn.h5
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 3df5365c23..9eca5abe1a 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -469,6 +469,9 @@ enum class PadAlignment : int64 {
// Returns a string representation of the given padding alignment.
string PadAlignmentString(PadAlignment alignment);
+// Print alignment to str. Needed to use CHECK_EQ between two PadAlignments.
+std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment);
+
// Describes a convolution.
//
// Uses the named argument construction form:
@@ -710,7 +713,7 @@ class PoolingDescriptor {
class AlgorithmDesc {
public:
typedef int64 Index;
- AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(false) {}
+ AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true) {}
AlgorithmDesc(Index a, bool use_tensor_ops)
: algo_(a), tensor_ops_enabled_(use_tensor_ops) {}
bool is_default() const { return algo_ == kDefaultAlgorithm; }