# app/modules/terminal/auto/schema.py

from pydantic import BaseModel, Field, validator
from typing import List, Optional, Dict, Any, Literal
from datetime import datetime
from enum import Enum
import re
import uuid


def sanitize_string(value: str) -> str:
    """Sanitize input to prevent XSS and injection attacks"""
    if not value:
        return value
    
    dangerous = ['<script', 'javascript:', 'onerror=', 'onclick=',
                '--', ';', 'drop ', 'delete ', 'insert ', 'update ', 'union ', 'select ']
    value_lower = value.lower()
    
    for pattern in dangerous:
        if pattern in value_lower:
            raise ValueError('Invalid characters or patterns detected')
    
    return value.strip()


class TradeModeEnum(str, Enum):
    JOIN = "Join"
    GRID = "Grid"


class DeployConfig(BaseModel):
    """Configuration for deploying a trading instance"""
    platform: str = Field("kalshi", pattern="^(kalshi|turbine|polymarket)$", description="Trading platform")
    num_markets: int = Field(..., ge=1, le=2)
    mode: str = Field(..., pattern="^(hotkeys|automated)$")
    markets: List[str] = Field(..., min_items=1, max_items=2)

    # Kalshi credentials (required when platform=kalshi)
    kalshi_api_key: Optional[str] = Field(None, min_length=36, max_length=36, description="Kalshi API key (UUID format)")
    rsa_key_path: Optional[str] = Field(None, min_length=100, max_length=10000, description="RSA private key content")

    # Turbine credentials (required when platform=turbine)
    turbine_private_key: Optional[str] = Field(None, min_length=64, max_length=70, description="EVM wallet private key (0x-prefixed hex)")
    turbine_asset: Optional[str] = Field(None, pattern="^(BTC|ETH|SOL)$", description="Turbine asset (BTC/ETH/SOL)")

    # Polymarket credentials (required when platform=polymarket)
    poly_private_key: Optional[str] = Field(None, min_length=64, max_length=70, description="Polygon wallet private key (0x-prefixed hex)")
    poly_market_type: Optional[str] = Field(None, pattern="^(interval|event)$", description="Polymarket market type")
    poly_asset: Optional[str] = Field(None, pattern="^(BTC|ETH|SOL)$", description="Polymarket interval asset")
    poly_interval: Optional[str] = Field(None, pattern="^(5m|15m)$", description="Polymarket interval duration")

    # Kalshi interval market support
    interval_mode: bool = Field(False, description="Enable interval/series market auto-rotation")
    interval_repeating: bool = Field(False, description="Auto-continue to next interval market (vs one-time)")

    both_side: Optional[str] = None
    market_priority: Optional[str] = None
    side_priority: Optional[str] = None

    min_spread: int = Field(1, ge=1, le=99)
    max_spread: int = Field(7, ge=1, le=99)
    m1_bounds: List[int] = Field([1, 7, 93, 99], min_items=4, max_items=4)
    m2_bounds: Optional[List[int]] = Field(None, min_items=4, max_items=4)

    position_increment: int = Field(10, ge=1, le=1000)
    max_position: int = Field(1000, ge=1, le=100000)

    trade_strategy: Optional[str] = Field("join_jump", pattern="^(join_jump|rolling_avg|cdf)$", description="Trading strategy")
    join_only: bool = False
    grid_mode: bool = False
    grid_levels: Optional[List[List[int]]] = None
    price_feed: Optional[str] = Field(None, pattern="^(poly_mid|kalshi_mid|pyth|ccxt_cfbm)$", description="Price feed source")
    rolling_avg_window: Optional[int] = Field(None, ge=0, le=300, description="Rolling average window in seconds (0=instant)")
    rolling_avg_spread: Optional[int] = Field(None, ge=1, le=99, description="Spread in cents to quote around rolling avg fair price")

    contract_increment: int = Field(3, ge=1, le=30, description="Contract increment (1-30)")

    # Telegram bot association
    telegram_bot_id: Optional[int] = Field(None, description="ID of saved Telegram bot to associate with this instance")

    # ========================================================================
    # VALIDATORS
    # ========================================================================
    
    @validator('markets', each_item=True)
    def validate_market_ticker(cls, v, values):
        """Validate market ticker format (Kalshi only — Turbine/Polymarket auto-discovers)"""
        platform = values.get('platform', 'kalshi')
        if platform == 'turbine':
            return v  # Turbine uses asset name (BTC/ETH/SOL), not ticker
        if platform == 'polymarket':
            return v  # Polymarket uses condition_id (hex) or auto-discovers

        if not v:
            raise ValueError('Market ticker cannot be empty')

        v = sanitize_string(v)

        # Pattern: KXNBA-25MAR05-LALMIA-LAL (3+ segments separated by hyphens)
        pattern = r'^[A-Z0-9]+(-[A-Z0-9]+){2,}$'
        if not re.match(pattern, v):
            raise ValueError(f'Invalid market ticker format: {v}. Expected: XXX-XXX-XXX (3+ segments)')

        if len(v) < 5 or len(v) > 100:
            raise ValueError('Market ticker must be 5-100 characters')

        segments = v.split('-')
        if len(segments) < 3:
            raise ValueError('Market ticker must have at least 3 segments')

        return v.upper()

    @validator('markets')
    def validate_markets_count(cls, v, values):
        """Ensure market count matches num_markets"""
        platform = values.get('platform', 'kalshi')
        if platform == 'turbine':
            return v  # Turbine only uses 1 market (asset name)
        if platform == 'polymarket':
            return v  # Polymarket: 1 for interval, 2 for event

        num_markets = values.get('num_markets', 1)
        if len(v) != num_markets:
            raise ValueError(f'Must provide exactly {num_markets} market(s), got {len(v)}')

        if len(v) != len(set(v)):
            raise ValueError('Duplicate markets not allowed')

        return v

    @validator('kalshi_api_key', always=True)
    def validate_api_key(cls, v, values):
        """Validate API key — required for Kalshi only"""
        platform = values.get('platform', 'kalshi')
        if platform != 'kalshi':
            return v  # Not required for Turbine/Polymarket

        if not v:
            raise ValueError('Kalshi API key is required')

        try:
            uuid.UUID(v, version=4)
        except ValueError:
            raise ValueError('Invalid API key format (must be UUID v4)')

        if not re.match(r'^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$', v.lower()):
            raise ValueError('API key must be valid UUID v4')

        return v

    @validator('rsa_key_path', always=True)
    def validate_rsa_key(cls, v, values):
        """Validate RSA private key — required for Kalshi only"""
        platform = values.get('platform', 'kalshi')
        if platform != 'kalshi':
            return v  # Not required for Turbine/Polymarket

        if not v or len(v) < 100:
            raise ValueError('RSA key too short or empty')

        if len(v) > 10000:
            raise ValueError('RSA key exceeds maximum length')

        if not ('BEGIN PRIVATE KEY' in v or 'BEGIN RSA PRIVATE KEY' in v):
            raise ValueError('Invalid RSA key: must be PEM format')

        if not ('END PRIVATE KEY' in v or 'END RSA PRIVATE KEY' in v):
            raise ValueError('Invalid RSA key: missing END marker')

        dangerous = ['<script', 'javascript:', 'onerror=', 'drop table',
                    'delete from', 'insert into', 'union select']
        v_lower = v.lower()
        for pattern in dangerous:
            if pattern in v_lower:
                raise ValueError('Invalid or malicious content detected in RSA key')

        try:
            from cryptography.hazmat.primitives import serialization
            from cryptography.hazmat.backends import default_backend

            key = serialization.load_pem_private_key(
                v.encode('utf-8'),
                password=None,
                backend=default_backend()
            )

            if key.key_size < 2048:
                raise ValueError(f'RSA key too weak: {key.key_size} bits (minimum 2048 required)')

        except Exception as e:
            raise ValueError(f'Invalid or corrupted RSA private key: {str(e)}')

        return v

    @validator('turbine_private_key', always=True)
    def validate_turbine_key(cls, v, values):
        """Validate EVM private key — required for Turbine only"""
        platform = values.get('platform', 'kalshi')
        if platform != 'turbine':
            return v  # Not required for Kalshi

        if not v:
            raise ValueError('Turbine wallet private key is required')

        # Must be hex string, optionally 0x-prefixed
        clean = v.strip()
        if clean.startswith('0x'):
            clean = clean[2:]

        if not re.match(r'^[a-fA-F0-9]{64}$', clean):
            raise ValueError('Invalid private key format (must be 64 hex characters, optionally 0x-prefixed)')

        return v

    @validator('turbine_asset', always=True)
    def validate_turbine_asset(cls, v, values):
        """Validate Turbine asset — required for Turbine only"""
        platform = values.get('platform', 'kalshi')
        if platform != 'turbine':
            return v

        if not v:
            raise ValueError('Turbine asset (BTC/ETH/SOL) is required')

        return v.upper()

    @validator('poly_private_key', always=True)
    def validate_poly_key(cls, v, values):
        """Validate Polygon wallet private key — required for Polymarket only"""
        platform = values.get('platform', 'kalshi')
        if platform != 'polymarket':
            return v

        if not v:
            raise ValueError('Polymarket wallet private key is required')

        clean = v.strip()
        if clean.startswith('0x'):
            clean = clean[2:]

        if not re.match(r'^[a-fA-F0-9]{64}$', clean):
            raise ValueError('Invalid private key format (must be 64 hex characters, optionally 0x-prefixed)')

        return v

    @validator('poly_market_type', always=True)
    def validate_poly_market_type(cls, v, values):
        """Validate market type — required for Polymarket only"""
        platform = values.get('platform', 'kalshi')
        if platform != 'polymarket':
            return v
        if not v:
            raise ValueError('Polymarket market type (interval/event) is required')
        return v

    @validator('max_spread')
    def validate_spread_range(cls, v, values):
        """Ensure max_spread > min_spread"""
        min_spread = values.get('min_spread')
        if min_spread is not None:
            if v <= min_spread:
                raise ValueError(f'max_spread ({v}) must be greater than min_spread ({min_spread})')
            
            if v - min_spread > 50:
                raise ValueError('Spread range too wide (max difference: 50)')
        
        return v
    
    @validator('m1_bounds')
    def validate_m1_bounds(cls, v):
        """Validate price bounds logical ordering"""
        if len(v) != 4:
            raise ValueError('m1_bounds must have exactly 4 values')
        
        if not all(isinstance(x, int) and 1 <= x <= 99 for x in v):
            raise ValueError('All bounds must be integers between 1 and 99')
        
        # [lower1, upper1, lower2, upper2]
        if v[0] >= v[1]:
            raise ValueError(f'First range invalid: lower ({v[0]}) must be < upper ({v[1]})')
        
        if v[2] >= v[3]:
            raise ValueError(f'Second range invalid: lower ({v[2]}) must be < upper ({v[3]})')
        
        # Ranges should not overlap
        if v[1] > v[2]:
            raise ValueError(f'Ranges overlap: first upper ({v[1]}) must be <= second lower ({v[2]})')
        
        return v
    
    @validator('m2_bounds')
    def validate_m2_bounds(cls, v):
        """Validate second market bounds"""
        if v is None:
            return v
        
        if len(v) != 4:
            raise ValueError('m2_bounds must have exactly 4 values')
        
        if not all(isinstance(x, int) and 1 <= x <= 99 for x in v):
            raise ValueError('All m2_bounds must be integers between 1 and 99')
        
        if v[0] >= v[1]:
            raise ValueError(f'First range invalid: lower ({v[0]}) must be < upper ({v[1]})')
        
        if v[2] >= v[3]:
            raise ValueError(f'Second range invalid: lower ({v[2]}) must be < upper ({v[3]})')
        
        if v[1] > v[2]:
            raise ValueError(f'Ranges overlap: first upper ({v[1]}) must be <= second lower ({v[2]})')
        
        return v
    
    @validator('both_side')
    def validate_both_side(cls, v, values):
        """Validate both_side for 2-market mode"""
        num_markets = values.get('num_markets')
        
        if num_markets == 2:
            if v is None:
                raise ValueError('both_side required when num_markets=2')
            if v.lower() not in ['yes', 'no']:
                raise ValueError('both_side must be "yes" or "no"')
        
        return sanitize_string(v).lower() if v else v
    
    @validator('market_priority')
    def validate_market_priority(cls, v):
        """Validate market_priority options"""
        if v is None:
            return v
        
        v = sanitize_string(v).lower()
        valid = ['none', 'market1', 'market2', 'expensive']
        
        if v not in valid:
            raise ValueError(f'market_priority must be one of: {", ".join(valid)}')
        
        return v
    
    @validator('side_priority')
    def validate_side_priority(cls, v):
        """Validate side_priority options"""
        if v is None:
            return v
        
        v = sanitize_string(v).lower()
        valid = ['yes', 'no', 'expensive', 'cheap']
        
        if v not in valid:
            raise ValueError(f'side_priority must be one of: {", ".join(valid)}')
        
        return v
    
    @validator('position_increment')
    def validate_position_increment(cls, v):
        """Validate position increment is reasonable"""
        if v < 1:
            raise ValueError('position_increment must be at least 1')
        if v > 1000:
            raise ValueError('position_increment cannot exceed 1,000 contracts')
        return v
    
    @validator('max_position')
    def validate_max_position(cls, v):
        """Validate max position is reasonable"""
        if v < 1:
            raise ValueError('max_position must be at least 1')
        if v > 100000:
            raise ValueError('max_position cannot exceed 100,000 contracts')
        return v
    
    @validator('grid_levels')
    def validate_grid_levels(cls, v, values):
        """Validate grid levels structure"""
        if v is None:
            return v
        
        grid_mode = values.get('grid_mode', False)
        if not grid_mode and v:
            raise ValueError('grid_levels should only be set when grid_mode=True')
        
        if not isinstance(v, list):
            raise ValueError('grid_levels must be a list')
        
        for i, level in enumerate(v):
            if not isinstance(level, list) or len(level) != 2:
                raise ValueError(f'Grid level {i} must be [price, size]')
            
            price, size = level
            if not (isinstance(price, int) and 1 <= price <= 99):
                raise ValueError(f'Grid level {i} price must be integer 1-99')
            
            if not (isinstance(size, int) and size > 0):
                raise ValueError(f'Grid level {i} size must be positive')
        
        return v
    
    @validator('contract_increment')
    def validate_contract_increment(cls, v):
        """Validate contract increment range"""
        if not 1 <= v <= 30:
            raise ValueError('contract_increment must be 1-30')
        return v


