🚨 Python GNN Threat Detector GUI | Real-Time Cyber Attack Detection with AI + CustomTkinter

 Demo :


Click Video πŸ‘‡πŸ‘‡πŸ‘‡


























πŸ“Œ Features

  • Embedded YouTube Short πŸŽ₯

  • Project explanation (code + GUI screenshots)

  • SEO keywords for AI in Cybersecurity, Python GNN, Threat Detection App


Code :


# gnn_threat_gui.py

# Machine Learning – Graph Neural Networks (GNNs) for Threat Detection

# "Network traffic ko graph ke form mein analyze karke attacks pakadna + CTkinter GUI"

# Author: FuzzuTech Plan (Modern + Stylish + Attractive)


import os

import sys

import time

import math

import threading

import queue

import random

from dataclasses import dataclass, field

from typing import Dict, List, Tuple


# ---- Third-party ----

import numpy as np

import torch

import torch.nn as nn

import torch.optim as optim

import networkx as nx

import matplotlib

matplotlib.use("Agg")  # for off-screen draw; embedding handles canvas

import matplotlib.pyplot as plt

from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

import customtkinter as ctk

import tkinter as tk

from tkinter import messagebox


# -----------------------------

# Utility: Reproducibility

# -----------------------------

def set_seed(seed=42):

    random.seed(seed)

    np.random.seed(seed)

    torch.manual_seed(seed)

set_seed(42)


# -----------------------------

# Simple GCN Layer (no PyG)

# H = ReLU(D^-1/2 Γ‚ D^-1/2 X W)

# -----------------------------

class GCNLayer(nn.Module):

    def __init__(self, in_dim, out_dim, activation=True, dropout=0.0):

        super().__init__()

        self.lin = nn.Linear(in_dim, out_dim, bias=False)

        self.act = nn.ReLU() if activation else nn.Identity()

        self.do = nn.Dropout(dropout)


    def forward(self, X, A):

        # A: adjacency with self loops (Γ‚)

        # normalize

        I = torch.eye(A.size(0), device=A.device)

        A_hat = A + I

        deg = A_hat.sum(dim=1)

        # avoid divide by zero

        deg_inv_sqrt = torch.pow(deg + 1e-8, -0.5)

        D_inv_sqrt = torch.diag(deg_inv_sqrt)

        A_norm = D_inv_sqrt @ A_hat @ D_inv_sqrt

        H = A_norm @ X

        H = self.lin(H)

        H = self.act(H)

        H = self.do(H)

        return H


# -----------------------------

# GCN Model for Node Anomaly

# -----------------------------

class GCNAnomaly(nn.Module):

    def __init__(self, in_dim, hidden=32):

        super().__init__()

        self.gcn1 = GCNLayer(in_dim, hidden, activation=True, dropout=0.1)

        self.gcn2 = GCNLayer(hidden, 1, activation=False, dropout=0.0)  # score/logit

        self.sig = nn.Sigmoid()


    def forward(self, X, A):

        h = self.gcn1(X, A)

        out = self.gcn2(h, A)  # (N,1)

        return self.sig(out).squeeze(-1)  # anomaly probability (0..1)


# -----------------------------

# Stream Simulator (synthetic packets -> graph)

# -----------------------------

@dataclass

class NodeStats:

    in_bytes: float = 0.0

    out_bytes: float = 0.0

    fails: int = 0

    succ: int = 0

    last_seen: float = 0.0

    degree: int = 0


@dataclass

