aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/nccl
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-09 01:59:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-09 02:02:12 -0700
commit5c469e6bafb479ef110b2f02f070507a3711664d (patch)
tree57decbd46a5dc60c770c5a31e152b549e74e5758 /tensorflow/contrib/nccl
parent1eea5ad3f9a622411117f7208d308055b0707d0f (diff)
Enabling fp16 for NCCL 1 and 2.
PiperOrigin-RevId: 192096789
Diffstat (limited to 'tensorflow/contrib/nccl')
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager.cc2
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager_test.cc214
-rw-r--r--tensorflow/contrib/nccl/ops/nccl_ops.cc14
-rw-r--r--tensorflow/contrib/nccl/python/ops/nccl_ops_test.py2
4 files changed, 127 insertions, 105 deletions
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.cc b/tensorflow/contrib/nccl/kernels/nccl_manager.cc
index 913935b382..b9b482a698 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_manager.cc
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager.cc
@@ -76,6 +76,8 @@ struct NcclManager::Communicator {
namespace {
ncclDataType_t ToNcclType(DataType t) {
switch (t) {
+ case DT_HALF:
+ return ncclHalf;
case DT_FLOAT:
return ncclFloat;
case DT_DOUBLE:
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
index 985b2bae25..06ca65e33a 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
@@ -48,35 +48,9 @@ static std::vector<BaseGPUDevice*> GetGPUDevices() {
return gpus;
}
+template <typename Scalar>
class NcclManagerTest : public ::testing::Test {
- protected:
- static void SetUpTestCase() {
- setenv("NCCL_DEBUG", "INFO", 1 /* replace */);
- devices = new std::vector<BaseGPUDevice*>(GetGPUDevices());
- CHECK(!devices->empty());
- LOG(ERROR) << "Running test with " << devices->size() << " gpus";
- }
- static void TearDownTestCase() {
- for (auto device : *devices) delete device;
- delete devices;
- }
-
- static Allocator* gpu_allocator(BaseGPUDevice* device) {
- return device->GetStepAllocator(AllocatorAttributes(),
- nullptr /* step_resource_manager */);
- }
-
- static std::vector<BaseGPUDevice*>* devices;
-
- template <typename Scalar>
- perftools::gputools::DeviceMemory<Scalar> AsDeviceMemory(
- const Scalar* cuda_memory) {
- perftools::gputools::DeviceMemoryBase wrapped(
- const_cast<Scalar*>(cuda_memory));
- perftools::gputools::DeviceMemory<Scalar> typed(wrapped);
- return typed;
- }
-
+ public:
// A single all-reduce to apply.
struct TestCase {
string key;
@@ -89,42 +63,52 @@ class NcclManagerTest : public ::testing::Test {
int num_completed = 0;
};
+ static void SetUpTestCase() {
+ setenv("NCCL_DEBUG", "INFO", 1 /* replace */);
+ devices_ = new std::vector<BaseGPUDevice*>(GetGPUDevices());
+ CHECK(!devices_->empty());
+ LOG(ERROR) << "Running test with " << devices_->size() << " gpus";
+ }
+
+ static void TearDownTestCase() {
+ for (auto device : *devices_) delete device;
+ delete devices_;
+ }
+
TestCase* MakeTestCase(int num_ranks, ncclRedOp_t reduction_op,
TensorShape shape, float value_offset) {
TestCase* test_case = new TestCase();
- test_case->expected = Tensor(DT_FLOAT, shape);
+ test_case->expected = Tensor(data_type_, shape);
if (reduction_op == ncclProd) {
- test::FillFn<float>(&test_case->expected, [](int) { return 1; });
+ test::FillFn<Scalar>(&test_case->expected,
+ [](int) { return static_cast<Scalar>(1); });
} else if (reduction_op == ncclSum) {
- test::FillFn<float>(&test_case->expected, [](int) { return 0; });
+ test::FillFn<Scalar>(&test_case->expected,
+ [](int) { return static_cast<Scalar>(0); });
} else if (reduction_op == ncclMax) {
- test::FillFn<float>(&test_case->expected, [](int) {
- return -1 * std::numeric_limits<float>::max();
- });
+ test::FillFn<Scalar>(&test_case->expected, [](int) { return -max_; });
} else if (reduction_op == ncclMin) {
- test::FillFn<float>(&test_case->expected, [](int) {
- return std::numeric_limits<float>::max();
- });
+ test::FillFn<Scalar>(&test_case->expected, [](int) { return max_; });
} else {
LOG(FATAL) << "Invalid reduction_op " << reduction_op;
}
- int mult = 1;
- for (int i = 0; i < num_ranks; ++i) {
- auto* device = devices->at(i % devices->size());
+ float value_scale = 0.01; // Small scale to avoid fp16 overflow.
+ for (int rank = 0; rank < num_ranks; ++rank) {
+ auto* device = GetDevice(rank);
auto* stream = device->tensorflow_gpu_device_info()->stream;
- Tensor in_cpu(DT_FLOAT, shape);
- test::FillFn<float>(&in_cpu, [mult, value_offset](int index) {
- return value_offset + (index + 1) * mult;
+ Tensor in_cpu(data_type_, shape);
+ test::FillFn<Scalar>(&in_cpu, [&](int index) {
+ return static_cast<Scalar>((index + 1) * value_scale + value_offset);
});
for (int j = 0; j < shape.num_elements(); ++j) {
- auto in_val = in_cpu.flat<float>()(j);
- auto out_expr = test_case->expected.flat<float>();
+ auto in_val = in_cpu.flat<Scalar>()(j);
+ auto out_expr = test_case->expected.template flat<Scalar>();
if (reduction_op == ncclProd) {
- out_expr(j) *= in_val;
+ out_expr(j) = out_expr(j) * in_val;
} else if (reduction_op == ncclSum) {
- out_expr(j) += in_val;
+ out_expr(j) = out_expr(j) + in_val;
} else if (reduction_op == ncclMax) {
if (in_val > out_expr(j)) {
out_expr(j) = in_val;
@@ -136,26 +120,18 @@ class NcclManagerTest : public ::testing::Test {
}
}
- mult *= 10;
- test_case->ins.emplace_back(gpu_allocator(device), DT_FLOAT, shape);
- test_case->outs.emplace_back(gpu_allocator(device), DT_FLOAT, shape);
+ value_scale *= 10;
+ test_case->ins.emplace_back(GpuAllocator(device), data_type_, shape);
+ test_case->outs.emplace_back(GpuAllocator(device), data_type_, shape);
const Tensor& in_gpu = test_case->ins.back();
- auto in_gpu_mem = AsDeviceMemory(in_gpu.flat<float>().data());
- stream->ThenMemcpy(&in_gpu_mem, in_cpu.flat<float>().data(),
+ auto in_gpu_mem = AsDeviceMemory(in_gpu.flat<Scalar>().data());
+ stream->ThenMemcpy(&in_gpu_mem, in_cpu.flat<Scalar>().data(),
in_cpu.TotalBytes());
}
return test_case;
}
- NcclManager::DoneCallback CreateDoneCallback(TestCase* test_case) {
- return [this, test_case](Status s) {
- mutex_lock l(test_case->mu);
- ++test_case->num_completed;
- test_case->final_status.Update(s);
- };
- }
-
void VerifyResults(const string& case_label, TestCase* test_case) {
// Wait for the done callback to be called.
{
@@ -168,41 +144,84 @@ class NcclManagerTest : public ::testing::Test {
test_case->mu.unlock();
}
// Copy memory to host and verify.
- for (int i = 0; i < test_case->outs.size(); ++i) {
- auto* device = devices->at(i % devices->size());
+ for (int rank = 0; rank < test_case->outs.size(); ++rank) {
+ auto* device = GetDevice(rank);
auto* stream = device->tensorflow_gpu_device_info()->stream;
- const Tensor& out_gpu = test_case->outs[i];
- Tensor out_cpu(DT_FLOAT, out_gpu.shape());
- auto out_gpu_mem = AsDeviceMemory(out_gpu.flat<float>().data());
- stream->ThenMemcpy(out_cpu.flat<float>().data(), out_gpu_mem,
+ const Tensor& out_gpu = test_case->outs[rank];
+ Tensor out_cpu(data_type_, out_gpu.shape());
+ auto out_gpu_mem = AsDeviceMemory(out_gpu.flat<Scalar>().data());
+ stream->ThenMemcpy(out_cpu.flat<Scalar>().data(), out_gpu_mem,
out_cpu.TotalBytes());
SE_ASSERT_OK(stream->BlockHostUntilDone());
- test::ExpectTensorEqual<float>(test_case->expected, out_cpu);
+ test::ExpectTensorNear<Scalar>(test_case->expected, out_cpu, 0.01);
}
}
+
+ NcclManager::DoneCallback CreateDoneCallback(TestCase* test_case) {
+ return [this, test_case](Status s) {
+ mutex_lock l(test_case->mu);
+ ++test_case->num_completed;
+ test_case->final_status.Update(s);
+ };
+ }
+
+ static BaseGPUDevice* GetDevice(size_t rank) {
+ return devices_->at(rank % devices_->size());
+ }
+
+ private:
+ static Allocator* GpuAllocator(BaseGPUDevice* device) {
+ return device->GetStepAllocator(AllocatorAttributes(),
+ nullptr /* step_resource_manager */);
+ }
+
+ static perftools::gputools::DeviceMemory<Scalar> AsDeviceMemory(
+ const Scalar* cuda_memory) {
+ perftools::gputools::DeviceMemoryBase wrapped(
+ const_cast<Scalar*>(cuda_memory));
+ perftools::gputools::DeviceMemory<Scalar> typed(wrapped);
+ return typed;
+ }
+
+ private:
+ static std::vector<BaseGPUDevice*>* devices_;
+ static const DataType data_type_;
+ static const Scalar max_;
};
-std::vector<BaseGPUDevice*>* NcclManagerTest::devices = nullptr;
+
+template <typename Scalar>
+std::vector<BaseGPUDevice*>* NcclManagerTest<Scalar>::devices_ = nullptr;
+template <typename Scalar>
+const DataType NcclManagerTest<Scalar>::data_type_ =
+ DataTypeToEnum<Scalar>::value;
+template <typename Scalar>
+const Scalar NcclManagerTest<Scalar>::max_ =
+ Eigen::NumTraits<Scalar>::highest();
+
+// Instantiate tests for float and half.
+using TypeList = ::testing::Types<float, Eigen::half>;
+TYPED_TEST_CASE(NcclManagerTest, TypeList);
// Test basic sum reduction.
-TEST_F(NcclManagerTest, BasicSumReduction) {
+TYPED_TEST(NcclManagerTest, BasicSumReduction) {
const int num_ranks = 3;
for (int op = 0; op < 4; ++op) {
ncclRedOp_t reduction_op = static_cast<ncclRedOp_t>(op);
- std::unique_ptr<TestCase> test_case(
- MakeTestCase(num_ranks, reduction_op, TensorShape({2, 3}), 0));
- for (int device_num = 0; device_num < num_ranks; ++device_num) {
- auto* device = devices->at(device_num % devices->size());
+ std::unique_ptr<typename TestFixture::TestCase> test_case(
+ this->MakeTestCase(num_ranks, reduction_op, TensorShape({2, 3}), 0.0f));
+ for (int rank = 0; rank < num_ranks; ++rank) {
+ auto* device = this->GetDevice(rank);
auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr;
auto* stream = device->tensorflow_gpu_device_info()->stream;
NcclManager::instance()->AddToAllReduce(
num_ranks, "allreduce", reduction_op, device->executor(),
- device->gpu_id(), event_mgr, stream, &test_case->ins[device_num],
- &test_case->outs[device_num], CreateDoneCallback(test_case.get()));
+ device->gpu_id(), event_mgr, stream, &test_case->ins[rank],
+ &test_case->outs[rank], this->CreateDoneCallback(test_case.get()));
}
LOG(ERROR) << "Verifying results";
- VerifyResults("test_case", test_case.get());
+ this->VerifyResults("test_case", test_case.get());
}
}
@@ -213,7 +232,7 @@ TEST_F(NcclManagerTest, BasicSumReduction) {
// with num_ranks > devices->size(), for some GPUs (e.g. K20m).
// To test the higher settings, increase num_ranks,
// num_collectives_per_iteration and time_limit_micros.
-TEST_F(NcclManagerTest, MultipleCallers) {
+TYPED_TEST(NcclManagerTest, MultipleCallers) {
const int num_ranks = 1; // 2;
const int num_collectives_per_iteration = 1; // 1000;
const int num_threads = 3;
@@ -223,49 +242,49 @@ TEST_F(NcclManagerTest, MultipleCallers) {
srand(Env::Default()->NowMicros());
for (;;) {
- std::vector<std::pair<int, int>> case_and_device_num;
- std::vector<std::unique_ptr<TestCase>> test_cases;
+ std::vector<std::pair<int, int>> case_and_rank;
+ std::vector<std::unique_ptr<typename TestFixture::TestCase>> test_cases;
for (int i = 0; i < num_collectives_per_iteration; ++i) {
- test_cases.emplace_back(
- MakeTestCase(num_ranks, ncclSum,
- TensorShape({100, i % 5 + 1, i % 3 + 1}), i + 0.1 * i));
+ test_cases.emplace_back(this->MakeTestCase(
+ num_ranks, ncclSum, TensorShape({100, i % 5 + 1, i % 3 + 1}),
+ 1.1f * i));
for (int j = 0; j < num_ranks; ++j) {
- case_and_device_num.emplace_back(i, j);
+ case_and_rank.emplace_back(i, j);
}
}
- for (int i = 0; i < num_ranks; ++i) {
- auto* device = devices->at(i % devices->size());
+ for (int rank = 0; rank < num_ranks; ++rank) {
+ auto* device = this->GetDevice(rank);
auto* stream = device->tensorflow_gpu_device_info()->stream;
SE_ASSERT_OK(stream->BlockHostUntilDone());
}
- std::shuffle(case_and_device_num.begin(), case_and_device_num.end(),
+ std::shuffle(case_and_rank.begin(), case_and_rank.end(),
std::mt19937(std::random_device()()));
- mutex mu; // guards case_and_device_num.
+ mutex mu; // guards case_and_rank.
std::unique_ptr<thread::ThreadPool> pool(
new thread::ThreadPool(Env::Default(), "test", num_threads));
- const int to_schedule = case_and_device_num.size();
+ const int to_schedule = case_and_rank.size();
for (int i = 0; i < to_schedule; ++i) {
auto fn = [&]() {
- int device_num;
+ int rank;
int test_num;
{
mutex_lock l(mu);
- test_num = case_and_device_num.back().first;
- device_num = case_and_device_num.back().second;
- case_and_device_num.pop_back();
+ test_num = case_and_rank.back().first;
+ rank = case_and_rank.back().second;
+ case_and_rank.pop_back();
}
- auto* device = devices->at(device_num % devices->size());
+ auto* device = this->GetDevice(rank);
auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr;
auto* stream = device->tensorflow_gpu_device_info()->stream;
- TestCase* test_case = test_cases[test_num].get();
+ typename TestFixture::TestCase* test_case = test_cases[test_num].get();
NcclManager::instance()->AddToAllReduce(
num_ranks, strings::StrCat("allreduce", test_num), ncclSum,
device->executor(), device->gpu_id(), event_mgr, stream,
- &test_case->ins[device_num], &test_case->outs[device_num],
- CreateDoneCallback(test_case));
+ &test_case->ins[rank], &test_case->outs[rank],
+ this->CreateDoneCallback(test_case));
};
pool->Schedule(fn);
}
@@ -274,7 +293,8 @@ TEST_F(NcclManagerTest, MultipleCallers) {
LOG(ERROR) << "Verifying results for " << num_collectives_per_iteration
<< " collectives";
for (int i = 0; i < test_cases.size(); ++i) {
- VerifyResults(strings::StrCat("collective", i), test_cases[i].get());
+ this->VerifyResults(strings::StrCat("collective", i),
+ test_cases[i].get());
}
int64 delta = Env::Default()->NowMicros() - start;
diff --git a/tensorflow/contrib/nccl/ops/nccl_ops.cc b/tensorflow/contrib/nccl/ops/nccl_ops.cc
index 8eb804c2e9..a353a34b80 100644
--- a/tensorflow/contrib/nccl/ops/nccl_ops.cc
+++ b/tensorflow/contrib/nccl/ops/nccl_ops.cc
@@ -25,7 +25,7 @@ REGISTER_OP("NcclAllReduce")
.Input("input: T")
.Output("data: T")
.Attr("reduction: {'min', 'max', 'prod', 'sum'}")
- .Attr("T: {float, float64, int32, int64}")
+ .Attr("T: {half, float, float64, int32, int64}")
.Attr("num_devices: int")
.Attr("shared_name: string")
.SetIsStateful()
@@ -51,7 +51,7 @@ REGISTER_OP("NcclReduce")
.Input("input: num_devices * T")
.Output("data: T")
.Attr("reduction: {'min', 'max', 'prod', 'sum'}")
- .Attr("T: {float, float64, int32, int64}")
+ .Attr("T: {half, float, float64, int32, int64}")
.Attr("num_devices: int")
.SetIsStateful()
.SetShapeFn(shape_inference::UnchangedShape)
@@ -69,7 +69,7 @@ reduction: the reduction operation to perform.
REGISTER_OP("_NcclReduceSend")
.Input("input: T")
.Attr("reduction: {'min', 'max', 'prod', 'sum'}")
- .Attr("T: {float, float64, int32, int64}")
+ .Attr("T: {half, float, float64, int32, int64}")
.Attr("num_devices: int")
.Attr("shared_name: string")
.SetIsStateful()
@@ -92,7 +92,7 @@ REGISTER_OP("_NcclReduceRecv")
.Input("input: T")
.Output("data: T")
.Attr("reduction: {'min', 'max', 'prod', 'sum'}")
- .Attr("T: {float, float64, int32, int64}")
+ .Attr("T: {half, float, float64, int32, int64}")
.Attr("num_devices: int")
.Attr("shared_name: string")
.SetIsStateful()
@@ -118,7 +118,7 @@ shared_name: Identifier that is shared between ops of the same reduce.
REGISTER_OP("NcclBroadcast")
.Input("input: T")
.Output("output: T")
- .Attr("T: {float, float64, int32, int64}")
+ .Attr("T: {half, float, float64, int32, int64}")
.Attr("shape: shape")
.SetIsStateful()
.SetShapeFn(shape_inference::UnchangedShape)
@@ -135,7 +135,7 @@ shape: The shape of the input tensor.
REGISTER_OP("_NcclBroadcastSend")
.Input("input: T")
- .Attr("T: {float, float64, int32, int64}")
+ .Attr("T: {half, float, float64, int32, int64}")
.Attr("num_devices: int")
.Attr("shared_name: string")
.SetIsStateful()
@@ -157,7 +157,7 @@ shared_name: Identifier that is shared between ops of the same broadcast.
REGISTER_OP("_NcclBroadcastRecv")
.Input("shape: int32")
.Output("output: T")
- .Attr("T: {float, float64, int32, int64}")
+ .Attr("T: {half, float, float64, int32, int64}")
.Attr("num_devices: int")
.Attr("shared_name: string")
.SetIsStateful()
diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
index 98fe394c5b..423a8689ae 100644
--- a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
+++ b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
@@ -72,7 +72,7 @@ class NcclTestCase(test.TestCase):
two.
device_sets: Tuple of virtual devices to run test on.
"""
- for dtype in [np.float32, np.int32, np.int64, np.float64]:
+ for dtype in [np.float16, np.float32, np.int32, np.int64, np.float64]:
# Create session inside outer loop to test use of
# same communicator across multiple sessions.
with self.test_session(use_gpu=True) as sess: