#!/usr/bin/env python3
"""
Swarm Protocol Connector for LuckBot.

Connects to the Swarm platform (swarm.perkos.xyz) via:
  - Firestore REST API for messaging (primary — always works)
  - Hub WSS for real-time messages (optional — falls back to polling)

The Swarm platform uses Firebase project "lucky-st". The Node.js plugin
uses the Firebase client SDK directly; we use the Firestore REST API
with the public API key (matching the same security model).
"""

import json
import logging
import asyncio
from datetime import datetime, timezone
from typing import Optional
from urllib.parse import urlparse

import httpx

try:
    import websockets
    HAS_WEBSOCKETS = True
except ImportError:
    HAS_WEBSOCKETS = False

from trading_core import (
    execute_tool as core_execute_tool,
    run_claude_agentic_loop,
    get_trading_context,
)
from roles import get_role_prompt, get_role_name

logger = logging.getLogger("luckyst.swarm")

# Firebase config for Swarm platform (project: lucky-st)
# These are public client-side values (same as in the Node.js plugin)
FIREBASE_API_KEY = "AIzaSyAwsFqFmZpw2QN0ZR1UmpgsC4ApTqHmoOM"
FIREBASE_PROJECT_ID = "lucky-st"
FIRESTORE_BASE = f"https://firestore.googleapis.com/v1/projects/{FIREBASE_PROJECT_ID}/databases/(default)/documents"

# Polling interval for Firestore (seconds)
POLL_INTERVAL = 5


