aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-10 11:46:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 11:51:44 -0700
commitfea9d07d1e34d5330a13024cb42d9bc460869905 (patch)
tree7030def73b6289540767b93e70fdaea4ecf3c4dc
parent2db78d20af06f256b86889c3f7d202ae88d6a896 (diff)
Remove references to std::string in MKL-related code.
tensorflow::string is sometimes ::string and sometimes std::string, which makes code that uses both subtly dangerous. For example, FactoryKeyCreator::AddAsKey() has an overload for tensorflow::string but had many callsites passing a std::string, causing incorrect behavior on the google platform. PiperOrigin-RevId: 208244169
-rw-r--r--tensorflow/core/kernels/mkl_fused_batch_norm_op.cc16
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.h16
-rw-r--r--tensorflow/core/kernels/mkl_tfconv_op.h7
3 files changed, 19 insertions, 20 deletions
diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
index 0149e78db5..aa572fb0a3 100644
--- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
@@ -899,8 +899,8 @@ class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
MklFusedBatchNormFwdPrimitiveFactory() {}
~MklFusedBatchNormFwdPrimitiveFactory() {}
- static std::string CreateKey(const MklBatchNormFwdParams& fwdParams) {
- std::string prefix = "bn_fwd";
+ static string CreateKey(const MklBatchNormFwdParams& fwdParams) {
+ string prefix = "bn_fwd";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(fwdParams.src_dims);
@@ -911,13 +911,13 @@ class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
}
MklPrimitive* GetBatchNormFwd(const MklBatchNormFwdParams& fwdParams) {
- std::string key = CreateKey(fwdParams);
+ string key = CreateKey(fwdParams);
return this->GetOp(key);
}
void SetBatchNormFwd(const MklBatchNormFwdParams& fwdParams,
MklPrimitive* op) {
- std::string key = CreateKey(fwdParams);
+ string key = CreateKey(fwdParams);
this->SetOp(key, op);
}
};
@@ -1122,8 +1122,8 @@ class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
MklFusedBatchNormBwdPrimitiveFactory() {}
~MklFusedBatchNormBwdPrimitiveFactory() {}
- static std::string CreateKey(const MklBatchNormBwdParams& bwdParams) {
- std::string prefix = "bn_bwd";
+ static string CreateKey(const MklBatchNormBwdParams& bwdParams) {
+ string prefix = "bn_bwd";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(bwdParams.src_dims);
@@ -1135,13 +1135,13 @@ class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
}
MklPrimitive* GetBatchNormBwd(const MklBatchNormBwdParams& bwdParams) {
- std::string key = CreateKey(bwdParams);
+ string key = CreateKey(bwdParams);
return this->GetOp(key);
}
void SetBatchNormBwd(const MklBatchNormBwdParams& bwdParams,
MklPrimitive* op) {
- std::string key = CreateKey(bwdParams);
+ string key = CreateKey(bwdParams);
this->SetOp(key, op);
}
};
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h
index 9c516afbd0..3a3de1c58b 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.h
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h
@@ -175,8 +175,8 @@ class MklPoolingFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
// primitive op from reuse perspective.
// A pooling key is a string which concates key parameters
// as well as algorithm kind (max versus avg).
- static std::string CreateKey(const MklPoolingParams& fwdParams) {
- std::string prefix = "pooling_fwd";
+ static string CreateKey(const MklPoolingParams& fwdParams) {
+ string prefix = "pooling_fwd";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(fwdParams.src_dims);
@@ -190,12 +190,12 @@ class MklPoolingFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
}
MklPrimitive* GetPoolingFwd(const MklPoolingParams& fwdParams) {
- std::string key = CreateKey(fwdParams);
+ string key = CreateKey(fwdParams);
return this->GetOp(key);
}
void SetPoolingFwd(const MklPoolingParams& fwdParams, MklPrimitive* op) {
- std::string key = CreateKey(fwdParams);
+ string key = CreateKey(fwdParams);
this->SetOp(key, op);
}
};
@@ -326,8 +326,8 @@ class MklPoolingBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
// primitive op from reuse perspective.
// A pooling key is a string which concates key parameters
// as well as algorithm kind (max versus avg).
- static std::string CreateKey(const MklPoolingParams& bwdParams) {
- std::string prefix = "pooling_bwd";
+ static string CreateKey(const MklPoolingParams& bwdParams) {
+ string prefix = "pooling_bwd";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(bwdParams.src_dims);
@@ -341,12 +341,12 @@ class MklPoolingBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
}
MklPrimitive* GetPoolingBwd(const MklPoolingParams& bwdParams) {
- std::string key = CreateKey(bwdParams);
+ string key = CreateKey(bwdParams);
return this->GetOp(key);
}
void SetPoolingBwd(const MklPoolingParams& bwdParams, MklPrimitive* op) {
- std::string key = CreateKey(bwdParams);
+ string key = CreateKey(bwdParams);
this->SetOp(key, op);
}
};
diff --git a/tensorflow/core/kernels/mkl_tfconv_op.h b/tensorflow/core/kernels/mkl_tfconv_op.h
index f4f0035f26..a9e92f6638 100644
--- a/tensorflow/core/kernels/mkl_tfconv_op.h
+++ b/tensorflow/core/kernels/mkl_tfconv_op.h
@@ -118,12 +118,11 @@ class MklToTfOp : public OpKernel {
CHECK(output_tensor->CopyFrom(input_tensor, output_shape));
}
} catch (mkldnn::error& e) {
- string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + std::string(e.message) + ", in file " +
- std::string(__FILE__) + ":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
context,
- errors::Aborted("Operation received an exception:", error_msg));
+ errors::Aborted("Operation received an exception: Status: ", e.status,
+ ", message: ", StringPiece(e.message), ", in file ",
+ __FILE__, ":", __LINE__));
}
}
#else