#!/usr/bin/env python3
"""
Core market making functions - ASYNC support added
"""

import os
import sys
import time
import json
import asyncio
from typing import Optional, Dict
from collections import deque
from datetime import datetime, timezone
from app.modules.terminal.auto.kalshi_api import KalshiAPITrader


class BaseMarketMaker(KalshiAPITrader):
    """Base market maker with shared logic"""
    
    def __init__(self, api_key: str, api_secret: str, config):
        super().__init__(api_key, api_secret, config)
        
        self.running = False
        self.stopping = False
        self.active = False
        self.paused = False
        self.waiting_for_manual_resume = False
        self.is_rebalancing = False
        self.contract_increment = 3
        
        # Single fire
        self.single_fire_mode = False
        self.single_fire_cycles_completed = 0
        
        # Fair value tracking
        self.fair_value_enabled = False
        self.fair_value_history = deque(maxlen=10)
        self.current_fair_value = None
        
        # Jump mode support
        self.jump_active: Dict[str, bool] = {}
        self.jump_target: Dict[str, Optional[int]] = {}
        self.instance_id: Optional[int] = None
        self.redis_client = None

        # Session tracking for on-chain reports
        self.session_start_time: Optional[str] = None
        self.cycles_completed: int = 0
    
    def toggle_jump(self, market_id: str):
        """Toggle jump mode for a market"""
        self.jump_active[market_id] = not self.jump_active.get(market_id, False)
        if not self.jump_active[market_id]:
            self.jump_target[market_id] = None

        status = "ON" if self.jump_active[market_id] else "OFF"
        market_label = market_id[-7:] if len(market_id) > 7 else market_id
        print(f"\n{'🔼' if self.jump_active[market_id] else '🔽'} {market_label}: Jump {status}")
    
    def update_fair_value(self, yes_bid: Optional[float], no_bid: Optional[float]):
        """Update fair value with rolling average"""
        if yes_bid is not None and no_bid is not None:
            midpoint = (yes_bid + no_bid) / 2
            self.fair_value_history.append(midpoint)
            
            if len(self.fair_value_history) > 0:
                self.current_fair_value = sum(self.fair_value_history) / len(self.fair_value_history)
    
    async def process_redis_commands_async(self):
        """Process commands from Redis - ASYNC"""
        if not self.redis_client or not self.instance_id:
            return

        command_key = f"trading:instance:{self.instance_id}:command"
        max_commands = 10
        for _ in range(max_commands):
            cmd_data = self.redis_client.lpop(command_key)
            if not cmd_data:
                break

            try:
                cmd = json.loads(cmd_data)
                action = cmd.get("action")

                if action == "toggle_jump":
                    market_idx = cmd.get("market_index", 0)
                    markets = getattr(self, '_get_markets', lambda: [])()
                    if markets and market_idx < len(markets):
                        self.toggle_jump(markets[market_idx])

                elif action == "toggle_pause":
                    if self.paused or self.waiting_for_manual_resume:
                        print("\n▶️  Resuming trading")
                        self.paused = False
                        self.waiting_for_manual_resume = False
                    else:
                        print("\n⏸️  Pausing trading")
                        self.paused = True

                elif action == "single_fire":
                    if not (self.paused or self.waiting_for_manual_resume):
                        print("\n⚠️  Already in continuous mode")
                    else:
                        print("\n🎯 Single fire mode")
                        self.single_fire_mode = True
                        self.single_fire_cycles_completed = 0
                        self.paused = False

                elif action == "toggle_fair_value":
                    self.fair_value_enabled = not self.fair_value_enabled
                    if not self.fair_value_enabled:
                        self.fair_value_history.clear()
                        self.current_fair_value = None
                    status = "ON" if self.fair_value_enabled else "OFF"
                    print(f"\n📊 Fair Value: {status}")

                elif action == "cancel_orders":
                    print("\n🚫 Cancelling all orders")
                    # Cancel orders on both markets - need async cancel
                    cancel_tasks = []
                    markets = getattr(self, '_get_markets', lambda: [])()
                    order_ids = getattr(self, 'order_ids', {})
                    for m in markets:
                        oid = order_ids.get(m)
                        if oid:
                            cancel_tasks.append(self.cancel_order(oid))
                            order_ids[m] = None
                    # For single market (auto1), check yes/no sides
                    for side in ["yes", "no"]:
                        oid = order_ids.get(side)
                        if oid:
                            cancel_tasks.append(self.cancel_order(oid))
                            order_ids[side] = None
                    if cancel_tasks:
                        await asyncio.gather(*cancel_tasks, return_exceptions=True)
                    # Reset last prices so new orders can be placed
                    last_prices = getattr(self, 'last_prices', {})
                    for k in last_prices:
                        last_prices[k] = None
                    print("✓ All orders cancelled")

                elif action == "accept_next_market":
                    # Interval mode: user accepted transition to next market
                    next_market = cmd.get("next_market")
                    # Validate market ticker format (alphanumeric, hyphens, max 100 chars)
                    import re
                    if next_market and not re.match(r'^[A-Za-z0-9_\-]{1,100}$', next_market):
                        print(f"⚠️  Invalid next_market format: {next_market[:50]}")
                        next_market = None
                    if next_market and hasattr(self, 'interval_waiting_for_next'):
                        self.interval_waiting_for_next = False
                        print(f"\n✓ Accepted next market: {next_market}")
                        # transition_to_market is async and defined on SMTrader
                        if hasattr(self, 'transition_to_market'):
                            await self.transition_to_market(next_market)
                            self.paused = False

                elif action == "decline_next_market":
                    # Interval mode: user declined — stop trading
                    if hasattr(self, 'interval_waiting_for_next'):
                        self.interval_waiting_for_next = False
                        self.active = False
                        print("\n⏹️ Declined next market — stopping")

                elif action == "stop":
                    print("\n⏸️  Stop requested - will complete current cycle")
                    self.stopping = True

                elif action == "force_stop":
                    print("\n⛔ Force stop - cancelling all orders")
                    cancel_tasks = []
                    markets = getattr(self, '_get_markets', lambda: [])()
                    order_ids = getattr(self, 'order_ids', {})
                    for m in markets:
                        oid = order_ids.get(m)
                        if oid:
                            cancel_tasks.append(self.cancel_order(oid))
                    for side in ["yes", "no"]:
                        oid = order_ids.get(side)
                        if oid:
                            cancel_tasks.append(self.cancel_order(oid))
                    if cancel_tasks:
                        await asyncio.gather(*cancel_tasks, return_exceptions=True)
                    self.active = False
                    self.stopping = False
                    print("✓ Orders cancelled, stopping")

            except Exception as e:
                print(f"Error processing command: {e}")

    def process_redis_commands(self):
        """Process commands from Redis - SYNC (for standalone scripts)"""
        if not self.redis_client or not self.instance_id:
            return

        command_key = f"trading:instance:{self.instance_id}:command"
        max_commands = 10
        for _ in range(max_commands):
            cmd_data = self.redis_client.lpop(command_key)
            if not cmd_data:
                break

            try:
                cmd = json.loads(cmd_data)
                action = cmd.get("action")

                if action == "toggle_jump":
                    market_idx = cmd.get("market_index", 0)
                    markets = getattr(self, '_get_markets', lambda: [])()
                    if markets and market_idx < len(markets):
                        self.toggle_jump(markets[market_idx])

                elif action == "toggle_pause":
                    if self.paused or self.waiting_for_manual_resume:
                        print("\n▶️  Resuming trading")
                        self.paused = False
                        self.waiting_for_manual_resume = False
                    else:
                        print("\n⏸️  Pausing trading")
                        self.paused = True

                elif action == "single_fire":
                    if not (self.paused or self.waiting_for_manual_resume):
                        print("\n⚠️  Already in continuous mode")
                    else:
                        print("\n🎯 Single fire mode")
                        self.single_fire_mode = True
                        self.single_fire_cycles_completed = 0
                        self.paused = False

                elif action == "toggle_fair_value":
                    self.fair_value_enabled = not self.fair_value_enabled
                    if not self.fair_value_enabled:
                        self.fair_value_history.clear()
                        self.current_fair_value = None
                    status = "ON" if self.fair_value_enabled else "OFF"
                    print(f"\n📊 Fair Value: {status}")

                elif action == "cancel_orders":
                    print("\n🚫 Cancelling all orders")
                    self.cancel_all_orders()
                    print("✓ All orders cancelled")

                elif action == "stop":
                    print("\n⏸️  Stop requested - will complete current cycle")
                    self.stopping = True

                elif action == "force_stop":
                    print("\n⛔ Force stop received - cancelling orders immediately")
                    self.cancel_all_orders()
                    self.active = False
                    self.stopping = False

            except Exception as e:
                print(f"Error processing command: {e}")
    
    def check_target_price(self, side: str, bid: float, bid_size: int, 
                          second_bid: Optional[float], current_price: Optional[float],
                          our_resting: int, other_bid: Optional[float] = None,
                          market_id: Optional[str] = None) -> Optional[float]:
        """Determine target price with bump mode support"""
        if bid is None:
            return None
        
        bid_cents = round(bid * 100)
        current_cents = round(current_price * 100) if current_price else None
        
        if bid_size > our_resting:
            others_best_cents = bid_cents
        elif second_bid:
            others_best_cents = round(second_bid * 100)
        else:
            others_best_cents = None
        
        # JUMP MODE (only in join mode, requires market_id)
        if market_id and self.jump_active.get(market_id, False):
            if self.jump_target[market_id] is None and others_best_cents:
                self.jump_target[market_id] = others_best_cents + 1

            target = self.jump_target[market_id]
            if target and others_best_cents and others_best_cents >= target:
                market_label = market_id[-7:] if len(market_id) > 7 else market_id
                print(f"\n⚠️ {market_label}: Jump disabled - others at/above ${target/100:.2f}")
                self.jump_active[market_id] = False
                self.jump_target[market_id] = None
            elif target and target <= 99:
                return target / 100

        # JOIN MODE
        if current_cents and bid_cents > current_cents:
            return bid
        if bid_size > our_resting:
            return bid
        if second_bid:
            return second_bid
        return bid
    
    def trading_loop(self):
        """Main trading loop"""
        last_status = time.time()
        
        while self.running and self.active:
            self.process_redis_commands()
            
            if not self.active:
                break
            
            self.refresh_market_data()
            self.check_fills()

            if not self.active:
                break
            
            if self.stopping:
                if self.both_filled():
                    print("\n✓ Current cycle completed - stopping")
                    self.cancel_all_orders()
                    self.active = False
                    break
                print("\r⏳ Waiting for current cycle to complete before stopping...", end="")
                sys.stdout.flush()
            elif self.paused:
                self.update_orders()
            elif self.waiting_for_manual_resume:
                if self.initialize_orders():
                    print("✓ Orders initialized")
                self.waiting_for_manual_resume = False
            elif self.both_filled():
                if not self.active:
                    break
                self.start_new_cycle()
            else:
                self.update_orders()
            
            if time.time() - last_status >= 1:
                self.print_status()
                last_status = time.time()
            
            time.sleep(1)
        
        print("\n✓ Trading stopped")
    
    async def generate_session_report(self, instance_id: int, markets: list, script_type: str):
        """
        Generate and publish an encrypted session report to Base.

        Only runs if LUCKBOT_BASE_ADDRESS and LUCKBOT_BASE_PRIVATE_KEY are set.
        Wrapped in try/except so it never crashes the stop flow.
        """
        base_address = os.environ.get("LUCKBOT_BASE_ADDRESS")
        base_key = os.environ.get("LUCKBOT_BASE_PRIVATE_KEY")
        if not base_address or not base_key:
            return

        try:
            from agentic.syndicate.session_summary import build_session_summary
            from agentic.syndicate.report_crypto import encrypt_report, ReportKeyManager
            from agentic.syndicate.chain_writer import write_report_to_chain

            print("\n📝 Generating session report...")

            # Build summary
            summary = await build_session_summary(
                kalshi_api=self,
                instance_id=instance_id,
                markets=markets,
                script_type=script_type,
                start_time=self.session_start_time or datetime.now(timezone.utc).isoformat(),
                cycles_completed=self.cycles_completed,
            )

            print(f"📊 {summary.to_text()}")

            # Encrypt
            encrypted = encrypt_report(summary.to_json())
            chain_bytes = encrypted.to_chain_bytes()

            # Write to Base
            print(f"⛓️  Writing report to Base ({len(chain_bytes)} bytes)...")
            result = await write_report_to_chain(chain_bytes, base_key, base_address)

            if result.success and result.tx_hash:
                print(f"✓ Report published: {result.tx_hash}")

                # Store AES key in Redis for access control
                if self.redis_client:
                    manager = ReportKeyManager(self.redis_client)
                    manager.store_report_key(
                        tx_hash=result.tx_hash,
                        aes_key=encrypted.key,
                        author_address=base_address,
                        instance_id=instance_id,
                        markets=markets,
                        timestamp=str(int(datetime.now(timezone.utc).timestamp())),
                    )
                    print("✓ Report key stored in Redis")
            else:
                print(f"⚠️ Report write failed: {result.error}")

        except Exception as e:
            print(f"⚠️ Session report error (non-fatal): {e}")

    def stop_trading(self):
        """Stop trading"""
        if not self.active:
            print("⚠️  Not currently trading")
            return
        print("\n⏸️  Stopping... waiting for current cycle to complete")
        self.stopping = True