#include <numeric>
#include <random>
#include <unordered_set>
#include <ranges>
#include <bit>

#include "testlib.h"

using Coordinate = int32_t;
using PathLength = int64_t;
using Index = uint32_t;
using SignedIndex = std::make_signed_t<Index>;

Index MIN_T, MAX_T, MIN_N, MAX_N;
Coordinate MIN_X, MAX_X;

struct KeyLock {
    Coordinate key, lock;
};

std::ostream& operator<<(std::ostream& out, const KeyLock key_lock) {
    return out << key_lock.key << ' ' << key_lock.lock;
}

struct Instance {
    std::vector<KeyLock> key_locks;
    Coordinate start, finish;
};

std::ostream& operator<<(std::ostream& out, const Instance& instance) {
    out << instance.key_locks.size() << ' ' << instance.start << ' ' << instance.finish << '\n';
    for (const KeyLock key_lock : instance.key_locks)
        out << key_lock << '\n';
    return out;
}

struct MultiInstance {
    std::vector<Instance> instances;
};

std::ostream& operator<<(std::ostream& out, const MultiInstance& multi_instance) {
    out << multi_instance.instances.size() << '\n';
    for (const Instance& instance : multi_instance.instances)
        out << instance;
    return out;
}

template<class Int>
std::vector<Int> sample_different_numbers(Int low, Int high, size_t n_numbers, bool increasing) {
    if (n_numbers == 0)
        return {};
    ensuref(n_numbers >= 0, "%s", std::format("cannot sample negative amount {} of numbers", n_numbers).c_str());
    ensuref(low <= high, "%s", std::format("cannot sample from an empty segment [{}; {}]", low, high).c_str());
    size_t n_samples = high - low + 1;
    ensuref(n_samples >= n_numbers, "%s", std::format("cannot sample {} numbers from a segment [{}; {}] of size {} < {}", n_numbers, low, high, n_samples, n_numbers).c_str());
    if (n_samples >= n_numbers * 2) {
        // way 1: simply take them at random, check for collisions every time
        std::vector<Int> result(n_numbers);
        std::unordered_set<Int> result_set;
        for (size_t i = 0; result_set.size() < n_numbers;) {
            Int x = rnd.next(low, high);
            if (result_set.insert(x).second)
                result[i++] = x;
        }
        if (increasing) {
            if (n_samples >= n_numbers * (8 * sizeof(size_t) - 1 - std::countl_zero(n_numbers)))
                std::sort(result.begin(), result.end());
            else {
                std::vector<bool> result_bool(n_samples);
                for (Int x : result)
                    result_bool[x - low] = true;
                result = std::views::iota(low, high + 1) | std::views::filter([&result_bool, low](Int i) { return result_bool[i - low]; }) | std::ranges::to<std::vector<Int>>();
            }
        }
        return result;
    }
    // way 2: sample the complement, then get and shuffle our set
    std::vector<bool> result_bool(n_samples, true);
    for (Int x : sample_different_numbers(low, high, n_samples - n_numbers, false))
        result_bool[x - low] = false;
    std::vector<Int> result = std::views::iota(low, high + 1) | std::views::filter([&result_bool, low](Int i) { return result_bool[i - low]; }) | std::ranges::to<std::vector<Int>>();
    result.reserve(n_numbers);
    for (size_t i = 0; i < n_samples; ++i)
        if (result_bool[i])
            result.push_back(i + low);
    if (!increasing)
        shuffle(result.begin(), result.end());
    return result;
}

MultiInstance generate_min_by_sum(bool ignore_reordering) {
    MultiInstance ans;
    Index total_n_in_ans = 0;
    uint64_t filled = 0;
    const Index test_index = opt<Index>("test_index");
    auto proceed = [&](const Instance& instance) -> bool {
        total_n_in_ans += instance.key_locks.size();
        if (total_n_in_ans <= MAX_N && ans.instances.size() < MAX_T) {
            ans.instances.push_back(instance);
            return false;
        }
        if (filled == test_index)
            return true;
        ans.instances = {instance};
        total_n_in_ans = instance.key_locks.size();
        ++filled;
        return false;
    };
    for (PathLength current_sum = static_cast<PathLength>(2 * MIN_X + 1 + 2 * MIN_N) * (MIN_N + 1) + 1;; ++current_sum) {
        for (Index n = MIN_N; n <= MAX_N; ++n) {
            // not the most optimal implementation, but ok
            PathLength coordinate_sum = current_sum - n;
            Index number_of_numbers = 2 * n + 2;
            if (coordinate_sum < number_of_numbers)
                break;
            // divide into sum of different numbers
            PathLength sum_of_increasing = coordinate_sum;
            std::vector<Coordinate> coordinates(number_of_numbers, MIN_X - 1);
            coordinates.back() = sum_of_increasing;
            std::function<bool(Index, Coordinate)> go_next = [&](Index i, Coordinate min_coordinate) -> bool {
                if (i == static_cast<Index>(-1))
                    return false;
                if (i == number_of_numbers)
                    return true;
                PathLength suffix_sum = sum_of_increasing - std::accumulate(coordinates.begin(), coordinates.begin() + i, PathLength(0));
                // find minimum
                PathLength maximum_number = std::min<PathLength>(MAX_X, suffix_sum / (number_of_numbers - i));
                PathLength minimum_number = std::max<PathLength>(min_coordinate, coordinates[i] + 1);
                if (i == number_of_numbers - 1) minimum_number = std::max(minimum_number, suffix_sum);
                if (minimum_number > maximum_number)
                    return go_next(i - 1, MIN_X);
                coordinates[i] = minimum_number;
                if (i + 1 < number_of_numbers)
                    coordinates[i + 1] = coordinates[i];
                return go_next(i + 1, coordinates[i] + 1);
            };
            while (go_next(0, MIN_X)) {
                if (ignore_reordering) {
                    std::vector<bool> key_mask(number_of_numbers);
                    for (Index i = number_of_numbers - n; i < number_of_numbers; ++i)
                        key_mask[i] = true;
                    do {
                        std::vector<Coordinate> permuted_coordinates, keys;
                        for (Index i = 0; i < number_of_numbers; ++i)
                            if (key_mask[i])
                                keys.push_back(coordinates[i]);
                            else
                                permuted_coordinates.push_back(coordinates[i]);
                        do {
                            Instance instance{
                                .key_locks = std::vector<KeyLock>(n),
                                .start = permuted_coordinates[0],
                                .finish = permuted_coordinates[1],
                            };
                            for (Index i = 0; i < n; ++i)
                                instance.key_locks[i] = {keys[i], permuted_coordinates[i + 2]};
                            shuffle(instance.key_locks.begin(), instance.key_locks.end());
                            if (proceed(instance))
                                return ans;
                        } while (std::next_permutation(permuted_coordinates.begin(), permuted_coordinates.end()));
                    } while (std::next_permutation(key_mask.begin(), key_mask.end()));
                } else {
                    std::vector<Coordinate> permuted_coordinates = coordinates;
                    do {
                        Instance instance{
                            .key_locks = std::vector<KeyLock>(n),
                            .start = permuted_coordinates[0],
                            .finish = permuted_coordinates[1],
                        };
                        for (Index i = 0; i < n; ++i)
                            instance.key_locks[i] = {permuted_coordinates[2 * i + 2], permuted_coordinates[2 * i + 3]};
                        if (proceed(instance))
                            return ans;
                    } while (std::next_permutation(permuted_coordinates.begin(), permuted_coordinates.end()));
                }
            }
        }
    }
}

template<class Int>
std::vector<Int> separate_into_positive_sum(Int sum, Int addends) {
    ensuref(addends >= 0, "%s", std::format("cannot decompose {} into {} addends: the number of addends is negative", sum, addends).c_str());
    ensuref(sum >= addends, "%s", std::format("cannot decompose {} into {} addends: the sum should be at least the number of addends", sum, addends).c_str());
    ensuref(sum == 0 || addends > 0, "%s", std::format("cannot decompose {} into {} addends: the sum of zero addends cannot be positive", sum, addends).c_str());
    if (sum == 0)
        return {};
    std::vector<bool> separation_table(sum - 1);
    std::fill(separation_table.begin(), separation_table.begin() + addends - 1, true);
    shuffle(separation_table.begin(), separation_table.end());
    std::vector<Int> ans(1, 1);
    for (bool b : separation_table) {
        if (b)
            ans.push_back(1);
        else
            ++ans.back();
    }
    return ans;
}

