Skip to content

lit-wsl

Build Documentation PyPI - Package Version PyPI - Python Version Docs with MkDocs uv linting: ruff ty prek security: bandit Semantic Versions Copier License

My personal library of reusable Pytorch Lightning components

Features

Installation

With pip:

python -m pip install lit-wsl

With uv:

uv add lit-wsl

How to use it

IntermediateLayerGetter

Capture intermediate layer outputs during forward pass:

import torch
from torchvision.models import resnet18
from lit_wsl.models.intermediate_layer_getter import IntermediateLayerGetter

model = resnet18(pretrained=True)
# Specify which layers to capture: {layer_name: output_name}
return_layers = {"layer2": "feat1", "layer4": "feat2"}
layer_getter = IntermediateLayerGetter(model, return_layers, keep_output=True)

x = torch.randn(1, 3, 224, 224)
intermediate_outputs, final_output = layer_getter(x)
# intermediate_outputs is OrderedDict with keys "feat1" and "feat2"
print(intermediate_outputs["feat1"].shape)  # torch.Size([1, 128, 28, 28])

WeightRenamer

Rename keys in checkpoint files:

from lit_wsl.models.weight_renamer import WeightRenamer

# Load checkpoint
renamer = WeightRenamer("old_model.pth")

# Remove common prefix
renamer.remove_prefix("model.")

# Rename specific keys
renamer.rename_keys({
    "backbone.conv1": "encoder.conv1",
    "head.fc": "classifier.fc"
})

# Save modified checkpoint
renamer.save("renamed_model.pth")

WeightMapper

Automatically map weights between different model architectures:

import torch
from lit_wsl.mapper.weight_mapper import WeightMapper
from lit_wsl.models.weight_renamer import WeightRenamer

# Define your models (with different layer names)
old_model = OldModelArchitecture()
new_model = NewModelArchitecture()

# Analyze and suggest mapping
mapper = WeightMapper(old_model, new_model)
mapping, unmatched = mapper.suggest_mapping(threshold=0.6)

# Apply mapping to checkpoint
renamer = WeightRenamer("old_weights.pth")
renamer.rename_keys(mapping)
renamer.save("adapted_weights.pth")

# Load adapted weights
new_model.load_state_dict(torch.load("adapted_weights.pth"))

Docs

uv run mkdocs build -f ./mkdocs.yml -d ./_build/

Update template

copier update --trust -A --vcs-ref=HEAD

Credits

This project was generated with 🚀 python project template.