from celery import Celery
import time
import json
import asyncio
import redis.asyncio as redis
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy import update

from app.config import get_settings

settings = get_settings()


async def _update_instance_status(instance_id: int, status: str):
    """Update instance status in the database after task completes"""
    from app.modules.terminal.auto.models import TradingInstance
    db_url = settings.database_url.replace("postgresql://", "postgresql+asyncpg://")
    engine = create_async_engine(db_url)
    factory = async_sessionmaker(engine, class_=AsyncSession)
    async with factory() as session:
        await session.execute(
            update(TradingInstance).where(TradingInstance.id == instance_id).values(status=status)
        )
        await session.commit()
    await engine.dispose()
    print(f"✓ Instance #{instance_id} status → {status}")

celery = Celery(
    'trading_tasks',
    broker=settings.CELERY_BROKER_URL,
    backend=settings.redis_url
)


async def get_user_credentials(user_id: int):
    """Get credentials from Redis session"""
    redis_client = redis.Redis.from_url(settings.redis_url, decode_responses=True)
    session_key = f"user:{user_id}:credentials"
    session_data = await redis_client.get(session_key)
    await redis_client.close()

    if not session_data:
        raise ValueError("Session expired. Credentials not found in Redis.")

    data = json.loads(session_data)
    from app.modules.terminal.auto.crypto import CryptoService
    crypto = CryptoService(master_key=settings.SECRET_KEY)

    platform = data.get("platform", "kalshi")
    if platform == "turbine":
        return {
            "platform": "turbine",
            "turbine_private_key": crypto.decrypt(data["turbine_private_key"]),
        }
    if platform == "polymarket":
        return {
            "platform": "polymarket",
            "poly_private_key": crypto.decrypt(data["poly_private_key"]),
        }

    return {
        "platform": "kalshi",
        "api_key": crypto.decrypt(data["api_key"]),
        "private_key": crypto.decrypt(data["rsa_key"]),
    }


