Code for the paper "Improving the Generalisation of Learned Reconstruction Frameworks" - A deep learning approach for CT image reconstruction using Graph Neural Networks and CNN modules.
This project implements a learned reconstruction framework for computed tomography (CT) imaging that combines:
- Graph neural network for Line Manifolds (GLM): GNNs that process sinogram data using information about the acquisition geometry
- Convolutional Neural Networks: For both sinogram and image domain processing
- Pseudo-inverse operators: Including backprojection and filtered backprojection
- End-to-end training: Joint optimization of sinogram processing and image reconstruction
The framework is designed to work with the 2DeteCT dataset and uses DVC (Data Version Control) for experiment tracking and pipeline management.
- Multi-stage training pipeline (preprocessing → pretraining → training)
- Distributed training support with PyTorch DDP
- Graph-based sinogram processing using PyTorch Geometric
- Integration with ODL (Operator Discretization Library) for tomographic operators
- Experiment tracking with DVCLive
- Support for various pseudo-inverse methods (backprojection, filtered backprojection)
- Python 3.11 or higher
- CUDA-capable GPU (recommended for training)
- uv package manager
-
Clone the repository
git clone https://github.com/Emvlt/GLM cd glm -
Install uv (if not already installed)
curl -LsSf https://astral.sh/uv/install.sh | sh -
Install dependencies
The project uses
uvfor dependency management. Dependencies will be automatically installed when running commands withuv run:uv sync
-
Verify installation
uv run python src/glm/install_test.py
This should print version information for all major dependencies:
- ODL (a non-official version on my fork): Geometry, Tomographic Operator, 2DeteCT dataloader, experimental code for Graph export and Surfaces handling
- ASTRA Toolbox: Tomographic Operator Backend
- PyTorch: Deep Learning
- PyTorch Geometric: Geometric Deep Learning
- imageio: raw data IO
The project automatically installs:
- astra-toolbox: For CT reconstruction algorithms
- torch: Deep learning framework
- torch-geometric: Graph neural network library
- odl: Custom fork for tomographic operators (installed from GitHub)
- dvc & dvclive: Experiment tracking and pipeline management
- matplotlib, imageio: Visualization and image I/O
This project uses the 2DeteCT dataset.
The expected directory structure is:
datasets/
├── raw/
│ └── 2detect/
│ ├── slice00001/
│ │ ├── mode2/
│ │ │ ├── dark.tif
│ │ │ ├── flat1.tif
│ │ │ ├── flat2.tif
│ │ │ ├── sinogram.tif
│ │ │ └── reconstruction.tif
│ │ │ └── segmentation.tif
│ │ └── ...
│ └── ...
└── processed/
└── 2detect/
└── (generated by preprocessing)
params.yaml:
data.raw_path: Path to raw datasetdata.processed_path: Path for processed data
The complete pipeline consists of three stages managed by DVC:
-
Data Preprocessing
uv run dvc repro prepare_training
Preprocesses raw sinograms and reconstructions from the 2DeteCT dataset.
-
Sinogram Model Pretraining
uv run dvc repro pretraining
Pretrains the sinogram processing model (GLM or CNN) in a self-supervised manner for just one epoch.
-
End-to-End Training
uv run dvc repro training
Trains the complete reconstruction pipeline including sinogram processing, pseudo-inverse, and image refinement.
Main configuration file: params.yaml
Data paths:
data:
raw_path: /path/to/raw/2detect
processed_path: /path/to/processed/2detectPretraining hyperparameters:
pretrain_parameters:
hyperparameters:
learning_rate: 5e-4
epochs: 1
batch_size: 8
downsampling: 1 # Angle downsampling factor
active_model: GLM # or sinogram_CNNTraining hyperparameters:
train_parameters:
hyperparameters:
learning_rate: 5e-5
epochs: 40
batch_size: 8
active_pseudo_inverse: filtered_backprojection
active_image_model: image_CNNglm/
├── src/glm/
│ ├── models/
│ │ ├── gnn.py # Graph neural network modules
│ │ ├── cnn.py # Convolutional neural networks
│ │ └── utils.py # Model loading utilities
│ ├── dataset.py # Dataset and dataloader
│ ├── preprocess_2detect.py # Data preprocessing
│ ├── pretrain.py # Sinogram model pretraining
│ ├── train.py # End-to-end training
│ ├── run_demo.py # Demo script
│ └── utils.py # General utilities
├── params.yaml # Configuration file
├── dvc.yaml # DVC pipeline definition
├── dvc.lock # DVC pipeline lock file
└── pyproject.toml # Project dependencies
The project uses DVCLive for experiment tracking. Metrics and plots are saved in the dvclive/ directory:
- Training/validation PSNR
- Loss curves
- Sinogram and reconstruction visualizations
View experiments:
dvc plots showTrained models are saved in:
src/glm/saved_models/pretrained_sinogram_model.pt- Pretrained sinogram processorsrc/glm/saved_models/end_to_end_model.pt- Complete reconstruction model
CUDA out of memory:
- Reduce
batch_sizeinparams.yaml - Reduce
n_channelsin model parameters - Increase
downsamplingto use fewer projection angles
Data loading errors:
- Verify dataset paths in
params.yaml - Check that raw data follows the expected directory structure
- Ensure preprocessing completed successfully
Import errors:
- Run
uv syncto ensure all dependencies are installed - Verify installation with
uv run python src/glm/install_test.py
If you use this code in your research, please cite:
@misc{valat2025improvinggeneralisationlearnedreconstruction,
title={Improving the Generalisation of Learned Reconstruction Frameworks},
author={Emilien Valat and Ozan Öktem},
year={2025},
eprint={2511.12730},
archivePrefix={arXiv},
primaryClass={eess.IV},
url={https://arxiv.org/abs/2511.12730},
}
🔑 Apache License 2.0
For questions or issues, please open an issue on GitHub or contact:
- Email: emilienvalat@gmail.com