Motivation
Anyone who has trained a neural network on real-world data knows the drill: fit a scaler on the training split, pickle it alongside the model weights, and carefully remember to apply — and later invert — it at every boundary between raw data and model outputs.
Scikit-learn's scalers are great, but they live outside the PyTorch world. They don't know about devices, they aren't part of state_dict, and they can't participate in an nn.Sequential pipeline. The result is boilerplate that every project re-implements in slightly different ways.
torchscalers solves this by making scalers first-class nn.Module objects.
Design
Every scaler in torchscalers is a subclass of torch.nn.Module. Fitted statistics (mean, std, min, max, …) are stored as module buffers. This single design decision gives you three things for free:
- Automatic checkpointing — buffers are included in
model.state_dict(), sotorch.save/torch.loadcaptures scaler parameters along with model weights. - Device transparency — calling
.to(device)on your model moves the scaler statistics alongside your parameters. No manual.to(device)calls on a separate scaler object. nn.Sequentialcompatibility —scaler(x)is identical toscaler.transform(x), so scalers compose naturally with other modules.
Quick Start
import torch
from torchscalers import ZScoreScaler
scaler = ZScoreScaler()
X_train = torch.randn(500, 8)
X_test = torch.randn(100, 8)
# Fit on training data, then scale
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler(X_test) # equivalent to scaler.transform(X_test)
# Recover the original scale
X_recovered = scaler.inverse_transform(X_test_scaled)Install via uv:
uv add torchscalersEmbedding a Scaler in a Model
Because scalers are nn.Module subclasses, storing them as attributes of your model registers them as child modules. Checkpointing then captures scaler statistics automatically — no extra steps required.
import torch
import torch.nn as nn
from torchscalers import ZScoreScaler
class MyModel(nn.Module):
def __init__(self, in_features: int, out_features: int) -> None:
super().__init__()
self.feature_scaler = ZScoreScaler()
self.target_scaler = ZScoreScaler()
self.linear = nn.Linear(in_features, out_features)
def forward(self, x):
return self.linear(self.feature_scaler(x))
model = MyModel(8, 1)
# Fit scalers on the training split before the training loop
model.feature_scaler.fit(X_train)
model.target_scaler.fit(y_train)
# Save — scaler statistics are included in state_dict automatically
torch.save(model.state_dict(), "checkpoint.pt")
# Reload into a fresh model
fresh = MyModel(8, 1)
fresh.load_state_dict(torch.load("checkpoint.pt", weights_only=True))
# Inverse-transform predictions back to the original target scale
with torch.no_grad():
pred_scaled = fresh(X_test)
pred_orig = fresh.target_scaler.inverse_transform(pred_scaled)PyTorch Lightning Integration
In a Lightning workflow, scalers should be fitted inside DataModule.setup() on the training split only — fitting on the validation or test split would be data leakage. Pass the fitted instances to the LightningModule so they are registered as child modules and therefore included in every checkpoint.
import lightning as L
from torchscalers import ZScoreScaler
class MyDataModule(L.LightningDataModule):
def __init__(self, X, y):
super().__init__()
self.X, self.y = X, y
self.feature_scaler = ZScoreScaler()
self.target_scaler = ZScoreScaler()
def setup(self, stage):
if stage == "fit":
n = int(len(self.X) * 0.8)
self.feature_scaler.fit(self.X[:n])
self.target_scaler.fit(self.y[:n])
class MyModel(L.LightningModule):
def __init__(self, feature_scaler, target_scaler):
super().__init__()
self.feature_scaler = feature_scaler # registered as child module
self.target_scaler = target_scaler
self.net = nn.Linear(8, 1)
def forward(self, x):
return self.net(self.feature_scaler(x))
def training_step(self, batch, batch_idx):
x, y = batch
loss = F.mse_loss(self(x), self.target_scaler(y))
self.log("train_loss", loss)
return loss
dm = MyDataModule(X_all, y_all)
dm.setup(stage="fit")
model = MyModel(
feature_scaler=dm.feature_scaler,
target_scaler=dm.target_scaler,
)
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, dm)
# Scaler statistics are saved in every checkpoint automaticallyAvailable Scalers
| Scaler | Description |
|---|---|
ZScoreScaler |
Standardise to zero mean and unit variance. |
MinMaxScaler |
Scale to [0, 1] using per-feature min/max. |
MaxAbsScaler |
Scale to [−1, 1] by dividing by the maximum absolute value. |
RobustScaler |
Scale using median and IQR — robust to outliers. |
ShiftScaleScaler |
Apply a user-specified (x + shift) * scale transformation. |
LogScaler |
Apply a log transformation: log(x + eps). |
PerDomainScaler |
Apply a separate scaler instance per string domain ID. |
MixedDomainScaler |
Apply a different scaler type per string domain ID. |
Links
- GitHub: github.com/pauknerd/torchscalers
- Documentation: pauknerd.github.io/torchscalers