void fix_feasibility(Instance& instance) {
    const size_t N = instance.key_locks.size() * 2 + 2;
    std::vector<std::pair<Coordinate, SignedIndex>> interesting_points(N);
    interesting_points[0] = { instance.start, static_cast<SignedIndex>(instance.key_locks.size()) };
    interesting_points[1] = { instance.finish, -1 - static_cast<SignedIndex>(instance.key_locks.size()) };
    for (SignedIndex j = 0; j < instance.key_locks.size(); ++j) {
        interesting_points[2 * j + 2] = { instance.key_locks[j].key, j };
        interesting_points[2 * j + 3] = { instance.key_locks[j].lock, -1 - j };
    }
    std::sort(interesting_points.begin(), interesting_points.end());
    std::pair<Index, Index> segment;
    for (Index j = 0; j < N; ++j)
        if (interesting_points[j].first == instance.start)
            segment = {j, j};
    while (true) {
        // try to go left
        bool locked_left = false;
        if (segment.first > 0) {
            Index new_point = segment.first - 1;
            if (interesting_points[new_point].first == instance.finish)
                break;
            if (interesting_points[new_point].second >= 0) {
                --segment.first;
                continue;
            }
            Index lock_index = -1 - interesting_points[new_point].second;
            Coordinate where_key = instance.key_locks[lock_index].key;
            if (where_key >= interesting_points[segment.first].first && where_key <= interesting_points[segment.second].first) {
                --segment.first;
                continue;
            }
            locked_left = true;
        }
        // try to go right
        bool locked_right = false;
        if (segment.second < N - 1) {
            Index new_point = segment.second + 1;
            if (interesting_points[new_point].first == instance.finish)
                break;
            if (interesting_points[new_point].second >= 0) {
                ++segment.second;
                continue;
            }
            Index lock_index = -1 - interesting_points[new_point].second;
            Coordinate where_key = instance.key_locks[lock_index].key;
            if (where_key >= interesting_points[segment.first].first && where_key <= interesting_points[segment.second].first) {
                ++segment.second;
                continue;
            }
            locked_right = true;
        }
        // ok, we failed. We have to unlock one of the sides
        bool we_will_unlock_right = locked_right;
        if (locked_left && locked_right)
            we_will_unlock_right = rnd.next(0, 1);
        Index lock_index;
        if (we_will_unlock_right)
            lock_index = -1 - interesting_points[++segment.second].second;
        else
            lock_index = -1 - interesting_points[--segment.first].second;
        std::swap(instance.key_locks[lock_index].key, instance.key_locks[lock_index].lock);
    }
}

MultiInstance generate_random_longest(const Index t, const Index sum_n, const Coordinate inaccuracy, Index random_units) {
    MultiInstance ans;
    ans.instances.resize(t);
    std::vector<Index> sizes = separate_into_positive_sum(sum_n, t);
    for (Index i = 0; i < t; ++i) {
        ans.instances[i].key_locks.resize(sizes[i]);
        Index edge_size = sizes[i] - random_units;
        const Index N_edge = 2 * edge_size + 2;
        const Coordinate LOW = 0, HIGH = N_edge - 1 + inaccuracy;
        std::vector<Coordinate> edge_numbers = sample_different_numbers<Coordinate>(LOW, HIGH, N_edge, true);
        const Index SMALLER = edge_size + 1 | 1;
        for (Index j = 0; j < SMALLER; ++j)
            edge_numbers[j] += MIN_X - LOW;
        for (Index j = SMALLER; j < N_edge; ++j)
            edge_numbers[j] += MAX_X - HIGH;
        ans.instances[i].finish = edge_numbers[0];
        for (Index j = 0; j < edge_size; ++j) {
            ans.instances[i].key_locks[j].key = edge_numbers[N_edge - 1 - j];
            ans.instances[i].key_locks[j].lock = edge_numbers[j + 1];
            if (j % 2)
                std::swap(ans.instances[i].key_locks[j].key, ans.instances[i].key_locks[j].lock);
        }
        Coordinate INNER_LOW = edge_numbers[SMALLER - 1] + 1;
        Coordinate INNER_HIGH = edge_numbers[SMALLER] - 1;
        std::vector<Coordinate> inner_numbers = sample_different_numbers<Coordinate>(INNER_LOW, INNER_HIGH, 2 * random_units, false);
        inner_numbers.push_back(edge_numbers[edge_size + 1]);
        shuffle(inner_numbers.begin(), inner_numbers.end());
        ans.instances[i].start = inner_numbers[0];
        for (Index j = 0; j < random_units; ++j) {
            ans.instances[i].key_locks[j + edge_size].key = inner_numbers[2 * j + 1];
            ans.instances[i].key_locks[j + edge_size].lock = inner_numbers[2 * j + 2];
        }
        fix_feasibility(ans.instances[i]);
        shuffle(ans.instances[i].key_locks.begin(), ans.instances[i].key_locks.end());
        if (rnd.next(0, 1)) {
            ans.instances[i].start = MIN_X + MAX_X - ans.instances[i].start;
            ans.instances[i].finish = MIN_X + MAX_X - ans.instances[i].finish;
            for (Index j = 0; j < sizes[i]; ++j) {
                ans.instances[i].key_locks[j].key = MIN_X + MAX_X - ans.instances[i].key_locks[j].key;
                ans.instances[i].key_locks[j].lock = MIN_X + MAX_X - ans.instances[i].key_locks[j].lock;
            }
        }
    }
    return ans;
}

