import xml.etree.ElementTree as ET
import customtkinter as ctk
from tkinter import filedialog, messagebox, StringVar
import os

class QETRenumberingApp:
    def __init__(self, root):
        self.root = root
        self.root.title("Numérotation de Fils QElectroTech")
        self.root.geometry("700x550")

        # Variables de données
        self.file_path = None
        self.tree = None
        self.root_xml = None
        self.folios = [] # Liste des éléments <diagram>

        # --- Interface Graphique ---
        
        # Section Fichier
        frame_file = ctk.CTkFrame(root)
        label_file = ctk.CTkLabel(frame_file, text="Fichier Projet")
        label_file.pack(pady=5)
        frame_file.pack(fill="x", padx=10, pady=5)
        
        self.btn_open = ctk.CTkButton(frame_file, text="Ouvrir fichier .qet", command=self.open_file)
        self.btn_open.pack(side="left", padx=5, pady=5)
        
        self.lbl_file = ctk.CTkLabel(frame_file, text="Aucun fichier sélectionné", text_color="gray")
        self.lbl_file.pack(side="left", padx=5)

        # Section Liste des Folios
        frame_list = ctk.CTkFrame(root)
        label_list = ctk.CTkLabel(frame_list, text="Sélection du Folio")
        label_list.pack(pady=5)
        frame_list.pack(fill="both", expand=True, padx=10, pady=5)
        
        self.scrollable_frame = ctk.CTkScrollableFrame(frame_list)
        self.scrollable_frame.pack(fill="both", expand=True, padx=5, pady=5)
        self.folio_buttons = []
        self.selected_folio_index = None

        # Section Options
        frame_options = ctk.CTkFrame(root)
        label_options = ctk.CTkLabel(frame_options, text="Paramètres de Numérotation")
        label_options.grid(row=0, column=0, columnspan=5, pady=5)
        frame_options.grid_columnconfigure(1, weight=1)
        frame_options.pack(fill="x", padx=10, pady=5)

        ctk.CTkLabel(frame_options, text="Préfixe :").grid(row=1, column=0, padx=5, pady=5, sticky="e")
        self.entry_prefix = ctk.CTkEntry(frame_options, width=200)
        self.entry_prefix.grid(row=1, column=1, padx=5, pady=5, sticky="w")
        self.entry_prefix.insert(0, "") # Vide par défaut

        ctk.CTkLabel(frame_options, text="Format :").grid(row=1, column=2, padx=15, pady=5, sticky="e")
        self.var_format = StringVar(value="0")
        ctk.CTkRadioButton(frame_options, text="0, 1, 2...", variable=self.var_format, value="0").grid(row=1, column=3)
        ctk.CTkRadioButton(frame_options, text="00, 01, 02...", variable=self.var_format, value="00").grid(row=1, column=4)

        ctk.CTkLabel(frame_options, text="Début :").grid(row=2, column=0, padx=5, pady=5, sticky="e")
        self.entry_start = ctk.CTkEntry(frame_options, width=200)
        self.entry_start.grid(row=2, column=1, padx=5, pady=5, sticky="w")
        self.entry_start.insert(0, "0")

        self.btn_set_08 = ctk.CTkButton(frame_options, text="Mettre 08", command=lambda: self.set_start_value("08"))
        self.btn_set_08.grid(row=2, column=2, padx=5, pady=5, sticky="w")

        # Section Actions
        frame_actions = ctk.CTkFrame(root)
        frame_actions.pack(fill="x", padx=10, pady=10)

        self.btn_process = ctk.CTkButton(frame_actions, text="Numéroter le folio sélectionné", command=self.process_folio, state="disabled", fg_color="#dddddd", height=2)
        self.btn_process.pack(side="left", fill="x", expand=True, padx=5)

        self.btn_save = ctk.CTkButton(frame_actions, text="Sauvegarder le projet sous...", command=self.save_file, state="disabled", height=2)
        self.btn_save.pack(side="right", fill="x", expand=True, padx=5)

    def set_start_value(self, val):
        self.entry_start.delete(0, "end")
        self.entry_start.insert(0, val)

    def open_file(self):
        path = filedialog.askopenfilename(filetypes=[("Projet QElectroTech", "*.qet"), ("Fichiers XML", "*.xml")])
        if path:
            self.file_path = path
            self.lbl_file.configure(text=os.path.basename(path), text_color="black")
            self.parse_qet()

    def parse_qet(self):
        try:
            self.tree = ET.parse(self.file_path)
            self.root_xml = self.tree.getroot()
            self.folios = []
            # Destroy old buttons
            for btn in self.folio_buttons:
                btn.destroy()
            self.folio_buttons = []
            self.selected_folio_index = None

            # Recherche des diagrammes (folios)
            for i, diagram in enumerate(self.root_xml.findall('diagram')):
                title = diagram.get('title', 'Sans titre')
                # Essayer de trouver le numéro de folio, sinon utiliser l'index
                num = diagram.get('folio', str(i + 1))
                label = f"N° {num} : {title}"
                self.folios.append(diagram)
                btn = ctk.CTkButton(self.scrollable_frame, text=label, command=lambda idx=i: self.on_folio_select(idx))
                btn.pack(fill="x", pady=2)
                self.folio_buttons.append(btn)
            
            self.btn_save.configure(state="normal")
            self.btn_process.configure(state="disabled")

        except Exception as e:
            messagebox.showerror("Erreur", f"Impossible de lire le fichier :\n{e}")

    def on_folio_select(self, idx):
        self.selected_folio_index = idx
        self.btn_process.configure(state="normal", fg_color="#aaffaa") # Vert clair quand actif
        
        # Réinitialiser le début selon le format
        default_val = "00" if self.var_format.get() == "00" else "0"
        self.entry_start.delete(0, "end")
        self.entry_start.insert(0, default_val)

    def process_folio(self):
        if self.selected_folio_index is None:
            return
        
        diagram = self.folios[self.selected_folio_index]
        prefix = self.entry_prefix.get()
        fmt = self.var_format.get()
        
        try:
            start_num = int(self.entry_start.get())
        except ValueError:
            start_num = 0

        try:
            count = self.renumber_wires_logic(diagram, prefix, fmt, start_num)
            messagebox.showinfo("Succès", f"Folio renuméroté avec succès.\n{count} équipotentielles traitées.")
        except Exception as e:
            messagebox.showerror("Erreur", f"Une erreur est survenue lors de la numérotation :\n{e}")

    def get_definitions(self, root_xml):
        definitions = {}
        collection = root_xml.find('collection')
        if collection is None:
            return definitions

        def traverse(element, current_path):
            for child in element:
                if child.tag == 'category':
                    name = child.get('name')
                    new_path = f"{current_path}/{name}" if current_path else name
                    traverse(child, new_path)
                elif child.tag == 'element':
                    name = child.get('name')
                    uri = f"embed://{current_path}/{name}" if current_path else f"embed://{name}"
                    definition = child.find('definition')
                    if definition is not None:
                        definitions[uri] = definition

        traverse(collection, "")
        return definitions

    def renumber_wires_logic(self, diagram, prefix, fmt, start_num=0):
        # 0. Index definitions
        definitions = self.get_definitions(self.root_xml)

        # 1. Identification des bornes de continuité et positions des terminaux
        continuity_uuids = set()
        terminals_pos = {} # (element_uuid, terminal_uuid) -> (x, y)
        elements_pos = {} # uuid -> (x, y) pour position par défaut

        for elem in diagram.findall('elements/element'):
            uuid = elem.get('uuid')
            type_uri = elem.get('type', '')
            
            # Stockage position pour tri géométrique si le fil n'a pas de points
            try:
                x = float(elem.get('x', 0))
                y = float(elem.get('y', 0))
                elements_pos[uuid] = (x, y)
            except:
                continue

            # Détection borne continuité
            if 'continuite' in type_uri.lower():
                continuity_uuids.add(uuid)
            
            # Parse terminals from definition
            def_node = definitions.get(type_uri)
            if def_node is not None:
                # Check link_type in definition if needed
                if 'continuite' in def_node.get('link_type', '').lower():
                     continuity_uuids.add(uuid)

                for term in def_node.findall('description/terminal'):
                    t_uuid = term.get('uuid')
                    if t_uuid:
                        try:
                            t_x = float(term.get('x', 0))
                            t_y = float(term.get('y', 0))
                            terminals_pos[(uuid, t_uuid)] = (x + t_x, y + t_y)
                        except:
                            pass

        # 2. Construction du graphe de connexions (Union-Find)
        conductors = diagram.findall('conductors/conductor')
        parent = list(range(len(conductors)))

        def find(i):
            if parent[i] == i: return i
            parent[i] = find(parent[i])
            return parent[i]

        def union(i, j):
            root_i = find(i)
            root_j = find(j)
            if root_i != root_j:
                parent[root_i] = root_j

        # Dictionnaire : (ElementUUID, TerminalUUID) -> Liste d'index de conducteurs
        terminals_map = {}

        for i, cond in enumerate(conductors):
            # Récupération des extrémités du fil
            e1_uid = cond.get('element1')
            t1_uid = cond.get('terminal1')
            e2_uid = cond.get('element2')
            t2_uid = cond.get('terminal2')

            # Enregistrement des connexions
            for e_uid, t_uid in [(e1_uid, t1_uid), (e2_uid, t2_uid)]:
                if not e_uid: continue

                # Si c'est une borne de continuité, on ignore le terminal spécifique
                # pour considérer que tout ce qui touche cet élément est connecté.
                if e_uid in continuity_uuids:
                    key = (e_uid, 'COMMON_POTENTIAL')
                else:
                    key = (e_uid, t_uid)

                if key not in terminals_map:
                    terminals_map[key] = []
                terminals_map[key].append(i)

        # Fusion des groupes (Union)
        for key, indices in terminals_map.items():
            base = indices[0]
            for other in indices[1:]:
                union(base, other)

        # 3. Regroupement par équipotentielle
        groups = {}
        for i, cond in enumerate(conductors):
            root = find(i)
            if root not in groups:
                groups[root] = []
            groups[root].append(cond)

        # 4. Filtrage et Tri
        equipotentials_to_process = []

        for root, group in groups.items():
            # Vérifier si l'équipotentielle est déjà numérotée (verrouillée)
            # On ignore si un des fils a un numéro qui n'est pas "" ou "_"
            is_locked = False
            for cond in group:
                num = cond.get('num', '')
                if num and num != "_":
                    is_locked = True
                    break
            
            if is_locked:
                continue

            # Calcul de la position (le point le plus en haut à gauche de tout le réseau)
            min_x = float('inf')
            min_y = float('inf')

            for cond in group:
                # Vérifier les points du tracé du fil
                points = cond.findall('point')
                if points:
                    for pt in points:
                        px = float(pt.get('x', 0))
                        py = float(pt.get('y', 0))
                        if px < min_x: min_x = px
                        if py < min_y: min_y = py
                else:
                    # Si pas de points (connexion directe), utiliser la position des terminaux
                    e1 = cond.get('element1')
                    t1 = cond.get('terminal1')
                    if e1 and t1 and (e1, t1) in terminals_pos:
                        tx, ty = terminals_pos[(e1, t1)]
                        if tx < min_x: min_x = tx
                        if ty < min_y: min_y = ty
                    elif e1 and e1 in elements_pos:
                        ex, ey = elements_pos[e1]
                        if ex < min_x: min_x = ex
                        if ey < min_y: min_y = ey
                    
                    e2 = cond.get('element2')
                    t2 = cond.get('terminal2')
                    if e2 and t2 and (e2, t2) in terminals_pos:
                        tx, ty = terminals_pos[(e2, t2)]
                        if tx < min_x: min_x = tx
                        if ty < min_y: min_y = ty
                    elif e2 and e2 in elements_pos:
                        ex, ey = elements_pos[e2]
                        if ex < min_x: min_x = ex
                        if ey < min_y: min_y = ey
            
            # Sécurité si infini
            if min_x == float('inf'): min_x = 0
            if min_y == float('inf'): min_y = 0

            equipotentials_to_process.append({
                'wires': group,
                'x': min_x,
                'y': min_y
            })

        # Tri : Haut en Bas (y), puis Gauche à Droite (x) pour un balayage par ligne.
        equipotentials_to_process.sort(key=lambda e: (e['y'], e['x']))

        # 5. Application de la numérotation
        counter = start_num
        for eq in equipotentials_to_process:
            # Formatage du numéro
            num_part = str(counter)
            if fmt == "00" and counter < 10:
                num_part = "0" + num_part
            
            label = f"{prefix}{num_part}"

            # Mise à jour de tous les fils du groupe
            for cond in eq['wires']:
                cond.set('num', label)
                # Note: QET met à jour l'affichage automatiquement basé sur l'attribut 'num'
            
            counter += 1
        
        return len(equipotentials_to_process)

    def save_file(self):
        if not self.tree: return
        
        path = filedialog.asksaveasfilename(defaultextension=".qet", filetypes=[("Projet QElectroTech", "*.qet")])
        if path:
            try:
                self.tree.write(path, encoding="UTF-8", xml_declaration=True)
                messagebox.showinfo("Sauvegardé", f"Fichier sauvegardé :\n{path}")
            except Exception as e:
                messagebox.showerror("Erreur", f"Erreur lors de la sauvegarde :\n{e}")

if __name__ == "__main__":
    root = ctk.CTk()
    app = QETRenumberingApp(root)
    root.mainloop()
