aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cwise_ops_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-31 22:07:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-31 23:17:46 -0700
commitabe9ab326625105adb3c9d46c027931aec947d1f (patch)
treed9fa7eb9a2fd9b37bc87f98cf353354391b9eb04 /tensorflow/core/kernels/cwise_ops_test.cc
parentc0637048dbc099eac1f75878b765220cd02ccfc0 (diff)
Merge changes from github.
Change: 128958134
Diffstat (limited to 'tensorflow/core/kernels/cwise_ops_test.cc')
-rw-r--r--tensorflow/core/kernels/cwise_ops_test.cc39
1 files changed, 23 insertions, 16 deletions
diff --git a/tensorflow/core/kernels/cwise_ops_test.cc b/tensorflow/core/kernels/cwise_ops_test.cc
index 2cf51878ba..823e7e14ed 100644
--- a/tensorflow/core/kernels/cwise_ops_test.cc
+++ b/tensorflow/core/kernels/cwise_ops_test.cc
@@ -23,13 +23,14 @@ limitations under the License.
namespace tensorflow {
-// Creates a Graph which applies a unary "func" on a 3D float tensor
-// of "num" elements.
-static Graph* Unary(const string& func, int num) {
+// Creates a Graph which applies a unary "func" on a 3D tensor of
+// type T with "num" elements.
+template <typename T>
+static Graph* Unary(const string& func, int num, DataType dtype) {
Graph* g = new Graph(OpRegistry::Global());
- Tensor data(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)}));
+ Tensor data(dtype, TensorShape({64, 64, num / (64 * 64)}));
CHECK_GT(data.NumElements(), 0);
- data.flat<float>().setRandom();
+ data.flat<T>().setRandom();
test::graph::Unary(g, func, test::graph::Constant(g, data), 0);
return g;
}
@@ -40,17 +41,23 @@ static int RowsAndColsArg(int r, int c) { return r * kRows + c; }
static int RowsFromArg(int arg) { return (arg / kRows); }
static int ColsFromArg(int arg) { return (arg % kRows); }
-#define BM_UNARY(DEVICE, FUNC) \
- static void BM_##DEVICE##_##FUNC(int iters, int num) { \
- const int64 tot = static_cast<int64>(iters) * num; \
- testing::ItemsProcessed(tot); \
- testing::BytesProcessed(tot * sizeof(float)); \
- test::Benchmark(#DEVICE, Unary(#FUNC, num)).Run(iters); \
- } \
- BENCHMARK(BM_##DEVICE##_##FUNC)->Range(4 << 10, 1 << 20);
-
-BM_UNARY(cpu, Floor);
-BM_UNARY(gpu, Floor);
+#define BM_UNARY(DEVICE, FUNC, T, TYPE) \
+ static void BM_##DEVICE##_##FUNC##_##TYPE(int iters, int num) { \
+ const int64 tot = static_cast<int64>(iters) * num; \
+ testing::ItemsProcessed(tot); \
+ testing::BytesProcessed(tot * sizeof(T)); \
+ test::Benchmark(#DEVICE, Unary<T>(#FUNC, num, TYPE)).Run(iters); \
+ } \
+ BENCHMARK(BM_##DEVICE##_##FUNC##_##TYPE)->Range(4 << 10, 1 << 20);
+
+BM_UNARY(cpu, Floor, float, DT_FLOAT);
+BM_UNARY(gpu, Floor, float, DT_FLOAT);
+BM_UNARY(cpu, Floor, double, DT_DOUBLE);
+BM_UNARY(gpu, Floor, double, DT_DOUBLE);
+BM_UNARY(cpu, Conj, std::complex<float>, DT_COMPLEX64);
+BM_UNARY(gpu, Conj, std::complex<float>, DT_COMPLEX64);
+BM_UNARY(cpu, Conj, std::complex<double>, DT_COMPLEX128);
+BM_UNARY(gpu, Conj, std::complex<double>, DT_COMPLEX128);
// data func scalar.
static Graph* BinaryScalar(int num, const string& func) {