# To learn more about passport MRZ types please visit https://www.doubango.org/SDKs/mrz/docs/MRZ_formats.html import base64, io, re, os from dataclasses import dataclass from functools import lru_cache from typing import Optional, Tuple, List import numpy as np import cv2 from PIL import Image from ultralytics import YOLO from paddleocr import PaddleOCR from datetime import date # ----------------------- Config ----------------------- @dataclass(frozen=True) class MRZConfig: yolo_weights: str = "ml/models/yolo_ocr/runs/detect/train/weights/best.pt" imgsz: int = 640 conf: float = 0.25 iou: float = 0.45 mrz_class_id: int = 0 margin_ratio: float = 0.03 ocr_lang: str = "en" ocr_min_score: float = 0.50 ocr_target_min_h: int = 220 crop_target_min_h: int = 180 repair: bool = True # ----------------------- Public API ----------------------- def extract_mrz_info( base64_str: str, cfg: MRZConfig = MRZConfig() ) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]]: """ Decode base64 image -> detect + crop MRZ -> OCR with PaddleOCR -> parse MRZ. Returns: (full_name, dob, gender, nationality, expiry) or Nones on failure. """ img_bgr = _b64_to_bgr(base64_str) if img_bgr is None: return (None, None, None, None, None) crop = _detect_and_crop_mrz(img_bgr, cfg) if crop is None: crop = img_bgr crop = _ensure_min_height(crop, cfg.crop_target_min_h) mrz_text = _paddle_ocr(crop, cfg) if not mrz_text: return (None, None, None, None, None) mrz_lines = _sanitize_mrz_lines(mrz_text) if not mrz_lines: return (None, None, None, None, None) if cfg.repair: parsed = _parse_mrz_lines_repair(mrz_lines) else: parsed = _parse_mrz_lines(mrz_lines) if not parsed: return (None, None, None, None, None) full_name = _compose_full_name(parsed.get("surname"), parsed.get("given_names")) dob = _format_date(parsed.get("dob_yyMMdd"), kind="dob") expiry = _format_date(parsed.get("exp_yyMMdd"), kind="expiry") gender = _normalize_gender(parsed.get("sex")) nationality= parsed.get("nationality") return full_name, dob,gender,nationality, expiry # ----------------------- Detection ----------------------- @lru_cache(maxsize=1) def _get_yolo(weights_path: str) -> YOLO: if not os.path.exists(weights_path): raise FileNotFoundError(f"YOLO weights not found: {weights_path}") return YOLO(weights_path) def _detect_and_crop_mrz(img_bgr: np.ndarray, cfg: MRZConfig) -> Optional[np.ndarray]: model = _get_yolo(cfg.yolo_weights) res = model.predict(img_bgr, imgsz=cfg.imgsz, conf=cfg.conf, iou=cfg.iou, verbose=False)[0] if res.boxes is None or len(res.boxes) == 0: return None xyxy = res.boxes.xyxy.cpu().numpy() conf = res.boxes.conf.cpu().numpy() clss = res.boxes.cls.cpu().numpy().astype(int) idxs = [i for i, c in enumerate(clss) if c == cfg.mrz_class_id] if not idxs: return None best_i = max(idxs, key=lambda i: conf[i]) x1, y1, x2, y2 = map(int, xyxy[best_i]) H, W = img_bgr.shape[:2] mw = int(cfg.margin_ratio * (x2 - x1)) mh = int(cfg.margin_ratio * (y2 - y1)) x1 = max(0, x1 - mw); y1 = max(0, y1 - mh) x2 = min(W - 1, x2 + mw); y2 = min(H - 1, y2 + mh) if x2 <= x1 or y2 <= y1: return None return img_bgr[y1:y2, x1:x2].copy() # ----------------------- OCR (PaddleOCR) ----------------------- @lru_cache(maxsize=1) def _get_paddle_ocr(lang: str) -> PaddleOCR: ocr = PaddleOCR( text_detection_model_name="PP-OCRv5_server_det", use_doc_orientation_classify=False, use_doc_unwarping=False, use_textline_orientation=False, lang=lang ) return ocr def _paddle_ocr( crop_bgr: np.ndarray, cfg: MRZConfig ) -> Optional[str]: if crop_bgr is None or crop_bgr.size == 0: return None h, w = crop_bgr.shape[:2] if h < cfg.ocr_target_min_h: scale = float(cfg.ocr_target_min_h) / max(1.0, float(h)) crop_bgr = cv2.resize(crop_bgr, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_CUBIC) img_rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB) ocr = _get_paddle_ocr(lang=cfg.ocr_lang) out = ocr.predict(img_rgb) if isinstance(out, list): out = out[0] if out else {} rec_texts = list(out.get("rec_texts", [])) rec_scores = list(out.get("rec_scores", [1.0] * len(rec_texts))) if not rec_texts: return None order = _order_indices_by_vertical_center(out, len(rec_texts)) lines: List[str] = [] for i in order: if i >= len(rec_texts): continue txt = rec_texts[i] sc = float(rec_scores[i] if i < len(rec_scores) else 1.0) if sc < cfg.ocr_min_score or not txt: continue cleaned = _clean_mrz_token(txt) if cleaned: lines.append(cleaned) if not lines: return None lines = sorted(lines, key=len, reverse=True)[:3] if len(lines) not in (2, 3): return None return "\n".join(lines) # ----------------------- Utilities ----------------------- _ALLOWED = set("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789<") def _b64_to_bgr(b64: str) -> Optional[np.ndarray]: """Decode base64 (with or without data URL) to BGR numpy image.""" if "," in b64 and ";base64" in b64[:64]: b64 = b64.split(",", 1)[1] try: data = base64.b64decode(b64, validate=True) img = Image.open(io.BytesIO(data)).convert("RGB") return np.array(img)[:, :, ::-1] # RGB->BGR except Exception: return None def _ensure_min_height(img_bgr: np.ndarray, target_min_h: int) -> np.ndarray: h, w = img_bgr.shape[:2] if h >= target_min_h: return img_bgr scale = target_min_h / float(h) return cv2.resize(img_bgr, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_CUBIC) def _clean_mrz_token(s: str) -> str: s = s.upper().replace(" ", "") return "".join(ch for ch in s if ch in _ALLOWED) def _sanitize_mrz_lines(txt: str) -> List[str]: txt = txt.upper().replace(" ", "") lines = [re.sub(r"[^A-Z0-9<]", "", ln) for ln in txt.splitlines() if ln.strip()] lines = [ln for ln in lines if ln] return lines[:3] def _order_indices_by_vertical_center(out: dict, n: int) -> List[int]: rb = out.get("rec_boxes", None) if rb is not None: rb = np.asarray(rb) if rb.ndim == 2 and rb.shape[1] >= 4 and rb.shape[0] == n: ycenters = 0.5 * (rb[:, 1].astype(float) + rb[:, 3].astype(float)) return list(np.argsort(ycenters)) rp = out.get("rec_polys", None) if rp is not None and len(rp) == n: ycenters = [np.mean(np.asarray(poly)[:, 1].astype(float)) for poly in rp] return list(np.argsort(ycenters)) dp = out.get("dt_polys", None) if dp is not None and len(dp) == n: ycenters = [np.mean(np.asarray(poly)[:, 1].astype(float)) for poly in dp] return list(np.argsort(ycenters)) return list(range(n)) # ----------------------- MRZ Parsing ----------------------- def _parse_mrz_lines_repair(lines: List[str]) -> Optional[dict]: if len(lines) == 3: l1 = _pad_to(lines[0], 30); l2 = _pad_to(lines[1], 30); l3 = _pad_to(lines[2], 30) parsed = _parse_td1(l1, l2, l3) repaired = _repair_td1([l1, l2, l3], parsed.copy()) return repaired or parsed elif len(lines) == 2: if max(len(lines[0]), len(lines[1])) >= 40: l1 = _pad_to(lines[0], 44); l2 = _pad_to(lines[1], 44) parsed = _parse_td3(l1, l2) repaired = _repair_td3([l1, l2], parsed.copy()) if repaired: return repaired l1b = _pad_to(lines[0], 36); l2b = _pad_to(lines[1], 36) parsed2 = _parse_td2(l1b, l2b) repaired2 = _repair_td2([l1b, l2b], parsed2.copy()) return repaired2 or parsed else: l1 = _pad_to(lines[0], 36); l2 = _pad_to(lines[1], 36) parsed = _parse_td2(l1, l2) repaired = _repair_td2([l1, l2], parsed.copy()) if repaired: return repaired l1b = _pad_to(lines[0], 44); l2b = _pad_to(lines[1], 44) parsed2 = _parse_td3(l1b, l2b) repaired2 = _repair_td3([l1b, l2b], parsed2.copy()) return repaired2 or parsed return None def _parse_mrz_lines(lines: List[str]) -> Optional[dict]: if len(lines) == 3: # TD1 (3 x 30) l1 = _pad_to(lines[0], 30) l2 = _pad_to(lines[1], 30) l3 = _pad_to(lines[2], 30) return _parse_td1(l1, l2, l3) elif len(lines) == 2: # TD2 (2 x 36) or TD3 (2 x 44) if max(len(lines[0]), len(lines[1])) >= 40: l1 = _pad_to(lines[0], 44) l2 = _pad_to(lines[1], 44) return _parse_td3(l1, l2) else: l1 = _pad_to(lines[0], 36) l2 = _pad_to(lines[1], 36) parsed = _parse_td2(l1, l2) if not parsed or not parsed.get("nationality"): l1 = _pad_to(lines[0], 44) l2 = _pad_to(lines[1], 44) return _parse_td3(l1, l2) return parsed def _pad_to(s: str, n: int) -> str: s = "".join(ch for ch in s if ch in _ALLOWED) return (s + "<" * n)[:n] def _compact_name(field: str) -> Tuple[str, str]: parts = field.split("<<", 1) surname = parts[0].replace("<", " ").strip() given = parts[1] if len(parts) > 1 else "" given = given.replace("<", " ").strip() return surname, given # TD3 (Passports) 2 x 44 def _parse_td3(l1: str, l2: str) -> dict: doc_type = l1[0:2] issuing = l1[2:5] name_field= l1[5:44] surname, given = _compact_name(name_field) number = l2[0:9] nationality = l2[10:13] dob = l2[13:19] sex = l2[20:21] expiry = l2[21:27] return dict( doc_type=doc_type, issuing=issuing, surname=surname, given_names=given, number=number, nationality=nationality, sex=sex, dob_yyMMdd=dob, exp_yyMMdd=expiry ) # TD2 (ID, 2 x 36) def _parse_td2(l1: str, l2: str) -> dict: doc_type = l1[0:2] issuing = l1[2:5] name_field= l1[5:36] surname, given = _compact_name(name_field) number = l2[0:9] nationality = l2[10:13] dob = l2[13:19] sex = l2[20:21] expiry = l2[21:27] return dict( doc_type=doc_type, issuing=issuing, surname=surname, given_names=given, number=number, nationality=nationality, sex=sex, dob_yyMMdd=dob, exp_yyMMdd=expiry ) # TD1 (ID, 3 x 30) def _parse_td1(l1: str, l2: str, l3: str) -> dict: doc_type = l1[0:2] issuing = l1[2:5] number = l1[5:14] # l1[14] number check digit (ignored here) # l1[15:30] optional dob = l2[0:6] # l2[6] dob check sex = l2[7:8] expiry = l2[8:14] # l2[14] expiry check nationality = l2[15:18] name_field= l3[0:30] surname, given = _compact_name(name_field) return dict( doc_type=doc_type, issuing=issuing, surname=surname, given_names=given, number=number, nationality=nationality, sex=sex, dob_yyMMdd=dob, exp_yyMMdd=expiry ) # ----------------------- Formatting ----------------------- def _compose_full_name(surname: Optional[str], given: Optional[str]) -> Optional[str]: if not (surname or given): return None parts = [] if given: parts.append(given) if surname: parts.append(surname) return " ".join(" ".join(p.split()) for p in parts).strip() or None def _normalize_gender(sex: Optional[str]) -> Optional[str]: if not sex: return None s = sex[0].upper() if s == "M": return "M" if s == "F": return "F" return "X" def _format_date(yymmdd: Optional[str], kind: str) -> Optional[str]: if not yymmdd or len(yymmdd) < 6: return None try: yy = int(yymmdd[0:2]); mm = int(yymmdd[2:4]); dd = int(yymmdd[4:6]) d = _infer_century_two_digit(yy, mm, dd, kind) return d.strftime("%Y-%m-%d") except Exception: return None def _infer_century_two_digit(yy: int, mm: int, dd: int, kind: str) -> date: cands: List[date] = [] for century in (1900, 2000): try: cands.append(date(century + yy, mm, dd)) except ValueError: pass if not cands: raise ValueError("invalid date") today = date.today() if kind == "dob": # plausible human age: 0..120 best = None; best_pen = 1e12 for d in cands: age = (today - d).days / 365.2425 pen = 0 if 0 <= age <= 120 else 1e6 if pen < best_pen: best_pen, best = pen, d return best else: best = None; best_pen = 1e12 for d in cands: delta = (d - today).days / 365.2425 pen = 0 if delta < -15 or delta > 25: pen += 1e3 if delta < 0: pen += 10 if pen < best_pen: best_pen, best = pen, d return best # ----------------------- MRZ Checksums & Self-repair ----------------------- def _mrz_val(ch: str) -> int: if ch == '<': return 0 if '0' <= ch <= '9': return ord(ch) - 48 if 'A' <= ch <= 'Z': return ord(ch) - 55 return 0 def _mrz_check_digit(s: str) -> str: w = (7, 3, 1) total = 0 for i, ch in enumerate(s): total += _mrz_val(ch) * w[i % 3] return str(total % 10) _CONF_MAP = { 'O': '0', 'I': '1', 'Z': '2', 'S': '5', 'B': '8', '0': 'O', '1': 'I', '2': 'Z', '5': 'S', '8': 'B' } def _one_edit_fix(field: str, allowed: str, expected_cd: str) -> str | None: N = len(field) for i in range(N): orig = field[i] if orig not in _CONF_MAP: continue cand = _CONF_MAP[orig] if allowed == 'num' and not cand.isdigit(): if orig in 'OIZSB': cand = {'O':'0','I':'1','Z':'2','S':'5','B':'8'}[orig] else: continue f2 = field[:i] + cand + field[i+1:] if _mrz_check_digit(f2) == expected_cd: return f2 return None def _is_valid_yyMMdd(s: str) -> bool: if len(s) != 6 or not s.isdigit(): return False yy, mm, dd = int(s[:2]), int(s[2:4]), int(s[4:6]) if not (1 <= mm <= 12 and 1 <= dd <= 31): return False return True def _repair_td3(lines: list[str], parsed: dict) -> dict | None: l2 = (lines[1] + "<"*44)[:44] num, num_cd = l2[0:9], l2[9] nat = l2[10:13] dob, dob_cd = l2[13:19], l2[19] sex = l2[20:21] exp, exp_cd = l2[21:27], l2[27] opt, opt_cd = l2[28:42], l2[42] comp_cd = l2[43] changed = False if _mrz_check_digit(num) != num_cd: fixed = _one_edit_fix(num, 'alnum', num_cd) if fixed: num, changed = fixed, True if (not _is_valid_yyMMdd(dob)) or _mrz_check_digit(dob) != dob_cd: fixed = _one_edit_fix(dob, 'num', dob_cd) if fixed and _is_valid_yyMMdd(fixed): dob, changed = fixed, True if (not _is_valid_yyMMdd(exp)) or _mrz_check_digit(exp) != exp_cd: fixed = _one_edit_fix(exp, 'num', exp_cd) if fixed and _is_valid_yyMMdd(fixed): exp, changed = fixed, True if _mrz_check_digit(opt) != opt_cd: fixed = _one_edit_fix(opt, 'alnum', opt_cd) if fixed: opt, changed = fixed, True comp_input = num + num_cd + dob + dob_cd + exp + exp_cd + opt + opt_cd if _mrz_check_digit(comp_input) != comp_cd: fixed = _one_edit_fix(opt, 'alnum', opt_cd) if fixed: opt, changed = fixed, True comp_input = num + num_cd + dob + dob_cd + exp + exp_cd + opt + opt_cd if _mrz_check_digit(comp_input) != comp_cd and not changed: return None parsed['number'] = num parsed['dob_yyMMdd'] = dob parsed['exp_yyMMdd'] = exp parsed['sex'] = sex parsed['nationality'] = nat return parsed def _repair_td2(lines: list[str], parsed: dict) -> dict | None: l2 = (lines[1] + "<"*36)[:36] num, num_cd = l2[0:9], l2[9] nat = l2[10:13] dob, dob_cd = l2[13:19], l2[19] sex = l2[20:21] exp, exp_cd = l2[21:27], l2[27] opt = l2[28:35] comp_cd = l2[35] changed = False if _mrz_check_digit(num) != num_cd: fixed = _one_edit_fix(num, 'alnum', num_cd) if fixed: num, changed = fixed, True if (not _is_valid_yyMMdd(dob)) or _mrz_check_digit(dob) != dob_cd: fixed = _one_edit_fix(dob, 'num', dob_cd) if fixed and _is_valid_yyMMdd(fixed): dob, changed = fixed, True if (not _is_valid_yyMMdd(exp)) or _mrz_check_digit(exp) != exp_cd: fixed = _one_edit_fix(exp, 'num', exp_cd) if fixed and _is_valid_yyMMdd(fixed): exp, changed = fixed, True comp_input = num + num_cd + dob + dob_cd + exp + exp_cd + opt if _mrz_check_digit(comp_input) != comp_cd and not changed: fixed = _one_edit_fix(num, 'alnum', num_cd) if fixed: num, changed = fixed, True comp_input = num + num_cd + dob + dob_cd + exp + exp_cd + opt if _mrz_check_digit(comp_input) != comp_cd and not changed: return None parsed['number'] = num parsed['dob_yyMMdd'] = dob parsed['exp_yyMMdd'] = exp parsed['sex'] = sex parsed['nationality'] = nat return parsed def _repair_td1(lines: list[str], parsed: dict) -> dict | None: l1 = (lines[0] + "<"*30)[:30] l2 = (lines[1] + "<"*30)[:30] num, num_cd = l1[5:14], l1[14] opt1 = l1[15:30] dob, dob_cd = l2[0:6], l2[6] sex = l2[7:8] exp, exp_cd = l2[8:14], l2[14] nat = l2[15:18] opt2 = l2[18:29] comp_cd = l2[29] changed = False if _mrz_check_digit(num) != num_cd: fixed = _one_edit_fix(num, 'alnum', num_cd) if fixed: num, changed = fixed, True if (not _is_valid_yyMMdd(dob)) or _mrz_check_digit(dob) != dob_cd: fixed = _one_edit_fix(dob, 'num', dob_cd) if fixed and _is_valid_yyMMdd(fixed): dob, changed = fixed, True if (not _is_valid_yyMMdd(exp)) or _mrz_check_digit(exp) != exp_cd: fixed = _one_edit_fix(exp, 'num', exp_cd) if fixed and _is_valid_yyMMdd(fixed): exp, changed = fixed, True comp_input = num + num_cd + opt1 + dob + dob_cd + exp + exp_cd + opt2 if _mrz_check_digit(comp_input) != comp_cd and not changed: fixed = _one_edit_fix(opt2, 'alnum', _mrz_check_digit(opt2)) if fixed: opt2, changed = fixed, True comp_input = num + num_cd + opt1 + dob + dob_cd + exp + exp_cd + opt2 if _mrz_check_digit(comp_input) != comp_cd and not changed: return None parsed['number'] = num parsed['dob_yyMMdd'] = dob parsed['exp_yyMMdd'] = exp parsed['sex'] = sex parsed['nationality'] = nat return parsed def image_to_base64(image_path): with open(image_path, "rb") as image_file: encoded = base64.b64encode(image_file.read()).decode("utf-8") return encoded