class TrafficGraph:

    num_nodes: int = 48

    nodes: List[str] = field(default_factory=list)

    G: nx.Graph = field(default_factory=nx.Graph)

    stats: Dict[str, NodeStats] = field(default_factory=dict)


    def __post_init__(self):

        if not self.nodes:

            self.nodes = [f"10.0.0.{i+1}" for i in range(self.num_nodes)]

        self.G.add_nodes_from(self.nodes)

        for n in self.nodes:

            self.stats[n] = NodeStats()


    def step(self, now=None, attack_mode=False):

        """Generate random packets and update graph."""

        if now is None:

            now = time.time()


        # randomly choose pairs

        pairs = random.randint(30, 80)

        for _ in range(pairs):

            src, dst = random.sample(self.nodes, 2)

            # create/update edge weight (traffic)

            bytes_sent = np.clip(np.random.lognormal(mean=7.5, sigma=1.0), 100, 3e6)

            success_prob = 0.96


            # if attack_mode, inject higher fails from a small attacker set

            if attack_mode:

                # choose small attacker community

                attacker_ids = set(self.nodes[:max(2, self.num_nodes // 16)])

                if src in attacker_ids or dst in attacker_ids:

                    success_prob = 0.6

                    bytes_sent *= np.random.uniform(0.2, 1.0)


            success = np.random.rand() < success_prob


            # update stats

            self.stats[src].out_bytes += bytes_sent

            self.stats[dst].in_bytes += bytes_sent

            if success:

                self.stats[src].succ += 1

            else:

                self.stats[src].fails += 1

            self.stats[src].last_seen = now

            self.stats[dst].last_seen = now


            # add/update edge

            w = self.G.get_edge_data(src, dst, default={"weight": 0})["weight"] + bytes_sent

            self.G.add_edge(src, dst, weight=w)


        # update degrees

        for n in self.nodes:

            self.stats[n].degree = self.G.degree[n]


    def feature_matrix(self) -> np.ndarray:

        """Build node feature matrix: [deg, in_bytes, out_bytes, fail_ratio, activity, clustering]"""

        features = []

        max_in = max(1.0, max((s.in_bytes for s in self.stats.values()), default=1.0))

        max_out = max(1.0, max((s.out_bytes for s in self.stats.values()), default=1.0))

        now = time.time()

        clust = nx.clustering(self.G, weight="weight")

        for n in self.nodes:

            s = self.stats[n]

            deg = float(s.degree)

            in_norm = s.in_bytes / max_in

            out_norm = s.out_bytes / max_out

            total = s.succ + s.fails

            fail_ratio = (s.fails / total) if total > 0 else 0.0

            # activity: recentness scaled 0..1 (last 60s)

            dt = now - s.last_seen

            activity = math.exp(-dt / 60.0) if s.last_seen > 0 else 0.0

            ccoef = clust.get(n, 0.0)

            features.append([deg, in_norm, out_norm, fail_ratio, activity, ccoef])

        X = np.array(features, dtype=np.float32)

        # normalize degree

        if X.shape[0] > 0 and X[:,0].max() > 0:

            X[:,0] = X[:,0] / (X[:,0].max() + 1e-8)

        return X


    def adjacency_matrix(self) -> np.ndarray:

        A = nx.to_numpy_array(self.G, nodelist=self.nodes, weight=None, dtype=np.float32)

        return A


# -----------------------------

# Synthetic Training Data Builder

# -----------------------------

def build_synthetic_batch(n_graphs=12, n_nodes=48) -> List[Tuple[np.ndarray, np.ndarray, np.ndarray]]:

    """Return list of (X, A, y) for training. y is binary attack label per node."""

    data = []

    for g in range(n_graphs):

        tg = TrafficGraph(num_nodes=n_nodes)

        # normal warmup

        for _ in range(random.randint(4, 8)):

            tg.step(attack_mode=False)

        # inject attack for half of graphs

        attack_on = (g % 2 == 0)

        for _ in range(random.randint(3, 6)):

            tg.step(attack_mode=attack_on)


        X = tg.feature_matrix()

        A = tg.adjacency_matrix()


        # Label rule for training: nodes with high fail_ratio + above-median degree are attacks

        fail_ratio = X[:,3]

        deg = X[:,0]

        thresh_f = np.quantile(fail_ratio, 0.8)

        thresh_d = np.quantile(deg, 0.6)

        y = ((fail_ratio > thresh_f) & (deg > thresh_d)).astype(np.float32)


        data.append((X, A, y))

    return data


def numpy_to_torch(X: np.ndarray, A: np.ndarray, y: np.ndarray, device) -> Tuple[torch.Tensor,...]:

    Xt = torch.from_numpy(X).to(device)

    At = torch.from_numpy(A).to(device)

    yt = torch.from_numpy(y).to(device)

    return Xt, At, yt


# -----------------------------

# Trainer

# -----------------------------

def train_quick(model: nn.Module, device="cpu", epochs=60, lr=1e-2, n_graphs=16, n_nodes=48, progress_cb=None):

    model.to(device)

    opt = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)

    loss_fn = nn.BCELoss()

    model.train()


    batch = build_synthetic_batch(n_graphs=n_graphs, n_nodes=n_nodes)

    for ep in range(1, epochs+1):

        total_loss = 0.0

        random.shuffle(batch)

        for (X, A, y) in batch:

            Xt, At, yt = numpy_to_torch(X, A, y, device)

            opt.zero_grad()

            pred = model(Xt, At)

            loss = loss_fn(pred, yt)

            loss.backward()

            opt.step()

            total_loss += float(loss.item())

        if progress_cb:

            progress_cb(ep, epochs, total_loss/len(batch))

    return model


# -----------------------------

# Live Inference

# -----------------------------

def infer_scores(model: nn.Module, tg: TrafficGraph, device="cpu") -> np.ndarray:

    model.eval()

    X = tg.feature_matrix()

    A = tg.adjacency_matrix()

    with torch.no_grad():

        Xt = torch.from_numpy(X).to(device)

        At = torch.from_numpy(A).to(device)

        scores = model(Xt, At).cpu().numpy()

    return scores  # 0..1 anomaly probability


# -----------------------------

# GUI (CustomTkinter)

# -----------------------------

class ThreatGUI(ctk.CTk):

    def __init__(self):

        super().__init__()

        self.title("GNN Threat Detection – FuzzuTech")

        ctk.set_appearance_mode("dark")

        ctk.set_default_color_theme("blue")


        self.geometry("1100x700")

        self.minsize(980, 620)


        # Data / Model

        self.device = "cpu"

        self.tg = TrafficGraph(num_nodes=48)

        self.model = GCNAnomaly(in_dim=6, hidden=48)

        self.streaming = False

        self.stream_thread = None

        self.stream_q = queue.Queue()

        self.attack_toggle = tk.BooleanVar(value=True)

        self.threshold = tk.DoubleVar(value=0.65)

        self.last_scores = np.zeros(self.tg.num_nodes, dtype=np.float32)

        self.pos_cache = None  # layout cache


        # Layout: left control panel, right graph panel

        self.columnconfigure(0, weight=0)

        self.columnconfigure(1, weight=1)

        self.rowconfigure(0, weight=1)


        self._build_left_panel()

        self._build_graph_panel()


        # Initial plot

        self.after(300, self._refresh_plot)


    # Left panel with cards

    def _build_left_panel(self):

        panel = ctk.CTkFrame(self, corner_radius=16, fg_color=("gray12", "gray11"))

        panel.grid(row=0, column=0, sticky="nsw", padx=14, pady=14)

        panel.grid_propagate(False)

        panel.configure(width=310)


        # Header

        title = ctk.CTkLabel(panel, text="Threat Dashboard", font=("Segoe UI Semibold", 20))

        title.pack(pady=(14, 6))


        # Status Card

        self.status_lbl = ctk.CTkLabel(panel, text="Status: Idle", font=("Inter", 14))

        self.status_lbl.pack(pady=(6, 6))


        # Threshold slider

        thr_frame = ctk.CTkFrame(panel, corner_radius=12)

        thr_frame.pack(fill="x", padx=10, pady=8)

        ctk.CTkLabel(thr_frame, text="Anomaly Threshold").pack(pady=(8, 0))

        thr = ctk.CTkSlider(thr_frame, from_=0.3, to=0.95, number_of_steps=65,

                            variable=self.threshold, command=lambda v: self._update_status())

        thr.pack(fill="x", padx=12, pady=10)

        self.thr_value_lbl = ctk.CTkLabel(thr_frame, text=f"{self.threshold.get():.2f}")

        self.thr_value_lbl.pack(pady=(0, 10))


        # Attack toggle (for synthetic stream)

        atk_frame = ctk.CTkFrame(panel, corner_radius=12)

        atk_frame.pack(fill="x", padx=10, pady=8)

        atk_chk = ctk.CTkCheckBox(atk_frame, text="Inject Attacks in Stream",

                                  variable=self.attack_toggle, onvalue=True, offvalue=False,

                                  command=self._update_status)

        atk_chk.pack(padx=10, pady=10)


        # Train button + progress

        train_frame = ctk.CTkFrame(panel, corner_radius=12)

        train_frame.pack(fill="x", padx=10, pady=8)

        ctk.CTkLabel(train_frame, text="Model Training").pack(pady=(8, 0))

        self.progress = ctk.CTkProgressBar(train_frame)

        self.progress.set(0)

        self.progress.pack(fill="x", padx=12, pady=6)

        self.loss_lbl = ctk.CTkLabel(train_frame, text="Loss: —")

        self.loss_lbl.pack(pady=(0, 8))

        btn_train = ctk.CTkButton(train_frame, text="Train (Quick ~60 epochs)",

                                  command=self._train_async, corner_radius=12)

        btn_train.pack(padx=10, pady=(0, 10))


        # Stream controls

        ctl_frame = ctk.CTkFrame(panel, corner_radius=12)

        ctl_frame.pack(fill="x", padx=10, pady=8)

        ctk.CTkLabel(ctl_frame, text="Live Stream").pack(pady=(8, 0))

        btn_start = ctk.CTkButton(ctl_frame, text="Start Stream", command=self._start_stream, corner_radius=12)

        btn_stop = ctk.CTkButton(ctl_frame, text="Stop Stream", command=self._stop_stream, corner_radius=12)

        btn_reset = ctk.CTkButton(ctl_frame, text="Reset Graph", command=self._reset_graph, corner_radius=12)

        btn_start.pack(fill="x", padx=10, pady=(8, 4))

        btn_stop.pack(fill="x", padx=10, pady=4)

        btn_reset.pack(fill="x", padx=10, pady=(4, 10))


        # Alerts panel

        alerts = ctk.CTkFrame(panel, corner_radius=12)

        alerts.pack(fill="both", expand=True, padx=10, pady=(8, 12))

        ctk.CTkLabel(alerts, text="Alerts").pack(pady=(8, 6))

        self.alert_list = tk.Listbox(alerts, height=8, bg="#121212", fg="#ED5E68",

                                     activestyle="none", highlightthickness=0, borderwidth=0)

        self.alert_list.pack(fill="both", expand=True, padx=10, pady=(0, 10))


    # Right graph panel

    def _build_graph_panel(self):

        right = ctk.CTkFrame(self, corner_radius=16)

        right.grid(row=0, column=1, sticky="nsew", padx=(0, 14), pady=14)

        right.rowconfigure(1, weight=1)

        right.columnconfigure(0, weight=1)


        header = ctk.CTkLabel(right, text="Network Graph (Live)", font=("Segoe UI Semibold", 20))

        header.grid(row=0, column=0, sticky="w", padx=16, pady=(12, 6))


        # matplotlib figure

        self.fig = plt.figure(figsize=(6.5, 4.5), dpi=100)

        self.ax = self.fig.add_subplot(111)

        self.ax.axis("off")

        self.canvas = FigureCanvasTkAgg(self.fig, master=right)

        self.canvas.get_tk_widget().grid(row=1, column=0, sticky="nsew", padx=12, pady=12)


        # legend

        legend_frame = ctk.CTkFrame(right, corner_radius=10)

        legend_frame.grid(row=2, column=0, sticky="ew", padx=12, pady=(0,12))

        ctk.CTkLabel(legend_frame, text="Legend:  • Normal  • Suspicious  • Attack", text_color="#cfcfcf").pack(pady=6)


    # ---- Training (threaded) ----

    def _train_async(self):

        def cb(ep, total, loss):

            self.progress.set(ep/total)

            self.loss_lbl.configure(text=f"Loss: {loss:.4f}  (Epoch {ep}/{total})")

            self.update_idletasks()

        def run():

            try:

                self._set_status("Training...")

                train_quick(self.model, device=self.device, epochs=60, lr=1e-2,

                            n_graphs=16, n_nodes=self.tg.num_nodes, progress_cb=cb)

                self._set_status("Training complete ✓")

            except Exception as e:

                self._set_status("Training error")

                messagebox.showerror("Training Error", str(e))

        threading.Thread(target=run, daemon=True).start()


    # ---- Streaming ----

    def _start_stream(self):

        if self.streaming:

            return

        self.streaming = True

        self._set_status("Streaming...")

        self.stream_thread = threading.Thread(target=self._stream_loop, daemon=True)

        self.stream_thread.start()


    def _stop_stream(self):

        self.streaming = False

        self._set_status("Stopped")


    def _reset_graph(self):

        self.tg = TrafficGraph(num_nodes=self.tg.num_nodes)

        self.alert_list.delete(0, tk.END)

        self.last_scores = np.zeros(self.tg.num_nodes, dtype=np.float32)

        self.pos_cache = None

        self._refresh_plot()

        self._set_status("Graph reset")


    def _stream_loop(self):

        while self.streaming:

            self.tg.step(attack_mode=self.attack_toggle.get())

            # inference

            try:

                scores = infer_scores(self.model, self.tg, device=self.device)

                self.last_scores = scores

                # alerts

                self._evaluate_alerts(scores)

            except Exception as e:

                self._set_status("Inference error")

                print("Inference error:", e, file=sys.stderr)

            # redraw plot

            self._refresh_plot()

            time.sleep(0.7)


    def _evaluate_alerts(self, scores: np.ndarray):

        thr = float(self.threshold.get())

        idxs = np.where(scores >= thr)[0]

        ts = time.strftime("%H:%M:%S")

        for i in idxs:

            node = self.tg.nodes[i]

            score = scores[i]

            # latest entry for node? avoid spamming: only add if last item isn't same node

            last = self.alert_list.get(tk.END) if self.alert_list.size() > 0 else ""

            if node not in last:

                self.alert_list.insert(tk.END, f"[{ts}] ATTACK? {node}  (p={score:.2f})")

                # keep list short

                if self.alert_list.size() > 200:

                    self.alert_list.delete(0)


    # ---- Plotting ----

    def _refresh_plot(self):

        self.ax.clear()

        self.ax.axis("off")


        # layout cache for stability

        if self.pos_cache is None or len(self.pos_cache) != len(self.tg.nodes):

            self.pos_cache = nx.spring_layout(self.tg.G, seed=42, k=0.45, iterations=50, weight="weight")

        pos = self.pos_cache


        # node colors based on score

        scores = self.last_scores if self.last_scores is not None else np.zeros(len(self.tg.nodes))

        thr = float(self.threshold.get())

        colors = []

        sizes = []

        for i, n in enumerate(self.tg.nodes):

            s = scores[i] if i < len(scores) else 0.0

            # color bands: green <0.4, orange 0.4..thr, red >= thr

            if s >= thr:

                colors.append("#F44336")  # red

                sizes.append(180)

            elif s >= 0.4:

                colors.append("#FFC107")  # amber

                sizes.append(140)

            else:

                colors.append("#4CAF50")  # green

                sizes.append(100)


        # edges with low alpha

        nx.draw_networkx_edges(self.tg.G, pos, ax=self.ax, alpha=0.25)

        nx.draw_networkx_nodes(self.tg.G, pos, node_color=colors, node_size=sizes, ax=self.ax, linewidths=0.5)

        # top-N labels by score

        if len(scores) > 0:

            topk = np.argsort(scores)[-5:][::-1]

            labels = {self.tg.nodes[i]: f"{self.tg.nodes[i]} ({scores[i]:.2f})" for i in topk}

            nx.draw_networkx_labels(self.tg.G, pos, labels=labels, font_size=7, ax=self.ax, font_color="#E0E0E0")


        # mini title

        self.ax.set_title("Live Threat Scores (GCN)", fontsize=12, color="#C8E6C9", pad=6)

        self.canvas.draw_idle()


        # small UI sync

        self.thr_value_lbl.configure(text=f"{thr:.2f}")

        self.update_idletasks()


    # ---- helpers ----

    def _set_status(self, text):

        self.status_lbl.configure(text=f"Status: {text}")


    def _update_status(self):

        thr = float(self.threshold.get())

        self.thr_value_lbl.configure(text=f"{thr:.2f}")


# -----------------------------

# Main

# -----------------------------

if __name__ == "__main__":

    try:

        app = ThreatGUI()

        app.mainloop()

    except Exception as e:

        print("Fatal error:", e)

        sys.exit(1)

Comments

Popular posts from this blog

πŸš€ Simple Login & Registration System in Python Tkinter πŸ“±

πŸ“‘ Fuzzu Packet Sniffer – Python GUI for Real-Time IP Monitoring | Tkinter + Scapy

πŸ”₯ Advanced MP3 Music Player in Python | CustomTkinter + Pygame | Free Source Code