# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Base classes for generating questions based on previously asked questions and most recent context data."""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from graphrag_llm.tokenizer import Tokenizer

from graphrag.query.context_builder.builders import (
    GlobalContextBuilder,
    LocalContextBuilder,
)

if TYPE_CHECKING:
    from graphrag_llm.completion import LLMCompletion


@dataclass
class QuestionResult:
    """A Structured Question Result."""

    response: list[str]
    context_data: str | dict[str, Any]
    completion_time: float
    llm_calls: int
    prompt_tokens: int


class BaseQuestionGen(ABC):
    """The Base Question Gen implementation."""

    def __init__(
        self,
        model: "LLMCompletion",
        context_builder: GlobalContextBuilder | LocalContextBuilder,
        tokenizer: Tokenizer | None = None,
        model_params: dict[str, Any] | None = None,
        context_builder_params: dict[str, Any] | None = None,
    ):
        self.model = model
        self.context_builder = context_builder
        self.tokenizer = tokenizer or model.tokenizer
        self.model_params = model_params or {}
        self.context_builder_params = context_builder_params or {}

    @abstractmethod
    async def generate(
        self,
        question_history: list[str],
        context_data: str | None,
        question_count: int,
        **kwargs,
    ) -> QuestionResult:
        """Generate questions."""

    @abstractmethod
    async def agenerate(
        self,
        question_history: list[str],
        context_data: str | None,
        question_count: int,
        **kwargs,
    ) -> QuestionResult:
        """Generate questions asynchronously."""
