lit-wsl
My personal library of reusable Pytorch Lightning components
Features
Installation
With pip:
python -m pip install lit-wsl
With uv:
uv add lit-wsl
How to use it
IntermediateLayerGetter
Capture intermediate layer outputs during forward pass:
import torch
from torchvision.models import resnet18
from lit_wsl.models.intermediate_layer_getter import IntermediateLayerGetter
model = resnet18(pretrained=True)
# Specify which layers to capture: {layer_name: output_name}
return_layers = {"layer2": "feat1", "layer4": "feat2"}
layer_getter = IntermediateLayerGetter(model, return_layers, keep_output=True)
x = torch.randn(1, 3, 224, 224)
intermediate_outputs, final_output = layer_getter(x)
# intermediate_outputs is OrderedDict with keys "feat1" and "feat2"
print(intermediate_outputs["feat1"].shape) # torch.Size([1, 128, 28, 28])
WeightRenamer
Rename keys in checkpoint files:
from lit_wsl.models.weight_renamer import WeightRenamer
# Load checkpoint
renamer = WeightRenamer("old_model.pth")
# Remove common prefix
renamer.remove_prefix("model.")
# Rename specific keys
renamer.rename_keys({
"backbone.conv1": "encoder.conv1",
"head.fc": "classifier.fc"
})
# Save modified checkpoint
renamer.save("renamed_model.pth")
WeightMapper
Automatically map weights between different model architectures:
import torch
from lit_wsl.mapper.weight_mapper import WeightMapper
from lit_wsl.models.weight_renamer import WeightRenamer
# Define your models (with different layer names)
old_model = OldModelArchitecture()
new_model = NewModelArchitecture()
# Analyze and suggest mapping
mapper = WeightMapper(old_model, new_model)
mapping, unmatched = mapper.suggest_mapping(threshold=0.6)
# Apply mapping to checkpoint
renamer = WeightRenamer("old_weights.pth")
renamer.rename_keys(mapping)
renamer.save("adapted_weights.pth")
# Load adapted weights
new_model.load_state_dict(torch.load("adapted_weights.pth"))
Docs
uv run mkdocs build -f ./mkdocs.yml -d ./_build/
Update template
copier update --trust -A --vcs-ref=HEAD