@celery.task(bind=True, name='trading_tasks.start_trading_instance')
def start_trading_instance(self, instance_id: int, user_id: int, script_type: str, markets: list, config: dict):
    import redis as sync_redis
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)

    try:
        credentials = loop.run_until_complete(get_user_credentials(user_id))
        redis_client_async = redis.Redis.from_url(settings.redis_url, decode_responses=True)
        platform = config.get("platform", credentials.get("platform", "kalshi"))

        # ── TURBINE ──────────────────────────────────────────────
        if platform == "turbine":
            from app.modules.terminal.auto.turbine_auto1 import TurbineSMTrader
            from app.modules.terminal.auto.turbine_client import TurbineClient

            private_key = credentials["turbine_private_key"]
            asset = config.get("turbine_asset", "BTC")

            # Register API credentials
            api_key_id = None
            api_private_key = None
            try:
                creds = TurbineClient.request_api_credentials(
                    host="https://api.turbinefi.com",
                    private_key=private_key,
                )
                api_key_id = creds["api_key_id"]
                api_private_key = creds["api_private_key"]
            except Exception as e:
                print(f"⚠️ Turbine API cred registration: {e}")

            client = TurbineClient(
                host="https://api.turbinefi.com",
                chain_id=137,
                private_key=private_key,
                api_key_id=api_key_id,
                api_private_key=api_private_key,
            )

            trader = TurbineSMTrader(client=client, asset=asset, contract_usdc=1.0)
            trader.redis_client = sync_redis.Redis.from_url(settings.redis_url, decode_responses=True)
            trader.instance_id = instance_id
            trader.contract_increment = config.get("contract_increment", 3)
            trader.running = True
            trader.active = True

            # Polymarket price feed config
            if config.get("price_feed") == "poly_mid":
                trader.price_feed = "poly_mid"
                # Convert spread from cents (e.g. 10) to 1e6 units (100_000)
                spread_cents = config.get("rolling_avg_spread", 10)
                trader.poly_mid_spread = int(spread_cents * 10_000)

            # Discover initial market
            if not trader.discover_market():
                raise ValueError(f"No active {asset} market found on Turbine")

            loop.run_until_complete(trader.refresh_market_data_async())
            loop.run_until_complete(trader.initialize_orders_async())

            async def turbine_trading_loop():
                await trader.run_trading_instance(instance_id, [asset])

            loop.run_until_complete(turbine_trading_loop())

            trader.redis_client.close()
            loop.run_until_complete(redis_client_async.close())

            loop2 = asyncio.new_event_loop()
            loop2.run_until_complete(_update_instance_status(instance_id, 'stopped'))
            loop2.close()

            return {'status': 'completed'}

        # ── POLYMARKET ────────────────────────────────────────────
        if platform == "polymarket":
            from app.modules.terminal.auto.poly_client import PolyClient

            private_key = credentials["poly_private_key"]
            poly_market_type = config.get("poly_market_type", "interval")
            poly_asset = config.get("poly_asset", "btc")
            poly_interval = config.get("poly_interval", "15m")

            client = PolyClient(private_key=private_key)

            if poly_market_type == "interval":
                # Single market (interval) mode
                from app.modules.terminal.auto.polymarket_auto1 import PolySMTrader

                trader = PolySMTrader(client=client, asset=poly_asset, interval=poly_interval)
                trader.redis_client = sync_redis.Redis.from_url(settings.redis_url, decode_responses=True)
                trader.instance_id = instance_id
                trader.contract_increment = config.get("contract_increment", 3)
                trader.running = True
                trader.active = True

                if not trader.discover_market():
                    raise ValueError(f"No active {poly_asset} {poly_interval} market found on Polymarket")

                loop.run_until_complete(trader.refresh_market_data_async())
                loop.run_until_complete(trader.initialize_orders_async())

                async def poly_interval_loop():
                    await trader.run_trading_instance(instance_id, [poly_asset])

                loop.run_until_complete(poly_interval_loop())

            else:
                # Event (2-market) mode
                from app.modules.terminal.auto.polymarket_auto2 import PolyAutoTrader

                # Resolve market data from condition_ids
                m1_data = client.get_market_by_condition(markets[0])
                m2_data = client.get_market_by_condition(markets[1]) if len(markets) > 1 else None

                if not m1_data:
                    raise ValueError(f"Polymarket market not found: {markets[0]}")
                if not m2_data:
                    raise ValueError(f"Polymarket market not found: {markets[1] if len(markets) > 1 else 'N/A'}")

                trader = PolyAutoTrader(client=client, market_1=m1_data, market_2=m2_data)
                trader.redis_client = sync_redis.Redis.from_url(settings.redis_url, decode_responses=True)
                trader.instance_id = instance_id
                trader.contract_increment = config.get("contract_increment", 3)
                trader.higher_first_mode = (config.get("market_priority") == "expensive")
                trader.running = True
                trader.active = True

                loop.run_until_complete(trader.refresh_market_data_async())
                if trader.higher_first_mode:
                    trader.select_active_market()
                loop.run_until_complete(trader.initialize_orders_async())

                async def poly_event_loop():
                    await trader.run_trading_instance(instance_id, markets)

                loop.run_until_complete(poly_event_loop())

            trader.redis_client.close()
            loop.run_until_complete(redis_client_async.close())

            loop2 = asyncio.new_event_loop()
            loop2.run_until_complete(_update_instance_status(instance_id, 'stopped'))
            loop2.close()

            return {'status': 'completed'}

        # ── KALSHI ───────────────────────────────────────────────
        if script_type in ("single_market", "auto1"):
            from app.modules.terminal.auto.auto1 import SMTrader as AutomatedTrader
        else:
            from app.modules.terminal.auto.auto2 import AutomatedTrader

        from app.modules.terminal.auto.kalshi_api import Config

        # Initialize trader based on script type
        if script_type in ("single_market", "auto1"):
            trader = AutomatedTrader(
                api_key=credentials["api_key"],
                api_secret=credentials["private_key"],
                market_id=markets[0],
                config=Config()
            )
        else:
            trader = AutomatedTrader(
                api_key=credentials["api_key"],
                api_secret=credentials["private_key"],
                market_1=markets[0],
                market_2=markets[1] if len(markets) > 1 else markets[0],
                config=Config()
            )

        trader.redis_client = sync_redis.Redis.from_url(settings.redis_url, decode_responses=True)
        trader.instance_id = instance_id
        trader.contract_increment = config["contract_increment"]

        if script_type in ("single_market", "auto1"):
            trader.one_side_first_mode = (config.get("market_priority") == "expensive")
            # Interval market config
            trader.interval_mode = config.get("interval_mode", False)
            trader.interval_repeating = config.get("interval_repeating", False)
        else:
            trader.higher_first_mode = (config.get("market_priority") == "expensive")

        # Telegram notifications
        trader.telegram_chat_id = config.get("telegram_chat_id") or settings.TELEGRAM_USER_ID

        # Set flags
        trader.running = True
        trader.active = True

        loop.run_until_complete(trader.refresh_market_data_async())

        # Higher-first mode: select active market based on orderbook data
        if script_type not in ("single_market", "auto1") and trader.higher_first_mode:
            resumed = loop.run_until_complete(trader.check_and_resume_rebalancing())
            if not resumed:
                trader.select_active_market()
                loop.run_until_complete(trader.initialize_orders_async())
        else:
            loop.run_until_complete(trader.initialize_orders_async())
        
        async def trading_loop_with_status():
            while trader.running and trader.active:
                await trader.process_redis_commands_async()
                
                if not trader.active:
                    break
                
                await trader.refresh_market_data_async()
                trader.check_fills()
                
                if not trader.active:
                    break
                
                if trader.stopping:
                    if await trader.both_filled_async():
                        await trader.cancel_all_orders_async()
                        trader.active = False
                        break
                elif trader.paused:
                    await trader.update_orders_async()
                elif trader.waiting_for_manual_resume:
                    if await trader.initialize_orders_async():
                        pass
                    trader.waiting_for_manual_resume = False
                elif await trader.both_filled_async():
                    if not trader.active:
                        break
                    await trader.start_new_cycle_async()
                else:
                    await trader.update_orders_async()
                
                status_data = await format_status(trader, instance_id, markets, script_type, platform="kalshi")
                # Add jump state so bot can check before triggering
                status_data["jump_active"] = {
                    m: trader.jump_active.get(m, False) for m in markets
                }
                # Interval market state
                if hasattr(trader, 'interval_mode') and trader.interval_mode:
                    status_data["interval_mode"] = True
                    status_data["interval_repeating"] = trader.interval_repeating
                    status_data["interval_waiting_for_next"] = getattr(trader, 'interval_waiting_for_next', False)
                    status_data["current_market"] = trader.market_id
                status_json = json.dumps(status_data)
                await redis_client_async.publish(
                    f"trading:instance:{instance_id}:updates",
                    status_json
                )
                await redis_client_async.set(
                    f"trading:instance:{instance_id}:status",
                    status_json, ex=30
                )
                
                await asyncio.sleep(1)
            
            await trader.close_session()
        
        loop.run_until_complete(trading_loop_with_status())

        trader.redis_client.close()
        loop.run_until_complete(redis_client_async.close())

        # Update DB status to stopped
        loop2 = asyncio.new_event_loop()
        loop2.run_until_complete(_update_instance_status(instance_id, 'stopped'))
        loop2.close()

        return {'status': 'completed'}
    except Exception as e:
        import traceback
        # Log full traceback server-side only (never store in task results)
        print(f"CELERY ERROR: {type(e).__name__}: {str(e)}\n{traceback.format_exc()}")

        # Update DB status to error
        try:
            loop3 = asyncio.new_event_loop()
            loop3.run_until_complete(_update_instance_status(instance_id, 'error'))
            loop3.close()
        except Exception:
            pass

        # Return sanitized error — no traceback, no locals with keys
        return {'status': 'error', 'error': f"{type(e).__name__}: {str(e)}"}
    finally:
        loop.close()


