#!/usr/bin/env python3
"""
Polymarket Single Market (Interval) Automated YES/NO Market Maker - ASYNC

Mirrors SMTrader (auto1.py) and TurbineSMTrader (turbine_auto1.py) logic
but uses the Polymarket CLOB API via py-clob-client SDK.

Supports 5-minute and 15-minute crypto interval markets on Polygon.

Key differences from Kalshi SMTrader:
  - Prices are decimals 0.00–1.00 (not cents 1–99)
  - Sizes are float shares (not integer contracts)
  - Orders identified by orderID string
  - Market auto-rotates via Gamma API slug discovery
  - EIP-712 wallet signing for orders, HMAC for auth
"""

import sys
import time
import json
import asyncio
from typing import Optional, Dict, List
from datetime import datetime, timezone
from collections import deque

from app.modules.terminal.auto.poly_client import (
    PolyClient,
    PolyMarket,
    PolyOrderBook,
    PolyOrderBookLevel,
)

# Constants
MARKET_POLL_SECONDS = 5
END_OF_MARKET_STOP_SECONDS = 60


class PolySMTrader:
    """Single market YES/NO trader for Polymarket interval markets."""

    SIDES = ["yes", "no"]

    def __init__(self, client: PolyClient, asset: str = "btc", interval: str = "15m"):
        self.client = client
        self.asset = asset.lower()
        self.interval = interval

        # Market state
        self.market: Optional[PolyMarket] = None
        self.condition_id: str = ""
        self.yes_token_id: str = ""
        self.no_token_id: str = ""
        self.end_time: int = 0
        self.tick_size: str = "0.01"
        self.neg_risk: bool = False

        # Orderbook
        self.yes_bid: Optional[float] = None
        self.yes_bid_size: float = 0
        self.yes_second_bid: Optional[float] = None
        self.yes_ask: Optional[float] = None
        self.no_bid: Optional[float] = None
        self.no_bid_size: float = 0
        self.no_second_bid: Optional[float] = None
        self.no_ask: Optional[float] = None

        # Order tracking
        self.order_ids: Dict[str, Optional[str]] = {"yes": None, "no": None}
        self.last_prices: Dict[str, Optional[float]] = {"yes": None, "no": None}
        self.current_increment: Dict[str, int] = {"yes": 0, "no": 0}
        self.cycle_start_resting: Dict[str, float] = {"yes": 0, "no": 0}
        self.cached_resting: Dict[str, float] = {"yes": 0, "no": 0}
        self.cached_position: Dict[str, float] = {"yes": 0, "no": 0}
        self.fill_prices: Dict[str, Optional[float]] = {"yes": None, "no": None}

        # Control flags
        self.running = False
        self.active = False
        self.stopping = False
        self.paused = False
        self.waiting_for_manual_resume = False
        self.is_rebalancing = False
        self.market_expiring = False
        self.contract_increment = 3  # shares per side per cycle

        # Single fire
        self.single_fire_mode = False
        self.single_fire_cycles_completed = 0

        # Fair value
        self.fair_value_enabled = False
        self.fair_value_history = deque(maxlen=10)
        self.current_fair_value = None

        # Jump mode
        self.jump_active: Dict[str, bool] = {}
        self.jump_target: Dict[str, Optional[float]] = {}

        # One-side-first
        self.one_side_first_mode = False
        self.active_side: Optional[str] = None

        # Interval mode (auto-rotate to next market in series)
        self.interval_mode = True  # Always true for poly interval
        self.interval_repeating = True  # Default to auto-rotate

        # Redis + instance tracking
        self.instance_id: Optional[int] = None
        self.redis_client = None

        # Session tracking
        self.session_start_time: Optional[str] = None
        self.cycles_completed: int = 0

    def _get_markets(self):
        return [self.condition_id]

    def _token_for_side(self, side: str) -> str:
        return self.yes_token_id if side == "yes" else self.no_token_id

    # ============================================================
    # MARKET DISCOVERY
    # ============================================================

    def discover_market(self) -> bool:
        """Find the active interval market for this asset."""
        try:
            markets = self.client.find_interval_markets(self.asset, self.interval)
            if not markets:
                print(f"No active {self.asset} {self.interval} markets found")
                return False

            # Pick the market with the earliest end date (currently active)
            markets.sort(key=lambda m: m.end_date)
            active = [m for m in markets if m.active]
            if not active:
                print(f"No active {self.asset} {self.interval} markets")
                return False

            self.market = active[0]
            self.condition_id = self.market.condition_id
            self.yes_token_id = self.market.token_ids[0] if len(self.market.token_ids) > 0 else ""
            self.no_token_id = self.market.token_ids[1] if len(self.market.token_ids) > 1 else ""
            self.tick_size = self.market.tick_size
            self.neg_risk = self.market.neg_risk

            # Parse end time
            if self.market.end_date:
                try:
                    from datetime import datetime as dt
                    end_dt = dt.fromisoformat(self.market.end_date.replace("Z", "+00:00"))
                    self.end_time = int(end_dt.timestamp())
                except Exception:
                    self.end_time = int(time.time()) + 900  # Default 15m

            slug_short = self.market.slug[:30] if self.market.slug else self.condition_id[:16]
            print(f"Market: {slug_short}... | {self.market.question[:50]}")
            return True
        except Exception as e:
            print(f"Market discovery error: {e}")
            return False

    async def check_market_transition(self) -> bool:
        """Check if market expired and find next interval market."""
        try:
            markets = self.client.find_interval_markets(self.asset, self.interval)
            active = [m for m in markets if m.active and m.condition_id != self.condition_id]

            if not active:
                return False

            # Check if current market is still active
            current = next((m for m in markets if m.condition_id == self.condition_id), None)
            if current and current.active:
                return False  # Current market still going

            # Transition to next
            active.sort(key=lambda m: m.end_date)
            new_market = active[0]

            print(f"\n New market: {new_market.slug[:30]}...")
            await self.cancel_all_orders_async()

            self.market = new_market
            self.condition_id = new_market.condition_id
            self.yes_token_id = new_market.token_ids[0] if len(new_market.token_ids) > 0 else ""
            self.no_token_id = new_market.token_ids[1] if len(new_market.token_ids) > 1 else ""
            self.tick_size = new_market.tick_size
            self.neg_risk = new_market.neg_risk
            self.market_expiring = False

            if new_market.end_date:
                try:
                    from datetime import datetime as dt
                    end_dt = dt.fromisoformat(new_market.end_date.replace("Z", "+00:00"))
                    self.end_time = int(end_dt.timestamp())
                except Exception:
                    self.end_time = int(time.time()) + 900

            # Reset order tracking
            self._reset_order_tracking()

            print(f"Now trading: {new_market.slug[:30]}...")
            return True
        except Exception as e:
            print(f"Market check error: {e}")
        return False

    def _reset_order_tracking(self):
        self.order_ids = {"yes": None, "no": None}
        self.last_prices = {"yes": None, "no": None}
        self.current_increment = {"yes": 0, "no": 0}
        self.cycle_start_resting = {"yes": 0, "no": 0}
        self.cached_resting = {"yes": 0, "no": 0}
        self.cached_position = {"yes": 0, "no": 0}
        self.fill_prices = {"yes": None, "no": None}

    def seconds_remaining(self) -> int:
        return max(0, self.end_time - int(time.time()))

    # ============================================================
    # MARKET DATA
    # ============================================================

    async def refresh_market_data_async(self):
        """Fetch orderbooks and open orders."""
        # YES orderbook
        try:
            yes_ob = await asyncio.to_thread(self.client.get_orderbook, self.yes_token_id)
            if yes_ob.bids:
                self.yes_bid = yes_ob.bids[0].price
                self.yes_bid_size = yes_ob.bids[0].size
                self.yes_second_bid = yes_ob.bids[1].price if len(yes_ob.bids) > 1 else None
            else:
                self.yes_bid = None
                self.yes_bid_size = 0
                self.yes_second_bid = None
            if yes_ob.asks:
                self.yes_ask = yes_ob.asks[0].price
            else:
                self.yes_ask = None
        except Exception:
            pass

        # NO orderbook
        try:
            no_ob = await asyncio.to_thread(self.client.get_orderbook, self.no_token_id)
            if no_ob.bids:
                self.no_bid = no_ob.bids[0].price
                self.no_bid_size = no_ob.bids[0].size
                self.no_second_bid = no_ob.bids[1].price if len(no_ob.bids) > 1 else None
            else:
                self.no_bid = None
                self.no_bid_size = 0
                self.no_second_bid = None
            if no_ob.asks:
                self.no_ask = no_ob.asks[0].price
            else:
                self.no_ask = None
        except Exception:
            pass

        # Open orders (check resting sizes)
        try:
            open_orders = await asyncio.to_thread(self.client.get_open_orders)
            api_orders = {o.get("id"): o for o in (open_orders or []) if isinstance(o, dict)}
            for side in self.SIDES:
                oid = self.order_ids[side]
                if oid and oid in api_orders:
                    order = api_orders[oid]
                    remaining = float(order.get("size_matched", 0))
                    original = float(order.get("original_size", 0))
                    self.cached_resting[side] = original - remaining
                elif oid:
                    # Order gone — filled or cancelled
                    self.order_ids[side] = None
                    self.last_prices[side] = None
                    self.cached_resting[side] = 0
                else:
                    self.cached_resting[side] = 0
        except Exception:
            pass

        # Fair value
        if self.fair_value_enabled:
            if self.yes_bid is not None and self.no_bid is not None:
                midpoint = (self.yes_bid + self.no_bid) / 2
                self.fair_value_history.append(midpoint)
                if self.fair_value_history:
                    self.current_fair_value = sum(self.fair_value_history) / len(self.fair_value_history)

    # ============================================================
    # ORDER OPERATIONS
    # ============================================================

    async def _place_order(self, side: str, price: float, size: float) -> Optional[str]:
        """Place a limit buy, return order_id or None."""
        token_id = self._token_for_side(side)
        try:
            oid = await asyncio.to_thread(
                self.client.place_limit_order,
                token_id, price, size, "BUY",
                self.tick_size, self.neg_risk,
            )
            return oid
        except Exception as e:
            print(f"{side.upper()}: Order failed: {e}")
            return None

    async def _cancel_order(self, order_id: str) -> bool:
        try:
            return await asyncio.to_thread(self.client.cancel_order, order_id)
        except Exception:
            return False

    async def cancel_all_orders_async(self):
        """Cancel all open orders."""
        try:
            await asyncio.to_thread(self.client.cancel_all)
        except Exception as e:
            print(f"Cancel error: {e}")

        for side in self.SIDES:
            self.order_ids[side] = None
            self.last_prices[side] = None
            self.cached_resting[side] = 0

    # ============================================================
    # BID/PRICE LOGIC
    # ============================================================

    def get_bid_info(self, side: str):
        if side == "yes":
            return self.yes_bid, self.yes_bid_size, self.yes_second_bid
        return self.no_bid, self.no_bid_size, self.no_second_bid

    def get_market_spread(self) -> Optional[float]:
        """Spread as decimal (e.g. 0.05 = 5 cents)."""
        if self.yes_bid is not None and self.no_bid is not None:
            return round(1.0 - self.yes_bid - self.no_bid, 4)
        return None

    def check_target_price(self, side, bid, bid_size, second_bid, current_price, our_resting) -> Optional[float]:
        """Determine target price (decimal 0.00–1.00)."""
        if bid is None:
            return None

        if bid_size > our_resting:
            others_best = bid
        elif second_bid is not None:
            others_best = second_bid
        else:
            others_best = None

        # JUMP MODE
        cid = self.condition_id
        if self.jump_active.get(cid, False):
            tick = float(self.tick_size)
            if self.jump_target.get(cid) is None and others_best is not None:
                self.jump_target[cid] = round(others_best + tick, 4)
            target = self.jump_target.get(cid)
            if target and others_best and others_best >= target:
                self.jump_active[cid] = False
                self.jump_target[cid] = None
            elif target and target <= 0.99:
                return target

        # JOIN MODE
        if current_price is not None and bid > current_price:
            return bid
        if bid_size > our_resting:
            return bid
        if second_bid is not None:
            return second_bid
        return bid

    # ============================================================
    # TRADING CYCLE
    # ============================================================

    async def initialize_orders_async(self) -> bool:
        """Place initial orders on both sides."""
        success = True
        sides = [self.active_side] if self.one_side_first_mode else self.SIDES

        for side in sides:
            if self.cached_resting[side] > 0:
                continue

            bid, bid_size, _ = self.get_bid_info(side)
            if bid is None:
                success = False
                continue

            price = bid
            size = float(self.contract_increment)

            oid = await self._place_order(side, price, size)
            if oid:
                self.order_ids[side] = oid
                self.last_prices[side] = price
                self.cycle_start_resting[side] = size
                self.cached_resting[side] = size
                print(f"{side.upper()}: Placed {size} @ ${price:.2f}")
                await asyncio.sleep(0.2)
            else:
                success = False

        return success

    def check_fills(self):
        """Detect fills by comparing resting to cycle start."""
        sides = [self.active_side] if self.one_side_first_mode else self.SIDES

        for side in sides:
            resting = self.cached_resting[side]
            start = self.cycle_start_resting[side]
            if start == 0:
                continue

            filled = start - resting
            target = float(self.contract_increment)
            if filled > self.current_increment[side]:
                if self.current_increment[side] < target and filled >= target:
                    self.fill_prices[side] = self.last_prices[side]
                self.current_increment[side] = filled

    async def both_filled_async(self) -> bool:
        if self.one_side_first_mode:
            target = float(self.contract_increment)
            return self.current_increment[self.active_side] >= target

        if self.is_rebalancing:
            if any(self.order_ids[s] and self.cached_resting[s] > 0 for s in self.SIDES):
                return False
            yes_pos = self.cached_position["yes"]
            no_pos = self.cached_position["no"]
            if yes_pos == no_pos:
                self.is_rebalancing = False
                return True
            await self.rebalance_async(yes_pos, no_pos)
            return False

        target = float(self.contract_increment)
        for side in self.SIDES:
            if self.current_increment[side] < target:
                return False

        yes_pos = self.cached_position["yes"]
        no_pos = self.cached_position["no"]
        if yes_pos == no_pos:
            return True

        await self.rebalance_async(yes_pos, no_pos)
        return False

    async def rebalance_async(self, yes_pos, no_pos):
        print(f"\nPosition mismatch: YES={yes_pos:.1f}, NO={no_pos:.1f}")
        await self.cancel_all_orders_async()

        lagging_side = "yes" if yes_pos < no_pos else "no"
        diff = abs(yes_pos - no_pos)

        self._reset_order_tracking()

        bid, _, _ = self.get_bid_info(lagging_side)
        if bid:
            oid = await self._place_order(lagging_side, bid, diff)
            if oid:
                self.order_ids[lagging_side] = oid
                self.last_prices[lagging_side] = bid
                self.cycle_start_resting[lagging_side] = diff
                self.is_rebalancing = True
                print(f"Rebalancing: {diff:.1f} {lagging_side.upper()} @ ${bid:.2f}")

    async def start_new_cycle_async(self):
        if not self.active:
            return

        self.cycles_completed += 1

        if self.single_fire_mode:
            self.single_fire_cycles_completed += 1
            if self.single_fire_cycles_completed >= 1:
                print("Single fire complete - pausing")
                self.paused = True
                self.single_fire_mode = False
                self.single_fire_cycles_completed = 0
                self.waiting_for_manual_resume = True
                return

        self.current_increment = {"yes": 0, "no": 0}
        self.order_ids = {"yes": None, "no": None}
        self.last_prices = {"yes": None, "no": None}
        self.fill_prices = {"yes": None, "no": None}
        self.cycle_start_resting = {"yes": 0, "no": 0}

        if await self.initialize_orders_async():
            print("New cycle initialized")
        else:
            print("Failed to start new cycle")

    async def update_orders_async(self):
        """Update orders — cancel and replace if price moved."""
        sides = [self.active_side] if self.one_side_first_mode else self.SIDES

        for side in sides:
            if not self.cached_resting[side]:
                continue

            bid, bid_size, second_bid = self.get_bid_info(side)
            target_price = self.check_target_price(
                side, bid, bid_size, second_bid,
                self.last_prices[side], self.cached_resting[side]
            )

            if target_price and self.last_prices[side]:
                tick = float(self.tick_size)
                if abs(target_price - self.last_prices[side]) >= tick:
                    # Cancel old
                    oid = self.order_ids[side]
                    if oid:
                        await self._cancel_order(oid)
                        self.order_ids[side] = None
                        await asyncio.sleep(0.1)

                    # Place new
                    size = float(self.contract_increment)
                    new_oid = await self._place_order(side, target_price, size)
                    if new_oid:
                        self.order_ids[side] = new_oid
                        direction = "^" if target_price > self.last_prices[side] else "v"
                        print(f"\n{direction} {side.upper()}: ${self.last_prices[side]:.2f} -> ${target_price:.2f}")
                        self.last_prices[side] = target_price
                        self.cycle_start_resting[side] = self.current_increment[side] + size
                        self.cached_resting[side] = size
                    else:
                        self.order_ids[side] = None
                        self.last_prices[side] = None

    # ============================================================
    # REDIS COMMANDS
    # ============================================================

    async def process_redis_commands_async(self):
        if not self.redis_client or not self.instance_id:
            return

        command_key = f"trading:instance:{self.instance_id}:command"
        while True:
            cmd_data = self.redis_client.lpop(command_key)
            if not cmd_data:
                break

            try:
                cmd = json.loads(cmd_data)
                action = cmd.get("action")

                if action == "toggle_jump":
                    cid = self.condition_id
                    self.jump_active[cid] = not self.jump_active.get(cid, False)
                    if not self.jump_active[cid]:
                        self.jump_target[cid] = None
                    print(f"\nJump {'ON' if self.jump_active[cid] else 'OFF'}")

                elif action == "toggle_pause":
                    if self.paused or self.waiting_for_manual_resume:
                        self.paused = False
                        self.waiting_for_manual_resume = False
                    else:
                        self.paused = True

                elif action == "single_fire":
                    if self.paused or self.waiting_for_manual_resume:
                        self.single_fire_mode = True
                        self.single_fire_cycles_completed = 0
                        self.paused = False

                elif action == "toggle_fair_value":
                    self.fair_value_enabled = not self.fair_value_enabled
                    if not self.fair_value_enabled:
                        self.fair_value_history.clear()
                        self.current_fair_value = None

                elif action == "cancel_orders":
                    await self.cancel_all_orders_async()

                elif action == "stop":
                    self.stopping = True

                elif action == "force_stop":
                    await self.cancel_all_orders_async()
                    self.active = False
                    self.stopping = False

            except Exception as e:
                print(f"Error processing command: {e}")

    # ============================================================
    # STATUS
    # ============================================================

    def format_status_data(self, instance_id: int) -> dict:
        return {
            "id": instance_id,
            "status": "running",
            "platform": "polymarket",
            "asset": self.asset.upper(),
            "interval": self.interval,
            "market_slug": self.market.slug[:30] if self.market else "",
            "condition_id": self.condition_id[:16] + "..." if self.condition_id else "",
            "time_remaining": self.seconds_remaining(),
            "position": self.cached_position.get("yes", 0) + self.cached_position.get("no", 0),
            "pnl": "+$0.00",
            "current_increment": {
                "yes": {"filled": self.current_increment.get("yes", 0), "total": self.contract_increment},
                "no": {"filled": self.current_increment.get("no", 0), "total": self.contract_increment}
            },
        }

    def print_status(self):
        def p(v):
            return f"${v:.2f}" if v is not None else "N/A"

        secs = self.seconds_remaining()
        m, s = divmod(secs, 60)
        mode = "PAUSED" if self.paused else ("EXPIRING" if self.market_expiring else "JOIN")

        print(
            f"\r[{mode}] YES: bid{p(self.yes_bid)} ask{p(self.yes_ask)} rest:{self.cached_resting['yes']:.1f}"
            f" || NO: bid{p(self.no_bid)} ask{p(self.no_ask)} rest:{self.cached_resting['no']:.1f}"
            f" | {m}:{s:02d}",
            end=""
        )
        sys.stdout.flush()

    # ============================================================
    # CLOSE
    # ============================================================

    async def close_session(self):
        try:
            self.client.close()
        except Exception:
            pass

    async def generate_session_report(self, instance_id, markets, script_type):
        """Placeholder — Polymarket doesn't use Base reports yet."""
        pass

    # ============================================================
    # MAIN TRADING LOOP
    # ============================================================

    async def run_trading_instance(self, instance_id: int, markets: list):
        self.session_start_time = datetime.now(timezone.utc).isoformat()
        last_status = time.time()
        last_market_check = time.time()

        while self.running and self.active:
            await self.process_redis_commands_async()

            if not self.active:
                break

            # Check for market rotation
            if time.time() - last_market_check >= MARKET_POLL_SECONDS:
                transitioned = await self.check_market_transition()
                last_market_check = time.time()
                if transitioned and not self.paused:
                    await self.initialize_orders_async()

            # Check market expiry
            secs = self.seconds_remaining()
            if secs < END_OF_MARKET_STOP_SECONDS and not self.market_expiring:
                self.market_expiring = True
                if any(self.order_ids[s] for s in self.SIDES):
                    print(f"\nMarket expiring in {secs}s -- pulling orders")
                    await self.cancel_all_orders_async()

            await self.refresh_market_data_async()
            self.check_fills()

            if not self.active:
                break

            if self.market_expiring:
                pass  # Wait for market transition
            elif self.stopping:
                if await self.both_filled_async():
                    await self.cancel_all_orders_async()
                    self.active = False
                    break
            elif self.paused:
                await self.update_orders_async()
            elif self.waiting_for_manual_resume:
                if await self.initialize_orders_async():
                    pass
                self.waiting_for_manual_resume = False
            elif await self.both_filled_async():
                if not self.active:
                    break
                await self.start_new_cycle_async()
            else:
                await self.update_orders_async()

            if time.time() - last_status >= 1:
                self.print_status()
                last_status = time.time()

            await asyncio.sleep(0.5)

        await self.close_session()
        print("\nTrading stopped")