class InstanceResponse(BaseModel):
    """Trading instance response"""
    id: int
    script: str
    markets: List[str]
    status: str
    start_time: Optional[str]
    position: int
    pnl: str
    config: Dict[str, Any]
    trade_mode: str
    orderbook: Optional[Dict[str, Any]]
    celery_task_id: Optional[str]
    current_increment: Optional[Dict[str, Any]]
    telegram_bot_id: Optional[int] = None
    
    class Config:
        from_attributes = True


class InstanceListResponse(BaseModel):
    """List of trading instances"""
    instances: List[InstanceResponse]
    total: int


class PresetCreate(BaseModel):
    """Create a bot strategy preset"""
    name: str = Field(..., min_length=1, max_length=100)
    strategy_config: Dict[str, Any]

    @validator('name')
    def validate_name(cls, v):
        return sanitize_string(v)

    @validator('strategy_config')
    def validate_strategy_config_size(cls, v):
        import json as _json
        if len(_json.dumps(v)) > 10_000:
            raise ValueError('strategy_config exceeds maximum size (10KB)')
        return v


class PresetUpdate(BaseModel):
    """Update a bot strategy preset"""
    name: Optional[str] = Field(None, min_length=1, max_length=100)
    strategy_config: Optional[Dict[str, Any]] = None

    @validator('name')
    def validate_name(cls, v):
        if v is not None:
            return sanitize_string(v)
        return v

    @validator('strategy_config')
    def validate_strategy_config_size(cls, v):
        if v is not None:
            import json as _json
            if len(_json.dumps(v)) > 10_000:
                raise ValueError('strategy_config exceeds maximum size (10KB)')
        return v


