import pandas as pd
import json
import re
import logging
from unidecode import unidecode
from langchain_community.vectorstores import FAISS
from groq import Groq
import os
from dotenv import load_dotenv
from pathlib import Path
from modules.processing.embeddings import get_hf_embeddings

load_dotenv()
EMBEDDINGS_MODEL = os.getenv("EMBEDDINGS_MODEL", "intfloat/e5-small-v2")
GROQ_MODEL_CIE10 = os.getenv("GROQ_MODEL_CIE10", "openai/gpt-oss-120b")

class CIE10Retriever:
    def __init__(
        self,
        faiss_index_path: str,
        estructura_json_path: str,
        groq_api_key: str = None,
        groq_model: str = GROQ_MODEL_CIE10  # "llama-3.3-70b-versatile"
    ):
        # Inicialización de embeddings e índice FAISS
        self.embeddings = get_hf_embeddings(EMBEDDINGS_MODEL)
        self.db = None
        self._index_load_error = None
        self.index_path = faiss_index_path
        self.db = self._cargar_indice(faiss_index_path)
        # Carga de la estructura jerárquica CIE-10
        self.estructura = self._cargar_estructura(estructura_json_path)
        self.normalizador = self._crear_normalizador()
        # Cliente Groq para generación de respuestas
        api_key = groq_api_key or os.getenv("GROQ_API_KEY")
        if not api_key:
            raise RuntimeError("Se requiere la variable de entorno GROQ_API_KEY o pasar groq_api_key.")
        self.groq_client = Groq(api_key=api_key)
        self.groq_model = groq_model

    def _cargar_indice(self, path: str):
        try:
            index_dir = Path(path)
            if not index_dir.exists():
                raise FileNotFoundError(f"Directorio no encontrado: {index_dir}")
            if not (index_dir / 'index.faiss').exists():
                raise FileNotFoundError(f"Archivo index.faiss no encontrado en: {index_dir}")
            return FAISS.load_local(str(index_dir), self.embeddings, allow_dangerous_deserialization=True)
        except Exception as e:
            self._index_load_error = e
            logging.error('Error cargando indice FAISS CIE-10: %s', str(e))
            return None

    def _require_db(self):
        if self.db is None:
            detalle = f"{self._index_load_error}" if self._index_load_error else 'desconocido'
            raise RuntimeError(
                'Indice FAISS CIE-10 no disponible. '
                f'Ruta: {self.index_path}. '
                f'Detalle: {detalle}. '
                'Genera el indice con: python modules/processing/cie10/Carga_estructura.py'
            )

    def _cargar_estructura(self, path: str) -> dict:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)

    def _crear_normalizador(self) -> dict:
        return {
            'sinonimos': {
                r"\b(trazo\s+sugestivo\b|fx|fract)": "fractura",
                r"\b(dcha|der)\b": "derecho",
                r"\b(izq|izqda)\b": "izquierdo",
                r"\b(tac)\b": "tomografía axial computarizada",
                r"\b(5to|quinto)\b": "5"
            },
            'reemplazos': {
                "QUEMADURA POR FRICCION": "QUEMADURA POR FRICCIÓN",
                "TRAZO SUGESTIVO DE FRACTURA": "FRACTURA"
            }
        }

    def normalizar_consulta(self, texto: str) -> str:
        texto = texto.upper()
        for patron, reemplazo in self.normalizador['reemplazos'].items():
            texto = texto.replace(patron, reemplazo)
        for patron, reemplazo in self.normalizador['sinonimos'].items():
            texto = re.sub(patron, reemplazo, texto, flags=re.IGNORECASE)
        return unidecode(texto).strip()

    def buscar(self, consulta: str, k: int = 50, lambda_param: float = 0.8, nivel_minimo: int = 3) -> dict:
        self._require_db()
        consulta_norm = self.normalizar_consulta(consulta)
        resultados = self.db.max_marginal_relevance_search(
            consulta_norm,
            k=k,
            lambda_param=lambda_param,
            filter={"nivel": {"$gte": nivel_minimo}}
        )
        return self._procesar_resultados(resultados)

    def _procesar_resultados(self, resultados) -> dict:
        jerarquia = {'bloques': [], 'categorias': [], 'subcategorias': [], 'detalles': []}
        for doc in resultados:
            metadata = doc.metadata
            nivel = metadata['nivel']
            entrada = {
                'codigo': metadata['codigo'],
                'descripcion': doc.page_content.split(":")[-1].strip(),
                'nivel': nivel,
                'ruta': metadata.get('ruta', ''),
                'score': self._calcular_score(nivel)
            }
            if nivel == 1:
                jerarquia['bloques'].append(entrada)
            elif nivel == 2:
                jerarquia['categorias'].append(entrada)
            else:
                jerarquia['subcategorias'].append(entrada)
            jerarquia['detalles'].append(entrada)
        return jerarquia

    def _calcular_score(self, nivel: int) -> float:
        return {1: 0.3, 2: 0.6, 3: 1.0}.get(nivel, 0)

    def obtener_ruta_completa(self, codigo: str) -> list:
        for bloque in self.estructura['bloques']:
            if codigo == bloque['codigo']:
                return [bloque]
            for categoria in bloque['categorias']:
                if codigo == categoria['codigo']:
                    return [bloque, categoria]
                for subcat in categoria['subcategorias']:
                    if codigo == subcat['codigo']:
                        return [bloque, categoria, subcat]
        return []

    def validar_codigo(self, codigo: str) -> bool:
        return bool(self.obtener_ruta_completa(codigo))

    def mejores_resultados(self, resultados: dict, top_n: int = 3) -> list:
        return sorted(
            resultados['detalles'],
            key=lambda x: x['score'],
            reverse=True
        )[:top_n]

    # ---------------- Nuevas funciones integradas ----------------
    def generar_respuesta_groq(
        self,
        consulta: str,
        contexto: str,
        temperature: float = 0.0
    ) -> str:
        # prompt = (
        #     "Eres un asistente médico experto en codificación CIE-10. "
        #     "Tu tarea es asignar el código CIE-10 correcto basándote en la consulta del usuario y en el contexto proporcionado.\n\n"
        #     "El contexto consiste en fragmentos de un corpus que relaciona códigos y descripciones médicas.\n"
        #     "Responde únicamente en el siguiente formato:\n\n"
        #     "Código: <código> - Descripción: <descripción>\n\n"
        #     "Si no puedes determinar un código con certeza, responde exactamente: 'No se pudo determinar el código CIE-10.'\n\n"
        #     f"Consulta: {consulta}\n\n"
        #     f"Contexto:\n{contexto}"
        # )
        prompt = (
            "Eres un asistente médico experto en codificación CIE‑10. "
            "Tu tarea es asignar los códigos CIE‑10 correctos basándote en la consulta del usuario y en el contexto proporcionado.\n\n"
            "Si la consulta contiene múltiples condiciones separadas por '+', debes responder **un código y descripción por cada condición**, en líneas separadas.\n\n"
            "El contexto consiste en fragmentos de un corpus que relaciona códigos y descripciones médicas.\n"
            "Responde únicamente en el siguiente formato, una línea por código:\n\n"
            "Código: <código> - Descripción: <descripción>\n\n"
            "Si alguna condición no se puede codificar con certeza, responde exactamente: "
            "'No se pudo determinar el código CIE‑10 para: <texto de la condición>'.\n\n"
            f"Consulta: {consulta}\n\n"
            f"Contexto:\n{contexto}"
        )
        response = self.groq_client.chat.completions.create(
            model=self.groq_model,
            temperature=temperature,
            top_p=0.1,
            messages=[
                {"role": "system", "content": "Eres un asistente médico experto."},
                {"role": "user", "content": prompt}
            ]
        )
        return response.choices[0].message.content.strip()

    def asignar_codigo_cie10(
        self,
        consulta: str,
        temperature: float = 0.0,
        k: int = 50,
        lambda_param: float = 0.8
    ) -> str:
        resultados = self.buscar(consulta, k=k, lambda_param=lambda_param)
        contexto = "".join(
            f"Código: {d['codigo']} - Descripción: {d['descripcion']}\n"
            for d in resultados['detalles']
        )
        return self.generar_respuesta_groq(consulta, contexto, temperature)


    def asignar_codigos_batch(
        self,
        texto_html: str,
        temperature: float = 0.0
    ) -> pd.DataFrame:
        """
        Extrae cada diagnóstico de los <li> en la sección Diagnósticos
        y les asigna un código CIE-10 usando RAG.
        """
        # 1. Extraer todos los <li> ... </li>
        items = re.findall(r"<b>Diagnósticos</b>.*?<ol>(.*?)</ol>", texto_html, flags=re.DOTALL)
        lista_html = items[0] if items else texto_html  # si no encuentra sección, procesa todo

        diagnosticos = re.findall(r"<li>(.*?)</li>", lista_html, flags=re.DOTALL)
        diagnosticos = [d.strip() for d in diagnosticos if d.strip()]

        # 2. Para cada diagnóstico, asignar código
        resultados = []
        for diag in diagnosticos:
            codigo = self.asignar_codigo_cie10(diag, temperature)
            resultados.append({
                'diagnostico': diag,
                'codigo_asignado': codigo
            })

        return pd.DataFrame(resultados)

    # def asignar_codigos_batch(
    #     self,
    #     texto_html: str,
    #     temperature: float = 0.0
    # ) -> pd.DataFrame:
    #     # Extraer diagnósticos numerados de un HTML simple
    #     items = re.findall(r"\d+\.\s*([^<]+)", texto_html)
    #     resultados = []
    #     for diag in items:
    #         codigo = self.asignar_codigo_cie10(diag.strip(), temperature)
    #         resultados.append({'diagnostico': diag.strip(), 'codigo_asignado': codigo})
    #     return pd.DataFrame(resultados)

# Ejemplo de uso si se ejecuta como script

if __name__ == "__main__":
    retriever = CIE10Retriever(
        faiss_index_path="faiss_principal_jerarquico",
        estructura_json_path="cie10_estructura_completa.json"
    )
    html_input = (
        "<p>1.Trauma en HOMBRO.<br>"
        "2. Trauma en pie izquierdo + herida.<br>"
        "3. Fractura en falange proximal de 1er dedo pie izquierdo.<br>"
        "4. Contusión del tobillo (diagnóstico presuntivo inicial).<br>"
        "5. Contusión de otras partes y de las no especificadas del pie (diagnóstico presuntivo inicial)."  
        "</p>"
    )
    #Trauma en hombro izquierdo
    #Trauma en tobillo izquierdo
    df_result = retriever.asignar_codigos_batch(html_input)
    print(df_result)
