aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/bias_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu_3.cu.cc4
-rw-r--r--tensorflow/core/kernels/resize_nearest_neighbor_op.cc2
-rw-r--r--tensorflow/core/kernels/sample_distorted_bounding_box_op.cc12
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.cc6
-rw-r--r--tensorflow/core/kernels/softplus_op.h6
-rw-r--r--tensorflow/core/kernels/softsign_op.h4
-rw-r--r--tensorflow/core/kernels/summary_op.cc4
-rw-r--r--tensorflow/core/kernels/training_ops.cc7
-rw-r--r--tensorflow/core/kernels/training_ops_gpu.cu.cc9
10 files changed, 29 insertions, 27 deletions
diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc
index bfb64b26c7..62c6ed31a0 100644
--- a/tensorflow/core/kernels/bias_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc
@@ -104,7 +104,7 @@ __global__ void BiasGradNHWC_SharedAtomics(int32 nthreads,
T* bias_backprop, int32 bias_size) {
T* s_data = reinterpret_cast<T*>(s_buf);
for (int32 index = threadIdx.x; index < bias_size; index += blockDim.x) {
- s_data[index] = 0;
+ s_data[index] = T(0);
}
__syncthreads();
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
index dbf096ac45..ccd983833d 100644
--- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
+++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
@@ -269,7 +269,7 @@ __global__ void PadInputCustomKernelNHWC(int nthreads, const T* input,
int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
output[output_index] = input[input_index];
} else {
- output[output_index] = 0;
+ output[output_index] = T(0);
}
}
}
@@ -295,7 +295,7 @@ __global__ void PadInputCustomKernelNCHW(int nthreads, const T* input,
int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
output[output_index] = input[input_index];
} else {
- output[output_index] = 0;
+ output[output_index] = T(0);
}
}
}
diff --git a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
index 61b89fb9a5..06eb59382f 100644
--- a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
+++ b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
@@ -138,7 +138,7 @@ class ResizeNearestNeighborOpGrad : public OpKernel {
for (int y = 0; y < out_height; ++y) {
for (int x = 0; x < out_width; ++x) {
for (int b = 0; b < batch_size; ++b) {
- output_data(b, y, x, c) = 0;
+ output_data(b, y, x, c) = T(0);
}
}
}
diff --git a/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc b/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc
index eb14009c63..79c6a43b19 100644
--- a/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc
+++ b/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc
@@ -363,11 +363,11 @@ class SampleDistortedBoundingBoxOp : public OpKernel {
typename TTypes<T, 1>::Tensor size_data = size->tensor<T, 1>();
typename TTypes<float, 3>::Tensor bboxes_data = bboxes->tensor<float, 3>();
- begin_data(0) = offset_height;
- size_data(0) = target_height;
+ begin_data(0) = T(offset_height);
+ size_data(0) = T(target_height);
- begin_data(1) = offset_width;
- size_data(1) = target_width;
+ begin_data(1) = T(offset_width);
+ size_data(1) = T(target_width);
bboxes_data(0, 0, 0) =
static_cast<float>(crop_rect.min_y_) / static_cast<float>(height);
@@ -379,8 +379,8 @@ class SampleDistortedBoundingBoxOp : public OpKernel {
static_cast<float>(crop_rect.max_x_) / static_cast<float>(width);
// Retain all of the channels.
- begin_data(2) = 0;
- size_data(2) = -1;
+ begin_data(2) = T(0);
+ size_data(2) = T(-1);
}
private:
diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc
index d7995ac3cc..5d4a19da3f 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops.cc
@@ -394,12 +394,12 @@ class SparseSegmentReductionOpBase : public OpKernel {
out = L(0);
} else {
int r = num % 8;
- T m = 1;
+ T m(1);
if (is_mean_ && (num < 10)) {
- m = num;
+ m = T(num);
}
if (is_sqrtn_ && (num < 10)) {
- m = sqrt(num);
+ m = T(sqrt(num));
}
switch (r) {
case 2: {
diff --git a/tensorflow/core/kernels/softplus_op.h b/tensorflow/core/kernels/softplus_op.h
index 304b69a82f..928e64c338 100644
--- a/tensorflow/core/kernels/softplus_op.h
+++ b/tensorflow/core/kernels/softplus_op.h
@@ -34,8 +34,8 @@ struct Softplus {
void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
typename TTypes<T>::Tensor activations) {
activations.device(d) =
- (features > features.constant(30.f))
- .select(features, (features.exp() + features.constant(1.0f)).log());
+ (features > features.constant(T(30)))
+ .select(features, (features.exp() + features.constant(T(1))).log());
}
};
@@ -51,7 +51,7 @@ struct SoftplusGrad {
typename TTypes<T>::ConstTensor features,
typename TTypes<T>::Tensor backprops) {
backprops.device(d) =
- gradients / ((-features).exp() + features.constant(1.0f));
+ gradients / ((-features).exp() + features.constant(T(1)));
}
};
diff --git a/tensorflow/core/kernels/softsign_op.h b/tensorflow/core/kernels/softsign_op.h
index 36790a5874..9222a6686a 100644
--- a/tensorflow/core/kernels/softsign_op.h
+++ b/tensorflow/core/kernels/softsign_op.h
@@ -34,7 +34,7 @@ struct Softsign {
void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
typename TTypes<T>::Tensor activations) {
activations.device(d) =
- features / (features.abs() + features.constant(1.0f));
+ features / (features.abs() + features.constant(T(1)));
}
};
@@ -50,7 +50,7 @@ struct SoftsignGrad {
typename TTypes<T>::ConstTensor features,
typename TTypes<T>::Tensor backprops) {
backprops.device(d) =
- gradients / (features.abs() + features.constant(1.0f)).square();
+ gradients / (features.abs() + features.constant(T(1))).square();
}
};
diff --git a/tensorflow/core/kernels/summary_op.cc b/tensorflow/core/kernels/summary_op.cc
index 9fd5a4a6fc..16e5b0a0ff 100644
--- a/tensorflow/core/kernels/summary_op.cc
+++ b/tensorflow/core/kernels/summary_op.cc
@@ -52,7 +52,7 @@ class SummaryScalarOp : public OpKernel {
for (int i = 0; i < Ttags.size(); i++) {
Summary::Value* v = s.add_value();
v->set_tag(Ttags(i));
- v->set_simple_value(Tvalues(i));
+ v->set_simple_value(T(Tvalues(i)));
}
Tensor* summary_tensor = nullptr;
@@ -92,7 +92,7 @@ class SummaryHistoOp : public OpKernel {
errors::OutOfRange("Nan in summary histogram for: ", name()));
break;
}
- histo.Add(v);
+ histo.Add(static_cast<double>(v));
}
Summary s;
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index f761bf6dfc..d56aceb683 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -121,9 +121,10 @@ struct ApplyAdam<CPUDevice, T> {
typename TTypes<T>::ConstScalar beta2,
typename TTypes<T>::ConstScalar epsilon,
typename TTypes<T>::ConstFlat grad) {
- const T alpha = lr() * std::sqrt(1 - beta2_power()) / (1 - beta1_power());
- m.device(d) += (grad - m) * (1 - beta1());
- v.device(d) += (grad.square() - v) * (1 - beta2());
+ const T alpha =
+ lr() * std::sqrt(T(1) - beta2_power()) / (T(1) - beta1_power());
+ m.device(d) += (grad - m) * (T(1) - beta1());
+ v.device(d) += (grad.square() - v) * (T(1) - beta2());
var.device(d) -= (m * alpha) / (v.sqrt() + epsilon());
}
};
diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc
index 22570ebd5a..6885300997 100644
--- a/tensorflow/core/kernels/training_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc
@@ -64,15 +64,16 @@ struct ApplyAdadelta<GPUDevice, T> {
bcast[0] = grad.dimension(0);
Eigen::Sizes<1> single;
- accum.device(d) =
- accum_update * rho.reshape(single).broadcast(bcast) +
- grad.square() * (grad.constant(1) - rho.reshape(single).broadcast(bcast));
+ accum.device(d) = accum_update * rho.reshape(single).broadcast(bcast) +
+ grad.square() * (grad.constant(T(1)) -
+ rho.reshape(single).broadcast(bcast));
const auto update =
(accum_update + epsilon.reshape(single).broadcast(bcast)).sqrt() *
(accum + epsilon.reshape(single).broadcast(bcast)).rsqrt() * grad;
accum_update.device(d) =
accum_update * rho.reshape(single).broadcast(bcast) +
- update.square() * (grad.constant(1) - rho.reshape(single).broadcast(bcast));
+ update.square() *
+ (grad.constant(T(1)) - rho.reshape(single).broadcast(bcast));
var.device(d) -= update * lr.reshape(single).broadcast(bcast);
}
};