#include "utility.h"

#ifndef PARALLEL_SORT_STL_H
#define PARALLEL_SORT_STL_H

namespace internal
{
	std::size_t g_depth = 0L;
	const std::size_t cutoff = 1000000L;

	template<class RanIt, class _Pred>
	void qsort3w(RanIt _First, RanIt _Last, _Pred compare)
	{
		if (_First >= _Last) return;
		
		std::size_t _Size = 0L; g_depth++;
		if ((_Size = std::distance(_First, _Last)) > 0)
		{
			RanIt _LeftIt = _First, _RightIt = _Last;
			bool is_swapped_left = false, is_swapped_right = false;
			typename std::iterator_traits<RanIt>::value_type _Pivot = *_First;

			RanIt _FwdIt = _First + 1;
			while (_FwdIt <= _RightIt)
			{
				if (compare(*_FwdIt, _Pivot))
				{
					is_swapped_left = true;
					std::iter_swap(_LeftIt, _FwdIt);
					_LeftIt++; _FwdIt++;
				}

				else if (compare(_Pivot, *_FwdIt)) {
					is_swapped_right = true;
					std::iter_swap(_RightIt, _FwdIt);
					_RightIt--;
				}

				else _FwdIt++;
			}

			if (_Size >= internal::cutoff)
			{
				#pragma omp taskgroup
				{
					#pragma omp task untied mergeable
					if ((std::distance(_First, _LeftIt) > 0) && (is_swapped_left))
						qsort3w(_First, _LeftIt - 1, compare);

					#pragma omp task untied mergeable
					if ((std::distance(_RightIt, _Last) > 0) && (is_swapped_right))
						qsort3w(_RightIt + 1, _Last, compare);
				}
			}

			else
			{
				#pragma omp task untied mergeable
				{
					if ((std::distance(_First, _LeftIt) > 0) && is_swapped_left)
						qsort3w(_First, _LeftIt - 1, compare);

					if ((std::distance(_RightIt, _Last) > 0) && is_swapped_right)
						qsort3w(_RightIt + 1, _Last, compare);
				}
			}
		}
	}

	template<class BidirIt, class _Pred >
	void parallel_sort(BidirIt _First, BidirIt _Last, _Pred compare)
	{
		g_depth = 0L;
		#pragma omp parallel num_threads(12)
		#pragma omp master
			internal::qsort3w(_First, _Last - 1, compare);
	}

	template<class BidirIt, class _CompKey, class _CompVals> 
	void parallel_stable_sort(BidirIt _First, BidirIt _Last,
		_CompKey comp_key, _CompVals comp_vals)
	{
		if (_First >= _Last) return;

		std::vector<std::size_t> pv;
		#pragma omp task mergeable untied 
			internal::parallel_sort(_First, _Last, comp_key);

		omp_lock_t lock;
		omp_init_lock(&lock);
	
		pv.push_back(0);
		
		#pragma omp parallel for
		for (BidirIt _FwdIt = _First + 1; _FwdIt != _Last; _FwdIt++)
			if ((comp_key(*_FwdIt, *(_FwdIt - 1)) || (comp_key(*(_FwdIt - 1), *_FwdIt))))
			{
				omp_set_lock(&lock);
				pv.push_back(std::distance(_First, _FwdIt));
				omp_unset_lock(&lock);
			}
		
		pv.push_back(std::distance(_First, _Last));

		internal::parallel_sort(pv.begin(), pv.end(), 
			[&](const std::size_t item1, const std::size_t item2) { return item1 < item2; });

		#pragma omp parallel for
		for (auto _FwdIt = pv.begin(); _FwdIt != pv.end() - 1; _FwdIt++)
			internal::parallel_sort(_First + *_FwdIt, _First + *(_FwdIt + 1), comp_vals);
	}

	template<class BidirIt, class _CompKey, class _CompVals>
	void sequential_stable_sort(BidirIt _First, BidirIt _Last,
		_CompKey comp_key, _CompVals comp_vals)
	{
		std::sort(_First, _Last, comp_key);

		BidirIt	_First_p = _First, _Last_p = _First_p;
		for (BidirIt _FwdIt = _First + 1; _FwdIt != _Last; _FwdIt++)
		{
			if ((comp_key(*_FwdIt, *(_FwdIt - 1)) || (comp_key(*(_FwdIt - 1), *_FwdIt))))
			{
				_Last_p = _FwdIt;
				if (_First_p < _Last_p)
				{
					std::sort(_First_p, _Last_p, comp_vals);
					_First_p = _Last_p;
				}
			}
		}

		std::sort(_First_p, _Last, comp_vals);
	}
}

#endif // PARALLEL_SORT_STL_H