// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project

#ifndef KOKKOS_NESTED_SORT_IMPL_HPP_
#define KOKKOS_NESTED_SORT_IMPL_HPP_

#include <Kokkos_Macros.hpp>
#ifdef KOKKOS_ENABLE_EXPERIMENTAL_CXX20_MODULES
import kokkos.core;
#else
#include <Kokkos_Core.hpp>
#endif

#include <type_traits>

namespace Kokkos {
namespace Experimental {
namespace Impl {

// true for TeamVectorRange, false for ThreadVectorRange
template <bool teamLevel>
struct NestedRange {};

// Specialization for team-level
template <>
struct NestedRange<true> {
  template <typename TeamMember, typename SizeType>
  KOKKOS_FUNCTION static auto create(const TeamMember& t, SizeType len) {
    return Kokkos::TeamVectorRange(t, len);
  }
  template <typename TeamMember>
  KOKKOS_FUNCTION static void barrier(const TeamMember& t) {
    t.team_barrier();
  }
};

// Specialization for thread-level
template <>
struct NestedRange<false> {
  template <typename TeamMember, typename SizeType>
  KOKKOS_FUNCTION static auto create(const TeamMember& t, SizeType len) {
    return Kokkos::ThreadVectorRange(t, len);
  }
  // Barrier is no-op, as vector lanes of a thread are implicitly synchronized
  // after parallel region
  template <typename TeamMember>
  KOKKOS_FUNCTION static void barrier(const TeamMember&) {}
};

// When just doing sort (not sort_by_key), use nullptr_t for ValueViewType.
// This only takes the NestedRange instance for template arg deduction.
template <class TeamMember, class KeyViewType, class ValueViewType,
          class Comparator, bool useTeamLevel>
KOKKOS_INLINE_FUNCTION void sort_nested_impl(
    const TeamMember& t, const KeyViewType& keyView,
    [[maybe_unused]] const ValueViewType& valueView, const Comparator& comp,
    const NestedRange<useTeamLevel>) {
  using SizeType  = typename KeyViewType::size_type;
  using KeyType   = typename KeyViewType::non_const_value_type;
  using Range     = NestedRange<useTeamLevel>;
  SizeType n      = keyView.extent(0);
  SizeType npot   = 1;
  SizeType levels = 0;
  // FIXME: ceiling power-of-two is a common thing to need - make it a utility
  while (npot < n) {
    levels++;
    npot <<= 1;
  }
  for (SizeType i = 0; i < levels; i++) {
    for (SizeType j = 0; j <= i; j++) {
      // n/2 pairs of items are compared in parallel
      Kokkos::parallel_for(Range::create(t, npot / 2), [=](const SizeType k) {
        // How big are the brown/pink boxes?
        // (Terminology comes from Wikipedia diagram)
        // https://commons.wikimedia.org/wiki/File:BitonicSort.svg#/media/File:BitonicSort.svg
        SizeType boxSize = SizeType(2) << (i - j);
        // Which box contains this thread?
        SizeType boxID     = k >> (i - j);          // k * 2 / boxSize;
        SizeType boxStart  = boxID << (1 + i - j);  // boxID * boxSize
        SizeType boxOffset = k - (boxStart >> 1);   // k - boxID * boxSize / 2;
        SizeType elem1     = boxStart + boxOffset;
        // In first phase (j == 0, brown box): within a box, compare with the
        // opposite value in the box.
        // In later phases (j > 0, pink box): within a box, compare with fixed
        // distance (boxSize / 2) apart.
        SizeType elem2 = (j == 0) ? (boxStart + boxSize - 1 - boxOffset)
                                  : (elem1 + boxSize / 2);
        if (elem2 < n) {
          KeyType key1 = keyView(elem1);
          KeyType key2 = keyView(elem2);
          if (comp(key2, key1)) {
            keyView(elem1) = key2;
            keyView(elem2) = key1;
            if constexpr (!std::is_same_v<ValueViewType, std::nullptr_t>) {
              Kokkos::kokkos_swap(valueView(elem1), valueView(elem2));
            }
          }
        }
      });
      Range::barrier(t);
    }
  }
}

}  // namespace Impl
}  // namespace Experimental
}  // namespace Kokkos
#endif
