diff --git a/.claude/commands/implement-phase.md b/.claude/commands/implement-phase.md
deleted file mode 100644
index 617b1b1..0000000
--- a/.claude/commands/implement-phase.md
+++ /dev/null
@@ -1,13 +0,0 @@
-# /implement-phase
-
-Read README.md, CLAUDE.md, PROJECT_SPEC.md, and TASKS.md.
-
-Implement the next unfinished task from TASKS.md.
-
-Rules:
-- Make a small focused change.
-- Keep modules separated.
-- Add or update tests.
-- Update documentation if public behavior changes.
-- Do not add heavy ML dependencies without asking.
-- Preserve future mesh-based dewarping architecture.
diff --git a/.claude/commands/plan-next.md b/.claude/commands/plan-next.md
deleted file mode 100644
index 8c7ebc7..0000000
--- a/.claude/commands/plan-next.md
+++ /dev/null
@@ -1,12 +0,0 @@
-# /plan-next
-
-Read README.md, CLAUDE.md, PROJECT_SPEC.md, and TASKS.md.
-
-Then produce a short implementation plan for the next logical feature.
-
-Rules:
-- Do not edit files.
-- Keep the plan practical.
-- Include files that need to be created or changed.
-- Include tests that should be added.
-- Mention risks or assumptions.
diff --git a/.claude/commands/review-cv-pipeline.md b/.claude/commands/review-cv-pipeline.md
deleted file mode 100644
index f8ec9ab..0000000
--- a/.claude/commands/review-cv-pipeline.md
+++ /dev/null
@@ -1,17 +0,0 @@
-# /review-cv-pipeline
-
-Review the current computer vision pipeline.
-
-Focus on:
-- Document boundary detection quality
-- Corner ordering correctness
-- Perspective transform robustness
-- Debug output usefulness
-- Failure cases
-- Whether the architecture supports future mesh dewarping
-
-Return:
-- Bugs
-- Missing tests
-- Suggested improvements
-- Priority order
diff --git a/.gitattributes b/.gitattributes
index dfe0770..efa005e 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1,2 +1,5 @@
-# Auto detect text files and perform LF normalization
* text=auto
+*.png binary
+*.jpg binary
+*.jpeg binary
+*.onnx binary
diff --git a/.gitignore b/.gitignore
index 4ef0cf4..d337628 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,46 +10,46 @@ env/
.env
.env.*
-# Build
+# Build / packaging
build/
dist/
*.egg-info/
-# Test/cache
+# Test / cache
.pytest_cache/
.mypy_cache/
.ruff_cache/
.coverage
htmlcov/
-# IDE
+# IDE / OS
.vscode/
.idea/
-
-# OS
.DS_Store
Thumbs.db
-# Project outputs
-output/
-outputs/
-debug/
-tmp/
-temp/
-
-# Large sample data
-samples/private/
+# Large local models
+models/*.onnx
+models/*.engine
+models/*.trt
+models/*.pth
+models/*.pt
+!models/.gitkeep
+
+# Generated outputs
+outputs/*
+!outputs/.gitkeep
+!outputs/deshadowed/
+!outputs/deshadowed/.gitkeep
+!outputs/warp_detection/
+!outputs/warp_detection/.gitkeep
+!outputs/template_rectified/
+!outputs/template_rectified/.gitkeep
+!outputs/template_rectified/Untitled5_template_matches.png
+
+# Local datasets / archives
+input/private/
datasets/
*.zip
*.7z
*.rar
-
-.venv/
-__pycache__/
-*.pyc
-
-outputs/
-models/*.onnx
-
-.DS_Store
-Thumbs.db
\ No newline at end of file
diff --git a/CLAUDE.md b/CLAUDE.md
deleted file mode 100644
index 34814f4..0000000
--- a/CLAUDE.md
+++ /dev/null
@@ -1,500 +0,0 @@
-# CLAUDE.md
-
-This file gives Claude Code project-specific context, rules, and implementation guidance for **WarpLess Docs**.
-
-## Project Summary
-
-WarpLess Docs is an open-source document image dewarping and cleanup toolkit.
-
-The project converts real-world document photos into clean scanner-like images that are ready for OCR, form extraction, archival workflows, and AI document understanding.
-
-The target documents are not perfect scans. They may be:
-
-- Folded
-- Curved
-- Skewed
-- Shadowed
-- Captured by phone camera
-- Perspective distorted
-- Printed as forms or tables
-- Handwritten
-- Misaligned with a clean template PDF
-
-The project must avoid being a simple deskew script. It should be designed as a practical document preprocessing engine.
-
----
-
-## Product Goal
-
-The long-term product goal is:
-
-> Convert poor-quality real-world document images into reliable OCR-ready and form-extraction-ready outputs.
-
-Important outputs:
-
-- Clean flattened image
-- OCR-friendly image
-- PDF export
-- Template-aligned image
-- Cell crops for structured forms
-- JSON metadata
-- Debug visualizations
-
----
-
-## Main User Personas
-
-1. **Developer**
- - Wants a Python package or CLI to clean document photos before OCR.
-
-2. **AI/OCR Engineer**
- - Wants reliable preprocessing for document AI pipelines.
-
-3. **Business Automation Team**
- - Wants form reading from camera photos, scanned forms, invoices, contracts, or medical documents.
-
-4. **Persian/Arabic OCR Developer**
- - Wants better preprocessing for right-to-left forms and handwritten digits.
-
----
-
-## Non-Negotiable Engineering Principles
-
-1. **Do not implement only global homography**
- - Perspective correction is only Phase 1.
- - The architecture must allow local mesh-based dewarping.
-
-2. **Keep modules small and testable**
- - Detection, dewarp, cleanup, alignment, and export must be separate.
-
-3. **Always produce debug outputs**
- - Every major step should optionally output visual debug images.
-
-4. **Do not destroy original image data**
- - Keep original input untouched.
- - Store transformation metadata.
-
-5. **Prefer deterministic CV first, AI second**
- - Use OpenCV and geometry for baseline.
- - Add ML/AI only where classical methods are insufficient.
-
-6. **Design for real-world bad inputs**
- - Low light
- - Shadows
- - Paper folds
- - Perspective distortion
- - Curved table forms
- - Phone camera noise
-
-7. **Make the CLI useful from the first version**
- - The repository should be demoable from terminal early.
-
-8. **Make the README demo-friendly**
- - Add before/after examples as soon as sample outputs exist.
-
----
-
-## Suggested Architecture
-
-```text
-warpless/
-├── __init__.py
-├── pipeline.py
-├── types.py
-├── config.py
-├── detection/
-│ ├── __init__.py
-│ ├── page_detector.py
-│ ├── contour_detector.py
-│ └── corner_refinement.py
-├── dewarp/
-│ ├── __init__.py
-│ ├── perspective.py
-│ ├── mesh.py
-│ ├── grid.py
-│ └── local_warp.py
-├── cleanup/
-│ ├── __init__.py
-│ ├── shadows.py
-│ ├── contrast.py
-│ └── binarize.py
-├── alignment/
-│ ├── __init__.py
-│ ├── template_loader.py
-│ ├── feature_matcher.py
-│ ├── form_aligner.py
-│ └── cells.py
-├── export/
-│ ├── __init__.py
-│ ├── image_exporter.py
-│ ├── pdf_exporter.py
-│ └── metadata_exporter.py
-├── debug/
-│ ├── __init__.py
-│ └── visualizer.py
-└── utils/
- ├── __init__.py
- ├── image_io.py
- ├── geometry.py
- └── logging.py
-```
-
----
-
-## Suggested Data Types
-
-Create clear data classes early.
-
-Recommended core types:
-
-```python
-from dataclasses import dataclass
-from pathlib import Path
-from typing import Any
-
-import numpy as np
-
-
-@dataclass
-class DocumentCorners:
- top_left: tuple[float, float]
- top_right: tuple[float, float]
- bottom_right: tuple[float, float]
- bottom_left: tuple[float, float]
-
-
-@dataclass
-class WarpLessConfig:
- output_width: int | None = None
- output_height: int | None = None
- enable_shadow_cleanup: bool = True
- enable_mesh_dewarp: bool = False
- debug: bool = False
- debug_dir: Path | None = None
-
-
-@dataclass
-class WarpLessResult:
- original_image: np.ndarray
- corrected_image: np.ndarray
- corners: DocumentCorners | None
- metadata: dict[str, Any]
-```
-
----
-
-## MVP Implementation Order
-
-Follow this order unless the user explicitly requests otherwise.
-
-### Step 1: Project skeleton
-
-Create:
-
-- Package structure
-- `pyproject.toml`
-- CLI entrypoint
-- Basic image IO
-- Result data structures
-- Debug directory support
-
-### Step 2: Basic page detection
-
-Implement:
-
-- Load image
-- Resize for processing
-- Grayscale
-- Blur
-- Edge detection
-- Contour detection
-- Largest quadrilateral selection
-- Corner ordering
-- Debug overlay
-
-### Step 3: Perspective correction
-
-Implement:
-
-- Four-point transform
-- Auto output size
-- Save corrected image
-- CLI command
-
-### Step 4: Shadow cleanup
-
-Implement:
-
-- Background estimation
-- Illumination normalization
-- Adaptive threshold option
-- OCR-friendly grayscale output
-
-### Step 5: Mesh dewarp prototype
-
-Implement:
-
-- Uniform mesh grid
-- Manual or detected control points
-- Piecewise affine transform
-- Debug grid overlay
-
-### Step 6: Template alignment
-
-Implement:
-
-- Load template image or PDF
-- Feature matching
-- Homography alignment
-- Export aligned output
-- Optional cell crop export
-
-### Step 7: Demo assets
-
-Add:
-
-- `samples/`
-- Before/after images
-- README examples
-- GIF or comparison grid
-
----
-
-## CLI Requirements
-
-The first working CLI should support:
-
-```bash
-warpless input.jpg --output output.png
-```
-
-Future CLI flags:
-
-```bash
-warpless input.jpg \
- --output output.png \
- --debug \
- --debug-dir output/debug \
- --shadow-cleanup \
- --template form.pdf \
- --export-cells output/cells
-```
-
-CLI behavior:
-
-- Never overwrite files silently unless `--overwrite` is provided.
-- Print clear progress messages.
-- Print output paths at the end.
-- Exit with non-zero code on failure.
-- Show helpful error messages for unsupported files.
-
----
-
-## Coding Style
-
-Use:
-
-- Python 3.11+
-- Type hints
-- Dataclasses or Pydantic for config/result objects
-- Small pure functions where possible
-- Clear docstrings for public functions
-- `pathlib.Path` instead of raw string paths
-- `logging` instead of random `print`, except CLI user messages
-
-Avoid:
-
-- Large monolithic scripts
-- Hidden global state
-- Silent failures
-- Hardcoded absolute paths
-- Overwriting source images
-- Mixing UI/demo code into core package
-
----
-
-## Dependencies Guidance
-
-Preferred baseline dependencies:
-
-- `opencv-python`
-- `numpy`
-- `scipy`
-- `scikit-image`
-- `pillow`
-- `typer`
-- `rich`
-- `pydantic` or dataclasses
-- `pytest`
-
-Optional future dependencies:
-
-- `torch`
-- `onnxruntime`
-- `pymupdf`
-- `fastapi`
-- `uvicorn`
-
-Do not add heavy ML dependencies until there is a working classical CV baseline.
-
----
-
-## Testing Strategy
-
-Add tests for:
-
-- Corner ordering
-- Four-point perspective transform
-- Image loading/saving
-- Metadata generation
-- CLI argument parsing
-- Template alignment utilities
-- Mesh grid generation
-
-Test names should be clear:
-
-```text
-test_order_corners_returns_tl_tr_br_bl
-test_four_point_transform_preserves_document_area
-test_cli_fails_on_missing_input_file
-```
-
-Use small synthetic images where possible.
-
----
-
-## Debug Output Rules
-
-When `debug=True`, save visual files such as:
-
-```text
-debug/
-├── 01_input_resized.png
-├── 02_edges.png
-├── 03_contours.png
-├── 04_detected_corners.png
-├── 05_perspective_corrected.png
-├── 06_shadow_cleaned.png
-└── metadata.json
-```
-
-Every debug image should be understandable without reading code.
-
----
-
-## README Requirements
-
-Whenever a working feature is added, update README with:
-
-- What the feature does
-- CLI example
-- Python API example if available
-- Before/after image if possible
-- Known limitations
-
-README should remain friendly for recruiters, developers, and open-source contributors.
-
----
-
-## GitHub Positioning
-
-The repository should look like a serious open-source project.
-
-Emphasize these keywords naturally:
-
-- Document AI
-- OCR preprocessing
-- Document dewarping
-- Image rectification
-- Form understanding
-- Computer vision
-- Mesh-based warping
-- Persian OCR
-- Arabic OCR
-- Mobile document scanning
-
----
-
-## Important Limitations to Mention Publicly
-
-Be honest in docs:
-
-- Early versions may fail on very dark images.
-- Severe occlusion may not be recoverable.
-- Mesh dewarping is experimental at first.
-- Template alignment requires a good reference image or PDF.
-- OCR is not the first goal. The first goal is preparing images for OCR.
-
----
-
-## Commit Style
-
-Use clear commit messages:
-
-```text
-feat: add contour-based page detection
-feat: add four-point perspective correction
-fix: improve corner ordering for rotated documents
-docs: add first before-after example
-test: add geometry utility tests
-```
-
----
-
-## Suggested First Issues
-
-Create GitHub issues for:
-
-1. Implement basic CLI
-2. Add page boundary detection
-3. Add four-point perspective correction
-4. Add debug image export
-5. Add shadow cleanup
-6. Add mesh grid prototype
-7. Add template alignment prototype
-8. Add sample before/after images
-9. Add Persian/Arabic digit preprocessing notes
-10. Add web demo plan
-
----
-
-## Claude Behavior Rules
-
-When working in this repository, Claude should:
-
-1. First inspect the current file tree.
-2. Read `README.md`, `CLAUDE.md`, and `PROJECT_SPEC.md` before editing.
-3. Make small, coherent changes.
-4. Prefer one feature per commit-sized change.
-5. Add or update tests when adding logic.
-6. Update docs when public behavior changes.
-7. Avoid adding large dependencies without asking.
-8. Avoid rewriting the whole project when a focused patch is enough.
-9. Keep the architecture ready for future mesh-based dewarping.
-10. Always preserve sample input images and original data.
-
----
-
-## Good First Implementation Prompt
-
-Use this prompt with Claude Code after creating the repository:
-
-```text
-Read README.md, CLAUDE.md, and PROJECT_SPEC.md. Then create the initial Python package structure for WarpLess Docs with a Typer CLI, basic image loading/saving utilities, dataclass-based config/result types, and a placeholder pipeline that copies the input image to the output. Add pytest tests for image IO and CLI missing-file behavior. Do not add dewarping logic yet.
-```
-
----
-
-## Second Implementation Prompt
-
-```text
-Implement Phase 1 page detection and perspective correction. Add contour-based page boundary detection with debug overlays, corner ordering, and a four-point transform. Expose it through the CLI as `warpless input.jpg --output output.png --debug`. Add tests for corner ordering and transform sizing. Update README with usage examples.
-```
-
----
-
-## Third Implementation Prompt
-
-```text
-Implement basic shadow cleanup as an optional pipeline step. Use background estimation and illumination normalization. Add `--shadow-cleanup` to the CLI, save debug outputs, and update README with before/after examples and limitations.
-```
diff --git a/PROJECT_SPEC.md b/PROJECT_SPEC.md
deleted file mode 100644
index 98b1cc4..0000000
--- a/PROJECT_SPEC.md
+++ /dev/null
@@ -1,102 +0,0 @@
-# Project Specification: WarpLess Docs
-
-## One-Line Description
-
-WarpLess Docs is a document image preprocessing toolkit that turns folded, curved, skewed, shadowed, and perspective-distorted document photos into clean scanner-like, OCR-ready images.
-
----
-
-## Problem
-
-Real-world document photos are often not suitable for OCR or form extraction because they contain:
-
-- Perspective distortion
-- Curved paper
-- Folded or wavy surfaces
-- Shadows
-- Low contrast
-- Phone camera noise
-- Misalignment with the original form template
-
-Most OCR engines fail when the image is not geometrically and visually normalized first.
-
----
-
-## Proposed Solution
-
-WarpLess Docs provides a pipeline for:
-
-1. Detecting document boundaries
-2. Correcting perspective
-3. Dewarping local page distortions
-4. Reducing shadows
-5. Aligning the image to a reference template
-6. Exporting OCR-ready images, cell crops, and metadata
-
----
-
-## MVP Scope
-
-The first release should include:
-
-- Python package
-- CLI
-- Page boundary detection
-- Perspective correction
-- Shadow cleanup
-- Debug visualizations
-- Clean output image export
-
-The first release does not need:
-
-- Full OCR
-- Deep learning dewarping
-- Web UI
-- Production-grade template alignment
-- Handwritten digit recognition
-
----
-
-## Success Criteria for MVP
-
-The MVP is successful if a user can run:
-
-```bash
-warpless samples/input.jpg --output output/clean.png --debug
-```
-
-and receive:
-
-- A clean corrected image
-- A debug folder with intermediate steps
-- Clear terminal messages
-- No destructive changes to original files
-
----
-
-## Future Scope
-
-Future versions may include:
-
-- Mesh-based dewarping
-- Template alignment
-- Cell crop extraction
-- Persian and Arabic OCR preprocessing
-- Web demo
-- Batch processing
-- PDF export
-- Integration with OCR engines
-- Dataset generation tools
-
----
-
-## Target Repository Quality
-
-The repository should look professional enough for:
-
-- GitHub trending potential
-- Recruiter review
-- CTO-level portfolio presentation
-- Open-source collaboration
-- Technical blog posts
-- LinkedIn demo videos
diff --git a/README.md b/README.md
index 85ee067..38a88a4 100644
--- a/README.md
+++ b/README.md
@@ -1,11 +1,11 @@
# WarpLess Docs
- AI-powered document cleanup for real-world photos.
+ AI-powered document cleanup and template rectification for real-world form photos.
- Turn warped, shadowed, skewed, folded, and camera-captured document photos into cleaner scanner-like images for OCR, form extraction, archiving, and automation.
+ Turn shadowed, warped, folded, and camera-captured document photos into cleaner, template-aligned outputs for OCR, form extraction, and document automation.
@@ -17,112 +17,102 @@
---
-## What is WarpLess Docs?
+## Overview
-**WarpLess Docs** is a Python toolkit for preparing real-world document photos for downstream AI and automation pipelines.
+**WarpLess Docs** is a modular Python toolkit for preparing real-world document photos for downstream AI systems.
-Most document scanners work well when the paper is flat, bright, and perfectly aligned. Real documents are rarely like that.
+The project currently focuses on three practical stages:
-WarpLess Docs focuses on the hard cases:
+1. **ML shadow removal** — reduces document shadows while preserving text and table structure.
+2. **Template-based rectification** — aligns a real photographed form to a clean, flat template image.
+3. **Warp diagnostics** — highlights curved or non-straight document structures before deeper dewarping.
-- phone-captured documents
-- shadows from hands, phones, desks, or uneven lighting
-- folded or curved paper
-- skewed and perspective-distorted images
-- forms, tables, contracts, invoices, and handwritten documents
-- OCR and form extraction preprocessing
-
-The project is designed as a modular document restoration pipeline, starting with **ML-powered document shadow removal** and expanding toward dewarping, template alignment, and OCR-ready exports.
+This is useful for OCR preprocessing, handwritten form extraction, invoice processing, archival workflows, and document automation pipelines.
---
-## Current Capability: ML Document Shadow Removal
-
-The first working module integrates a document-specific ML shadow removal backend based on **DocShadow SD7K** through **ONNX Runtime**.
-
-This module is designed to reduce shadows while preserving the important parts of a document:
-
-- printed text
-- table lines
-- form structure
-- handwriting
-- signatures
-- scanned-paper texture
+## Current Pipeline
-Unlike a simple brightness or threshold filter, this backend uses a trained document shadow removal model.
+```text
+input/samples/*.jpg
+ ↓
+main.py
+ ↓
+outputs/deshadowed/*_deshadowed.png
+ ↓
+rectify_with_template.py + input/template/template_page_1.png
+ ↓
+outputs/template_rectified/*_template_rectified.png
+outputs/template_rectified/*_template_matches.png
+ ↓
+detect_warp.py
+ ↓
+outputs/warp_detection/*_warp_overlay.png
+```
---
-## Before / After Samples
+## Shadow Removal Before / After
-The repository includes real sample inputs and generated outputs.
+The first stage uses a document-specific ML shadow removal backend based on **DocShadow SD7K** through **ONNX Runtime**. It is designed to reduce uneven lighting and shadows while keeping document details readable.
| Input Photo | Shadow-Removed Output |
|---|---|
-|
|
|
-|
|
|
+|
|
|
+|
|
|
-These examples demonstrate the first public milestone of WarpLess Docs: improving difficult document photos before OCR or form extraction.
+This stage improves the visual quality of the photo before template alignment, OCR, and geometry diagnostics.
---
-## Why This Project Matters
-
-Document AI systems are only as reliable as the images they receive.
-
-In production workflows, documents often arrive as mobile photos instead of clean scans. That creates problems for OCR, table extraction, handwritten digit recognition, and form automation.
+## Template Matching Example
-WarpLess Docs is built to solve that preprocessing layer.
+The rectification stage uses the clean template as the benchmark coordinate system and finds matching points between the real document photo and the template.
-Potential use cases:
+
+
+
-- OCR preprocessing
-- document AI pipelines
-- form reading systems
-- handwritten digit extraction
-- Persian, Arabic, and multilingual document cleanup
-- invoice and contract scanning
-- medical or government form digitization
-- dataset preparation for OCR models
-- mobile scanner applications
-- enterprise document automation
+The match visualization is intentionally useful for debugging: it shows whether the system found stable correspondence points before warping the image into the template coordinate space.
---
-## Pipeline Vision
+## Project Structure
```text
-Real-world document photo
- ↓
-Document detection
- ↓
-Perspective correction
- ↓
-Mesh-based dewarping
- ↓
-ML shadow removal
- ↓
-Template alignment
- ↓
-OCR-ready output
+WarpLess-Docs/
+├── input/
+│ ├── samples/ # Real photographed documents
+│ └── template/ # Clean flat template image
+│ └── template_page_1.png
+├── models/
+│ └── docshadow_sd7k.onnx # Not committed; download separately
+├── outputs/
+│ ├── deshadowed/ # Shadow-removal outputs
+│ ├── template_rectified/ # Template aligned outputs + match debug images
+│ └── warp_detection/ # Curvature diagnostic overlays
+├── src/
+│ └── warpless_docs/
+│ ├── shadow_removal/
+│ ├── template_rectification/
+│ └── warp_detection/
+├── main.py # Stage 1: shadow removal
+├── rectify_with_template.py # Stage 2: template-based rectification
+├── detect_warp.py # Stage 3: warp diagnostics
+├── requirements.txt
+├── pyproject.toml
+└── README.md
```
-The current repository starts with the **ML shadow removal** stage and is structured to grow into the full pipeline.
+Generated images are written under `outputs/`. Large model files are kept under `models/` and are intentionally not committed.
---
## Installation
-Clone the repository:
-
```bash
git clone https://github.com/ehsanwwe/WarpLess-Docs.git
cd WarpLess-Docs
-```
-
-Create and activate a virtual environment:
-
-```bash
python -m venv .venv
```
@@ -130,31 +120,27 @@ Windows PowerShell:
```powershell
.\.venv\Scripts\Activate.ps1
+pip install -r requirements.txt
```
Linux / macOS:
```bash
source .venv/bin/activate
-```
-
-Install dependencies:
-
-```bash
pip install -r requirements.txt
```
Recommended Python version:
```text
-Python 3.11+
+Python 3.11 or 3.12
```
---
-## Model Download
+## Download the Shadow Removal Model
-The ONNX model is not committed to the repository because it is a large binary file.
+The ONNX model is not committed to GitHub because it is a large binary file.
Create the model directory:
@@ -168,7 +154,7 @@ Download the DocShadow SD7K ONNX model:
curl -L "https://github.com/fabio-sim/DocShadow-ONNX-TensorRT/releases/download/v1.0.0/docshadow_sd7k.onnx" -o "models/docshadow_sd7k.onnx"
```
-Windows PowerShell alternative:
+Windows PowerShell:
```powershell
Invoke-WebRequest `
@@ -184,199 +170,179 @@ models/docshadow_sd7k.onnx
---
-## Run the Demo
+## Prepare Inputs
-Process all sample images:
+Put real document photos here:
-```bash
-python main.py
+```text
+input/samples/
```
-Process a single image:
+Put the clean flat template here:
-```bash
-python main.py --input "input/samples/Untitled2.jpg"
+```text
+input/template/template_page_1.png
```
-Use a larger inference size:
+The template should be the same form type as the photographed documents. If a sample belongs to another form/template, remove it from `input/samples` or use a different template for that sample.
+
+---
+
+## Run Stage 1: Shadow Removal
+
+Process all samples:
```bash
-python main.py --size 1024 1024
+python main.py
```
-Force CPU execution:
+Process one image:
```bash
-python main.py --cpu
+python main.py --input "input/samples/Untitled5.jpg"
```
-Outputs are saved to:
+Outputs:
```text
-outputs/
+outputs/deshadowed/Untitled5_deshadowed.png
```
-Example:
+For higher quality inference:
-```text
-input/samples/Untitled2.jpg
-outputs/Untitled2_deshadowed.png
+```bash
+python main.py --size 1024 1024
```
----
+Force CPU:
-## Python API
+```bash
+python main.py --cpu
+```
-```python
-from warpless_docs.shadow_removal import DocShadowONNXRemover
+---
-remover = DocShadowONNXRemover(
- model_path="models/docshadow_sd7k.onnx",
- input_size=(768, 768),
-)
+## Run Stage 2: Template Rectification
-remover.remove_shadow_from_path(
- input_path="input/samples/Untitled2.jpg",
- output_path="outputs/Untitled2_deshadowed.png",
-)
+Process all deshadowed outputs using the first template in `input/template`:
+
+```bash
+python rectify_with_template.py
```
----
+Process one image with an explicit template:
-## Repository Structure
+```bash
+python rectify_with_template.py \
+ --template "input/template/template_page_1.png" \
+ --input "outputs/deshadowed/Untitled5_deshadowed.png"
+```
+
+Outputs:
```text
-WarpLess-Docs/
-├── input/
-│ └── samples/
-│ ├── Untitled2.jpg
-│ └── Untitled6.jpg
-├── outputs/
-│ ├── Untitled2_deshadowed.png
-│ └── Untitled6_deshadowed.png
-├── models/
-│ └── docshadow_sd7k.onnx
-├── src/
-│ └── warpless_docs/
-│ └── shadow_removal/
-│ ├── __init__.py
-│ └── docshadow_onnx.py
-├── main.py
-├── requirements.txt
-├── README.md
-└── LICENSE
+outputs/template_rectified/Untitled5_template_rectified.png
+outputs/template_rectified/Untitled5_template_rectification_report.json
+outputs/template_rectified/Untitled5_template_matches.png
```
----
+The rectified image has the exact same canvas size and coordinate system as the template.
-## Roadmap
+Useful options:
-### Phase 1: Document Cleanup Core
+```bash
+python rectify_with_template.py --feature-method sift
+python rectify_with_template.py --feature-method orb
+python rectify_with_template.py --homography-only
+python rectify_with_template.py --no-debug
+```
-- [x] ML-based document shadow removal backend
-- [x] Batch processing for sample images
-- [x] Before/after output generation
-- [ ] output quality comparison tools
-- [ ] debug visualizations
-- [ ] automatic image quality scoring
+---
-### Phase 2: Document Geometry
+## Run Stage 3: Warp Diagnostics
-- [ ] page boundary detection
-- [ ] perspective correction
-- [ ] auto-rotation
-- [ ] corner refinement
-- [ ] background removal
+The warp detector prefers deshadowed images from `outputs/deshadowed/`.
-### Phase 3: Advanced Dewarping
+```bash
+python detect_warp.py
+```
-- [ ] mesh grid generation
-- [ ] local document dewarping
-- [ ] fold and curve correction
-- [ ] wavy edge correction
-- [ ] multi-strategy dewarping comparison
+Process one image:
-### Phase 4: Template Alignment
+```bash
+python detect_warp.py --input "outputs/deshadowed/Untitled5_deshadowed.png"
+```
+
+Outputs:
-- [ ] PDF-to-image conversion
-- [ ] reference form alignment
-- [ ] keypoint matching
-- [ ] table cell extraction
-- [ ] JSON metadata export
-- [ ] visual alignment report
+```text
+outputs/warp_detection/Untitled5_warp_overlay.png
+outputs/warp_detection/Untitled5_warp_report.json
+```
-### Phase 5: OCR Preparation
+If the detector is too conservative:
-- [ ] OCR-friendly preprocessing
-- [ ] cell crop extraction
-- [ ] grayscale and binary export
-- [ ] Persian and Arabic digit normalization
-- [ ] OCR integration examples
+```bash
+python detect_warp.py --min-deviation 1.2 --min-angle-change 0.6
+```
---
-## Technical Direction
+## Python API
-WarpLess Docs is built around a modular computer vision architecture.
+### Shadow removal
-Planned modules:
+```python
+from warpless_docs.shadow_removal import DocShadowONNXRemover
-```text
-warpless_docs/
-├── detection/
-├── dewarp/
-├── shadow_removal/
-├── alignment/
-├── ocr_prep/
-├── export/
-└── utils/
+remover = DocShadowONNXRemover(model_path="models/docshadow_sd7k.onnx")
+remover.remove_shadow_from_path(
+ input_path="input/samples/Untitled5.jpg",
+ output_path="outputs/deshadowed/Untitled5_deshadowed.png",
+)
```
-The goal is to make each stage replaceable, testable, and useful both as a standalone module and as part of a complete document AI pipeline.
-
----
+### Template rectification
-## Tech Stack
-
-- Python
-- NumPy
-- Pillow
-- OpenCV
-- ONNX Runtime
-- PyTorch-compatible research backend options
-- Future API layer with FastAPI
-- Future web demo with React / Next.js
+```python
+from warpless_docs.template_rectification import TemplateRectifier
+
+rectifier = TemplateRectifier(template_path="input/template/template_page_1.png")
+rectifier.rectify_path(
+ input_path="outputs/deshadowed/Untitled5_deshadowed.png",
+ output_path="outputs/template_rectified/Untitled5_template_rectified.png",
+ output_json_path="outputs/template_rectified/Untitled5_template_rectification_report.json",
+ output_debug_path="outputs/template_rectified/Untitled5_template_matches.png",
+)
+```
---
-## Project Vision
-
-WarpLess Docs is not just a scanner filter.
+## Notes on Template-Based Rectification
-It is a preprocessing engine for document understanding systems.
+Template rectification works best when:
-The long-term goal is to make difficult real-world document photos reliable enough for:
+- the sample image is the same form layout as the template
+- the document is visible enough after shadow removal
+- the photo is not extremely blurred
+- there are enough printed structures, table lines, logos, or text blocks for feature matching
-- OCR
-- AI extraction
-- form automation
-- archival workflows
-- enterprise document pipelines
+If a sample does not match the template, the script will skip it and print a clear reason instead of crashing the whole batch.
---
-## Contributing
-
-Contributions, ideas, datasets, benchmarks, and real-world document samples are welcome.
-
-Good first contribution areas:
+## Roadmap
-- new sample documents
-- shadow removal comparisons
-- document boundary detection
-- OCR preprocessing experiments
-- dewarping strategies
-- benchmark scripts
+- [x] ML-based document shadow removal
+- [x] Batch processing for sample photos
+- [x] Curvature and warp diagnostic overlays
+- [x] Template-based rectification using a clean benchmark image
+- [x] Match visualization for debugging alignment quality
+- [ ] OCR-ready cell extraction
+- [ ] Template region mapping and JSON field export
+- [ ] Stronger local mesh dewarping
+- [ ] FastAPI demo endpoint
+- [ ] Web demo for upload, preview, and download
---
@@ -388,6 +354,4 @@ MIT License
## Author
-Built by **Ehsan Moradi**
-
-Senior software engineer, AI developer, computer vision engineer, and real-time graphics specialist.
+Built by **Ehsan Moradi**.
diff --git a/TASKS.md b/TASKS.md
deleted file mode 100644
index 80103f4..0000000
--- a/TASKS.md
+++ /dev/null
@@ -1,53 +0,0 @@
-# TASKS.md
-
-## Phase 0: Repository Setup
-
-- [ ] Create Python package structure
-- [ ] Add `pyproject.toml`
-- [ ] Add CLI entrypoint
-- [ ] Add tests folder
-- [ ] Add sample input/output folders
-- [ ] Add basic documentation
-
-## Phase 1: Baseline Scanner
-
-- [ ] Load image
-- [ ] Detect page contour
-- [ ] Order page corners
-- [ ] Apply perspective transform
-- [ ] Save output image
-- [ ] Add debug overlays
-- [ ] Add CLI example
-
-## Phase 2: Cleanup
-
-- [ ] Add grayscale export
-- [ ] Add contrast normalization
-- [ ] Add shadow cleanup
-- [ ] Add OCR-friendly preprocessing
-- [ ] Add debug output for each step
-
-## Phase 3: Dewarp Prototype
-
-- [ ] Add mesh grid structure
-- [ ] Add grid visualization
-- [ ] Add piecewise affine transform
-- [ ] Add manual control point JSON support
-- [ ] Add local warp debug output
-
-## Phase 4: Template Alignment
-
-- [ ] Load template image
-- [ ] Load template PDF
-- [ ] Match features
-- [ ] Align captured image to template
-- [ ] Export alignment metadata
-- [ ] Export cell crops
-
-## Phase 5: Public Demo
-
-- [ ] Add before/after sample images
-- [ ] Add README screenshots
-- [ ] Add short demo GIF
-- [ ] Add web demo plan
-- [ ] Add LinkedIn launch post draft
diff --git a/detect_warp.py b/detect_warp.py
new file mode 100644
index 0000000..7b96ed0
--- /dev/null
+++ b/detect_warp.py
@@ -0,0 +1,152 @@
+import argparse
+import sys
+from pathlib import Path
+
+PROJECT_ROOT = Path(__file__).resolve().parent
+SRC_DIR = PROJECT_ROOT / "src"
+
+if str(SRC_DIR) not in sys.path:
+ sys.path.insert(0, str(SRC_DIR))
+
+from warpless_docs.warp_detection import DocumentWarpDetector, WarpDetectorConfig
+
+
+SUPPORTED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"}
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="WarpLess Docs - document curvature and warp diagnostics"
+ )
+
+ parser.add_argument(
+ "--input",
+ default=None,
+ help="Single image path. Prefer a deshadowed image from outputs/deshadowed.",
+ )
+ parser.add_argument(
+ "--input-dir",
+ default="outputs/deshadowed",
+ help="Directory containing deshadowed document images.",
+ )
+ parser.add_argument(
+ "--input-pattern",
+ default="*_deshadowed.*",
+ help="Glob pattern for images inside --input-dir.",
+ )
+ parser.add_argument(
+ "--fallback-samples-dir",
+ default="input/samples",
+ help="Fallback raw samples directory if no deshadowed outputs exist.",
+ )
+ parser.add_argument(
+ "--output-dir",
+ default="outputs/warp_detection",
+ help="Directory to save warp overlays and JSON reports.",
+ )
+ parser.add_argument("--max-side", type=int, default=1800)
+ parser.add_argument("--min-deviation", type=float, default=2.0)
+ parser.add_argument("--min-angle-change", type=float, default=1.0)
+
+ return parser.parse_args()
+
+
+def is_supported_image(path: Path) -> bool:
+ return path.is_file() and path.suffix.lower() in SUPPORTED_EXTENSIONS
+
+
+def find_images(directory: Path, pattern: str = "*") -> list[Path]:
+ if not directory.exists():
+ return []
+ return sorted(path for path in directory.glob(pattern) if is_supported_image(path))
+
+
+def find_sample_images(directory: Path) -> list[Path]:
+ if not directory.exists():
+ return []
+ return sorted(path for path in directory.rglob("*") if is_supported_image(path))
+
+
+def normalized_stem(input_path: Path) -> str:
+ stem = input_path.stem
+ if stem.endswith("_deshadowed"):
+ return stem[: -len("_deshadowed")]
+ return stem
+
+
+def build_output_paths(input_path: Path, output_dir: Path) -> tuple[Path, Path]:
+ stem = normalized_stem(input_path)
+ return output_dir / f"{stem}_warp_overlay.png", output_dir / f"{stem}_warp_report.json"
+
+
+def resolve_input_paths(args: argparse.Namespace) -> list[Path]:
+ if args.input:
+ return [Path(args.input)]
+
+ deshadowed = find_images(Path(args.input_dir), args.input_pattern)
+ if deshadowed:
+ print(f"Using deshadowed images from: {args.input_dir}")
+ return deshadowed
+
+ fallback = find_sample_images(Path(args.fallback_samples_dir))
+ if fallback:
+ print(f"No deshadowed outputs found. Falling back to: {args.fallback_samples_dir}")
+ return fallback
+
+
+def main() -> None:
+ args = parse_args()
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ input_paths = resolve_input_paths(args)
+ if not input_paths:
+ print("No images found. Run shadow removal first with: python main.py")
+ return
+
+ detector = DocumentWarpDetector(
+ config=WarpDetectorConfig(
+ max_image_side=args.max_side,
+ min_deviation_px=args.min_deviation,
+ min_angle_change_deg=args.min_angle_change,
+ )
+ )
+
+ print("WarpLess Docs curvature detection")
+ print(f"Output dir : {output_dir}")
+ print(f"Images : {len(input_paths)}")
+ print(f"Max side : {args.max_side}")
+ print(f"Min deviation : {args.min_deviation}")
+ print(f"Min angle change : {args.min_angle_change}")
+ print("-" * 72)
+
+ for index, input_path in enumerate(input_paths, start=1):
+ if not input_path.exists():
+ print(f"[SKIP] Missing input: {input_path}")
+ continue
+
+ overlay_path, report_path = build_output_paths(input_path, output_dir)
+ print(f"[{index}/{len(input_paths)}] Processing: {input_path}")
+
+ try:
+ result = detector.analyze_path(
+ input_path=input_path,
+ output_overlay_path=overlay_path,
+ output_json_path=report_path,
+ )
+ print(
+ f" findings={result.total_findings} "
+ f"strong={result.strong_count} medium={result.medium_count} mild={result.mild_count}"
+ )
+ print(f" overlay: {overlay_path}")
+ print(f" report : {report_path}")
+ except Exception as exc:
+ print(f" failed: {input_path}")
+ print(f" reason: {exc}")
+
+ print("-" * 72)
+ print("Done.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/__init__.py b/input/samples/.gitkeep
similarity index 100%
rename from __init__.py
rename to input/samples/.gitkeep
diff --git a/input/samples/Untitled.jpg b/input/samples/Untitled.jpg
deleted file mode 100644
index 2c92084..0000000
Binary files a/input/samples/Untitled.jpg and /dev/null differ
diff --git a/input/template/.gitkeep b/input/template/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/input/template/template.png b/input/template/template.png
new file mode 100644
index 0000000..37143cf
Binary files /dev/null and b/input/template/template.png differ
diff --git a/input/template/template_page_1.png b/input/template/template_page_1.png
new file mode 100644
index 0000000..d698bfb
Binary files /dev/null and b/input/template/template_page_1.png differ
diff --git a/main.py b/main.py
index 9f403c2..7fc8632 100644
--- a/main.py
+++ b/main.py
@@ -11,95 +11,57 @@
from warpless_docs.shadow_removal import DocShadowONNXRemover
-SUPPORTED_EXTENSIONS = {
- ".jpg",
- ".jpeg",
- ".png",
- ".bmp",
- ".webp",
- ".tif",
- ".tiff",
-}
+SUPPORTED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"}
def parse_args() -> argparse.Namespace:
- parser = argparse.ArgumentParser(
- description="WarpLess Docs - Document shadow removal test runner"
- )
+ parser = argparse.ArgumentParser(description="WarpLess Docs - ML document shadow removal")
parser.add_argument(
"--input",
default=None,
- help="Single input image path. If empty, all images inside input/samples will be processed.",
+ help="Single input image path. If empty, all images inside input/samples are processed.",
)
-
parser.add_argument(
"--samples-dir",
default="input/samples",
- help="Directory containing sample document photos.",
+ help="Directory containing real document photos.",
)
-
parser.add_argument(
"--output-dir",
- default="outputs",
- help="Directory to save processed outputs.",
+ default="outputs/deshadowed",
+ help="Directory to save shadow-removed images.",
)
-
parser.add_argument(
"--model",
default="models/docshadow_sd7k.onnx",
help="Path to DocShadow ONNX model.",
)
-
parser.add_argument(
"--size",
nargs=2,
type=int,
default=[768, 768],
metavar=("WIDTH", "HEIGHT"),
- help="Inference resize size. Start with 768 768. Try 1024 1024 for better quality.",
- )
-
- parser.add_argument(
- "--cpu",
- action="store_true",
- help="Force CPU provider.",
+ help="Inference resize size. Try 1024 1024 for higher quality.",
)
+ parser.add_argument("--cpu", action="store_true", help="Force CPU provider.")
return parser.parse_args()
-def find_sample_images(samples_dir: Path) -> list[Path]:
- if not samples_dir.exists():
- raise FileNotFoundError(f"Samples directory not found: {samples_dir}")
+def is_supported_image(path: Path) -> bool:
+ return path.is_file() and path.suffix.lower() in SUPPORTED_EXTENSIONS
- images = [
- path
- for path in samples_dir.rglob("*")
- if path.is_file() and path.suffix.lower() in SUPPORTED_EXTENSIONS
- ]
- return sorted(images)
+def find_images(directory: Path) -> list[Path]:
+ if not directory.exists():
+ raise FileNotFoundError(f"Samples directory not found: {directory}")
+ return sorted(path for path in directory.rglob("*") if is_supported_image(path))
def build_output_path(input_path: Path, output_dir: Path) -> Path:
- output_name = f"{input_path.stem}_deshadowed.png"
- return output_dir / output_name
-
-
-def process_image(
- remover: DocShadowONNXRemover,
- input_path: Path,
- output_dir: Path,
-) -> Path:
- output_path = build_output_path(input_path, output_dir)
-
- saved_path = remover.remove_shadow_from_path(
- input_path=input_path,
- output_path=output_path,
- )
-
- return saved_path
+ return output_dir / f"{input_path.stem}_deshadowed.png"
def main() -> None:
@@ -115,43 +77,37 @@ def main() -> None:
prefer_gpu=not args.cpu,
)
- if args.input:
- input_paths = [Path(args.input)]
- else:
- input_paths = find_sample_images(Path(args.samples_dir))
+ input_paths = [Path(args.input)] if args.input else find_images(Path(args.samples_dir))
if not input_paths:
print("No sample images found.")
return
- print("WarpLess Docs shadow removal test")
+ print("WarpLess Docs shadow removal")
print(f"Model : {model_path}")
print(f"Output dir: {output_dir}")
print(f"Providers : {remover.providers}")
print(f"Images : {len(input_paths)}")
- print("-" * 60)
+ print("-" * 68)
for index, input_path in enumerate(input_paths, start=1):
if not input_path.exists():
print(f"[SKIP] Input not found: {input_path}")
continue
+ output_path = build_output_path(input_path, output_dir)
print(f"[{index}/{len(input_paths)}] Processing: {input_path}")
try:
- saved_path = process_image(
- remover=remover,
- input_path=input_path,
- output_dir=output_dir,
- )
- print(f"Saved: {saved_path}")
+ saved_path = remover.remove_shadow_from_path(input_path=input_path, output_path=output_path)
+ print(f" saved: {saved_path}")
except Exception as exc:
- print(f"Failed: {input_path}")
- print(f"Reason: {exc}")
+ print(f" failed: {input_path}")
+ print(f" reason: {exc}")
- print("-" * 60)
+ print("-" * 68)
print("Done.")
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/models/.gitkeep b/models/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/outputs/.gitkeep b/outputs/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/pyproject.toml b/pyproject.toml
index c4152bf..f8d4bcb 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,47 +1,31 @@
[project]
name = "warpless-docs"
-version = "0.1.0"
-description = "AI-powered document dewarping toolkit for turning warped document photos into clean OCR-ready scans."
+version = "0.2.0"
+description = "Document cleanup and template-based rectification toolkit for real-world scanned forms."
readme = "README.md"
requires-python = ">=3.11"
-authors = [
- { name = "Ehsan Moradi" }
-]
+authors = [{ name = "Ehsan Moradi" }]
license = { text = "MIT" }
keywords = [
"document-ai",
"ocr",
"computer-vision",
"document-dewarping",
- "image-rectification",
- "form-understanding"
+ "template-alignment",
+ "form-understanding",
]
dependencies = [
- "opencv-python>=4.9.0",
- "numpy>=1.26.0",
+ "numpy>=1.24.0",
"pillow>=10.0.0",
- "typer>=0.12.0",
- "rich>=13.0.0"
+ "opencv-python>=4.8.0",
+ "onnxruntime>=1.17.0",
+ "onnx>=1.14.0",
]
[project.optional-dependencies]
-dev = [
- "pytest>=8.0.0",
- "ruff>=0.5.0",
- "mypy>=1.10.0"
-]
-advanced = [
- "scipy>=1.12.0",
- "scikit-image>=0.23.0",
- "pymupdf>=1.24.0"
-]
-
-[project.scripts]
-warpless = "warpless.cli:app"
+gpu = ["onnxruntime-gpu>=1.17.0"]
+dev = ["pytest>=8.0.0", "ruff>=0.5.0"]
[tool.ruff]
line-length = 100
target-version = "py311"
-
-[tool.pytest.ini_options]
-testpaths = ["tests"]
diff --git a/rectify_with_template.py b/rectify_with_template.py
new file mode 100644
index 0000000..2c0450c
--- /dev/null
+++ b/rectify_with_template.py
@@ -0,0 +1,204 @@
+import argparse
+import sys
+from pathlib import Path
+
+PROJECT_ROOT = Path(__file__).resolve().parent
+SRC_DIR = PROJECT_ROOT / "src"
+
+if str(SRC_DIR) not in sys.path:
+ sys.path.insert(0, str(SRC_DIR))
+
+from warpless_docs.template_rectification import TemplateRectifier, TemplateRectifierConfig
+
+
+SUPPORTED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".webp", ".tif", ".tiff"}
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="WarpLess Docs - template-based document rectification"
+ )
+
+ parser.add_argument(
+ "--template",
+ default=None,
+ help="Template image path. If empty, the first image inside input/template is used.",
+ )
+ parser.add_argument(
+ "--template-dir",
+ default="input/template",
+ help="Directory containing the clean flat template image.",
+ )
+ parser.add_argument(
+ "--input",
+ default=None,
+ help="Single input image path. If empty, batch mode is used.",
+ )
+ parser.add_argument(
+ "--deshadowed-dir",
+ default="outputs/deshadowed",
+ help="Preferred directory containing shadow-removed images.",
+ )
+ parser.add_argument(
+ "--legacy-deshadowed-dir",
+ default="outputs",
+ help="Backward-compatible fallback for older outputs placed directly in outputs/.",
+ )
+ parser.add_argument(
+ "--samples-dir",
+ default="input/samples",
+ help="Fallback directory containing raw sample photos.",
+ )
+ parser.add_argument(
+ "--output-dir",
+ default="outputs/template_rectified",
+ help="Directory to save rectified outputs, reports, and match visualizations.",
+ )
+ parser.add_argument("--max-side", type=int, default=1800)
+ parser.add_argument(
+ "--feature-method",
+ choices=["auto", "sift", "orb"],
+ default="auto",
+ help="Feature detector for template matching. Auto tries SIFT first, then ORB.",
+ )
+ parser.add_argument("--homography-only", action="store_true", help="Disable piecewise affine refinement.")
+ parser.add_argument("--no-debug", action="store_true", help="Do not save match visualization images.")
+
+ return parser.parse_args()
+
+
+def is_supported_image(path: Path) -> bool:
+ return path.is_file() and path.suffix.lower() in SUPPORTED_EXTENSIONS
+
+
+def find_images(directory: Path, pattern: str = "*") -> list[Path]:
+ if not directory.exists():
+ return []
+ return sorted(path for path in directory.glob(pattern) if is_supported_image(path))
+
+
+def find_recursive_images(directory: Path) -> list[Path]:
+ if not directory.exists():
+ return []
+ return sorted(path for path in directory.rglob("*") if is_supported_image(path))
+
+
+def resolve_template(template_path: str | None, template_dir: Path) -> Path:
+ if template_path:
+ path = Path(template_path)
+ if not path.exists():
+ raise FileNotFoundError(f"Template image not found: {path}")
+ return path
+
+ candidates = find_images(template_dir)
+ if not candidates:
+ raise FileNotFoundError(
+ f"No template image found in {template_dir}. "
+ "Put your clean template there, for example input/template/template_page_1.png"
+ )
+
+ return candidates[0]
+
+
+def resolve_inputs(args: argparse.Namespace) -> list[Path]:
+ if args.input:
+ path = Path(args.input)
+ if not path.exists():
+ raise FileNotFoundError(f"Input image not found: {path}")
+ return [path]
+
+ # Stage 2 should normally consume stage 1 outputs.
+ deshadowed = find_images(Path(args.deshadowed_dir), "*_deshadowed.*")
+ if deshadowed:
+ print(f"Using deshadowed images from: {args.deshadowed_dir}")
+ return deshadowed
+
+ # Backward compatibility for older project versions.
+ legacy = find_images(Path(args.legacy_deshadowed_dir), "*_deshadowed.*")
+ if legacy:
+ print(f"Using legacy deshadowed images from: {args.legacy_deshadowed_dir}")
+ return legacy
+
+ fallback = find_recursive_images(Path(args.samples_dir))
+ if fallback:
+ print(f"No deshadowed outputs found. Falling back to raw samples from: {args.samples_dir}")
+ return fallback
+
+
+def normalized_stem(input_path: Path) -> str:
+ stem = input_path.stem
+ if stem.endswith("_deshadowed"):
+ return stem[: -len("_deshadowed")]
+ return stem
+
+
+def build_output_paths(input_path: Path, output_dir: Path) -> tuple[Path, Path, Path]:
+ stem = normalized_stem(input_path)
+ rectified_path = output_dir / f"{stem}_template_rectified.png"
+ report_path = output_dir / f"{stem}_template_rectification_report.json"
+ debug_path = output_dir / f"{stem}_template_matches.png"
+ return rectified_path, report_path, debug_path
+
+
+def main() -> None:
+ args = parse_args()
+
+ template_path = resolve_template(args.template, Path(args.template_dir))
+ input_paths = resolve_inputs(args)
+
+ if not input_paths:
+ print("No input images found.")
+ print("Run shadow removal first with: python main.py")
+ return
+
+ output_dir = Path(args.output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ rectifier = TemplateRectifier(
+ template_path=template_path,
+ config=TemplateRectifierConfig(
+ max_image_side=args.max_side,
+ feature_method=args.feature_method,
+ use_piecewise_warp=not args.homography_only,
+ output_match_debug=not args.no_debug,
+ ),
+ )
+
+ print("WarpLess Docs template rectification")
+ print(f"Template : {template_path}")
+ print(f"Output dir: {output_dir}")
+ print(f"Images : {len(input_paths)}")
+ print(f"Features : {args.feature_method}")
+ print(f"Mode : {'homography only' if args.homography_only else 'homography + piecewise affine'}")
+ print("-" * 78)
+
+ for index, input_path in enumerate(input_paths, start=1):
+ rectified_path, report_path, debug_path = build_output_paths(input_path, output_dir)
+ print(f"[{index}/{len(input_paths)}] Processing: {input_path}")
+
+ try:
+ result = rectifier.rectify_path(
+ input_path=input_path,
+ output_path=rectified_path,
+ output_json_path=report_path,
+ output_debug_path=None if args.no_debug else debug_path,
+ )
+ print(
+ f" mode={result.mode} feature={result.feature_method} "
+ f"good_matches={result.good_matches} inliers={result.inlier_matches} "
+ f"inlier_ratio={result.inlier_ratio:.2f}"
+ )
+ print(f" rectified: {rectified_path}")
+ print(f" report : {report_path}")
+ if not args.no_debug:
+ print(f" matches : {debug_path}")
+ except Exception as exc:
+ print(f" failed: {input_path}")
+ print(f" reason: {exc}")
+
+ print("-" * 78)
+ print("Done.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/requirements.txt b/requirements.txt
index ad81c50..c3f9129 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,5 @@
-
numpy>=1.24.0
pillow>=10.0.0
opencv-python>=4.8.0
-onnxruntime-gpu>=1.17.0
-onnx>=1.14.0
\ No newline at end of file
+onnxruntime>=1.17.0
+onnx>=1.14.0
diff --git a/src/warpless_docs/__init__.py b/src/warpless_docs/__init__.py
new file mode 100644
index 0000000..d3ec452
--- /dev/null
+++ b/src/warpless_docs/__init__.py
@@ -0,0 +1 @@
+__version__ = "0.2.0"
diff --git a/shadow_removal/__init__.py b/src/warpless_docs/shadow_removal/__init__.py
similarity index 58%
rename from shadow_removal/__init__.py
rename to src/warpless_docs/shadow_removal/__init__.py
index f9583cd..662b522 100644
--- a/shadow_removal/__init__.py
+++ b/src/warpless_docs/shadow_removal/__init__.py
@@ -1,3 +1,3 @@
from .docshadow_onnx import DocShadowONNXRemover
-__all__ = ["DocShadowONNXRemover"]
\ No newline at end of file
+__all__ = ["DocShadowONNXRemover"]
diff --git a/shadow_removal/docshadow_onnx.py b/src/warpless_docs/shadow_removal/docshadow_onnx.py
similarity index 76%
rename from shadow_removal/docshadow_onnx.py
rename to src/warpless_docs/shadow_removal/docshadow_onnx.py
index 8f3d712..849ac1c 100644
--- a/shadow_removal/docshadow_onnx.py
+++ b/src/warpless_docs/shadow_removal/docshadow_onnx.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from pathlib import Path
from typing import Optional, Tuple, Union
@@ -10,6 +12,8 @@
class DocShadowONNXRemover:
+ """Document shadow-removal backend powered by a DocShadow ONNX model."""
+
def __init__(
self,
model_path: Union[str, Path] = "models/docshadow_sd7k.onnx",
@@ -22,28 +26,28 @@ def __init__(
if not self.model_path.exists():
raise FileNotFoundError(
f"ONNX model not found: {self.model_path}\n"
- "Put docshadow_sd7k.onnx inside the models folder."
+ "Download docshadow_sd7k.onnx into the models folder first."
)
- self.providers = self._select_providers(prefer_gpu)
-
- self.session = ort.InferenceSession(
- str(self.model_path),
- providers=self.providers,
- )
-
+ self.providers = self._select_providers(prefer_gpu=prefer_gpu)
+ self.session = ort.InferenceSession(str(self.model_path), providers=self.providers)
self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name
@staticmethod
def _select_providers(prefer_gpu: bool) -> list[str]:
available = ort.get_available_providers()
- providers = []
+ providers: list[str] = []
if prefer_gpu and "CUDAExecutionProvider" in available:
providers.append("CUDAExecutionProvider")
- providers.append("CPUExecutionProvider")
+ if "CPUExecutionProvider" in available:
+ providers.append("CPUExecutionProvider")
+
+ if not providers:
+ raise RuntimeError(f"No usable ONNX Runtime provider found: {available}")
+
return providers
@staticmethod
@@ -57,10 +61,8 @@ def _to_pil_rgb(image: ImageInput) -> Image.Image:
if isinstance(image, np.ndarray):
if image.ndim != 3:
raise ValueError("Expected image array with shape HxWxC.")
-
if image.shape[2] == 4:
image = image[:, :, :3]
-
return Image.fromarray(image.astype(np.uint8)).convert("RGB")
raise TypeError(f"Unsupported input type: {type(image)}")
@@ -68,47 +70,31 @@ def _to_pil_rgb(image: ImageInput) -> Image.Image:
@staticmethod
def _preprocess(image: Image.Image) -> np.ndarray:
arr = np.asarray(image).astype(np.float32) / 255.0
-
- # HWC to NCHW
- arr = arr.transpose(2, 0, 1)
- arr = np.expand_dims(arr, axis=0)
-
- return arr.astype(np.float32)
+ return arr.transpose(2, 0, 1)[None].astype(np.float32)
@staticmethod
def _postprocess(output: np.ndarray) -> Image.Image:
if output.ndim != 4:
raise ValueError(f"Expected output shape NCHW, got: {output.shape}")
- arr = output[0]
-
- # CHW to HWC
- arr = arr.transpose(1, 2, 0)
-
+ arr = output[0].transpose(1, 2, 0)
if arr.max() <= 1.5:
arr = arr * 255.0
arr = np.clip(arr, 0, 255).astype(np.uint8)
-
return Image.fromarray(arr, mode="RGB")
def remove_shadow(self, image: ImageInput) -> Image.Image:
pil_image = self._to_pil_rgb(image)
original_size = pil_image.size
+ model_image = pil_image
if self.input_size is not None:
width, height = self.input_size
model_image = pil_image.resize((width, height), Image.Resampling.LANCZOS)
- else:
- model_image = pil_image
input_tensor = self._preprocess(model_image)
-
- output = self.session.run(
- [self.output_name],
- {self.input_name: input_tensor},
- )[0]
-
+ output = self.session.run([self.output_name], {self.input_name: input_tensor})[0]
result = self._postprocess(output)
if result.size != original_size:
@@ -123,8 +109,5 @@ def remove_shadow_from_path(
) -> Path:
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
-
- result = self.remove_shadow(input_path)
- result.save(output_path)
-
- return output_path
\ No newline at end of file
+ self.remove_shadow(input_path).save(output_path)
+ return output_path
diff --git a/src/warpless_docs/template_rectification/__init__.py b/src/warpless_docs/template_rectification/__init__.py
new file mode 100644
index 0000000..6738598
--- /dev/null
+++ b/src/warpless_docs/template_rectification/__init__.py
@@ -0,0 +1,3 @@
+from .template_rectifier import TemplateRectifier, TemplateRectifierConfig, TemplateRectificationResult
+
+__all__ = ["TemplateRectifier", "TemplateRectifierConfig", "TemplateRectificationResult"]
diff --git a/src/warpless_docs/template_rectification/template_rectifier.py b/src/warpless_docs/template_rectification/template_rectifier.py
new file mode 100644
index 0000000..eb75b95
--- /dev/null
+++ b/src/warpless_docs/template_rectification/template_rectifier.py
@@ -0,0 +1,497 @@
+from __future__ import annotations
+
+import json
+from dataclasses import asdict, dataclass
+from pathlib import Path
+from typing import Literal, Optional, Union
+
+import cv2
+import numpy as np
+from PIL import Image
+
+
+ImageInput = Union[str, Path, Image.Image, np.ndarray]
+FeatureMethod = Literal["auto", "sift", "orb"]
+
+
+@dataclass
+class TemplateRectifierConfig:
+ """Configuration for template-based document rectification."""
+
+ max_image_side: int = 1800
+ feature_method: FeatureMethod = "auto"
+
+ sift_features: int = 6000
+ orb_features: int = 10000
+ ratio_test: float = 0.78
+ min_good_matches: int = 35
+ min_inliers: int = 20
+ ransac_reprojection_threshold: float = 4.0
+
+ use_piecewise_warp: bool = True
+ piecewise_grid_cols: int = 5
+ piecewise_grid_rows: int = 8
+ piecewise_min_points: int = 45
+
+ output_match_debug: bool = True
+ max_debug_matches: int = 160
+
+
+@dataclass
+class TemplateRectificationResult:
+ template_width: int
+ template_height: int
+ input_width: int
+ input_height: int
+ mode: str
+ feature_method: str
+ total_keypoints_template: int
+ total_keypoints_input: int
+ good_matches: int
+ inlier_matches: int
+ inlier_ratio: float
+ homography: list[list[float]]
+
+ def to_dict(self) -> dict:
+ return asdict(self)
+
+ def save_json(self, output_path: Union[str, Path]) -> Path:
+ output_path = Path(output_path)
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ output_path.write_text(
+ json.dumps(self.to_dict(), ensure_ascii=False, indent=2),
+ encoding="utf-8",
+ )
+ return output_path
+
+
+class TemplateRectifier:
+ """
+ Align a photographed/deshadowed form to a clean flat template.
+
+ The template image defines the target coordinate system. The rectified output
+ always has the exact same width and height as the template image.
+ """
+
+ def __init__(
+ self,
+ template_path: Union[str, Path],
+ config: Optional[TemplateRectifierConfig] = None,
+ ) -> None:
+ self.template_path = Path(template_path)
+ self.config = config or TemplateRectifierConfig()
+
+ if not self.template_path.exists():
+ raise FileNotFoundError(f"Template image not found: {self.template_path}")
+
+ self.template_rgb = self._resize_if_needed(self._to_rgb_array(self.template_path))
+ self.template_gray = self._prepare_gray(self.template_rgb)
+
+ def rectify(
+ self,
+ image: ImageInput,
+ ) -> tuple[np.ndarray, TemplateRectificationResult, Optional[np.ndarray]]:
+ input_rgb = self._resize_if_needed(self._to_rgb_array(image))
+ input_gray = self._prepare_gray(input_rgb)
+
+ feature_method, template_kp, template_desc, input_kp, input_desc = self._detect_features(
+ template_gray=self.template_gray,
+ input_gray=input_gray,
+ )
+
+ if template_desc is None or input_desc is None:
+ raise RuntimeError("Could not compute enough feature descriptors for template or input image.")
+
+ good_matches = self._match_descriptors(feature_method, input_desc, template_desc)
+
+ if len(good_matches) < self.config.min_good_matches:
+ raise RuntimeError(
+ f"Not enough template matches: {len(good_matches)}. "
+ f"Need at least {self.config.min_good_matches}."
+ )
+
+ input_points = np.float32([input_kp[m.queryIdx].pt for m in good_matches])
+ template_points = np.float32([template_kp[m.trainIdx].pt for m in good_matches])
+
+ homography, inlier_mask = cv2.findHomography(
+ input_points,
+ template_points,
+ cv2.RANSAC,
+ ransacReprojThreshold=self.config.ransac_reprojection_threshold,
+ maxIters=6000,
+ confidence=0.995,
+ )
+
+ if homography is None or inlier_mask is None:
+ raise RuntimeError("Could not estimate homography from input image to template.")
+
+ inlier_mask_flat = inlier_mask.ravel().astype(bool)
+ inlier_count = int(inlier_mask_flat.sum())
+
+ if inlier_count < self.config.min_inliers:
+ raise RuntimeError(
+ f"Not enough inlier matches: {inlier_count}. Need at least {self.config.min_inliers}."
+ )
+
+ template_height, template_width = self.template_rgb.shape[:2]
+ coarse = cv2.warpPerspective(
+ input_rgb,
+ homography,
+ (template_width, template_height),
+ flags=cv2.INTER_CUBIC,
+ borderMode=cv2.BORDER_REPLICATE,
+ )
+
+ mode = "homography"
+ rectified = coarse
+
+ inlier_input_points = input_points[inlier_mask_flat]
+ inlier_template_points = template_points[inlier_mask_flat]
+
+ if self.config.use_piecewise_warp and inlier_count >= self.config.piecewise_min_points:
+ try:
+ rectified = self._piecewise_affine_warp(
+ input_rgb=input_rgb,
+ homography=homography,
+ source_points=inlier_input_points,
+ target_points=inlier_template_points,
+ output_size=(template_width, template_height),
+ )
+ mode = "piecewise_affine"
+ except Exception:
+ rectified = coarse
+ mode = "homography_fallback"
+
+ debug_image = None
+ if self.config.output_match_debug:
+ debug_image = self._draw_debug_matches(
+ input_rgb=input_rgb,
+ input_kp=input_kp,
+ template_rgb=self.template_rgb,
+ template_kp=template_kp,
+ matches=good_matches,
+ inlier_mask=inlier_mask_flat,
+ )
+
+ result = TemplateRectificationResult(
+ template_width=template_width,
+ template_height=template_height,
+ input_width=input_rgb.shape[1],
+ input_height=input_rgb.shape[0],
+ mode=mode,
+ feature_method=feature_method,
+ total_keypoints_template=len(template_kp),
+ total_keypoints_input=len(input_kp),
+ good_matches=len(good_matches),
+ inlier_matches=inlier_count,
+ inlier_ratio=float(inlier_count / max(1, len(good_matches))),
+ homography=homography.astype(float).tolist(),
+ )
+
+ return rectified, result, debug_image
+
+ def rectify_path(
+ self,
+ input_path: Union[str, Path],
+ output_path: Union[str, Path],
+ output_json_path: Optional[Union[str, Path]] = None,
+ output_debug_path: Optional[Union[str, Path]] = None,
+ ) -> TemplateRectificationResult:
+ rectified, result, debug = self.rectify(input_path)
+ self.save_image(rectified, output_path)
+
+ if output_json_path is not None:
+ result.save_json(output_json_path)
+
+ if output_debug_path is not None and debug is not None:
+ self.save_image(debug, output_debug_path)
+
+ return result
+
+ def _detect_features(
+ self,
+ template_gray: np.ndarray,
+ input_gray: np.ndarray,
+ ) -> tuple[str, list[cv2.KeyPoint], np.ndarray, list[cv2.KeyPoint], np.ndarray]:
+ preferred_methods: list[str]
+ if self.config.feature_method == "auto":
+ preferred_methods = ["sift", "orb"]
+ else:
+ preferred_methods = [self.config.feature_method]
+
+ last_error: Optional[Exception] = None
+
+ for method in preferred_methods:
+ try:
+ if method == "sift" and hasattr(cv2, "SIFT_create"):
+ detector = cv2.SIFT_create(
+ nfeatures=self.config.sift_features,
+ contrastThreshold=0.015,
+ edgeThreshold=10,
+ )
+ elif method == "orb":
+ detector = cv2.ORB_create(
+ nfeatures=self.config.orb_features,
+ scaleFactor=1.2,
+ nlevels=8,
+ edgeThreshold=15,
+ patchSize=31,
+ fastThreshold=7,
+ )
+ else:
+ continue
+
+ template_kp, template_desc = detector.detectAndCompute(template_gray, None)
+ input_kp, input_desc = detector.detectAndCompute(input_gray, None)
+
+ if template_desc is None or input_desc is None:
+ continue
+ if len(template_kp) < 10 or len(input_kp) < 10:
+ continue
+
+ return method, template_kp, template_desc, input_kp, input_desc
+ except Exception as exc:
+ last_error = exc
+
+ if last_error is not None:
+ raise RuntimeError(f"Feature detection failed: {last_error}") from last_error
+
+ raise RuntimeError("No usable feature detector found. Try installing a recent opencv-python.")
+
+ def _match_descriptors(
+ self,
+ method: str,
+ input_desc: np.ndarray,
+ template_desc: np.ndarray,
+ ) -> list[cv2.DMatch]:
+ if method == "sift":
+ matcher = cv2.BFMatcher(cv2.NORM_L2, crossCheck=False)
+ else:
+ matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
+
+ raw_matches = matcher.knnMatch(input_desc, template_desc, k=2)
+
+ good_matches: list[cv2.DMatch] = []
+ for pair in raw_matches:
+ if len(pair) != 2:
+ continue
+ best, second = pair
+ if best.distance < self.config.ratio_test * second.distance:
+ good_matches.append(best)
+
+ good_matches.sort(key=lambda match: match.distance)
+ return good_matches
+
+ def _piecewise_affine_warp(
+ self,
+ input_rgb: np.ndarray,
+ homography: np.ndarray,
+ source_points: np.ndarray,
+ target_points: np.ndarray,
+ output_size: tuple[int, int],
+ ) -> np.ndarray:
+ output_width, output_height = output_size
+
+ target_points, source_points = self._add_mesh_anchor_points(
+ target_points=target_points,
+ source_points=source_points,
+ homography=homography,
+ output_width=output_width,
+ output_height=output_height,
+ )
+
+ triangles = self._build_delaunay_triangles(target_points, output_width, output_height)
+ canvas = np.zeros((output_height, output_width, 3), dtype=np.uint8)
+
+ for triangle_indices in triangles:
+ dst_tri = np.float32([target_points[i] for i in triangle_indices])
+ src_tri = np.float32([source_points[i] for i in triangle_indices])
+ self._warp_triangle(input_rgb, canvas, src_tri, dst_tri)
+
+ coarse = cv2.warpPerspective(
+ input_rgb,
+ homography,
+ (output_width, output_height),
+ flags=cv2.INTER_CUBIC,
+ borderMode=cv2.BORDER_REPLICATE,
+ )
+ empty_mask = np.all(canvas == 0, axis=2)
+ canvas[empty_mask] = coarse[empty_mask]
+ return canvas
+
+ def _add_mesh_anchor_points(
+ self,
+ target_points: np.ndarray,
+ source_points: np.ndarray,
+ homography: np.ndarray,
+ output_width: int,
+ output_height: int,
+ ) -> tuple[np.ndarray, np.ndarray]:
+ grid_points = []
+ for row in range(self.config.piecewise_grid_rows + 1):
+ y = output_height * row / self.config.piecewise_grid_rows
+ for col in range(self.config.piecewise_grid_cols + 1):
+ x = output_width * col / self.config.piecewise_grid_cols
+ grid_points.append([x, y])
+
+ grid_points.extend(
+ [
+ [0, 0],
+ [output_width - 1, 0],
+ [output_width - 1, output_height - 1],
+ [0, output_height - 1],
+ [output_width / 2, 0],
+ [output_width - 1, output_height / 2],
+ [output_width / 2, output_height - 1],
+ [0, output_height / 2],
+ ]
+ )
+
+ anchors_target = np.float32(grid_points)
+ inv_h = np.linalg.inv(homography)
+ anchors_source = cv2.perspectiveTransform(anchors_target.reshape(-1, 1, 2), inv_h).reshape(-1, 2)
+
+ return (
+ np.vstack([target_points, anchors_target]).astype(np.float32),
+ np.vstack([source_points, anchors_source]).astype(np.float32),
+ )
+
+ @staticmethod
+ def _build_delaunay_triangles(points: np.ndarray, width: int, height: int) -> list[tuple[int, int, int]]:
+ subdiv = cv2.Subdiv2D((0, 0, width, height))
+
+ clipped_points = []
+ for point in points:
+ x = float(np.clip(point[0], 0, width - 1))
+ y = float(np.clip(point[1], 0, height - 1))
+ clipped_points.append((x, y))
+ subdiv.insert((x, y))
+
+ triangle_list = subdiv.getTriangleList()
+ point_array = np.array(clipped_points, dtype=np.float32)
+
+ triangles: list[tuple[int, int, int]] = []
+ seen: set[tuple[int, int, int]] = set()
+
+ for triangle in triangle_list:
+ coords = np.array(
+ [[triangle[0], triangle[1]], [triangle[2], triangle[3]], [triangle[4], triangle[5]]],
+ dtype=np.float32,
+ )
+
+ if np.any(coords[:, 0] < 0) or np.any(coords[:, 0] >= width):
+ continue
+ if np.any(coords[:, 1] < 0) or np.any(coords[:, 1] >= height):
+ continue
+
+ indices = []
+ for coord in coords:
+ distances = np.linalg.norm(point_array - coord, axis=1)
+ indices.append(int(np.argmin(distances)))
+
+ if len(set(indices)) != 3:
+ continue
+
+ key = tuple(sorted(indices))
+ if key in seen:
+ continue
+
+ seen.add(key)
+ triangles.append(tuple(indices))
+
+ return triangles
+
+ @staticmethod
+ def _warp_triangle(source: np.ndarray, destination: np.ndarray, src_tri: np.ndarray, dst_tri: np.ndarray) -> None:
+ src_rect = cv2.boundingRect(src_tri)
+ dst_rect = cv2.boundingRect(dst_tri)
+
+ x, y, w, h = src_rect
+ dx, dy, dw, dh = dst_rect
+ if w <= 0 or h <= 0 or dw <= 0 or dh <= 0:
+ return
+
+ source_patch = source[y : y + h, x : x + w]
+ if source_patch.size == 0:
+ return
+
+ src_offset = np.float32([[p[0] - x, p[1] - y] for p in src_tri])
+ dst_offset = np.float32([[p[0] - dx, p[1] - dy] for p in dst_tri])
+ transform = cv2.getAffineTransform(src_offset, dst_offset)
+
+ warped_patch = cv2.warpAffine(
+ source_patch,
+ transform,
+ (dw, dh),
+ flags=cv2.INTER_CUBIC,
+ borderMode=cv2.BORDER_REFLECT_101,
+ )
+
+ mask = np.zeros((dh, dw, 3), dtype=np.float32)
+ cv2.fillConvexPoly(mask, np.int32(dst_offset), (1.0, 1.0, 1.0), lineType=cv2.LINE_AA)
+
+ destination_roi = destination[dy : dy + dh, dx : dx + dw]
+ if destination_roi.shape[:2] != warped_patch.shape[:2]:
+ return
+
+ blended = destination_roi.astype(np.float32) * (1.0 - mask) + warped_patch.astype(np.float32) * mask
+ destination[dy : dy + dh, dx : dx + dw] = np.clip(blended, 0, 255).astype(np.uint8)
+
+ def _draw_debug_matches(
+ self,
+ input_rgb: np.ndarray,
+ input_kp: list[cv2.KeyPoint],
+ template_rgb: np.ndarray,
+ template_kp: list[cv2.KeyPoint],
+ matches: list[cv2.DMatch],
+ inlier_mask: np.ndarray,
+ ) -> np.ndarray:
+ draw_matches = [match for index, match in enumerate(matches) if inlier_mask[index]]
+ draw_matches = draw_matches[: self.config.max_debug_matches]
+ draw_mask = [1] * len(draw_matches)
+
+ debug_bgr = cv2.drawMatches(
+ cv2.cvtColor(input_rgb, cv2.COLOR_RGB2BGR),
+ input_kp,
+ cv2.cvtColor(template_rgb, cv2.COLOR_RGB2BGR),
+ template_kp,
+ draw_matches,
+ None,
+ matchesMask=draw_mask,
+ flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS,
+ )
+ return cv2.cvtColor(debug_bgr, cv2.COLOR_BGR2RGB)
+
+ @staticmethod
+ def save_image(rgb: np.ndarray, output_path: Union[str, Path]) -> Path:
+ output_path = Path(output_path)
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ cv2.imwrite(str(output_path), cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR))
+ return output_path
+
+ @staticmethod
+ def _to_rgb_array(image: ImageInput) -> np.ndarray:
+ if isinstance(image, (str, Path)):
+ return np.asarray(Image.open(image).convert("RGB"))
+ if isinstance(image, Image.Image):
+ return np.asarray(image.convert("RGB"))
+ if isinstance(image, np.ndarray):
+ if image.ndim != 3:
+ raise ValueError("Expected image array with shape HxWxC.")
+ if image.shape[2] == 4:
+ image = image[:, :, :3]
+ return image.astype(np.uint8)
+ raise TypeError(f"Unsupported image input type: {type(image)}")
+
+ @staticmethod
+ def _prepare_gray(rgb: np.ndarray) -> np.ndarray:
+ gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY)
+ return cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(gray)
+
+ def _resize_if_needed(self, rgb: np.ndarray) -> np.ndarray:
+ height, width = rgb.shape[:2]
+ max_side = max(height, width)
+ if max_side <= self.config.max_image_side:
+ return rgb
+ scale = self.config.max_image_side / float(max_side)
+ return cv2.resize(rgb, (max(1, int(width * scale)), max(1, int(height * scale))), interpolation=cv2.INTER_AREA)
diff --git a/src/warpless_docs/warp_detection/__init__.py b/src/warpless_docs/warp_detection/__init__.py
new file mode 100644
index 0000000..e6cdadc
--- /dev/null
+++ b/src/warpless_docs/warp_detection/__init__.py
@@ -0,0 +1,13 @@
+from .warp_detector import (
+ DocumentWarpDetector,
+ WarpDetectionResult,
+ WarpDetectorConfig,
+ WarpFinding,
+)
+
+__all__ = [
+ "DocumentWarpDetector",
+ "WarpDetectionResult",
+ "WarpDetectorConfig",
+ "WarpFinding",
+]
diff --git a/src/warpless_docs/warp_detection/warp_detector.py b/src/warpless_docs/warp_detection/warp_detector.py
index 19741f1..dc1d9b1 100644
--- a/src/warpless_docs/warp_detection/warp_detector.py
+++ b/src/warpless_docs/warp_detection/warp_detector.py
@@ -9,7 +9,6 @@
import numpy as np
from PIL import Image
-
ImageInput = Union[str, Path, Image.Image, np.ndarray]
AxisType = Literal["horizontal", "vertical"]
SeverityType = Literal["mild", "medium", "strong"]
@@ -17,6 +16,8 @@
@dataclass
class WarpDetectorConfig:
+ """Line-based detector for curved document structures."""
+
max_image_side: int = 1800
adaptive_block_size: int = 41
adaptive_c: int = 13
@@ -24,18 +25,18 @@ class WarpDetectorConfig:
vertical_kernel_ratio: float = 0.018
bins: int = 110
peak_density: float = 0.10
- max_peaks_per_bin: int = 80
- link_tolerance_px: float = 18.0
- max_missing_bins: int = 6
- min_track_bins_ratio: float = 0.18
- min_track_span_ratio: float = 0.18
+ max_peaks_per_bin: int = 90
+ link_tolerance_px: float = 20.0
+ max_missing_bins: int = 7
+ min_track_bins_ratio: float = 0.16
+ min_track_span_ratio: float = 0.16
min_deviation_px: float = 2.0
- min_angle_change_deg: float = 1.2
+ min_angle_change_deg: float = 1.0
medium_deviation_px: float = 6.0
strong_deviation_px: float = 12.0
medium_angle_change_deg: float = 3.5
strong_angle_change_deg: float = 7.0
- max_findings_to_draw: int = 40
+ max_findings_to_draw: int = 45
@dataclass
@@ -95,7 +96,13 @@ def save_json(self, output_path: Union[str, Path]) -> Path:
class DocumentWarpDetector:
- """Detect curved horizontal/vertical document structures before dewarping."""
+ """
+ Detect curved horizontal/vertical document structures before dewarping.
+
+ The module slices the page, finds line peaks inside each slice, links peaks
+ into centerlines, then measures how far each centerline deviates from a
+ straight line.
+ """
def __init__(self, config: Optional[WarpDetectorConfig] = None) -> None:
self.config = config or WarpDetectorConfig()
@@ -131,7 +138,6 @@ def analyze_path(
if output_overlay_path is not None:
self.save_overlay(overlay, output_overlay_path)
-
if output_json_path is not None:
result.save_json(output_json_path)
@@ -159,10 +165,7 @@ def draw_overlay(self, rgb: np.ndarray, result: WarpDetectionResult) -> np.ndarr
label = f"{finding.axis} {finding.severity} dev={finding.max_deviation_px:.1f}px angle={finding.angle_change_deg:.1f}"
cv2.putText(overlay, label, (x, max(18, y - 6)), cv2.FONT_HERSHEY_SIMPLEX, 0.42, color, 1, cv2.LINE_AA)
- summary = (
- f"Warp findings: {result.total_findings} | "
- f"strong={result.strong_count} medium={result.medium_count} mild={result.mild_count}"
- )
+ summary = f"Warp findings: {result.total_findings} | strong={result.strong_count} medium={result.medium_count} mild={result.mild_count}"
cv2.rectangle(overlay, (12, 12), (min(result.image_width - 12, 740), 50), (0, 0, 0), -1)
cv2.putText(overlay, summary, (24, 38), cv2.FONT_HERSHEY_SIMPLEX, 0.68, (255, 255, 255), 2, cv2.LINE_AA)
return overlay
@@ -181,14 +184,12 @@ def _to_rgb_array(self, image: ImageInput) -> np.ndarray:
raise TypeError(f"Unsupported image input type: {type(image)}")
def _resize_if_needed(self, rgb: np.ndarray) -> np.ndarray:
- height, width = rgb.shape[:2]
- max_side = max(height, width)
-
+ h, w = rgb.shape[:2]
+ max_side = max(h, w)
if max_side <= self.config.max_image_side:
return rgb
-
scale = self.config.max_image_side / float(max_side)
- return cv2.resize(rgb, (int(width * scale), int(height * scale)), interpolation=cv2.INTER_AREA)
+ return cv2.resize(rgb, (max(1, int(w * scale)), max(1, int(h * scale))), interpolation=cv2.INTER_AREA)
def _make_binary_mask(self, rgb: np.ndarray) -> np.ndarray:
gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY)
@@ -206,29 +207,28 @@ def _make_binary_mask(self, rgb: np.ndarray) -> np.ndarray:
block_size,
self.config.adaptive_c,
)
- edges = cv2.Canny(gray, 50, 150)
+ edges = cv2.Canny(gray, 45, 145)
binary = cv2.bitwise_or(adaptive, edges)
return cv2.morphologyEx(binary, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2)), iterations=1)
def _extract_axis_mask(self, binary: np.ndarray, axis: AxisType) -> np.ndarray:
- height, width = binary.shape[:2]
-
+ h, w = binary.shape[:2]
if axis == "horizontal":
- kernel_width = max(9, int(width * self.config.horizontal_kernel_ratio))
- open_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_width, 1))
+ kernel_size = max(9, int(w * self.config.horizontal_kernel_ratio))
+ open_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, 1))
reconnect_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 3))
else:
- kernel_height = max(9, int(height * self.config.vertical_kernel_ratio))
- open_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, kernel_height))
+ kernel_size = max(9, int(h * self.config.vertical_kernel_ratio))
+ open_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, kernel_size))
reconnect_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 7))
mask = cv2.morphologyEx(binary, cv2.MORPH_OPEN, open_kernel, iterations=1)
return cv2.dilate(mask, reconnect_kernel, iterations=1)
def _track_axis(self, mask: np.ndarray, axis: AxisType) -> list[WarpFinding]:
- height, width = mask.shape[:2]
+ h, w = mask.shape[:2]
bins = max(20, self.config.bins)
- scan_length = width if axis == "horizontal" else height
+ scan_length = w if axis == "horizontal" else h
edges = np.linspace(0, scan_length, bins + 1).astype(int)
active_tracks: list[dict] = []
@@ -239,7 +239,7 @@ def _track_axis(self, mask: np.ndarray, axis: AxisType) -> list[WarpFinding]:
continue
peaks = self._find_peaks_in_bin(mask, axis, int(start), int(end))
- used_peaks: set[int] = set()
+ used: set[int] = set()
next_active: list[dict] = []
for track in active_tracks:
@@ -252,7 +252,7 @@ def _track_axis(self, mask: np.ndarray, axis: AxisType) -> list[WarpFinding]:
best_distance = float("inf")
for peak_index, peak in enumerate(peaks):
- if peak_index in used_peaks:
+ if peak_index in used:
continue
distance = abs(float(peak["center"]) - float(last_cross))
if distance < best_distance:
@@ -262,18 +262,16 @@ def _track_axis(self, mask: np.ndarray, axis: AxisType) -> list[WarpFinding]:
if best_index is not None and best_distance <= self.config.link_tolerance_px:
track["points"].append(self._make_point(axis, (start + end) / 2.0, float(peaks[best_index]["center"])))
track["last_bin"] = bin_index
- used_peaks.add(best_index)
+ used.add(best_index)
next_active.append(track)
for peak_index, peak in enumerate(peaks):
- if peak_index not in used_peaks:
- next_active.append(
- {
- "points": [self._make_point(axis, (start + end) / 2.0, float(peak["center"]))],
- "last_bin": bin_index,
- }
- )
+ if peak_index not in used:
+ next_active.append({
+ "points": [self._make_point(axis, (start + end) / 2.0, float(peak["center"]))],
+ "last_bin": bin_index,
+ })
active_tracks = next_active
@@ -292,37 +290,34 @@ def _track_axis(self, mask: np.ndarray, axis: AxisType) -> list[WarpFinding]:
if metrics is None:
continue
- max_deviation_px, mean_deviation_px, angle_change_deg, length_px = metrics
- if length_px < min_span:
+ max_dev, mean_dev, angle_change, length = metrics
+ if length < min_span:
continue
-
- if max_deviation_px < self.config.min_deviation_px and angle_change_deg < self.config.min_angle_change_deg:
+ if max_dev < self.config.min_deviation_px and angle_change < self.config.min_angle_change_deg:
continue
- severity = self._classify_severity(max_deviation_px, angle_change_deg)
+ severity = self._classify_severity(max_dev, angle_change)
pts = np.array(points, dtype=np.float32)
x_min, y_min = np.floor(pts.min(axis=0)).astype(int)
x_max, y_max = np.ceil(pts.max(axis=0)).astype(int)
pad = 6
- x = max(0, x_min - pad)
- y = max(0, y_min - pad)
- w = min(width - x, x_max - x_min + 1 + pad * 2)
- h = min(height - y, y_max - y_min + 1 + pad * 2)
-
- findings.append(
- WarpFinding(
- axis=axis,
- severity=severity,
- bbox=(int(x), int(y), int(w), int(h)),
- max_deviation_px=float(max_deviation_px),
- mean_deviation_px=float(mean_deviation_px),
- angle_change_deg=float(angle_change_deg),
- length_px=float(length_px),
- coverage_ratio=float(length_px / max(1.0, scan_length)),
- point_count=len(points),
- centerline=[(int(round(px)), int(round(py))) for px, py in points],
- )
- )
+ x = max(0, int(x_min - pad))
+ y = max(0, int(y_min - pad))
+ bw = min(w - x, int(x_max - x_min + 1 + pad * 2))
+ bh = min(h - y, int(y_max - y_min + 1 + pad * 2))
+
+ findings.append(WarpFinding(
+ axis=axis,
+ severity=severity,
+ bbox=(x, y, bw, bh),
+ max_deviation_px=float(max_dev),
+ mean_deviation_px=float(mean_dev),
+ angle_change_deg=float(angle_change),
+ length_px=float(length),
+ coverage_ratio=float(length / max(1.0, scan_length)),
+ point_count=len(points),
+ centerline=[(int(round(px)), int(round(py))) for px, py in points],
+ ))
return findings
@@ -336,9 +331,7 @@ def _find_peaks_in_bin(self, mask: np.ndarray, axis: AxisType, start: int, end:
return []
projection = projection.astype(np.float32)
- kernel = np.ones(7, dtype=np.float32) / 7.0
- projection = np.convolve(projection, kernel, mode="same")
-
+ projection = np.convolve(projection, np.ones(7, dtype=np.float32) / 7.0, mode="same")
threshold = max(2.0, max(1, end - start) * self.config.peak_density)
active = np.where(projection >= threshold)[0]
if active.size == 0:
@@ -354,12 +347,13 @@ def _find_peaks_in_bin(self, mask: np.ndarray, axis: AxisType, start: int, end:
peaks: list[dict] = []
for group in groups:
- indexes = np.array(group, dtype=np.float32)
- weights = projection[np.array(group, dtype=np.int32)]
+ idx = np.array(group, dtype=np.float32)
+ idx_int = np.array(group, dtype=np.int32)
+ weights = projection[idx_int]
total = float(weights.sum())
if total <= 0:
continue
- peaks.append({"center": float((indexes * weights).sum() / total), "strength": float(weights.max())})
+ peaks.append({"center": float((idx * weights).sum() / total), "strength": float(weights.max())})
peaks.sort(key=lambda item: item["strength"], reverse=True)
return peaks[: self.config.max_peaks_per_bin]
@@ -367,8 +361,8 @@ def _find_peaks_in_bin(self, mask: np.ndarray, axis: AxisType, start: int, end:
@staticmethod
def _make_point(axis: AxisType, scan_center: float, cross_center: float) -> tuple[float, float]:
if axis == "horizontal":
- return (scan_center, cross_center)
- return (cross_center, scan_center)
+ return scan_center, cross_center
+ return cross_center, scan_center
@staticmethod
def _smooth_track(points: list[tuple[float, float]]) -> list[tuple[float, float]]:
@@ -377,13 +371,11 @@ def _smooth_track(points: list[tuple[float, float]]) -> list[tuple[float, float]
arr = np.array(points, dtype=np.float32)
kernel = np.ones(5, dtype=np.float32) / 5.0
- smooth_x = np.convolve(arr[:, 0], kernel, mode="same")
- smooth_y = np.convolve(arr[:, 1], kernel, mode="same")
- smooth_x[:2] = arr[:2, 0]
- smooth_y[:2] = arr[:2, 1]
- smooth_x[-2:] = arr[-2:, 0]
- smooth_y[-2:] = arr[-2:, 1]
- return list(zip(smooth_x.tolist(), smooth_y.tolist()))
+ sx = np.convolve(arr[:, 0], kernel, mode="same")
+ sy = np.convolve(arr[:, 1], kernel, mode="same")
+ sx[:2], sy[:2] = arr[:2, 0], arr[:2, 1]
+ sx[-2:], sy[-2:] = arr[-2:, 0], arr[-2:, 1]
+ return list(zip(sx.tolist(), sy.tolist()))
def _measure_curve(self, points: list[tuple[float, float]], axis: AxisType) -> Optional[tuple[float, float, float, float]]:
if len(points) < 6:
@@ -408,8 +400,8 @@ def _measure_curve(self, points: list[tuple[float, float]], axis: AxisType) -> O
return None
residual = dependent - np.polyval(line_coeff, normalized)
- max_deviation = float(np.max(np.abs(residual)))
- mean_deviation = float(np.mean(np.abs(residual)))
+ max_dev = float(np.max(np.abs(residual)))
+ mean_dev = float(np.mean(np.abs(residual)))
a, b, _ = curve_coeff
left = float(normalized.min())
@@ -417,8 +409,7 @@ def _measure_curve(self, points: list[tuple[float, float]], axis: AxisType) -> O
slope_left = (2.0 * a * left + b) / span
slope_right = (2.0 * a * right + b) / span
angle_change = abs(float(np.degrees(np.arctan(slope_right)) - np.degrees(np.arctan(slope_left))))
-
- return max_deviation, mean_deviation, angle_change, span
+ return max_dev, mean_dev, angle_change, span
def _dedupe_findings(self, findings: list[WarpFinding]) -> list[WarpFinding]:
deduped: list[WarpFinding] = []
@@ -470,7 +461,7 @@ def _severity_rank(severity: SeverityType) -> int:
@staticmethod
def _severity_color(severity: SeverityType) -> tuple[int, int, int]:
if severity == "strong":
- return (255, 60, 60)
+ return 255, 60, 60
if severity == "medium":
- return (255, 185, 40)
- return (60, 190, 255)
+ return 255, 185, 40
+ return 60, 190, 255