import os, time, datetime, jwt, asyncio, json
import httpx
from fastapi import FastAPI, HTTPException, Depends, Form, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import StreamingResponse, Response, JSONResponse
from pydantic import BaseModel
from functools import lru_cache
from dotenv import load_dotenv
from fastapi.middleware.cors import CORSMiddleware

load_dotenv()

# ========= Config =========
QDRANT_URL   = os.getenv("QDRANT_URL")
OLLAMA_URL   = os.getenv("OLLAMA_URL")
EMBED_MODEL  = os.getenv("EMBED_MODEL", "nomic-embed-text")
VECTOR_SIZE  = int(os.getenv("VECTOR_SIZE", "768"))
PORT         = int(os.getenv("PORT", 8070))

LLM_MODEL    = os.getenv("LLM_MODEL", "qwen2.5:7b-instruct-q4_K_M")
MAX_TOKENS   = int(os.getenv("MAX_TOKENS", "150"))
TEMPERATURE  = float(os.getenv("TEMPERATURE", "0.2"))
TOP_P        = float(os.getenv("TOP_P", "0.9"))

SECRET_KEY   = os.getenv("SECRET_KEY", "cambia_este_secreto")
ADMIN_USER   = os.getenv("ADMIN_USER", "admin_user")
ADMIN_PASS   = os.getenv("ADMIN_PASS", "admin_password")
TOKEN_MIN    = int(os.getenv("TOKEN_MIN", "120"))
ALGORITHM    = "HS256"

app = FastAPI(title="Servidor IA Local Optimizado")

# --- CORS ---
# Para producción, pon el/los dominios exactos del front:
# ALLOW_ORIGINS = ["https://iaserver.socioturnos.com"]
ALLOW_ORIGINS = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=ALLOW_ORIGINS,
    allow_credentials=True,
    allow_methods=["*"],            # GET, POST, OPTIONS...
    allow_headers=["*"],            # Authorization, Content-Type...
    expose_headers=["*"],
    max_age=86400,
)

# --- Middleware: responder SIEMPRE cualquier OPTIONS con 204 ---
@app.middleware("http")
async def options_always_204(request: Request, call_next):
    if request.method == "OPTIONS":
        return Response(status_code=204)
    return await call_next(request)

bearer = HTTPBearer()

# ========= HTTP Clientes (asíncronos) =========
OLLAMA = httpx.AsyncClient(base_url=OLLAMA_URL, timeout=25,
                           limits=httpx.Limits(max_keepalive_connections=10, max_connections=20))
QDRANT = httpx.AsyncClient(base_url=QDRANT_URL, timeout=10,
                           limits=httpx.Limits(max_keepalive_connections=10, max_connections=20))

KNOWN_COLLECTIONS = set()

# ========= MODELOS =========
class ChatIn(BaseModel):
    question: str
    collection: str = "faqs_public"
    audience: str | None = None
    domain: str | None = None
    channel: str | None = None

class TokenIn(BaseModel):
    username: str
    password: str

# ========= UTILES =========
def create_token(data: dict, expire_minutes: int = TOKEN_MIN) -> str:
    payload = data.copy()
    payload["exp"] = datetime.datetime.utcnow() + datetime.timedelta(minutes=expire_minutes)
    return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)

def verify_token(creds: HTTPAuthorizationCredentials = Depends(bearer)) -> dict:
    token = creds.credentials
    try:
        return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
    except Exception as e:
        raise HTTPException(status_code=401, detail=f"Token inválido: {e}")

@lru_cache(maxsize=2048)
def _cache_key(text: str) -> str:
    return text

async def _embed(text: str):
    r = await OLLAMA.post("/api/embeddings", json={"model": EMBED_MODEL, "prompt": text}, timeout=12)
    if r.status_code != 200:
        raise HTTPException(500, f"Error embedding: {r.text}")
    return r.json().get("embedding")

async def ensure_collection(collection: str):
    if collection in KNOWN_COLLECTIONS:
        return
    info = await QDRANT.get(f"/collections/{collection}")
    if info.status_code == 404:
        cr = await QDRANT.put(
            f"/collections/{collection}",
            json={"vectors": {"size": VECTOR_SIZE, "distance": "Cosine"}}
        )
        if cr.status_code >= 300:
            raise HTTPException(500, f"Error creando colección: {cr.text}")
    elif info.status_code >= 300:
        raise HTTPException(500, f"Error consultando colección: {info.text}")
    KNOWN_COLLECTIONS.add(collection)

