# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Parameterization settings for the default configuration."""

import logging
from dataclasses import dataclass
from pathlib import Path

from pydantic import BaseModel, Field

logger = logging.getLogger(__name__)

from graphrag.config.defaults import graphrag_config_defaults
from graphrag.prompts.index.extract_graph import GRAPH_EXTRACTION_PROMPT


@dataclass
class ExtractGraphPrompts:
    """Graph extraction prompt templates."""

    extraction_prompt: str


class ExtractGraphConfig(BaseModel):
    """Configuration section for entity extraction."""

    completion_model_id: str = Field(
        description="The model ID to use for text embeddings.",
        default=graphrag_config_defaults.extract_graph.completion_model_id,
    )
    model_instance_name: str = Field(
        description="The model singleton instance name. This primarily affects the cache storage partitioning.",
        default=graphrag_config_defaults.extract_graph.model_instance_name,
    )
    prompt: str | None = Field(
        description="The entity extraction prompt to use.",
        default=graphrag_config_defaults.extract_graph.prompt,
    )
    entity_types: list[str] = Field(
        description="The entity extraction entity types to use.",
        default=graphrag_config_defaults.extract_graph.entity_types,
    )
    max_gleanings: int = Field(
        description="The maximum number of entity gleanings to use.",
        default=graphrag_config_defaults.extract_graph.max_gleanings,
    )

    def resolved_prompts(self) -> ExtractGraphPrompts:
        """Get the resolved graph extraction prompts."""
        if self.prompt:
            extraction_prompt = Path(self.prompt).read_text(encoding="utf-8")
            logger.info("Using custom extraction prompt from %s", self.prompt)
            print(f"[extract_graph] Using custom extraction prompt from {self.prompt}")
        else:
            extraction_prompt = GRAPH_EXTRACTION_PROMPT
            logger.info("Using default extraction prompt")
            print("[extract_graph] Using default extraction prompt")

        return ExtractGraphPrompts(
            extraction_prompt=extraction_prompt,
        )