class SwarmConnector:
    """
    Swarm Protocol client using Firestore polling.

    Flow:
      1. Update agent status to "online" in Firestore
      2. Get assigned projects/channels from agent doc
      3. Poll Firestore for new messages every POLL_INTERVAL seconds
      4. Route messages through Claude agentic loop
      5. Send responses back via Firestore

    Optionally tries Hub WSS for real-time delivery (falls back to polling).
    """

    def __init__(
        self,
        hub_url: str,
        org_id: str,
        agent_name: str,
        agent_type: str,
        api_key: str,
        agent_id: str,
    ):
        self.hub_url = hub_url.rstrip("/")
        self.org_id = org_id
        self.agent_name = agent_name
        self.agent_type = agent_type
        self.api_key = api_key
        self.agent_id = agent_id

        # Derive WSS URL from hub URL
        parsed = urlparse(self.hub_url)
        ws_scheme = "wss" if parsed.scheme == "https" else "ws"
        self.ws_base = f"{ws_scheme}://{parsed.netloc}"

        # Hub JWT (for WSS, if available)
        self.hub_jwt: Optional[str] = None
        self.hub_refresh_token: Optional[str] = None
        self.hub_available = False

        # Per-channel poll state: {channel_id: last_seen_timestamp_iso}
        self.poll_state: dict[str, str] = {}

        # Per-channel chat history (capped at 20 messages)
        self.chat_histories: dict[str, list] = {}

        # WebSocket connection
        self._ws = None
        self._running = True
        self._http = None  # Shared httpx client

    # ========================================================================
    # HTTP CLIENT
    # ========================================================================

    def _get_http(self) -> httpx.AsyncClient:
        """Get or create shared HTTP client."""
        if self._http is None or self._http.is_closed:
            self._http = httpx.AsyncClient(timeout=15)
        return self._http

    # ========================================================================
    # MESSAGE ROUTING
    # ========================================================================

    def _should_respond(self, content: str) -> bool:
        """Check if this agent should respond to a message.

        Rules:
        - If message contains @AgentName, only that agent responds
        - If no @mention, all agents respond
        - Case-insensitive matching
        """
        text_lower = content.lower()

        # Check for any @mention pattern
        if "@" not in text_lower:
            return True  # No @mention — everyone responds

        # Check if this agent is specifically mentioned
        my_names = [
            self.agent_name.lower(),
            self.agent_name.lower().replace("-", ""),
            self.agent_name.lower().replace("-", " "),
        ]

        for name in my_names:
            if f"@{name}" in text_lower:
                return True

        # There's an @mention but it's not us — don't respond
        return False

    # ========================================================================
    # FIRESTORE REST API
    # ========================================================================

    def _fs_url(self, path: str) -> str:
        """Build Firestore REST API URL with API key."""
        return f"{FIRESTORE_BASE}/{path}?key={FIREBASE_API_KEY}"

    @staticmethod
    def _fs_val(field: dict):
        """Extract value from a Firestore field."""
        if "stringValue" in field:
            return field["stringValue"]
        if "integerValue" in field:
            return int(field["integerValue"])
        if "booleanValue" in field:
            return field["booleanValue"]
        if "timestampValue" in field:
            return field["timestampValue"]
        if "arrayValue" in field:
            return [SwarmConnector._fs_val(v) for v in field["arrayValue"].get("values", [])]
        if "mapValue" in field:
            return {k: SwarmConnector._fs_val(v) for k, v in field["mapValue"].get("fields", {}).items()}
        if "nullValue" in field:
            return None
        return None

    async def _fs_get_doc(self, collection: str, doc_id: str) -> Optional[dict]:
        """Get a single Firestore document."""
        url = self._fs_url(f"{collection}/{doc_id}")
        try:
            resp = await self._get_http().get(url)
            if resp.status_code == 404:
                return None
            if resp.status_code >= 400:
                logger.warning(f"Firestore GET {collection}/{doc_id} failed ({resp.status_code}): {resp.text[:200]}")
                return None
            data = resp.json()
            fields = data.get("fields", {})
            return {k: self._fs_val(v) for k, v in fields.items()}
        except Exception as e:
            logger.error(f"Firestore GET {collection}/{doc_id} error: {e}")
            return None

    async def _fs_update_doc(self, collection: str, doc_id: str, fields: dict) -> bool:
        """Update specific fields of a Firestore document."""
        mask = "&".join(f"updateMask.fieldPaths={k}" for k in fields)
        url = f"{FIRESTORE_BASE}/{collection}/{doc_id}?key={FIREBASE_API_KEY}&{mask}"

        fs_fields = {}
        for k, v in fields.items():
            if isinstance(v, str):
                fs_fields[k] = {"stringValue": v}
            elif isinstance(v, bool):
                fs_fields[k] = {"booleanValue": v}
            elif isinstance(v, int):
                fs_fields[k] = {"integerValue": str(v)}
            elif isinstance(v, dict):
                fs_fields[k] = {"mapValue": {"fields": {
                    mk: {"stringValue": str(mv)} for mk, mv in v.items()
                }}}

        body = {"fields": fs_fields}
        try:
            resp = await self._get_http().patch(url, json=body)
            if resp.status_code >= 400:
                logger.warning(f"Firestore PATCH {collection}/{doc_id} failed ({resp.status_code}): {resp.text[:200]}")
                return False
            return True
        except Exception as e:
            logger.error(f"Firestore PATCH error: {e}")
            return False

    async def _fs_create_doc(self, collection: str, fields: dict) -> bool:
        """Create a new document in a Firestore collection."""
        url = self._fs_url(collection)

        fs_fields = {}
        for k, v in fields.items():
            if isinstance(v, str):
                fs_fields[k] = {"stringValue": v}
            elif isinstance(v, bool):
                fs_fields[k] = {"booleanValue": v}
            elif isinstance(v, int):
                fs_fields[k] = {"integerValue": str(v)}

        # Add server timestamp for createdAt
        fs_fields["createdAt"] = {"timestampValue": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ")}

        body = {"fields": fs_fields}
        try:
            resp = await self._get_http().post(url, json=body)
            if resp.status_code >= 400:
                logger.error(f"Firestore POST {collection} failed ({resp.status_code}): {resp.text[:300]}")
                return False
            return True
        except Exception as e:
            logger.error(f"Firestore POST error: {e}")
            return False

    async def _fs_query(self, collection: str, field: str, op: str, value, order_by: str = None) -> list[dict]:
        """Run a structured query against Firestore."""
        url = f"https://firestore.googleapis.com/v1/projects/{FIREBASE_PROJECT_ID}/databases/(default)/documents:runQuery?key={FIREBASE_API_KEY}"

        # Build field filter
        if isinstance(value, str):
            fs_value = {"stringValue": value}
        elif isinstance(value, int):
            fs_value = {"integerValue": str(value)}
        elif isinstance(value, bool):
            fs_value = {"booleanValue": value}
        else:
            fs_value = {"stringValue": str(value)}

        structured_query = {
            "from": [{"collectionId": collection}],
            "where": {
                "fieldFilter": {
                    "field": {"fieldPath": field},
                    "op": op,
                    "value": fs_value,
                }
            },
        }

        if order_by:
            structured_query["orderBy"] = [{"field": {"fieldPath": order_by}, "direction": "ASCENDING"}]

        body = {"structuredQuery": structured_query}
        try:
            resp = await self._get_http().post(url, json=body)
            if resp.status_code >= 400:
                logger.warning(f"Firestore query {collection} failed ({resp.status_code}): {resp.text[:200]}")
                return []

            results = []
            for item in resp.json():
                doc = item.get("document")
                if not doc:
                    continue
                fields = doc.get("fields", {})
                parsed = {k: self._fs_val(v) for k, v in fields.items()}
                # Extract doc ID from name path
                parsed["_id"] = doc["name"].split("/")[-1]
                results.append(parsed)
            return results
        except Exception as e:
            logger.error(f"Firestore query error: {e}")
            return []

    # ========================================================================
    # AGENT STATUS
    # ========================================================================

    async def _update_agent_status(self, status: str = "online"):
        """Update agent status and lastSeen in Firestore."""
        ok = await self._fs_update_doc("agents", self.agent_id, {
            "status": status,
            "lastSeen": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ"),
        })
        if ok:
            logger.info(f"Agent status: {status}")

    # ========================================================================
    # MESSAGE POLLING (primary method)
    # ========================================================================

    async def _get_agent_channels(self) -> list[tuple[str, str, str]]:
        """Get channels assigned to this agent via project membership.
        Returns list of (channel_id, channel_name, project_name).
        """
        # Get agent doc to find assigned projects
        agent = await self._fs_get_doc("agents", self.agent_id)
        if not agent:
            logger.warning("Agent not found in Firestore")
            return []

        project_ids = agent.get("projectIds") or []
        if not project_ids:
            logger.debug("No projects assigned to this agent")
            return []

        channels = []
        for project_id in project_ids:
            # Get project name
            project = await self._fs_get_doc("projects", project_id)
            project_name = project.get("name", project_id) if project else project_id

            # Query channels for this project
            ch_docs = await self._fs_query("channels", "projectId", "EQUAL", project_id)
            for ch in ch_docs:
                ch_id = ch.get("_id", "")
                ch_name = ch.get("name", "Channel")
                channels.append((ch_id, ch_name, project_name))

        return channels

    async def _poll_messages(self):
        """Poll Firestore for new messages across all assigned channels."""
        channels = await self._get_agent_channels()

        if not channels:
            return

        for channel_id, channel_name, project_name in channels:
            # Query messages for this channel
            messages = await self._fs_query("messages", "channelId", "EQUAL", channel_id)

            last_seen = self.poll_state.get(channel_id, "")
            new_messages = []
            max_ts = last_seen

            for msg in messages:
                sender_id = msg.get("senderId", "")
                sender_type = msg.get("senderType", "")
                ts = msg.get("createdAt", "")

                # Skip our own messages and agent messages
                if sender_id == self.agent_id:
                    if ts and ts > max_ts:
                        max_ts = ts
                    continue
                if sender_type == "agent":
                    if ts and ts > max_ts:
                        max_ts = ts
                    continue

                # Skip already-seen messages
                if last_seen and ts and ts <= last_seen:
                    if ts > max_ts:
                        max_ts = ts
                    continue

                content = msg.get("content") or msg.get("text") or ""
                if not content:
                    if ts and ts > max_ts:
                        max_ts = ts
                    continue

                sender_name = msg.get("senderName") or sender_id or "unknown"
                new_messages.append({
                    "content": content,
                    "sender_name": sender_name,
                    "channel_id": channel_id,
                    "channel_name": channel_name,
                    "project_name": project_name,
                    "ts": ts,
                })
                if ts and ts > max_ts:
                    max_ts = ts

            # Update poll state
            if max_ts > last_seen:
                self.poll_state[channel_id] = max_ts

            # Process new messages
            for msg in new_messages:
                logger.info(f"[{msg['project_name']}] #{msg['channel_name']} — {msg['sender_name']}: {msg['content'][:80]}")
                if not self._should_respond(msg["content"]):
                    logger.debug(f"Skipping — message addressed to another agent")
                    continue
                await self._handle_user_message(
                    channel_id=msg["channel_id"],
                    content=msg["content"],
                    sender_name=msg["sender_name"],
                    project_name=msg["project_name"],
                    channel_name=msg["channel_name"],
                )

    async def _poll_loop(self):
        """Main polling loop — check Firestore for new messages periodically."""
        logger.info(f"Firestore polling started (every {POLL_INTERVAL}s)")

        # First poll: set all current messages as "seen" (don't respond to history)
        channels = await self._get_agent_channels()
        for channel_id, _, _ in channels:
            messages = await self._fs_query("messages", "channelId", "EQUAL", channel_id)
            max_ts = ""
            for msg in messages:
                ts = msg.get("createdAt", "")
                if ts > max_ts:
                    max_ts = ts
            if max_ts:
                self.poll_state[channel_id] = max_ts
        logger.info(f"Initialized poll state for {len(channels)} channel(s)")

        while self._running:
            try:
                await self._poll_messages()
            except Exception as e:
                logger.error(f"Poll error: {e}")
            await asyncio.sleep(POLL_INTERVAL)

    # ========================================================================
    # MESSAGE HANDLING
    # ========================================================================

    async def _handle_user_message(
        self,
        channel_id: str,
        content: str,
        sender_name: str,
        project_name: str,
        channel_name: str,
    ):
        """Process a user message through the Claude agentic loop and respond."""
        # Get or create per-channel history
        history = self.chat_histories.get(channel_id, [])

        # Build tool executor for Swarm context
        async def swarm_tool_executor(tool_name: str, tool_input: dict) -> str:
            return await core_execute_tool(
                tool_name, tool_input,
                chat_id=None,
                jump_monitor_fn=None,
                stop_monitor_fn=None,
                schedule_fn=None,
                sessions=None,
            )

        # Get role-specific system prompt based on agent type
        role_prompt = get_role_prompt(self.agent_type)
        role_name = get_role_name(self.agent_type)

        # Build extra context
        trading_ctx = await get_trading_context()
        extra = f"\n\nCurrent state:\n{trading_ctx}"
        extra += f"\n\nInterface: Swarm Protocol (project: {project_name}, channel: #{channel_name})"
        extra += f"\nMessage from: {sender_name}"
        extra += f"\nYour role: {role_name} ({self.agent_name})"
        extra += "\nNote: Scheduling and jump monitor features require the Telegram interface."

        # Run Claude agentic loop with role-specific prompt
        reply, updated_history = await run_claude_agentic_loop(
            user_message=content,
            chat_history=history,
            tool_executor=swarm_tool_executor,
            system_prompt_extra=extra,
            system_prompt_override=role_prompt,
        )

        # Store updated history (capped)
        self.chat_histories[channel_id] = updated_history[-20:]

        # Send reply via Firestore
        await self._send_message(channel_id, reply)

    async def _send_message(self, channel_id: str, content: str):
        """Send a message by writing to Firestore messages collection."""
        ok = await self._fs_create_doc("messages", {
            "channelId": channel_id,
            "senderId": self.agent_id,
            "senderName": self.agent_name,
            "senderType": "agent",
            "content": content,
            "orgId": self.org_id,
        })
        if ok:
            logger.info(f"Reply sent to channel {channel_id}")
        else:
            logger.error(f"Failed to send reply to channel {channel_id}")

    # ========================================================================
    # HUB WSS (optional — falls back to polling)
    # ========================================================================

    async def _try_hub_auth(self) -> bool:
        """Try to authenticate with the Hub for WSS. Returns False if Hub unavailable."""
        url = f"{self.hub_url}/auth/token"
        try:
            resp = await self._get_http().post(url, json={
                "agentId": self.agent_id,
                "apiKey": self.api_key,
            })
            if resp.status_code >= 400:
                logger.info(f"Hub auth unavailable ({resp.status_code}) — using Firestore polling")
                return False
            data = resp.json()
            self.hub_jwt = data.get("token") or data.get("accessToken")
            self.hub_refresh_token = data.get("refreshToken")
            if self.hub_jwt:
                logger.info("Hub auth successful — WSS available")
                return True
            return False
        except Exception as e:
            logger.info(f"Hub unreachable ({e}) — using Firestore polling")
            return False

    async def _wss_loop(self):
        """Optional WSS loop. Falls back silently if Hub unavailable."""
        if not HAS_WEBSOCKETS or not self.hub_jwt:
            return

        backoff = 1
        while self._running:
            try:
                ws_url = f"{self.ws_base}?token={self.hub_jwt}"
                async with websockets.connect(
                    ws_url, ping_interval=30, ping_timeout=10, close_timeout=5,
                ) as ws:
                    self._ws = ws
                    backoff = 1
                    logger.info("Swarm Hub WSS connected (real-time mode)")

                    async for message in ws:
                        try:
                            data = json.loads(message)
                            if data.get("type") != "message":
                                continue
                            if data.get("senderId") == self.agent_id:
                                continue
                            if data.get("senderType") == "agent":
                                continue

                            content = data.get("content", "")
                            if not content:
                                continue

                            channel_id = data.get("channelId", "default")
                            sender_name = data.get("senderName", "unknown")
                            project_name = data.get("projectName", "Project")
                            channel_name = data.get("channelName", channel_id)

                            # Update poll state so poller doesn't double-process
                            ts = data.get("createdAt", datetime.now(timezone.utc).isoformat())
                            self.poll_state[channel_id] = ts

                            await self._handle_user_message(
                                channel_id=channel_id,
                                content=content,
                                sender_name=sender_name,
                                project_name=project_name,
                                channel_name=channel_name,
                            )
                        except Exception as e:
                            logger.error(f"WSS message handling error: {e}")

            except asyncio.CancelledError:
                break
            except Exception as e:
                self._ws = None
                logger.debug(f"WSS disconnected: {e}. Polling still active. Retry in {backoff}s")
                await asyncio.sleep(backoff)
                backoff = min(backoff * 2, 60)

        self._ws = None

    # ========================================================================
    # LIFECYCLE
    # ========================================================================

    async def run(self):
        """Main entry point."""
        logger.info(f"Starting Swarm connector: {self.agent_name} -> {self.hub_url}")

        # Step 1: Update agent status in Firestore
        await self._update_agent_status("online")

        # Step 2: Try Hub auth (optional — WSS is a bonus)
        self.hub_available = await self._try_hub_auth()

        # Step 3: Run polling (always) + WSS (if Hub available)
        tasks = [self._poll_loop()]
        if self.hub_available:
            tasks.append(self._wss_loop())

        try:
            await asyncio.gather(*tasks)
        except asyncio.CancelledError:
            logger.info("Swarm connector shutting down")
        finally:
            self._running = False
            await self._update_agent_status("offline")
            if self._http and not self._http.is_closed:
                await self._http.aclose()

    async def stop(self):
        """Graceful shutdown."""
        self._running = False
        if self._ws and not self._ws.closed:
            await self._ws.close()
        try:
            await self._update_agent_status("offline")
        except Exception:
            pass
        if self._http and not self._http.is_closed:
            await self._http.aclose()