class PresetResponse(BaseModel):
    """Bot strategy preset response"""
    id: int
    name: str
    strategy_config: Dict[str, Any]
    created_by_username: str
    created_at: datetime
    updated_at: datetime

    class Config:
        from_attributes = True


class PresetListResponse(BaseModel):
    """List of presets"""
    presets: List["PresetResponse"]
    total: int


class TelegramBotCreate(BaseModel):
    """Create a user-owned Telegram bot"""
    name: str = Field(..., min_length=1, max_length=100)
    bot_token: str = Field(..., min_length=30, max_length=200)
    chat_id: str = Field(..., min_length=1, max_length=50)

    @validator('name')
    def validate_name(cls, v):
        return sanitize_string(v)

    @validator('bot_token')
    def validate_bot_token(cls, v):
        v = v.strip()
        if not re.match(r'^\d+:[A-Za-z0-9_-]+$', v):
            raise ValueError('Invalid Telegram bot token format')
        return v

    @validator('chat_id')
    def validate_chat_id(cls, v):
        v = v.strip()
        if not re.match(r'^-?\d+$', v):
            raise ValueError('Chat ID must be a numeric value')
        return v


class TelegramBotUpdate(BaseModel):
    """Update a user-owned Telegram bot"""
    name: Optional[str] = Field(None, min_length=1, max_length=100)
    bot_token: Optional[str] = Field(None, min_length=30, max_length=200)
    chat_id: Optional[str] = Field(None, min_length=1, max_length=50)

    @validator('name')
    def validate_name(cls, v):
        if v is not None:
            return sanitize_string(v)
        return v

    @validator('bot_token')
    def validate_bot_token(cls, v):
        if v is not None:
            v = v.strip()
            if not re.match(r'^\d+:[A-Za-z0-9_-]+$', v):
                raise ValueError('Invalid Telegram bot token format')
        return v

    @validator('chat_id')
    def validate_chat_id(cls, v):
        if v is not None:
            v = v.strip()
            if not re.match(r'^-?\d+$', v):
                raise ValueError('Chat ID must be a numeric value')
        return v


