"""
Async PostgreSQL writer for Gemini prediction market orderbook data.
Consumes from an asyncio.Queue so DB writes never bottleneck the stream.
"""

import asyncio
import logging
import re
from typing import Any

import asyncpg

logger = logging.getLogger("gemini.db")

DEFAULT_DSN = "postgresql://postgres:postgres@localhost:5432/gemini_orderbooks"


def _sanitize_table_name(series: str) -> str:
    clean = re.sub(r"[^a-zA-Z0-9]", "_", series).lower()
    return f"orderbook_{clean}"


class OrderbookWriter:
    def __init__(self, dsn: str = DEFAULT_DSN, queue: asyncio.Queue | None = None):
        self.dsn = dsn
        self.queue: asyncio.Queue = queue or asyncio.Queue()
        self._pool: asyncpg.Pool | None = None
        self._known_tables: set[str] = set()

    async def start(self):
        self._pool = await asyncpg.create_pool(self.dsn, min_size=2, max_size=10)
        logger.info("DB pool created")

    async def stop(self):
        if self._pool:
            await self._pool.close()
            logger.info("DB pool closed")

    async def _ensure_table(self, table: str):
        if table in self._known_tables:
            return
        async with self._pool.acquire() as conn:
            await conn.execute(f"""
                CREATE TABLE IF NOT EXISTS {table} (
                    id              BIGSERIAL PRIMARY KEY,
                    ts              TIMESTAMPTZ NOT NULL DEFAULT now(),
                    instrument      TEXT NOT NULL,
                    event_ticker    TEXT NOT NULL,
                    contract_ticker TEXT NOT NULL,
                    market_interval TEXT,
                    best_bid        NUMERIC,
                    best_bid_size   NUMERIC,
                    best_ask        NUMERIC,
                    best_ask_size   NUMERIC
                );
            """)
            await conn.execute(f"""
                CREATE INDEX IF NOT EXISTS idx_{table}_ts ON {table} (ts);
            """)
            await conn.execute(f"""
                CREATE INDEX IF NOT EXISTS idx_{table}_instrument ON {table} (instrument);
            """)
        self._known_tables.add(table)
        logger.info("Ensured table %s", table)

    async def write_one(self, record: dict[str, Any]):
        table = _sanitize_table_name(record["series"])
        await self._ensure_table(table)

        async with self._pool.acquire() as conn:
            await conn.execute(
                f"""
                INSERT INTO {table}
                    (ts, instrument, event_ticker, contract_ticker,
                     market_interval, best_bid, best_bid_size,
                     best_ask, best_ask_size)
                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
                """,
                record["ts"],
                record["instrument"],
                record["event_ticker"],
                record["contract_ticker"],
                record.get("market_interval"),
                float(record["best_bid"]) if record.get("best_bid") else None,
                float(record["best_bid_size"]) if record.get("best_bid_size") else None,
                float(record["best_ask"]) if record.get("best_ask") else None,
                float(record["best_ask_size"]) if record.get("best_ask_size") else None,
            )

    async def consumer_loop(self):
        """Drain the queue forever, writing rows to Postgres."""
        logger.info("DB consumer loop started")
        while True:
            batch: list[dict] = []
            item = await self.queue.get()
            batch.append(item)

            for _ in range(200):
                try:
                    batch.append(self.queue.get_nowait())
                except asyncio.QueueEmpty:
                    break

            for rec in batch:
                try:
                    await self.write_one(rec)
                except Exception:
                    logger.exception("Failed to write record: %s", rec.get("instrument"))

            for _ in batch:
                self.queue.task_done()

            if len(batch) > 50:
                logger.info("Wrote batch of %d rows (queue depth: %d)", len(batch), self.queue.qsize())
