#!/usr/bin/env python3
"""
Polymarket Two-Market Automated NO Market Maker - ASYNC

Mirrors AutomatedTrader (auto2.py) logic but uses the Polymarket CLOB API.

Handles 2-outcome event markets (e.g. "Will Team A win?" + "Will Team B win?")
where each outcome is a separate condition with its own YES/NO token pair.

The trader buys NO on both outcomes, capturing the spread when both fill.
"""

import sys
import time
import json
import asyncio
from typing import Optional, Dict
from datetime import datetime, timezone
from collections import deque

from app.modules.terminal.auto.poly_client import (
    PolyClient,
    PolyMarket,
)

# We trade the NO side of each market — so we use the NO token_id
# In Polymarket, buying NO is equivalent to selling YES shares.


class PolyAutoTrader:
    """Two-market NO trader for Polymarket event markets."""

    def __init__(
        self,
        client: PolyClient,
        market_1: PolyMarket,
        market_2: PolyMarket,
    ):
        self.client = client
        self.market_1_data = market_1
        self.market_2_data = market_2

        # Market identifiers
        self.market_1 = market_1.condition_id
        self.market_2 = market_2.condition_id

        # Token IDs: index 0 = YES, index 1 = NO
        self.m1_yes_token = market_1.token_ids[0] if len(market_1.token_ids) > 0 else ""
        self.m1_no_token = market_1.token_ids[1] if len(market_1.token_ids) > 1 else ""
        self.m2_yes_token = market_2.token_ids[0] if len(market_2.token_ids) > 0 else ""
        self.m2_no_token = market_2.token_ids[1] if len(market_2.token_ids) > 1 else ""

        self.tick_size = market_1.tick_size or "0.01"
        self.neg_risk = market_1.neg_risk

        # Orderbook state (NO side bids for each market)
        self.m1_no_bid: Optional[float] = None
        self.m1_no_bid_size: float = 0
        self.m1_no_second_bid: Optional[float] = None
        self.m2_no_bid: Optional[float] = None
        self.m2_no_bid_size: float = 0
        self.m2_no_second_bid: Optional[float] = None

        # Order tracking (keyed by condition_id)
        self.order_ids: Dict[str, Optional[str]] = {self.market_1: None, self.market_2: None}
        self.last_prices: Dict[str, Optional[float]] = {self.market_1: None, self.market_2: None}
        self.current_increment: Dict[str, int] = {self.market_1: 0, self.market_2: 0}
        self.cycle_start_resting: Dict[str, float] = {self.market_1: 0, self.market_2: 0}
        self.cached_resting: Dict[str, float] = {self.market_1: 0, self.market_2: 0}
        self.cached_position: Dict[str, float] = {self.market_1: 0, self.market_2: 0}
        self.fill_prices: Dict[str, Optional[float]] = {self.market_1: None, self.market_2: None}

        self.higher_first_mode = False
        self.active_market: Optional[str] = None

        # Control flags
        self.running = False
        self.active = False
        self.stopping = False
        self.paused = False
        self.waiting_for_manual_resume = False
        self.is_rebalancing = False
        self.contract_increment = 3

        # Single fire
        self.single_fire_mode = False
        self.single_fire_cycles_completed = 0

        # Fair value
        self.fair_value_enabled = False
        self.fair_value_history = deque(maxlen=10)
        self.current_fair_value = None

        # Jump mode
        self.jump_active: Dict[str, bool] = {self.market_1: False, self.market_2: False}
        self.jump_target: Dict[str, Optional[float]] = {self.market_1: None, self.market_2: None}

        # Redis + instance
        self.instance_id: Optional[int] = None
        self.redis_client = None

        # Session
        self.session_start_time: Optional[str] = None
        self.cycles_completed: int = 0

    def _get_markets(self):
        return [self.market_1, self.market_2]

    def _no_token_for_market(self, market_id: str) -> str:
        if market_id == self.market_1:
            return self.m1_no_token
        return self.m2_no_token

    def _short_label(self, market_id: str) -> str:
        data = self.market_1_data if market_id == self.market_1 else self.market_2_data
        return data.slug[:20] if data.slug else market_id[:12]

    # ============================================================
    # MARKET DATA
    # ============================================================

    async def refresh_market_data_async(self):
        """Fetch NO-side orderbooks for both markets."""
        for market_id, token_id, setter_prefix in [
            (self.market_1, self.m1_no_token, "m1"),
            (self.market_2, self.m2_no_token, "m2"),
        ]:
            try:
                ob = await asyncio.to_thread(self.client.get_orderbook, token_id)
                if ob.bids:
                    setattr(self, f"{setter_prefix}_no_bid", ob.bids[0].price)
                    setattr(self, f"{setter_prefix}_no_bid_size", ob.bids[0].size)
                    setattr(self, f"{setter_prefix}_no_second_bid",
                            ob.bids[1].price if len(ob.bids) > 1 else None)
                else:
                    setattr(self, f"{setter_prefix}_no_bid", None)
                    setattr(self, f"{setter_prefix}_no_bid_size", 0)
                    setattr(self, f"{setter_prefix}_no_second_bid", None)
            except Exception:
                pass

        # Open orders
        try:
            open_orders = await asyncio.to_thread(self.client.get_open_orders)
            api_orders = {o.get("id"): o for o in (open_orders or []) if isinstance(o, dict)}
            for market_id in [self.market_1, self.market_2]:
                oid = self.order_ids[market_id]
                if oid and oid in api_orders:
                    order = api_orders[oid]
                    original = float(order.get("original_size", 0))
                    matched = float(order.get("size_matched", 0))
                    self.cached_resting[market_id] = original - matched
                elif oid:
                    self.order_ids[market_id] = None
                    self.last_prices[market_id] = None
                    self.cached_resting[market_id] = 0
                else:
                    self.cached_resting[market_id] = 0
        except Exception:
            pass

        # Fair value
        if self.fair_value_enabled and self.m1_no_bid and self.m2_no_bid:
            avg_bid = (self.m1_no_bid + self.m2_no_bid) / 2
            self.fair_value_history.append(avg_bid)
            if self.fair_value_history:
                self.current_fair_value = sum(self.fair_value_history) / len(self.fair_value_history)

    def _get_bid_info(self, market_id: str):
        if market_id == self.market_1:
            return self.m1_no_bid, self.m1_no_bid_size, self.m1_no_second_bid
        return self.m2_no_bid, self.m2_no_bid_size, self.m2_no_second_bid

    # ============================================================
    # PRICE LOGIC
    # ============================================================

    def check_target_price(self, market_id, bid, bid_size, second_bid, current_price, our_resting):
        if bid is None:
            return None

        if bid_size > our_resting:
            others_best = bid
        elif second_bid is not None:
            others_best = second_bid
        else:
            others_best = None

        # JUMP MODE
        if self.jump_active.get(market_id, False):
            tick = float(self.tick_size)
            if self.jump_target.get(market_id) is None and others_best is not None:
                self.jump_target[market_id] = round(others_best + tick, 4)
            target = self.jump_target.get(market_id)
            if target and others_best and others_best >= target:
                self.jump_active[market_id] = False
                self.jump_target[market_id] = None
            elif target and target <= 0.99:
                return target

        # JOIN MODE
        if current_price is not None and bid > current_price:
            return bid
        if bid_size > our_resting:
            return bid
        if second_bid is not None:
            return second_bid
        return bid

    # ============================================================
    # HIGHER-FIRST MODE
    # ============================================================

    def select_active_market(self):
        m1_no = self.m1_no_bid or 0
        m2_no = self.m2_no_bid or 0

        if m1_no >= m2_no:
            self.active_market = self.market_1
        else:
            self.active_market = self.market_2

        other = self.market_2 if self.active_market == self.market_1 else self.market_1
        print(f"Higher-first: {self._short_label(self.active_market)} "
              f"(NO ${max(m1_no, m2_no):.2f}) over {self._short_label(other)} "
              f"(NO ${min(m1_no, m2_no):.2f})")

    # ============================================================
    # ORDER OPERATIONS
    # ============================================================

    async def _place_order(self, market_id: str, price: float, size: float) -> Optional[str]:
        token_id = self._no_token_for_market(market_id)
        try:
            oid = await asyncio.to_thread(
                self.client.place_limit_order,
                token_id, price, size, "BUY",
                self.tick_size, self.neg_risk,
            )
            return oid
        except Exception as e:
            print(f"{self._short_label(market_id)}: Order failed: {e}")
            return None

    async def cancel_all_orders_async(self):
        try:
            await asyncio.to_thread(self.client.cancel_all)
        except Exception as e:
            print(f"Cancel error: {e}")

        for mid in [self.market_1, self.market_2]:
            self.order_ids[mid] = None
            self.last_prices[mid] = None
            self.cached_resting[mid] = 0

    async def initialize_orders_async(self) -> bool:
        success = True
        markets = [self.active_market] if self.higher_first_mode else [self.market_1, self.market_2]

        for market_id in markets:
            if self.cached_resting[market_id] > 0:
                continue

            bid, bid_size, _ = self._get_bid_info(market_id)
            if bid is None:
                success = False
                continue

            size = float(self.contract_increment)
            oid = await self._place_order(market_id, bid, size)
            if oid:
                self.order_ids[market_id] = oid
                self.last_prices[market_id] = bid
                self.cycle_start_resting[market_id] = size
                print(f"{self._short_label(market_id)}: Placed {size} NO @ ${bid:.2f}")
                await asyncio.sleep(0.2)
            else:
                success = False

        return success

    def check_fills(self):
        markets = [self.active_market] if self.higher_first_mode else [self.market_1, self.market_2]

        for market_id in markets:
            resting = self.cached_resting[market_id]
            start = self.cycle_start_resting[market_id]
            if start == 0:
                continue

            filled = start - resting
            target = float(self.contract_increment)
            if filled > self.current_increment[market_id]:
                if self.current_increment[market_id] < target and filled >= target:
                    self.fill_prices[market_id] = self.last_prices[market_id]
                self.current_increment[market_id] = filled

    async def both_filled_async(self) -> bool:
        if self.higher_first_mode:
            target = float(self.contract_increment)
            return self.current_increment[self.active_market] >= target

        if self.is_rebalancing:
            if any(self.order_ids[m] and self.cached_resting[m] > 0 for m in [self.market_1, self.market_2]):
                return False
            m1_pos = self.cached_position[self.market_1]
            m2_pos = self.cached_position[self.market_2]
            if m1_pos == m2_pos:
                self.is_rebalancing = False
                return True
            await self.rebalance_async(m1_pos, m2_pos)
            return False

        target = float(self.contract_increment)
        if not all(self.current_increment[m] >= target for m in [self.market_1, self.market_2]):
            return False

        m1_pos = self.cached_position[self.market_1]
        m2_pos = self.cached_position[self.market_2]
        if m1_pos == m2_pos:
            return True

        await self.rebalance_async(m1_pos, m2_pos)
        return False

    async def rebalance_async(self, m1_pos, m2_pos):
        print(f"\nPosition mismatch: {self._short_label(self.market_1)}={m1_pos:.1f}, "
              f"{self._short_label(self.market_2)}={m2_pos:.1f}")
        await self.cancel_all_orders_async()

        lagging = self.market_1 if m1_pos < m2_pos else self.market_2
        diff = abs(m1_pos - m2_pos)

        for mid in [self.market_1, self.market_2]:
            self.current_increment[mid] = 0
            self.order_ids[mid] = None
            self.last_prices[mid] = None
            self.cycle_start_resting[mid] = 0
            self.fill_prices[mid] = None

        bid, _, _ = self._get_bid_info(lagging)
        if bid:
            oid = await self._place_order(lagging, bid, diff)
            if oid:
                self.order_ids[lagging] = oid
                self.last_prices[lagging] = bid
                self.cycle_start_resting[lagging] = diff
                self.is_rebalancing = True
                print(f"Rebalancing: {diff:.1f} NO @ ${bid:.2f} on {self._short_label(lagging)}")

    async def start_new_cycle_async(self):
        if not self.active:
            return

        self.cycles_completed += 1

        if self.higher_first_mode:
            self.active_market = self.market_2 if self.active_market == self.market_1 else self.market_1
            print(f"\nFilled — switching to {self._short_label(self.active_market)}")

        if self.single_fire_mode:
            self.single_fire_cycles_completed += 1
            if self.single_fire_cycles_completed >= 1:
                print("Single fire complete — pausing")
                self.paused = True
                self.single_fire_mode = False
                self.single_fire_cycles_completed = 0
                self.waiting_for_manual_resume = True
                return

        for mid in [self.market_1, self.market_2]:
            self.current_increment[mid] = 0
            self.order_ids[mid] = None
            self.last_prices[mid] = None
            self.fill_prices[mid] = None

        if await self.initialize_orders_async():
            print("New cycle initialized")
        else:
            print("Failed to start new cycle")

    async def update_orders_async(self):
        markets = [self.active_market] if self.higher_first_mode else [self.market_1, self.market_2]

        for market_id in markets:
            if not self.cached_resting[market_id]:
                continue

            bid, bid_size, second_bid = self._get_bid_info(market_id)
            target_price = self.check_target_price(
                market_id, bid, bid_size, second_bid,
                self.last_prices[market_id], self.cached_resting[market_id]
            )

            if target_price and self.last_prices[market_id]:
                tick = float(self.tick_size)
                if abs(target_price - self.last_prices[market_id]) >= tick:
                    oid = self.order_ids[market_id]
                    if oid:
                        await asyncio.to_thread(self.client.cancel_order, oid)
                        self.order_ids[market_id] = None
                        await asyncio.sleep(0.1)

                    size = float(self.contract_increment)
                    new_oid = await self._place_order(market_id, target_price, size)
                    if new_oid:
                        self.order_ids[market_id] = new_oid
                        direction = "^" if target_price > self.last_prices[market_id] else "v"
                        print(f"\n{direction} {self._short_label(market_id)}: "
                              f"${self.last_prices[market_id]:.2f} -> ${target_price:.2f}")
                        self.last_prices[market_id] = target_price
                        self.cycle_start_resting[market_id] = self.current_increment[market_id] + size
                    else:
                        self.order_ids[market_id] = None
                        self.last_prices[market_id] = None

    # ============================================================
    # REDIS COMMANDS
    # ============================================================

    async def process_redis_commands_async(self):
        if not self.redis_client or not self.instance_id:
            return

        command_key = f"trading:instance:{self.instance_id}:command"
        while True:
            cmd_data = self.redis_client.lpop(command_key)
            if not cmd_data:
                break

            try:
                cmd = json.loads(cmd_data)
                action = cmd.get("action")

                if action == "toggle_jump":
                    idx = cmd.get("market_index", 0)
                    markets = [self.market_1, self.market_2]
                    if idx < len(markets):
                        mid = markets[idx]
                        self.jump_active[mid] = not self.jump_active.get(mid, False)
                        if not self.jump_active[mid]:
                            self.jump_target[mid] = None
                        print(f"\nJump {'ON' if self.jump_active[mid] else 'OFF'} on {self._short_label(mid)}")

                elif action == "toggle_pause":
                    if self.paused or self.waiting_for_manual_resume:
                        self.paused = False
                        self.waiting_for_manual_resume = False
                    else:
                        self.paused = True

                elif action == "single_fire":
                    if self.paused or self.waiting_for_manual_resume:
                        self.single_fire_mode = True
                        self.single_fire_cycles_completed = 0
                        self.paused = False

                elif action == "toggle_fair_value":
                    self.fair_value_enabled = not self.fair_value_enabled
                    if not self.fair_value_enabled:
                        self.fair_value_history.clear()
                        self.current_fair_value = None

                elif action == "cancel_orders":
                    await self.cancel_all_orders_async()

                elif action == "stop":
                    self.stopping = True

                elif action == "force_stop":
                    await self.cancel_all_orders_async()
                    self.active = False
                    self.stopping = False

            except Exception as e:
                print(f"Error processing command: {e}")

    # ============================================================
    # STATUS
    # ============================================================

    def format_status_data(self, instance_id: int) -> dict:
        return {
            "id": instance_id,
            "status": "running",
            "platform": "polymarket",
            "position": self.cached_position.get(self.market_1, 0) + self.cached_position.get(self.market_2, 0),
            "pnl": "+$0.00",
            "current_increment": {
                "m1": {"filled": self.current_increment.get(self.market_1, 0), "total": self.contract_increment},
                "m2": {"filled": self.current_increment.get(self.market_2, 0), "total": self.contract_increment}
            },
        }

    def print_status(self):
        def p(v):
            return f"${v:.2f}" if v is not None else "N/A"

        mode = "PAUSED" if self.paused else "JOIN"
        rebal = " [REBAL]" if self.is_rebalancing else ""

        print(
            f"\r[{mode}]{rebal} {self._short_label(self.market_1)}: NO {p(self.m1_no_bid)} "
            f"rest:{self.cached_resting[self.market_1]:.1f} "
            f"cycle:{self.current_increment[self.market_1]}/{self.contract_increment} || "
            f"{self._short_label(self.market_2)}: NO {p(self.m2_no_bid)} "
            f"rest:{self.cached_resting[self.market_2]:.1f} "
            f"cycle:{self.current_increment[self.market_2]}/{self.contract_increment}",
            end=""
        )
        sys.stdout.flush()

    # ============================================================
    # CLOSE
    # ============================================================

    async def close_session(self):
        try:
            self.client.close()
        except Exception:
            pass

    async def generate_session_report(self, instance_id, markets, script_type):
        pass

    # ============================================================
    # MAIN LOOP
    # ============================================================

    async def run_trading_instance(self, instance_id: int, markets: list):
        self.session_start_time = datetime.now(timezone.utc).isoformat()
        last_status = time.time()

        while self.running and self.active:
            await self.process_redis_commands_async()

            if not self.active:
                break

            await self.refresh_market_data_async()
            self.check_fills()

            if not self.active:
                break

            if self.stopping:
                if await self.both_filled_async():
                    print("\nCycle completed — stopping")
                    await self.cancel_all_orders_async()
                    self.active = False
                    break
            elif self.paused:
                await self.update_orders_async()
            elif self.waiting_for_manual_resume:
                if await self.initialize_orders_async():
                    pass
                self.waiting_for_manual_resume = False
            elif await self.both_filled_async():
                if not self.active:
                    break
                await self.start_new_cycle_async()
            else:
                await self.update_orders_async()

            if time.time() - last_status >= 1:
                self.print_status()
                last_status = time.time()

            await asyncio.sleep(1)

        await self.close_session()
        print("\nTrading stopped")