async def format_status(trader, instance_id, markets, script_type, platform="kalshi"):
    """Format status with full 5-level orderbook"""
    # Turbine/Polymarket status is handled by their own format_status_data()
    if platform in ("turbine", "polymarket"):
        return trader.format_status_data(instance_id)

    if script_type in ("single_market", "auto1"):
        m = markets[0]
        ob, market_data = await asyncio.gather(
            trader._request("GET", f"/markets/{m}/orderbook"),
            trader._request("GET", f"/markets/{m}"),
        )

        # Extract last_price from market data (YES price in cents)
        market_last_price = 0
        if isinstance(market_data, dict) and market_data:
            market_last_price = market_data.get("market", {}).get("last_price", 0) or 0

        def format_side(side):
            if not ob or 'orderbook' not in ob:
                return {"side": side.upper(), "last_traded": 0, "volume": 0, "resting_order": None, "queue_position": None, "bids": [], "asks": []}

            orderbook = ob['orderbook']

            if side == "yes":
                yes_levels = [lvl for lvl in orderbook.get('yes', []) if lvl[1] > 0]
                yes_sorted = sorted(yes_levels, key=lambda x: x[0], reverse=True)[:5]
                bids = [{"price": lvl[0], "size": lvl[1]} for lvl in yes_sorted]

                no_levels = [lvl for lvl in orderbook.get('no', []) if lvl[1] > 0]
                no_sorted = sorted(no_levels, key=lambda x: x[0], reverse=True)[:5]
                asks = [{"price": 100 - lvl[0], "size": lvl[1]} for lvl in no_sorted]
            else:  # no
                no_levels = [lvl for lvl in orderbook.get('no', []) if lvl[1] > 0]
                no_sorted = sorted(no_levels, key=lambda x: x[0], reverse=True)[:5]
                bids = [{"price": lvl[0], "size": lvl[1]} for lvl in no_sorted]

                yes_levels = [lvl for lvl in orderbook.get('yes', []) if lvl[1] > 0]
                yes_sorted = sorted(yes_levels, key=lambda x: x[0], reverse=True)[:5]
                asks = [{"price": 100 - lvl[0], "size": lvl[1]} for lvl in yes_sorted]

            # last_traded from market data: YES side = last_price, NO side = 100 - last_price
            if market_last_price > 0:
                last_traded = market_last_price if side == "yes" else 100 - market_last_price
            else:
                last_traded = bids[0]['price'] if bids else 0
            
            resting_price = trader.last_prices.get(side)
            resting_qty = trader.cached_resting.get(side, 0) or 0
            resting_order = None
            
            if resting_price and resting_qty > 0:
                resting_price_cents = round(resting_price * 100) if resting_price < 1 else round(resting_price)
                if any(b['price'] == resting_price_cents for b in bids):
                    resting_order = {"price_level": resting_price_cents, "quantity": resting_qty}

            queue_pos = trader.cached_queue_position.get(side)
            queue_position = None
            if resting_order and queue_pos is not None:
                queue_position = {"price_level": resting_order["price_level"], "position": queue_pos}

            return {
                "side": side.upper(),
                "last_traded": last_traded,
                "volume": 0,
                "resting_order": resting_order,
                "queue_position": queue_position,
                "bids": bids,
                "asks": asks
            }
        
        return {
            "id": instance_id,
            "status": "running",
            "position": (trader.cached_position.get("yes", 0) or 0) + (trader.cached_position.get("no", 0) or 0),
            "pnl": "+$0.00",
            "orderbook": {
                "yes": format_side("yes"),
                "no": format_side("no")
            },
            "current_increment": {
                "yes": {"filled": trader.current_increment.get("yes", 0), "total": trader.contract_increment},
                "no": {"filled": trader.current_increment.get("no", 0), "total": trader.contract_increment}
            }
        }
    else:
        m1, m2 = markets[0], markets[1] if len(markets) > 1 else markets[0]
        m1_ob, m2_ob, m1_mkt, m2_mkt = await asyncio.gather(
            trader._request("GET", f"/markets/{m1}/orderbook"),
            trader._request("GET", f"/markets/{m2}/orderbook"),
            trader._request("GET", f"/markets/{m1}"),
            trader._request("GET", f"/markets/{m2}"),
        )

        # Extract last_price from market data (YES price in cents)
        market_last_prices = {}
        for ticker, mkt_data in [(m1, m1_mkt), (m2, m2_mkt)]:
            if isinstance(mkt_data, dict) and mkt_data:
                market_last_prices[ticker] = mkt_data.get("market", {}).get("last_price", 0) or 0
            else:
                market_last_prices[ticker] = 0

        def format_orderbook(market_id, ob_data):
            if not ob_data or 'orderbook' not in ob_data:
                return {"side": "NO", "last_traded": 0, "volume": 0, "resting_order": None, "queue_position": None, "bids": [], "asks": []}

            orderbook = ob_data['orderbook']
            no_levels = [lvl for lvl in orderbook.get('no', []) if lvl[1] > 0]
            no_sorted = sorted(no_levels, key=lambda x: x[0], reverse=True)[:5]
            bids = [{"price": lvl[0], "size": lvl[1]} for lvl in no_sorted]

            yes_levels = [lvl for lvl in orderbook.get('yes', []) if lvl[1] > 0]
            yes_sorted = sorted(yes_levels, key=lambda x: x[0], reverse=True)[:5]
            asks = [{"price": 100 - lvl[0], "size": lvl[1]} for lvl in yes_sorted]

            # last_traded from market data: NO side = 100 - last_price (last_price is YES)
            yes_last = market_last_prices.get(market_id, 0)
            if yes_last > 0:
                last_traded = 100 - yes_last
            else:
                last_traded = bids[0]['price'] if bids else 0
            
            resting_price = trader.last_prices.get(market_id)
            resting_qty = trader.cached_resting.get(market_id, 0) or 0
            resting_order = None
            
            if resting_price and resting_qty > 0:
                resting_price_cents = round(resting_price * 100) if resting_price < 1 else round(resting_price)
                if any(b['price'] == resting_price_cents for b in bids):
                    resting_order = {"price_level": resting_price_cents, "quantity": resting_qty}

            queue_pos = trader.cached_queue_position.get(market_id)
            queue_position = None
            if resting_order and queue_pos is not None:
                queue_position = {"price_level": resting_order["price_level"], "position": queue_pos}

            return {
                "side": "NO",
                "last_traded": last_traded,
                "volume": 0,
                "resting_order": resting_order,
                "queue_position": queue_position,
                "bids": bids,
                "asks": asks
            }
        
        return {
            "id": instance_id,
            "status": "running",
            "position": (trader.cached_position.get(m1, 0) or 0) + (trader.cached_position.get(m2, 0) or 0),
            "pnl": "+$0.00",
            "orderbook": {
                m1: format_orderbook(m1, m1_ob),
                m2: format_orderbook(m2, m2_ob)
            },
            "current_increment": {
                "m1": {"filled": trader.current_increment.get(m1, 0), "total": trader.contract_increment},
                "m2": {"filled": trader.current_increment.get(m2, 0), "total": trader.contract_increment}
            }
        }