import dataclasses
import logging
import random
import uuid
from abc import abstractmethod
from collections import deque
from datetime import datetime, timedelta
from typing import Callable, Tuple, Any, Deque

from overcooked_simulator.game_items import Item, Plate, ItemInfo

log = logging.getLogger(__name__)

ORDER_CATEGORY = "Order"


@dataclasses.dataclass
class Order:
    meal: ItemInfo
    start_time: datetime
    max_duration: timedelta
    score_calc: Callable[[timedelta, ...], float]
    timed_penalties: list[
        Tuple[timedelta, float] | Tuple[timedelta, float, int, timedelta]
    ]
    expired_penalty: float
    uuid: str = dataclasses.field(default_factory=lambda: uuid.uuid4().hex)

    finished_info: dict[str, Any] = dataclasses.field(default_factory=dict)
    _timed_penalties: list[Tuple[datetime, float]] = dataclasses.field(
        default_factory=list
    )

    def order_time(self, env_time: datetime) -> timedelta:
        return self.start_time - env_time

    def create_penalties(self, env_time: datetime):
        for penalty_info in self.timed_penalties:
            match penalty_info:
                case (offset, penalty):
                    self._timed_penalties.append((env_time + offset, penalty))
                case (duration, penalty, number_repeat, offset):
                    self._timed_penalties.extend(
                        [
                            (env_time + offset + (duration * i), penalty)
                            for i in range(number_repeat)
                        ]
                    )


class OrderGeneration:
    def __init__(self, available_meals: dict[str, ItemInfo], **kwargs):
        self.available_meals: list[ItemInfo] = list(available_meals.values())

    @abstractmethod
    def init_orders(self, now) -> list[Order]:
        ...

    @abstractmethod
    def get_orders(
        self,
        passed_time: timedelta,
        now: datetime,
        new_finished_orders: list[Order],
        expired_orders: list[Order],
    ) -> list[Order]:
        ...


def zero(item: ItemInfo, **kwargs) -> float:
    return 0.0


@dataclasses.dataclass
class RandomOrderKwarg:
    num_start_meals: int
    sample_on_serving: bool
    sample_on_dur: bool
    sample_on_dur_func: dict
    max_orders: int
    duration_sample: dict
    score_calc_gen_func: Callable[
        [ItemInfo, timedelta, datetime, Any], Callable[[timedelta, Order], float]
    ]
    score_calc_gen_kwargs: dict
    expired_penalty_func: Callable[[ItemInfo], float] = zero
    expired_penalty_kwargs: dict = dataclasses.field(default_factory=dict)


class RandomOrderGeneration(OrderGeneration):
    def __init__(self, available_meals: dict[str, ItemInfo], **kwargs):
        super().__init__(available_meals, **kwargs)
        self.kwargs: RandomOrderKwarg = RandomOrderKwarg(**kwargs["kwargs"])
        self.next_order_time: datetime | None = datetime.max
        self.number_cur_orders = 0
        self.needed_orders: int = 0
        """For the sample on dur but when it was restricted due to max order number."""

    def init_orders(self, now) -> list[Order]:
        self.number_cur_orders = self.kwargs.num_start_meals
        if self.kwargs.sample_on_dur:
            self.create_random_next_time_delta(now)
        return self.create_orders_for_meals(
            random.choices(self.available_meals, k=self.kwargs.num_start_meals),
            now,
            self.kwargs.sample_on_serving,
        )

    def get_orders(
        self,
        passed_time: timedelta,
        now: datetime,
        new_finished_orders: list[Order],
        expired_orders: list[Order],
    ) -> list[Order]:
        self.number_cur_orders -= len(new_finished_orders)
        self.number_cur_orders -= len(expired_orders)
        if self.kwargs.sample_on_serving:
            if new_finished_orders:
                self.create_random_next_time_delta(now)
                return []
        if self.needed_orders:
            self.needed_orders -= len(new_finished_orders)
            self.needed_orders = max(self.needed_orders, 0)
            self.number_cur_orders += len(new_finished_orders)
            return self.create_orders_for_meals(
                random.choices(self.available_meals, k=len(new_finished_orders)),
                now,
            )
        if self.next_order_time <= now:
            if self.number_cur_orders >= self.kwargs.max_orders:
                self.needed_orders += 1
            else:
                if self.kwargs.sample_on_dur:
                    self.create_random_next_time_delta(now)
                else:
                    self.next_order_time = datetime.max
                self.number_cur_orders += 1
                return self.create_orders_for_meals(
                    [random.choice(self.available_meals)],
                    now,
                )
        return []

    def create_orders_for_meals(
        self, meals: list[ItemInfo], now: datetime, no_time_limit: bool = False
    ) -> list[Order]:
        orders = []
        for meal in meals:
            if no_time_limit:
                duration = datetime.max - now
            else:
                duration = timedelta(
                    seconds=getattr(random, self.kwargs.duration_sample["func"])(
                        **self.kwargs.duration_sample["kwargs"]
                    )
                )
            log.info(f"Create order for meal {meal} with duration {duration}")
            orders.append(
                Order(
                    meal=meal,
                    start_time=now,
                    max_duration=duration,
                    score_calc=self.kwargs.score_calc_gen_func(
                        meal=meal,
                        duration=duration,
                        now=now,
                        kwargs=self.kwargs.score_calc_gen_kwargs,
                    ),
                    timed_penalties=[],
                    expired_penalty=self.kwargs.expired_penalty_func(
                        meal, **self.kwargs.expired_penalty_kwargs
                    ),
                )
            )

        return orders

    def create_random_next_time_delta(self, now: datetime):
        self.next_order_time = now + timedelta(
            seconds=getattr(random, self.kwargs.sample_on_dur_func["func"])(
                **self.kwargs.sample_on_dur_func["kwargs"]
            )
        )
        log.info(f"Next order in {self.next_order_time}")