class TelegramBotResponse(BaseModel):
    """Telegram bot response"""
    id: int
    name: str
    chat_id: str
    active_instance_id: Optional[int]
    created_at: datetime
    updated_at: datetime

    class Config:
        from_attributes = True


class TelegramBotListResponse(BaseModel):
    """List of telegram bots"""
    bots: List[TelegramBotResponse]
    total: int


class InstanceControlRequest(BaseModel):
    action: Literal["pause", "resume", "toggle_pause", "single_fire", "stop", "force_stop", "end", "cancel_orders", "accept_next_market", "decline_next_market"]
    next_market: Optional[str] = None

    @validator('action')
    def validate_action(cls, v):
        valid = ["pause", "resume", "toggle_pause", "single_fire", "stop", "force_stop", "end", "cancel_orders", "accept_next_market", "decline_next_market"]
        if v not in valid:
            raise ValueError(f'action must be one of: {", ".join(valid)}')
        return v


class SigningProxyCreate(BaseModel):
    """Create signing proxy (deprecated - use Redis session storage)"""
    kalshi_api_key: str = Field(..., min_length=36, max_length=36)
    private_key_content: str = Field(..., min_length=100, max_length=10000)
    
    @validator('kalshi_api_key')
    def validate_api_key(cls, v):
        try:
            uuid.UUID(v, version=4)
        except ValueError:
            raise ValueError('Invalid API key format (must be UUID v4)')
        return v
    
    @validator('private_key_content')
    def validate_private_key(cls, v):
        if not ('BEGIN PRIVATE KEY' in v or 'BEGIN RSA PRIVATE KEY' in v):
            raise ValueError('Invalid private key format (must be PEM)')
        return v


