# quad_ai.py - QUAD Artificial Intelligence Learning System
import random
import math
from config import *

class Node:
    """
    A node can represent anything: a feeling, action, concept, or state.
    Each node has a floating-point weight that determines its influence.
    """
    def __init__(self, name, weight=0.5):
        self.name = name
        self.weight = weight  # 0.0 to 1.0
        self.connections = []  # Expanded nodes
    
    def expand(self, child_name, child_weight=0.3):
        """EXPANDING: Create associated sub-nodes."""
        child = Node(child_name, child_weight)
        self.connections.append(child)
        return child
    
    def adjust(self, delta):
        """Adjust weight with bounds."""
        self.weight = max(0.0, min(1.0, self.weight + delta))

class QUAD_Brain:
    """
    STEP 1: Node Initialization
    Each unit has a brain with personality nodes.
    """
    def __init__(self, faction="neutral"):
        # Core personality nodes
        self.aggression = Node("Aggression", random.uniform(0.4, 0.7))
        self.fear = Node("Fear", random.uniform(0.2, 0.5))
        self.accuracy = Node("Accuracy", random.uniform(0.3, 0.6))
        self.grouping = Node("Grouping", random.uniform(0.4, 0.8))
        
        # Tactical nodes (EXPANDING from core nodes)
        self.cover_seeking = self.fear.expand("CoverSeeking", 0.3)
        self.flanking = self.aggression.expand("Flanking", 0.2)
        self.suppression = self.fear.expand("Suppression", 0.1)
        
        # Memory nodes (ASSUMING)
        self.assumed_enemies = []  # Positions where we assume enemies are
        self.known_cover = []      # Known safe positions
        
        # Team behavior
        self.faction = faction
        self.band_mates = []  # BANDING: linked units
        
        # Learning metrics
        self.hits_landed = 0
        self.hits_received = 0
        self.survival_time = 0
        self.kills = 0
    
    def sense_environment(self, unit, all_units, obstacles):
        """
        STEP 3: Sensory Input
        Returns vector: [nearest_enemy_dist, nearest_ally_dist, projectile_density]
        """
        nearest_enemy = float('inf')
        nearest_ally = float('inf')
        
        for other in all_units:
            if other.id == unit.id or other.health <= 0:
                continue
            
            dist = math.dist((unit.x, unit.y), (other.x, other.y))
            
            # Determine if friend or foe
            is_ally = self._is_ally(unit, other)
            
            if is_ally:
                nearest_ally = min(nearest_ally, dist)
            else:
                nearest_enemy = min(nearest_enemy, dist)
        
        # Projectile density - how much danger nearby
        projectile_density = self.fear.weight * (1.0 if nearest_enemy < 100 else 0.5)
        
        return [nearest_enemy, nearest_ally, projectile_density]
    
    def _is_ally(self, unit1, unit2):
        """Check if two units are allies."""
        from combine_overwatch import EliteUnit
        from civil_protection import CivilProtection
        from resistance import RebelUnit
        
        unit1_combine = isinstance(unit1, (EliteUnit, CivilProtection))
        unit2_combine = isinstance(unit2, (EliteUnit, CivilProtection))
        
        return unit1_combine == unit2_combine
    
    def predict_best_move(self, unit, sensory_input, all_units):
        """
        STEP 5: The Prediction Loop
        Predict which direction minimizes "Damage Received" weight.
        """
        nearest_enemy_dist, nearest_ally_dist, projectile_density = sensory_input
        
        # Calculate desire to approach or retreat
        aggression_factor = self.aggression.weight
        fear_factor = self.fear.weight + projectile_density
        grouping_factor = self.grouping.weight
        
        # PREDICTION: If fear > aggression, seek cover
        if fear_factor > aggression_factor:
            # Move toward nearest cover or away from danger
            if nearest_enemy_dist < 150:
                return self._predict_retreat(unit, all_units)
            else:
                return self._predict_to_cover(unit)
        
        # If grouping is high, stay near allies
        if grouping_factor > 0.6 and nearest_ally_dist > 100:
            return self._predict_to_allies(unit, all_units)
        
        # Otherwise, advance aggressively
        return self._predict_advance(unit, all_units)
    
    def _predict_retreat(self, unit, all_units):
        """Predict best retreat vector."""
        # Find average enemy position
        enemy_x, enemy_y, count = 0, 0, 0
        for other in all_units:
            if not self._is_ally(unit, other) and other.health > 0:
                enemy_x += other.x
                enemy_y += other.y
                count += 1
        
        if count > 0:
            avg_x, avg_y = enemy_x / count, enemy_y / count
            # Move away from average enemy position
            dx = -3 if avg_x > unit.x else 3
            dy = -3 if avg_y > unit.y else 3
            return dx, dy
        return 0, 0
    
    def _predict_to_cover(self, unit):
        """Predict movement toward cover."""
        # Use known cover or move toward obstacles
        if self.known_cover:
            target = random.choice(self.known_cover)
            dx = 2 if target[0] > unit.x else -2
            dy = 2 if target[1] > unit.y else -2
            return dx, dy
        return random.choice([-1, 0, 1]), random.choice([-1, 0, 1])
    
    def _predict_to_allies(self, unit, all_units):
        """Predict movement toward allies."""
        closest_ally = None
        min_dist = float('inf')
        
        for other in all_units:
            if self._is_ally(unit, other) and other.id != unit.id and other.health > 0:
                dist = math.dist((unit.x, unit.y), (other.x, other.y))
                if dist < min_dist:
                    min_dist = dist
                    closest_ally = other
        
        if closest_ally:
            dx = 2 if closest_ally.x > unit.x else -2
            dy = 2 if closest_ally.y > unit.y else -2
            return dx, dy
        return 0, 0
    
    def _predict_advance(self, unit, all_units):
        """Predict aggressive advance."""
        # Find nearest enemy
        for other in all_units:
            if not self._is_ally(unit, other) and other.health > 0:
                dx = 3 if other.x > unit.x else -3
                dy = 3 if other.y > unit.y else -3
                return dx, dy
        return 1, 1
    
    def assume_enemy_position(self, last_known_x, last_known_y, velocity_x=0, velocity_y=0):
        """
        STEP 6: The "Assuming" Engine
        When enemy goes behind obstacle, assume they're still there + movement.
        """
        assumed_x = last_known_x + (velocity_x * 2)
        assumed_y = last_known_y + (velocity_y * 2)
        
        # Add to memory
        self.assumed_enemies.append({
            "pos": (assumed_x, assumed_y),
            "confidence": 0.7,
            "decay": 30  # Frames until we forget
        })
    
    def update_assumptions(self):
        """Decay assumed enemy positions over time."""
        for assumption in self.assumed_enemies[:]:
            assumption["decay"] -= 1
            assumption["confidence"] *= 0.95
            if assumption["decay"] <= 0 or assumption["confidence"] < 0.1:
                self.assumed_enemies.remove(assumption)
    
    def apply_feedback(self, hit_enemy=False, got_hit=False, killed_enemy=False):
        """
        STEP 7: Feedback (Reward)
        Reinforce successful behaviors, discourage failures.
        """
        if hit_enemy:
            self.aggression.adjust(0.05)
            self.accuracy.adjust(0.03)
            self.hits_landed += 1
            if killed_enemy:
                self.kills += 1
                self.aggression.adjust(0.08)
        
        if got_hit:
            self.fear.adjust(0.04)
            self.cover_seeking.adjust(0.05)
            self.hits_received += 1
            # Reduce aggression slightly
            self.aggression.adjust(-0.02)
        
        # Survival time increases confidence
        self.survival_time += 1
        if self.survival_time > 500:
            self.accuracy.adjust(0.001)  # Veteran bonus
    
    def band_with_allies(self, nearby_allies):
        """
        STEP 2: Banding Setup
        Share learned weights with nearby units.
        """
        self.band_mates = nearby_allies
        
        if not nearby_allies:
            return
        
        # Average weights across the band
        avg_aggression = sum(ally.brain.aggression.weight for ally in nearby_allies) / len(nearby_allies)
        avg_fear = sum(ally.brain.fear.weight for ally in nearby_allies) / len(nearby_allies)
        
        # Partial update - blend with band mates
        blend_factor = self.grouping.weight * 0.1
        self.aggression.weight += (avg_aggression - self.aggression.weight) * blend_factor
        self.fear.weight += (avg_fear - self.fear.weight) * blend_factor
    
    def get_flattened_data(self):
        """
        STEP 8: Flattened Execution
        Return GPU-ready flattened format.
        """
        return [
            self.aggression.weight,
            self.fear.weight,
            self.accuracy.weight,
            self.grouping.weight,
            self.cover_seeking.weight,
            self.flanking.weight,
            float(self.hits_landed),
            float(self.hits_received),
            float(self.survival_time),
        ]
    
    def clone_for_next_gen(self):
        """
        STEP 9: Iterative Refinement
        Units that survive longer pass their weights to next generation.
        """
        new_brain = QUAD_Brain(self.faction)
        
        # Transfer learned weights with slight mutation
        mutation = 0.05
        new_brain.aggression.weight = self.aggression.weight + random.uniform(-mutation, mutation)
        new_brain.fear.weight = self.fear.weight + random.uniform(-mutation, mutation)
        new_brain.accuracy.weight = self.accuracy.weight + random.uniform(-mutation, mutation)
        new_brain.grouping.weight = self.grouping.weight + random.uniform(-mutation, mutation)
        
        # Clamp
        new_brain.aggression.adjust(0)
        new_brain.fear.adjust(0)
        new_brain.accuracy.adjust(0)
        new_brain.grouping.adjust(0)
        
        return new_brain