# ========= STREAM =========
async def stream_llm(prompt: str):
    async with OLLAMA.stream("POST", "/api/generate", json={
        "model": LLM_MODEL,
        "prompt": prompt,
        "stream": True,
        "options": {
            "num_predict": MAX_TOKENS,
            "temperature": TEMPERATURE,
            "top_p": TOP_P,
            "keep_alive": "60m"
        }
    }) as r:
        if r.status_code != 200:
            text = await r.aread()
            raise HTTPException(500, f"Error LLM: {text.decode()}")
        async for line in r.aiter_lines():
            if not line:
                continue
            try:
                chunk = json.loads(line).get("response", "")
            except Exception:
                chunk = ""
            if chunk:
                yield chunk

# ========= WARMUP =========
@app.on_event("startup")
async def on_startup():
    async def _warm():
        try:
            await _embed("warmup")
            await OLLAMA.post(
                "/api/generate",
                json={
                    "model": LLM_MODEL, "prompt": "ok", "stream": False,
                    "options": {"num_predict": 2, "keep_alive": "60m"}
                },
                timeout=15
            )
        except Exception:
            pass
    asyncio.create_task(_warm())

# ========= ENDPOINTS =========
@app.get("/")
async def root():
    return {"status": "OK", "model": LLM_MODEL}

# Declaramos /token con POST y OPTIONS (por si algún proxy mira los métodos permitidos)
@app.api_route("/token", methods=["POST", "OPTIONS"])
async def token(body: TokenIn | None = None):
    # Si es preflight (OPTIONS), devolvemos 204 inmediatamente
    if body is None:
        return Response(status_code=204)
    if body.username == ADMIN_USER and body.password == ADMIN_PASS:
        return {"token": create_token({"sub": body.username})}
    raise HTTPException(401, "Credenciales inválidas")

@app.post("/ingest_text")
async def ingest_text(
    collection: str = Form(...),
    audience: str   = Form(...),
    domain: str     = Form(...),
    channels: str   = Form(...),
    text: str       = Form(...),
    _user: dict     = Depends(verify_token)
):
    await ensure_collection(collection)
    emb = await _embed(text)
    pid = int(time.time() * 1000)
    await QDRANT.put(
        f"/collections/{collection}/points",
        json={"points": [{"id": pid, "vector": emb, "payload": {
            "text": text, "audience": audience, "domain": domain, "channels": channels.split(",")
        }}]}
    )
    return {"ok": True, "id": pid}

# ========= CHAT NORMAL =========
@app.post("/chat")
async def chat(body: ChatIn, _user: dict = Depends(verify_token)):
    await ensure_collection(body.collection)
    query_vec = await _embed(body.question)

    payload = {
        "vector": query_vec,
        "limit": 4,
        "with_payload": True,
        "params": {"hnsw_ef": 64, "exact": False}
    }

    res = await QDRANT.post(f"/collections/{body.collection}/points/search", json=payload)
    hits = res.json().get("result", [])
    ctx = "\n".join([h["payload"]["text"] for h in hits if "payload" in h])[:2000]

    prompt = (
        "Eres el asistente virtual de PIISA Industrial Park. Responde en español de forma completa y amable. "
        "No inventes datos; si no tienes la información, sugiere escribir a soporte@gspiisa.net o llamar al 809-957-2020 ext.2245.\n\n"
        f"Contexto:\n{ctx}\n\n"
        f"Pregunta: {body.question}\n\nRespuesta:"
    )

    r = await OLLAMA.post("/api/generate", json={
        "model": LLM_MODEL,
        "prompt": prompt,
        "stream": False,
        "options": {"num_predict": MAX_TOKENS, "temperature": TEMPERATURE, "top_p": TOP_P, "keep_alive": "60m"}
    })
    if r.status_code != 200:
        raise HTTPException(500, f"Error LLM: {r.text}")

    return {"answer": r.json().get("response", "").strip()}

# ========= CHAT STREAM =========
@app.post("/chat/stream")
async def chat_stream(body: ChatIn, _user: dict = Depends(verify_token)):
    await ensure_collection(body.collection)
    query_vec = await _embed(body.question)

    payload = {
        "vector": query_vec,
        "limit": 4,
        "with_payload": True,
        "params": {"hnsw_ef": 64, "exact": False}
    }

    res = await QDRANT.post(f"/collections/{body.collection}/points/search", json=payload)
    hits = res.json().get("result", [])
    ctx = "\n".join([h["payload"]["text"] for h in hits if "payload" in h])[:2000]

    prompt = (
        "Eres el asistente virtual de PIISA Industrial Park. Responde en español de forma completa y amable. "
        "No inventes datos; si no tienes la información, sugiere escribir a soporte@gspiisa.net o llamar al 809-957-2020 ext.2245.\n\n"
        f"Contexto:\n{ctx}\n\n"
        f"Pregunta: {body.question}\n\nRespuesta:"
    )

    return StreamingResponse(stream_llm(prompt), media_type="text/plain")