def simple_score_calc_gen_func(
    meal: Item, duration: timedelta, now: datetime, kwargs: dict, **other_kwargs
) -> Callable:
    scores = kwargs["scores"]
    other = kwargs["other"]

    def score_calc(relative_order_time: timedelta, order: Order) -> float:
        if order.meal.name in scores:
            return scores[order.meal.name]
        return other

    return score_calc


def simple_expired_penalty(item: ItemInfo, default: float, **kwargs) -> float:
    return default


class OrderAndScoreManager:
    def __init__(self, order_config, available_meals: dict[str, ItemInfo]):
        self.score = 0
        self.order_gen: OrderGeneration = order_config["order_gen_class"](
            available_meals=available_meals, kwargs=order_config["kwargs"]
        )
        self.kwargs_for_func = order_config["kwargs"]
        self.serving_not_ordered_meals = order_config["serving_not_ordered_meals"]
        self.available_meals = available_meals
        self.open_orders: Deque[Order] = deque()

        # for logs or history in the future
        # TODO log who / which player served which meal -> for split scores
        self.served_meals: list[Tuple[Item, datetime]] = []
        self.last_finished = []
        self.next_relevant_time = datetime.max
        self.last_expired = []

    def update_next_relevant_time(self):
        next_relevant_time = datetime.max
        for order in self.open_orders:
            next_relevant_time = min(
                next_relevant_time, order.start_time + order.max_duration
            )
            for penalty in order._timed_penalties:
                next_relevant_time = min(next_relevant_time, penalty[0])
        self.next_relevant_time = next_relevant_time

    def serve_meal(self, item: Item, env_time: datetime) -> bool:
        if isinstance(item, Plate):
            meal = item.get_potential_meal()
            if meal is not None:
                if meal.name in self.available_meals:
                    order = self.find_order_for_meal(meal)
                    if order is None:
                        if self.serving_not_ordered_meals:
                            accept, score = self.serving_not_ordered_meals(meal)
                            if accept:
                                log.info(
                                    f"Serving meal without order {meal.name} with score {score}"
                                )
                                self.score += score
                                self.served_meals.append((meal, env_time))
                            return accept
                        log.info(
                            f"Do not serve meal {meal.name} because it is not ordered"
                        )
                        return False
                    order, index = order
                    score = order.score_calc(
                        relative_order_time=env_time - order.start_time,
                        order=order,
                    )
                    self.score += score
                    order.finished_info = {
                        "end_time": env_time,
                        "score": score,
                    }
                    log.info(f"Serving meal {meal.name} with order with score {score}")
                    self.last_finished.append(order)
                    del self.open_orders[index]
                    self.served_meals.append((meal, env_time))
                    return True
        log.info(f"Do not serve item {item}")
        return False

    def increment_score(self, score: int):
        self.score += score
        log.debug(f"Score: {self.score}")

    def create_init_orders(self, env_time):
        init_orders = self.order_gen.init_orders(env_time)
        self.open_orders.extend(init_orders)

    def progress(self, passed_time: timedelta, now: datetime):
        new_orders = self.order_gen.get_orders(
            passed_time=passed_time,
            now=now,
            new_finished_orders=self.last_finished,
            expired_orders=self.last_expired,
        )
        self.open_orders.extend(new_orders)
        self.last_finished = []
        self.last_expired = []
        if new_orders or self.next_relevant_time <= now:
            remove_orders = []
            for index, order in enumerate(self.open_orders):
                if now >= order.start_time + order.max_duration:
                    self.score += order.expired_penalty
                    remove_orders.append(index)
                remove_penalties = []
                for i, (penalty_time, penalty) in enumerate(order.timed_penalties):
                    if penalty_time < now:
                        self.score -= penalty
                        remove_penalties.append(i)

                for i in reversed(remove_penalties):
                    # or del order.timed_penalties[index]
                    order.timed_penalties.pop(i)
            expired_orders = []
            for remove_order in reversed(remove_orders):
                expired_orders.append(self.open_orders[remove_order])
                del self.open_orders[remove_order]
            self.last_expired = expired_orders

            self.update_next_relevant_time()

    def find_order_for_meal(self, meal) -> Tuple[Order, int] | None:
        for index, order in enumerate(self.open_orders):
            if order.meal.name == meal.name:
                return order, index

    def setup_penalties(self, new_orders: list[Order], env_time: datetime):
        for order in new_orders:
            order.create_penalties(env_time)

    def order_state(self) -> list[dict]:
        return [
            {
                "id": order.uuid,
                "category": ORDER_CATEGORY,
                "meal": order.meal.name,
                "start_time": order.start_time.isoformat(),
                "max_duration": order.max_duration.total_seconds(),
            }
            for order in self.open_orders
        ]


if __name__ == "__main__":
    import yaml

    order_config = yaml.safe_load(
        """orders:
  kwargs:
    duration_sample:
      func: uniform
      kwargs:
        a: 30
        b: 50
    max_orders: 5
    num_start_meals: 3
    sample_on_dur: false
    sample_on_dur_func:
      func: uniform
      kwargs:
        a: 30
        b: 50
    sample_on_serving: true
    score_calc_gen_func: null
    score_calc_gen_kwargs:
      other: 0
      scores:
        Burger: 15
        OnionSoup: 10
        Salad: 5
        TomatoSoup: 10
    score_calc_gen_func: ~''
  order_gen_class: ~
  serving_not_ordered_meals: null"""
    )
    order_config["orders"]["order_gen_class"] = RandomOrderGeneration
    order_config["orders"]["kwargs"]["score_calc_gen_func"] = simple_score_calc_gen_func
    print(yaml.dump(order_config))