class TeamPersonality:
    """
    STEP 10: Stage 50 Integration
    Aggregate team behaviors into faction-wide personalities.
    """
    def __init__(self, faction_name):
        self.faction_name = faction_name
        self.avg_aggression = 0.5
        self.avg_fear = 0.5
        self.avg_accuracy = 0.5
        self.tactical_doctrine = "balanced"  # aggressive, defensive, balanced
    
    def aggregate_from_units(self, units):
        """Learn team personality from all active units."""
        if not units:
            return
        
        total_aggression = sum(u.brain.aggression.weight for u in units if hasattr(u, 'brain'))
        total_fear = sum(u.brain.fear.weight for u in units if hasattr(u, 'brain'))
        total_accuracy = sum(u.brain.accuracy.weight for u in units if hasattr(u, 'brain'))
        
        count = len([u for u in units if hasattr(u, 'brain')])
        if count == 0:
            return
        
        self.avg_aggression = total_aggression / count
        self.avg_fear = total_fear / count
        self.avg_accuracy = total_accuracy / count
        
        # Determine doctrine
        if self.avg_aggression > 0.65:
            self.tactical_doctrine = "aggressive"
        elif self.avg_fear > 0.6:
            self.tactical_doctrine = "defensive"
        else:
            self.tactical_doctrine = "balanced"
    
    def influence_new_unit(self, brain):
        """New units adopt team personality."""
        blend = 0.3
        brain.aggression.weight += (self.avg_aggression - brain.aggression.weight) * blend
        brain.fear.weight += (self.avg_fear - brain.fear.weight) * blend
        brain.accuracy.weight += (self.avg_accuracy - brain.accuracy.weight) * blend
