diff --git a/include/boost/compute/algorithm/detail/merge_sort_on_gpu.hpp b/include/boost/compute/algorithm/detail/merge_sort_on_gpu.hpp index d5e1a2d8c..a3d932741 100644 --- a/include/boost/compute/algorithm/detail/merge_sort_on_gpu.hpp +++ b/include/boost/compute/algorithm/detail/merge_sort_on_gpu.hpp @@ -170,8 +170,12 @@ inline size_t bitonic_block_sort(KeyIterator keys_first, k.decl("compare") << " = " << compare(k.var("sibling_key"), k.var("my_key")) << ";\n" << + k.decl("equal") << " = !(compare || " << + compare(k.var("my_key"), + k.var("sibling_key")) << ");\n" << k.decl("swap") << " = compare ^ (sibling_idx < lid) ^ direction;\n" << + "swap = equal ? false : swap;\n" << "my_key = swap ? sibling_key : my_key;\n"; if(sort_by_key) { @@ -220,8 +224,12 @@ inline size_t bitonic_block_sort(KeyIterator keys_first, k.decl("compare") << " = " << compare(k.var("sibling_key"), k.var("my_key")) << ";\n" << + k.decl("equal") << " = !(compare || " << + compare(k.var("my_key"), + k.var("sibling_key")) << ");\n" << k.decl("swap") << " = compare ^ (sibling_idx < lid);\n" << + "swap = equal ? false : swap;\n" << "my_key = swap ? sibling_key : my_key;\n"; if(sort_by_key) { diff --git a/test/test_sort.cpp b/test/test_sort.cpp index e2dac4f4c..3b3a76f73 100644 --- a/test/test_sort.cpp +++ b/test/test_sort.cpp @@ -340,6 +340,7 @@ BOOST_AUTO_TEST_CASE(sort_int2) host[size/4] = int2_(20.f, 0.f); host[(size*3)/4] = int2_(9.f, 0.f); host[size-3] = int2_(-10.0f, 0.f); + host[size/2+1] = int2_(-10.0f, -1.f); boost::compute::vector vector(size, context); boost::compute::copy(host.begin(), host.end(), vector.begin(), queue); @@ -356,9 +357,11 @@ BOOST_AUTO_TEST_CASE(sort_int2) ); boost::compute::copy(vector.begin(), vector.end(), host.begin(), queue); BOOST_CHECK_CLOSE(host[0][0], -10.f, 0.1); + BOOST_CHECK_CLOSE(host[1][0], -10.f, 0.1); BOOST_CHECK_CLOSE(host[(size - 3)][0], 9.f, 0.1); BOOST_CHECK_CLOSE(host[(size - 2)][0], 20.f, 0.1); BOOST_CHECK_CLOSE(host[(size - 1)][0], 100.f, 0.1); + BOOST_CHECK_NE(host[0], host[1]); } BOOST_AUTO_TEST_SUITE_END()