MultiInstance generate_random(const Index t, const Index sum_n, double feasible) {
    MultiInstance ans;
    ans.instances.resize(t);
    std::vector<Index> sizes = separate_into_positive_sum(sum_n, t);
    for (Index i = 0; i < t; ++i) {
        ans.instances[i].key_locks.resize(sizes[i]);
        std::vector<Coordinate> numbers = sample_different_numbers(MIN_X, MAX_X, sizes[i] * 2 + 2, false);
        ans.instances[i].start = numbers[0];
        ans.instances[i].finish = numbers[1];
        for (Index j = 0; j < sizes[i]; ++j) {
            ans.instances[i].key_locks[j] = {
                .key = numbers[2 * j + 2],
                .lock = numbers[2 * j + 3],
            };
        }
        if (rnd.next() < feasible)
            fix_feasibility(ans.instances[i]);
    }
    return ans;
}

MultiInstance generate_min_by_permutation(bool disperse) {
    MultiInstance ans;
    Index total_n_in_ans = 0;
    uint64_t filled = 0;
    const Index test_index = opt<Index>("test_index");
    auto proceed = [&](const Instance& instance) -> bool {
        total_n_in_ans += instance.key_locks.size();
        if (total_n_in_ans <= MAX_N && ans.instances.size() < MAX_T) {
            ans.instances.push_back(instance);
            return false;
        }
        if (filled == test_index)
            return true;
        ans.instances = {instance};
        total_n_in_ans = instance.key_locks.size();
        ++filled;
        return false;
    };
    for (Index n = MIN_N; n <= MAX_N; ++n) {
        Index number_of_numbers = 2 * n + 2;
        std::vector<Index> coordinate_indices(number_of_numbers);
        for (Index i = 0; i < number_of_numbers; ++i)
            coordinate_indices[i] = i; // first n numbers must be increasing
        std::vector<Coordinate> increasing_coordinates(number_of_numbers);
        const Coordinate START_X = std::clamp(-static_cast<Coordinate>(number_of_numbers) / 2, MIN_X, MAX_X - static_cast<Coordinate>(number_of_numbers) + 1);
        for (Index i = 0; i < number_of_numbers; ++i)
            increasing_coordinates[i] = static_cast<Coordinate>(i) + START_X; // first n numbers must be increasing
        do {
            Instance instance;
            instance.key_locks.resize(n);
            if (disperse) {
                increasing_coordinates = sample_different_numbers(MIN_X, MAX_X, number_of_numbers, true);
            }
            for (Index i = 0; i < n; ++i)
                instance.key_locks[i] = {
                    .key = increasing_coordinates[coordinate_indices[i]],
                    .lock = increasing_coordinates[coordinate_indices[i + n]],
                };
            instance.start = increasing_coordinates[coordinate_indices[2 * n]];
            instance.finish = increasing_coordinates[coordinate_indices[2 * n + 1]];
            shuffle(instance.key_locks.begin(), instance.key_locks.end());
            if (proceed(instance))
                return ans;

            bool found_new_test = false;
            do {
                if (!std::next_permutation(coordinate_indices.begin(), coordinate_indices.end()))
                    break;
                bool good = true;
                for (Index i = 1; i < n; ++i)
                    if (coordinate_indices[i] <= coordinate_indices[i - 1]) {
                        good = false;
                        break;
                    }
                if (good) {
                    found_new_test = true;
                    break;
                }
            } while (true);
            if (!found_new_test)
                break;
        } while (true);
    }
    return ans;
}

int main(int argc, char* argv[]) {
    registerGen(argc, argv, 1);
    const std::string type = opt<std::string>("type");
    MIN_T = opt<Index>("min_t");
    MAX_T = opt<Index>("max_t");
    MIN_N = opt<Index>("min_n");
    MAX_N = opt<Index>("max_n");
    MIN_X = opt<Coordinate>("min_x");
    MAX_X = opt<Coordinate>("max_x");
    MultiInstance multi_instance;
    if (type == "min_by_sum")
        multi_instance = generate_min_by_sum(false);
    else if (type == "min_by_permutation") {
        const bool disperse = opt<bool>("disperse");
        multi_instance = generate_min_by_permutation(disperse);
    } else if (type == "min_by_sum_ignoring_reordering")
        multi_instance = generate_min_by_sum(true);
    else if (type == "random_max_test") {
        const double feasible = opt<double>("feasible");
        multi_instance = generate_random(std::min(MAX_T, MAX_N), MAX_N, feasible);
    } else if (type == "random_longest_test") {
        const Coordinate inaccuracy = opt<Coordinate>("inaccuracy");
        const Coordinate random_units = opt<Index>("random_units");
        multi_instance = generate_random_longest(std::min(MAX_T, MAX_N), MAX_N, inaccuracy, random_units);
    }
    else
        ensuref(false, "%s", std::format("Unknown test type: \"{}\"", type).c_str());
    std::cout << multi_instance;
}