class SigningProxyResponse(BaseModel):
    """Signing proxy response"""
    id: int
    user_id: int
    kalshi_api_key: str
    is_active: bool
    created_at: datetime
    
    class Config:
        from_attributes = True


class OrderbookLevel(BaseModel):
    """Orderbook price level"""
    price: int = Field(..., ge=1, le=99)
    size: int = Field(..., ge=0)
    
    @validator('size')
    def validate_size(cls, v):
        if v < 0:
            raise ValueError('Size cannot be negative')
        if v > 1000000:
            raise ValueError('Size exceeds maximum (1,000,000)')
        return v


class MarketOrderbook(BaseModel):
    """Market orderbook data"""
    side: Optional[str] = None
    last_traded: int = Field(..., ge=1, le=99)
    volume: int = Field(..., ge=0)
    resting_order: Optional[Dict[str, Any]] = None
    resting_yes: Optional[Dict[str, Any]] = None
    resting_no: Optional[Dict[str, Any]] = None
    bids: List[OrderbookLevel]
    asks: List[OrderbookLevel]
    
    @validator('side')
    def validate_side(cls, v):
        if v is not None and v.lower() not in ['yes', 'no']:
            raise ValueError('side must be "yes" or "no"')
        return v.lower() if v else v


class InstanceStatusUpdate(BaseModel):
    """Real-time instance status update"""
    id: int
    status: str
    position: int = Field(..., ge=-100000, le=100000)
    pnl: str
    orderbook: Optional[Dict[str, Any]] = None
    current_increment: Optional[Dict[str, Any]] = None
    
    @validator('status')
    def validate_status(cls, v):
        valid = ['running', 'paused', 'stopped', 'error', 'starting']
        v_lower = v.lower()
        if v_lower not in valid:
            raise ValueError(f'status must be one of: {", ".join(valid)}')
        return v_lower
    
    @validator('position')
    def validate_position(cls, v):
        if abs(v) > 100000:
            raise ValueError('Position exceeds maximum (±100,000)')
        return v