This tutorial will guide you through the entire process of training a YOLO segmentation model for ID document detection using the MIDV500 dataset and ultralytics. We'll cover everything including dataset acquisition, training, and creating a GUI application for document detection.
Demo Video: YOLO Segmentation for ID Document Detection
Getting the MIDV500 Dataset
The MIDV500 dataset is a comprehensive collection of identity documents designed for document detection and recognition tasks. It contains images of various identity documents from different countries.
midv500_links = [
"ftp://smartengines.com/midv-500/dataset/01_alb_id.zip",
"ftp://smartengines.com/midv-500/dataset/02_aut_drvlic_new.zip",
"ftp://smartengines.com/midv-500/dataset/03_aut_id_old.zip",
"ftp://smartengines.com/midv-500/dataset/04_aut_id.zip",
"ftp://smartengines.com/midv-500/dataset/05_aze_passport.zip",
"ftp://smartengines.com/midv-500/dataset/06_bra_passport.zip",
"ftp://smartengines.com/midv-500/dataset/07_chl_id.zip",
"ftp://smartengines.com/midv-500/dataset/08_chn_homereturn.zip",
"ftp://smartengines.com/midv-500/dataset/09_chn_id.zip",
"ftp://smartengines.com/midv-500/dataset/10_cze_id.zip",
"ftp://smartengines.com/midv-500/dataset/11_cze_passport.zip",
"ftp://smartengines.com/midv-500/dataset/12_deu_drvlic_new.zip",
"ftp://smartengines.com/midv-500/dataset/13_deu_drvlic_old.zip",
"ftp://smartengines.com/midv-500/dataset/14_deu_id_new.zip",
"ftp://smartengines.com/midv-500/dataset/15_deu_id_old.zip",
"ftp://smartengines.com/midv-500/dataset/16_deu_passport_new.zip",
"ftp://smartengines.com/midv-500/dataset/17_deu_passport_old.zip",
"ftp://smartengines.com/midv-500/dataset/18_dza_passport.zip",
"ftp://smartengines.com/midv-500/dataset/19_esp_drvlic.zip",
"ftp://smartengines.com/midv-500/dataset/20_esp_id_new.zip",
"ftp://smartengines.com/midv-500/dataset/21_esp_id_old.zip",
"ftp://smartengines.com/midv-500/dataset/22_est_id.zip",
"ftp://smartengines.com/midv-500/dataset/23_fin_drvlic.zip",
"ftp://smartengines.com/midv-500/dataset/24_fin_id.zip",
"ftp://smartengines.com/midv-500/dataset/25_grc_passport.zip",
"ftp://smartengines.com/midv-500/dataset/26_hrv_drvlic.zip",
"ftp://smartengines.com/midv-500/dataset/27_hrv_passport.zip",
"ftp://smartengines.com/midv-500/dataset/28_hun_passport.zip",
"ftp://smartengines.com/midv-500/dataset/29_irn_drvlic.zip",
"ftp://smartengines.com/midv-500/dataset/30_ita_drvlic.zip",
"ftp://smartengines.com/midv-500/dataset/31_jpn_drvlic.zip",
"ftp://smartengines.com/midv-500/dataset/32_lva_passport.zip",
"ftp://smartengines.com/midv-500/dataset/33_mac_id.zip",
"ftp://smartengines.com/midv-500/dataset/34_mda_passport.zip",
"ftp://smartengines.com/midv-500/dataset/35_nor_drvlic.zip",
"ftp://smartengines.com/midv-500/dataset/36_pol_drvlic.zip",
"ftp://smartengines.com/midv-500/dataset/37_prt_id.zip",
"ftp://smartengines.com/midv-500/dataset/38_rou_drvlic.zip",
"ftp://smartengines.com/midv-500/dataset/39_rus_internalpassport.zip",
"ftp://smartengines.com/midv-500/dataset/40_srb_id.zip",
"ftp://smartengines.com/midv-500/dataset/41_srb_passport.zip",
"ftp://smartengines.com/midv-500/dataset/42_svk_id.zip",
"ftp://smartengines.com/midv-500/dataset/43_tur_id.zip",
"ftp://smartengines.com/midv-500/dataset/44_ukr_id.zip",
"ftp://smartengines.com/midv-500/dataset/45_ukr_passport.zip",
"ftp://smartengines.com/midv-500/dataset/46_ury_passport.zip",
"ftp://smartengines.com/midv-500/dataset/47_usa_bordercrossing.zip",
"ftp://smartengines.com/midv-500/dataset/48_usa_passportcard.zip",
"ftp://smartengines.com/midv-500/dataset/49_usa_ssn82.zip",
"ftp://smartengines.com/midv-500/dataset/50_xpo_id.zip",
]
midv2019_links = [
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/01_alb_id.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/02_aut_drvlic_new.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/03_aut_id_old.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/04_aut_id.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/05_aze_passport.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/06_bra_passport.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/07_chl_id.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/08_chn_homereturn.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/09_chn_id.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/10_cze_id.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/11_cze_passport.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/12_deu_drvlic_new.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/13_deu_drvlic_old.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/14_deu_id_new.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/15_deu_id_old.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/16_deu_passport_new.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/17_deu_passport_old.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/18_dza_passport.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/19_esp_drvlic.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/20_esp_id_new.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/21_esp_id_old.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/22_est_id.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/23_fin_drvlic.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/24_fin_id.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/25_grc_passport.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/26_hrv_drvlic.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/27_hrv_passport.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/28_hun_passport.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/29_irn_drvlic.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/30_ita_drvlic.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/31_jpn_drvlic.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/32_lva_passport.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/33_mac_id.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/34_mda_passport.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/35_nor_drvlic.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/36_pol_drvlic.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/37_prt_id.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/38_rou_drvlic.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/39_rus_internalpassport.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/40_srb_id.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/41_srb_passport.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/42_svk_id.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/43_tur_id.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/44_ukr_id.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/45_ukr_passport.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/46_ury_passport.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/47_usa_bordercrossing.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/48_usa_passportcard.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/49_usa_ssn82.zip",
"ftp://smartengines.com/midv-500/extra/midv-2019/dataset/50_xpo_id.zip",
]
Installation
pip install midv500
Download
import midv500
dataset_dir = 'midv500_data/'
dataset_name = "all"
midv500.download_dataset(dataset_dir, dataset_name)
The expected directory structure after downloading is as follows:
midv500_data/
midv500/
01_alb_id/
ground_truth/
01_alb_id.json
CA/, CS/, HA/, HS/, KA/, KS/, PA/, PS/, TA/, TS/
*.json (annotation files)
images/
01_alb_id.tif
CA/, CS/, HA/, HS/, KA/, KS/, PA/, PS/, TA/, TS/
*.tif (image files)
videos/
*.MOV, *.mp4 (video files)
02_alb_pass/
... (similar structure)
... (50 document types total)
Converting MIDV500 to YOLO Format
The MIDV500 dataset uses JSON annotations, but YOLO requires a specific format. We need to convert the dataset using a custom Python script.
Prerequisites
Install the required dependencies:
pip install ultralytics opencv-python numpy tqdm pyyaml pillow
Conversion Script
Create and run the conversion script:
import os, re, json, shutil, random, math
from pathlib import Path
import yaml
import cv2
import numpy as np
from tqdm import tqdm
MIDV_ROOT = Path("midv500_data/midv500")
OUT_ROOT = Path("./dataset")
TRAIN_SPLIT = 0.9
CLASS_NAME = "document"
IMG_EXT = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
JSON_KEYS = ["quad", "quadrilateral", "polygon", "points"]
def find_images(root: Path):
return [p for p in root.rglob("*") if p.suffix.lower() in IMG_EXT]
def parse_annotation(img_path: Path):
if "images" in img_path.parts:
parts = list(img_path.parts)
try:
images_idx = parts.index("images")
parts[images_idx] = "ground_truth"
json_path = Path(*parts).with_suffix(".json")
if json_path.exists():
try:
data = json.loads(json_path.read_text())
if "quad" in data and isinstance(data["quad"], list) and len(data["quad"]) >= 4:
quad = data["quad"][:4]
result = []
for pt in quad:
if isinstance(pt, (list, tuple)) and len(pt) >= 2:
result.append([float(pt[0]), float(pt[1])])
if len(result) == 4:
return result
except Exception as e:
print(f"Error parsing {json_path}: {e}")
pass
except ValueError:
pass
sidecar = img_path.with_suffix(".json")
if sidecar.exists():
try:
data = json.loads(sidecar.read_text())
for k in JSON_KEYS:
if k in data and isinstance(data[k], (list, tuple)) and len(data[k]) >= 4:
pts = data[k][:4]
quad = []
for pt in pts:
if isinstance(pt, dict):
quad.append([float(pt["x"]), float(pt["y"])])
else:
quad.append([float(pt[0]), float(pt[1])])
return quad
except Exception:
pass
for parent in [img_path.parent, img_path.parent.parent]:
gt = parent / "ground_truth.txt"
if gt.exists():
stem = img_path.stem
with gt.open("r", encoding="utf-8", errors="ignore") as f:
for line in f:
parts = re.split(r"\s+", line.strip())
if len(parts) >= 9:
# first token can be filename or id
key = Path(parts[0]).stem
if key == stem:
try:
vals = list(map(float, parts[1:9]))
quad = [[vals[0], vals[1]],
[vals[2], vals[3]],
[vals[4], vals[5]],
[vals[6], vals[7]]]
return quad
except Exception:
pass
return None
def order_quad(quad):
"""Order points TL, TR, BR, BL."""
pts = np.array(quad, dtype=np.float32)
s = pts.sum(axis=1)
d = np.diff(pts, axis=1).ravel()
ordered = np.array([pts[np.argmin(s)],
pts[np.argmin(d)],
pts[np.argmax(s)],
pts[np.argmax(d)]], dtype=np.float32)
return ordered.tolist()
def main():
all_imgs = find_images(MIDV_ROOT)
if not all_imgs:
raise SystemExit(f"No images found under {MIDV_ROOT}")
records = []
for img in tqdm(all_imgs, desc="Scanning"):
quad = parse_annotation(img)
if quad is None:
continue
im = cv2.imread(str(img))
if im is None:
continue
h, w = im.shape[:2]
quad = order_quad(quad)
norm = []
for x,y in quad:
x = max(0, min(w-1, x)) / w
y = max(0, min(h-1, y)) / h
norm.append([x, y])
records.append((img, norm, (w,h)))
if not records:
raise SystemExit("Found no usable annotations. Adjust parse_annotation().")
random.shuffle(records)
n_train = int(len(records) * TRAIN_SPLIT)
train, val = records[:n_train], records[n_train:]
(OUT_ROOT / "images" / "train").mkdir(parents=True, exist_ok=True)
(OUT_ROOT / "images" / "val").mkdir(parents=True, exist_ok=True)
(OUT_ROOT / "labels" / "train").mkdir(parents=True, exist_ok=True)
(OUT_ROOT / "labels" / "val").mkdir(parents=True, exist_ok=True)
def write_example(split_dir, img_path, norm_quad):
dst_img = OUT_ROOT / "images" / split_dir / (img_path.stem + ".jpg")
dst_lbl = OUT_ROOT / "labels" / split_dir / (img_path.stem + ".txt")
img = cv2.imread(str(img_path))
if img is not None:
cv2.imwrite(str(dst_img), img, [cv2.IMWRITE_JPEG_QUALITY, 95])
else:
shutil.copy2(img_path, dst_img)
flat = " ".join(f"{x:.6f} {y:.6f}" for x,y in norm_quad)
dst_lbl.write_text(f"0 {flat}\n")
for img, quad, _ in tqdm(train, desc="Write train"):
write_example("train", img, quad)
for img, quad, _ in tqdm(val, desc="Write val"):
write_example("val", img, quad)
data = {
"path": str(OUT_ROOT.resolve()),
"train": "images/train",
"val": "images/val",
"names": [CLASS_NAME],
"nc": 1,
}
yaml.safe_dump(data, open(OUT_ROOT / "doc.yaml", "w"))
print(f"\nDONE. Wrote dataset at: {OUT_ROOT}\nTotal: {len(records)}, Train: {len(train)}, Val: {len(val)}")
if __name__ == "__main__":
main()
Running the Conversion
Execute the conversion script:
python prep_midv500_to_yolov11seg.py
Dataset Visualization and Verification
Before training, it's crucial to verify that the dataset conversion was successful. We'll create a visualization tool to inspect the annotations.
# dataset_viewer.py
import sys
import os
import cv2
import numpy as np
from pathlib import Path
from PySide6.QtWidgets import (QApplication, QMainWindow, QVBoxLayout,
QHBoxLayout, QWidget, QLabel, QPushButton,
QComboBox, QSlider, QCheckBox)
from PySide6.QtCore import Qt
from PySide6.QtGui import QPixmap, QImage
class DatasetViewer(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("YOLO Dataset Viewer")
self.setGeometry(100, 100, 1200, 800)
self.dataset_path = Path("dataset")
self.current_split = "train"
self.current_index = 0
self.show_annotations = True
self.image_files = []
self.init_ui()
self.load_dataset()
self.display_current_image()
def init_ui(self):
central_widget = QWidget()
self.setCentralWidget(central_widget)
layout = QVBoxLayout(central_widget)
controls_layout = QHBoxLayout()
self.split_combo = QComboBox()
self.split_combo.addItems(["train", "val"])
self.split_combo.currentTextChanged.connect(self.change_split)
controls_layout.addWidget(QLabel("Split:"))
controls_layout.addWidget(self.split_combo)
self.prev_btn = QPushButton("Previous")
self.prev_btn.clicked.connect(self.prev_image)
controls_layout.addWidget(self.prev_btn)
self.next_btn = QPushButton("Next")
self.next_btn.clicked.connect(self.next_image)
controls_layout.addWidget(self.next_btn)
self.annotation_checkbox = QCheckBox("Show Annotations")
self.annotation_checkbox.setChecked(True)
self.annotation_checkbox.toggled.connect(self.toggle_annotations)
controls_layout.addWidget(self.annotation_checkbox)
self.counter_label = QLabel("0 / 0")
controls_layout.addWidget(self.counter_label)
controls_layout.addStretch()
layout.addLayout(controls_layout)
self.image_label = QLabel()
self.image_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
self.image_label.setStyleSheet("border: 1px solid gray;")
layout.addWidget(self.image_label)
def load_dataset(self):
"""Load image files for current split"""
images_dir = self.dataset_path / "images" / self.current_split
if images_dir.exists():
self.image_files = sorted(list(images_dir.glob("*.jpg")))
else:
self.image_files = []
self.current_index = 0
self.update_counter()
def display_current_image(self):
"""Display the current image with optional annotations"""
if not self.image_files:
self.image_label.setText("No images found")
return
if self.current_index >= len(self.image_files):
self.current_index = 0
img_path = self.image_files[self.current_index]
img = cv2.imread(str(img_path))
if img is None:
self.image_label.setText("Failed to load image")
return
if self.show_annotations:
img = self.draw_annotations(img, img_path)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w, ch = img_rgb.shape
bytes_per_line = ch * w
qt_image = QImage(img_rgb.data, w, h, bytes_per_line, QImage.Format.Format_RGB888)
pixmap = QPixmap.fromImage(qt_image)
scaled_pixmap = pixmap.scaled(self.image_label.size(),
Qt.AspectRatioMode.KeepAspectRatio,
Qt.TransformationMode.SmoothTransformation)
self.image_label.setPixmap(scaled_pixmap)
self.update_counter()
def draw_annotations(self, img, img_path):
"""Draw YOLO annotations on image"""
labels_dir = self.dataset_path / "labels" / self.current_split
label_path = labels_dir / f"{img_path.stem}.txt"
if not label_path.exists():
return img
h, w = img.shape[:2]
with open(label_path, 'r') as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split()
if len(parts) < 9:
continue
class_id = int(parts[0])
coords = list(map(float, parts[1:]))
points = []
for i in range(0, len(coords), 2):
x = int(coords[i] * w)
y = int(coords[i+1] * h)
points.append([x, y])
points = np.array(points, np.int32)
cv2.polylines(img, [points], True, (0, 255, 0), 2)
cv2.putText(img, "document", tuple(points[0]),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
return img
def update_counter(self):
"""Update the image counter display"""
total = len(self.image_files)
current = self.current_index + 1 if total > 0 else 0
self.counter_label.setText(f"{current} / {total}")
def change_split(self, split):
"""Change the dataset split"""
self.current_split = split
self.load_dataset()
self.display_current_image()
def prev_image(self):
"""Go to previous image"""
if self.image_files:
self.current_index = (self.current_index - 1) % len(self.image_files)
self.display_current_image()
def next_image(self):
"""Go to next image"""
if self.image_files:
self.current_index = (self.current_index + 1) % len(self.image_files)
self.display_current_image()
def toggle_annotations(self, checked):
"""Toggle annotation visibility"""
self.show_annotations = checked
self.display_current_image()
if __name__ == "__main__":
app = QApplication(sys.argv)
viewer = DatasetViewer()
viewer.show()
sys.exit(app.exec())
Run the dataset viewer:
python dataset_viewer.py
Training the YOLO Model
Now that we have prepared and verified our dataset, let's train the YOLO segmentation model and export it as an ONNX model.
Training Script
import subprocess
import sys
import os
import yaml
import tempfile
from pathlib import Path
def create_dynamic_config():
"""Create a temporary YAML config with dynamic absolute path."""
dataset_dir = Path(__file__).parent / "dataset"
dataset_dir = dataset_dir.resolve()
config = {
'names': ['document'],
'nc': 1,
'path': str(dataset_dir),
'train': 'images/train',
'val': 'images/val'
}
temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False)
yaml.dump(config, temp_file, default_flow_style=False)
temp_file.close()
print(f"Created temporary config: {temp_file.name}")
print(f"Dataset path: {dataset_dir}")
return temp_file.name
def cleanup_temp_file(temp_path):
"""Clean up temporary file."""
try:
os.unlink(temp_path)
print(f"Cleaned up temporary config: {temp_path}")
except:
pass
def run_command(cmd, description):
"""Run a command and print its output."""
print(f"\n=== {description} ===")
print(f"Running: {cmd}")
try:
result = subprocess.run(cmd, shell=True, check=True)
return True
except subprocess.CalledProcessError as e:
print(f"Command failed with exit code {e.returncode}")
return False
def main():
if not Path("dataset/images/train").exists() or not Path("dataset/images/val").exists():
print("Error: dataset/images/train or dataset/images/val not found.")
print("Please run prep_midv500_to_yolov11seg.py first.")
sys.exit(1)
temp_config = create_dynamic_config()
try:
download_cmd = "yolo task=segment model=yolo11s-seg.pt"
print("Ensuring YOLO11s-seg model is available...")
subprocess.run(download_cmd, shell=True, check=False, capture_output=True)
train_cmd = (
f"yolo task=segment mode=train model=yolo11s-seg.pt data={temp_config} "
"imgsz=640 epochs=20 batch=16 device=0 patience=5"
)
print("Starting YOLO11 Segmentation Training for Document Detection")
print("Dataset: MIDV500 (converted)")
print("Model: YOLO11s-seg")
print("Image size: 640")
print("Epochs: 20")
print("Batch size: 16")
print("Device: GPU 0")
print("Early stopping patience: 5 epochs")
if run_command(train_cmd, "Training YOLO11 Segmentation Model"):
print("\n=== Training completed successfully! ===")
runs_dir = Path("runs/segment")
if runs_dir.exists():
train_dirs = [d for d in runs_dir.iterdir() if d.is_dir() and d.name.startswith("train")]
if train_dirs:
latest_train = max(train_dirs, key=lambda x: x.stat().st_mtime)
best_model = latest_train / "weights" / "best.pt"
if best_model.exists():
print(f"\nBest model saved at: {best_model}")
export_cmd = f"yolo export model={best_model} format=onnx opset=12"
if run_command(export_cmd, "Exporting model to ONNX"):
onnx_model = best_model.with_suffix(".onnx")
print(f"\nONNX model exported to: {onnx_model}")
print("\n=== Training and Export Complete! ===")
print(f"PyTorch model: {best_model}")
print(f"ONNX model: {onnx_model}")
else:
print("ONNX export failed, but training completed.")
else:
print("Warning: best.pt not found in expected location.")
else:
print("Training failed. Please check the error messages above.")
sys.exit(1)
finally:
cleanup_temp_file(temp_config)
if __name__ == "__main__":
main()
Training with Python Script
The easiest way to train the model is using the provided Python script that handles all the configuration automatically:
python train_yolo_doc_detection.py
Alternative: Command Line Training
If you prefer direct YOLO commands, you can also use:
# Basic training command (requires proper path configuration in doc.yaml)
yolo task=segment mode=train model=yolov11s-seg.pt data=dataset/doc.yaml imgsz=640 epochs=30 batch=16 device=0
# Training with GPU acceleration
yolo task=segment mode=train model=yolov11s-seg.pt data=dataset/doc.yaml imgsz=640 epochs=30 batch=16 device=0 amp=True
# Training with CPU (slower but works without GPU)
yolo task=segment mode=train model=yolov11s-seg.pt data=dataset/doc.yaml imgsz=640 epochs=30 batch=8 device=cpu
# Resume training from checkpoint
yolo task=segment mode=train model=runs/segment/train/weights/last.pt resume=True
Note: Direct YOLO commands require absolute paths in the doc.yaml
file. The Python script handles this automatically.
Training Parameters Explained
- model: Pre-trained YOLO11 segmentation model
- data: Path to dataset YAML configuration file
- epochs: Number of training iterations
- imgsz: Input image size (640 is standard)
- batch: Batch size (adjust based on GPU memory)
- device: Training device (0 for GPU, cpu for CPU)
- amp: Automatic Mixed Precision for faster training
Exporting to ONNX Format
After training, export the model to ONNX format for deployment and inference optimization.
# Export best model to ONNX
yolo export model=runs/segment/document_detection/weights/best.pt format=onnx opset=12
# Export with optimization
yolo export model=runs/segment/document_detection/weights/best.pt format=onnx opset=12 simplify=True half=False
# Export multiple formats
yolo export model=runs/segment/document_detection/weights/best.pt format=onnx,torchscript,tflite
ONNX Model Testing
import os, sys, warnings
import numpy as np
import cv2
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # don't let PyTorch probe CUDA
os.environ["ORT_LOGGING_LEVEL"] = "4" # onnxruntime: errors only
warnings.filterwarnings("ignore")
try:
import onnxruntime as ort
ort.set_default_logger_severity(4)
except ImportError:
print("❌ onnxruntime not installed. Run: pip install onnxruntime")
sys.exit(1)
from ultralytics import YOLO
MODEL_ONNX = "runs/segment/train/weights/best.onnx" # your exported ONNX
IMAGE_PATH = "CA01_01.jpg" # test image
IMG_SIZE = 640
CONF = 0.25
def order_pts(pts: np.ndarray) -> np.ndarray:
"""Order 4 points TL, TR, BR, BL."""
s = pts.sum(1)
d = np.diff(pts, axis=1).ravel()
return np.array([pts[np.argmin(s)],
pts[np.argmin(d)],
pts[np.argmax(s)],
pts[np.argmax(d)]], dtype=np.float32)
def polygon_to_quad(poly_xy: np.ndarray) -> np.ndarray:
"""
poly_xy: (K,2) float32 in original image coordinates.
Returns ordered 4x2 quad (float32). Uses approxPolyDP, with minAreaRect fallback.
"""
cnt = poly_xy.reshape(-1, 1, 2).astype(np.float32)
peri = cv2.arcLength(cnt, True)
for eps in (0.03, 0.025, 0.02, 0.015, 0.01):
approx = cv2.approxPolyDP(cnt, eps * peri, True)
if len(approx) == 4:
return order_pts(approx.reshape(-1, 2).astype(np.float32))
rect = cv2.minAreaRect(cnt)
box = cv2.boxPoints(rect).astype(np.float32)
return order_pts(box)
def main():
if not os.path.exists(MODEL_ONNX):
print(f"❌ ONNX model not found: {MODEL_ONNX}")
sys.exit(1)
if not os.path.exists(IMAGE_PATH):
print(f"❌ Image not found: {IMAGE_PATH}")
sys.exit(1)
print("Providers (onnxruntime):", ort.get_available_providers())
print(f"Loading ONNX model: {MODEL_ONNX}")
model = YOLO(MODEL_ONNX)
print("✅ ONNX model loaded")
img = cv2.imread(IMAGE_PATH)
if img is None:
print(f"❌ Failed to read image: {IMAGE_PATH}")
sys.exit(1)
H, W = img.shape[:2]
print(f"Image size: {W}x{H}")
res = model(img, imgsz=IMG_SIZE, conf=CONF, device='cpu')[0]
print("✅ Inference completed")
if res.masks is None or len(res.masks.data) == 0:
print("⚠️ No document detected.")
return
polys = res.masks.xy
areas = [cv2.contourArea(p.astype(np.float32)) for p in polys]
idx = int(np.argmax(areas))
poly = polys[idx].astype(np.float32)
quad = polygon_to_quad(poly)
w_out = int(max(np.linalg.norm(quad[1]-quad[0]), np.linalg.norm(quad[2]-quad[3])))
h_out = int(max(np.linalg.norm(quad[2]-quad[1]), np.linalg.norm(quad[3]-quad[0])))
dst = np.array([[0, 0],
[w_out - 1, 0],
[w_out - 1, h_out - 1],
[0, h_out - 1]], dtype=np.float32)
M = cv2.getPerspectiveTransform(quad, dst)
warped = cv2.warpPerspective(img, M, (w_out, h_out))
overlay = img.copy()
cv2.polylines(overlay, [quad.astype(int)], True, (0, 255, 0), 3)
cv2.imwrite("detect_overlay.jpg", overlay)
cv2.imwrite("document_warped.jpg", warped)
print("Quad (TL, TR, BR, BL):")
for i, (x, y) in enumerate(quad):
print(f" P{i}: ({x:.1f}, {y:.1f})")
print("Saved: detect_overlay.jpg, document_warped.jpg")
if __name__ == "__main__":
main()
Creating a GUI Application
Now let's create a GUI application for document detection using PySide6.
Document Detection Engine
class DocumentDetector:
"""Core document detection functionality"""
def __init__(self):
self.model = None
self.load_model()
def load_model(self):
"""Load the trained model (PyTorch first, then ONNX fallback)"""
model_paths = [
("runs/segment/train/weights/best.pt", "PyTorch"),
("runs/segment/train/weights/best.onnx", "ONNX"),
("yolo11s-seg.pt", "Pre-trained")
]
for model_path, model_type in model_paths:
if os.path.exists(model_path):
try:
self.model = YOLO(model_path)
print(f"✅ {model_type} model loaded: {model_path}")
return
except Exception as e:
print(f"❌ Failed to load {model_type} model: {e}")
continue
raise RuntimeError("No working model found!")
@staticmethod
def order_points(pts):
"""Order points in clockwise order starting from top-left"""
s = pts.sum(1)
d = np.diff(pts, axis=1).ravel()
return np.array([
pts[np.argmin(s)], # top-left
pts[np.argmin(d)], # top-right
pts[np.argmax(s)], # bottom-right
pts[np.argmax(d)] # bottom-left
], dtype="float32")
@staticmethod
def polygon_to_quad(poly_xy: np.ndarray) -> np.ndarray:
cnt = poly_xy.reshape(-1, 1, 2).astype(np.float32)
peri = cv2.arcLength(cnt, True)
for eps in (0.03, 0.025, 0.02, 0.015, 0.01):
approx = cv2.approxPolyDP(cnt, eps * peri, True)
if len(approx) == 4:
return DocumentDetector.order_points(approx.reshape(-1, 2).astype(np.float32))
rect = cv2.minAreaRect(cnt)
box = cv2.boxPoints(rect).astype(np.float32)
return DocumentDetector.order_points(box)
def detect_document(self, image: np.ndarray, conf: float = 0.25) -> Optional[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
if self.model is None:
raise RuntimeError("Model not loaded")
orig_h, orig_w = image.shape[:2]
print(f"Debug - Image size: {orig_w}x{orig_h}")
try:
results = self.model(image, imgsz=640, conf=conf, device='cpu')[0]
except Exception as e:
raise RuntimeError(f"Inference failed: {e}")
if results.masks is None or len(results.masks.data) == 0:
return None
polys = results.masks.xy
if not polys:
return None
areas = [cv2.contourArea(p.astype(np.float32)) for p in polys]
idx = int(np.argmax(areas))
poly = polys[idx].astype(np.float32)
print(f"Debug - Selected polygon has {len(poly)} points, area: {areas[idx]:.0f}")
quad = self.polygon_to_quad(poly)
if quad is None:
return None
overlay = image.copy()
cv2.polylines(overlay, [quad.astype(int)], True, (0, 255, 0), 3)
for i, point in enumerate(quad):
x, y = int(round(point[0])), int(round(point[1]))
cv2.circle(overlay, (x, y), 8, (255, 0, 0), -1)
cv2.putText(overlay, str(i), (x + 10, y - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
h, w = image.shape[:2]
w_out = int(max(np.linalg.norm(quad[1] - quad[0]), np.linalg.norm(quad[2] - quad[3])))
h_out = int(max(np.linalg.norm(quad[2] - quad[1]), np.linalg.norm(quad[3] - quad[0])))
dst = np.array([[0, 0], [w_out - 1, 0], [w_out - 1, h_out - 1], [0, h_out - 1]], dtype="float32")
M = cv2.getPerspectiveTransform(quad, dst)
rectified = cv2.warpPerspective(image, M, (w_out, h_out))
return overlay, rectified, quad
Main GUI Application
class DocumentDetectorGUI(QMainWindow):
def __init__(self):
super().__init__()
self.detector = None
self.current_image = None
self.current_overlay = None
self.current_rectified = None
self.current_quad = None
self.detection_thread = None
self.detection_worker = None
self.init_ui()
self.init_detector()
def init_ui(self):
self.setWindowTitle("Document Detection GUI")
self.setGeometry(100, 100, 1400, 800)
central_widget = QWidget()
self.setCentralWidget(central_widget)
main_layout = QHBoxLayout(central_widget)
left_panel = self.create_left_panel()
main_layout.addWidget(left_panel)
right_panel = self.create_right_panel()
main_layout.addWidget(right_panel, stretch=3)
self.statusBar().showMessage("Ready - Drop an image or click Browse to start")
def create_left_panel(self):
"""Create the left control panel"""
panel = QFrame()
panel.setMaximumWidth(300)
panel.setFrameStyle(QFrame.Shape.Box)
layout = QVBoxLayout(panel)
title = QLabel("Document Detection")
title.setAlignment(Qt.AlignmentFlag.AlignCenter)
title.setFont(QFont("Arial", 16, QFont.Weight.Bold))
layout.addWidget(title)
self.browse_btn = QPushButton("📁 Browse Image")
self.browse_btn.clicked.connect(self.browse_image)
layout.addWidget(self.browse_btn)
self.detect_btn = QPushButton("🔍 Detect Document")
self.detect_btn.clicked.connect(self.detect_document)
self.detect_btn.setEnabled(False)
layout.addWidget(self.detect_btn)
self.save_btn = QPushButton("💾 Save Rectified")
self.save_btn.clicked.connect(self.save_rectified)
self.save_btn.setEnabled(False)
layout.addWidget(self.save_btn)
layout.addWidget(QLabel()) # Spacer
self.progress_bar = QProgressBar()
self.progress_bar.setVisible(False)
layout.addWidget(self.progress_bar)
info_group = QGroupBox("Detection Info")
info_layout = QVBoxLayout(info_group)
self.info_text = QTextEdit()
self.info_text.setMaximumHeight(200)
self.info_text.setFont(QFont("Consolas", 9))
info_layout.addWidget(self.info_text)
layout.addWidget(info_group)
layout.addItem(QSpacerItem(20, 40, QSizePolicy.Policy.Minimum, QSizePolicy.Policy.Expanding))
return panel
def create_right_panel(self):
panel = QWidget()
layout = QVBoxLayout(panel)
splitter = QSplitter(Qt.Orientation.Horizontal)
original_group = QGroupBox("Original with Overlay")
original_layout = QVBoxLayout(original_group)
self.original_scroll = QScrollArea()
self.original_label = DropLabel("Drop image here or click Browse")
self.original_label.file_dropped.connect(self.load_image)
self.original_scroll.setWidget(self.original_label)
self.original_scroll.setWidgetResizable(True)
original_layout.addWidget(self.original_scroll)
splitter.addWidget(original_group)
rectified_group = QGroupBox("Rectified Document")
rectified_layout = QVBoxLayout(rectified_group)
self.rectified_scroll = QScrollArea()
self.rectified_label = QLabel("Rectified document will appear here")
self.rectified_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
self.rectified_label.setStyleSheet("""
QLabel {
border: 1px solid #ddd;
background-color: #f5f5f5;
min-height: 200px;
color: #888;
}
""")
self.rectified_scroll.setWidget(self.rectified_label)
self.rectified_scroll.setWidgetResizable(True)
rectified_layout.addWidget(self.rectified_scroll)
splitter.addWidget(rectified_group)
splitter.setSizes([1, 1]) # Equal sizes
layout.addWidget(splitter)
return panel
def init_detector(self):
try:
self.detector = DocumentDetector()
self.statusBar().showMessage("Model loaded successfully - Ready for detection")
except Exception as e:
QMessageBox.critical(self, "Error", f"Failed to load model: {e}")
self.statusBar().showMessage("Error loading model")
def browse_image(self):
"""Browse for an image file"""
file_path, _ = QFileDialog.getOpenFileName(
self,
"Select Image",
"",
"Image Files (*.jpg *.jpeg *.png *.bmp *.tiff *.tif)"
)
if file_path:
self.load_image(file_path)
def load_image(self, file_path: str):
"""Load an image from file path"""
try:
self.current_image = cv2.imread(file_path)
if self.current_image is None:
raise ValueError("Failed to load image")
height, width, channel = self.current_image.shape
bytes_per_line = 3 * width
q_image = QImage(
cv2.cvtColor(self.current_image, cv2.COLOR_BGR2RGB).data,
width, height, bytes_per_line, QImage.Format.Format_RGB888
)
pixmap = QPixmap.fromImage(q_image)
self.original_label.set_pixmap_scaled(pixmap)
self.detect_btn.setEnabled(True)
self.save_btn.setEnabled(False)
self.rectified_label.setText("Rectified document will appear here")
self.rectified_label.setPixmap(QPixmap())
file_name = Path(file_path).name
self.info_text.setText(f"Loaded: {file_name}\nDimensions: {width}x{height}\nReady for detection")
self.statusBar().showMessage(f"Loaded: {file_name}")
self.current_overlay = None
self.current_rectified = None
self.current_quad = None
self.detect_document()
except Exception as e:
QMessageBox.critical(self, "Error", f"Failed to load image: {e}")
def detect_document(self):
"""Run document detection"""
if self.current_image is None or self.detector is None:
print("Detection skipped: No image or detector")
return
try:
self.progress_bar.setVisible(True)
self.progress_bar.setRange(0, 0)
self.detect_btn.setEnabled(False)
self.statusBar().showMessage("Detecting document...")
self.cleanup_detection_thread()
self.detection_worker = DetectionWorker(self.detector, self.current_image, 0.25)
self.detection_thread = QThread()
self.detection_worker.moveToThread(self.detection_thread)
self.detection_thread.started.connect(self.detection_worker.run)
self.detection_worker.finished.connect(self.on_detection_finished)
self.detection_worker.error.connect(self.on_detection_error)
self.detection_worker.finished.connect(self.cleanup_detection_thread)
print("Starting detection thread...")
self.detection_thread.start()
except Exception as e:
print(f"Error in detect_document: {e}")
import traceback
traceback.print_exc()
self.progress_bar.setVisible(False)
self.detect_btn.setEnabled(True)
QMessageBox.critical(self, "Detection Error", f"Failed to start detection: {e}")
...
Running the GUI Application
# Install additional dependencies for GUI
pip install PySide6
# Run the GUI application
python document_detector_gui.py
Source Code
https://github.com/yushulx/python-barcode-qrcode-sdk/edit/main/examples/document-detection
Top comments (0)