torchscalers: PyTorch-Native Feature Scaling

April 10, 2026

Motivation

Anyone who has trained a neural network on real-world data knows the drill: fit a scaler on the training split, pickle it alongside the model weights, and carefully remember to apply — and later invert — it at every boundary between raw data and model outputs.

Scikit-learn's scalers are great, but they live outside the PyTorch world. They don't know about devices, they aren't part of state_dict, and they can't participate in an nn.Sequential pipeline. The result is boilerplate that every project re-implements in slightly different ways.

torchscalers solves this by making scalers first-class nn.Module objects.

Design

Every scaler in torchscalers is a subclass of torch.nn.Module. Fitted statistics (mean, std, min, max, …) are stored as module buffers. This single design decision gives you three things for free:

  • Automatic checkpointing — buffers are included in model.state_dict(), so torch.save / torch.load captures scaler parameters along with model weights.
  • Device transparency — calling .to(device) on your model moves the scaler statistics alongside your parameters. No manual .to(device) calls on a separate scaler object.
  • nn.Sequential compatibilityscaler(x) is identical to scaler.transform(x), so scalers compose naturally with other modules.

Quick Start

import torch
from torchscalers import ZScoreScaler
 
scaler = ZScoreScaler()
 
X_train = torch.randn(500, 8)
X_test  = torch.randn(100, 8)
 
# Fit on training data, then scale
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled  = scaler(X_test)          # equivalent to scaler.transform(X_test)
 
# Recover the original scale
X_recovered = scaler.inverse_transform(X_test_scaled)

Install via uv:

uv add torchscalers

Embedding a Scaler in a Model

Because scalers are nn.Module subclasses, storing them as attributes of your model registers them as child modules. Checkpointing then captures scaler statistics automatically — no extra steps required.

import torch
import torch.nn as nn
from torchscalers import ZScoreScaler
 
class MyModel(nn.Module):
    def __init__(self, in_features: int, out_features: int) -> None:
        super().__init__()
        self.feature_scaler = ZScoreScaler()
        self.target_scaler  = ZScoreScaler()
        self.linear = nn.Linear(in_features, out_features)
 
    def forward(self, x):
        return self.linear(self.feature_scaler(x))
 
model = MyModel(8, 1)
 
# Fit scalers on the training split before the training loop
model.feature_scaler.fit(X_train)
model.target_scaler.fit(y_train)
 
# Save — scaler statistics are included in state_dict automatically
torch.save(model.state_dict(), "checkpoint.pt")
 
# Reload into a fresh model
fresh = MyModel(8, 1)
fresh.load_state_dict(torch.load("checkpoint.pt", weights_only=True))
 
# Inverse-transform predictions back to the original target scale
with torch.no_grad():
    pred_scaled = fresh(X_test)
    pred_orig   = fresh.target_scaler.inverse_transform(pred_scaled)

PyTorch Lightning Integration

In a Lightning workflow, scalers should be fitted inside DataModule.setup() on the training split only — fitting on the validation or test split would be data leakage. Pass the fitted instances to the LightningModule so they are registered as child modules and therefore included in every checkpoint.

import lightning as L
from torchscalers import ZScoreScaler
 
class MyDataModule(L.LightningDataModule):
    def __init__(self, X, y):
        super().__init__()
        self.X, self.y = X, y
        self.feature_scaler = ZScoreScaler()
        self.target_scaler  = ZScoreScaler()
 
    def setup(self, stage):
        if stage == "fit":
            n = int(len(self.X) * 0.8)
            self.feature_scaler.fit(self.X[:n])
            self.target_scaler.fit(self.y[:n])
 
class MyModel(L.LightningModule):
    def __init__(self, feature_scaler, target_scaler):
        super().__init__()
        self.feature_scaler = feature_scaler   # registered as child module
        self.target_scaler  = target_scaler
        self.net = nn.Linear(8, 1)
 
    def forward(self, x):
        return self.net(self.feature_scaler(x))
 
    def training_step(self, batch, batch_idx):
        x, y = batch
        loss = F.mse_loss(self(x), self.target_scaler(y))
        self.log("train_loss", loss)
        return loss
 
dm = MyDataModule(X_all, y_all)
dm.setup(stage="fit")
 
model = MyModel(
    feature_scaler=dm.feature_scaler,
    target_scaler=dm.target_scaler,
)
 
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, dm)
# Scaler statistics are saved in every checkpoint automatically

Available Scalers

Scaler Description
ZScoreScaler Standardise to zero mean and unit variance.
MinMaxScaler Scale to [0, 1] using per-feature min/max.
MaxAbsScaler Scale to [−1, 1] by dividing by the maximum absolute value.
RobustScaler Scale using median and IQR — robust to outliers.
ShiftScaleScaler Apply a user-specified (x + shift) * scale transformation.
LogScaler Apply a log transformation: log(x + eps).
PerDomainScaler Apply a separate scaler instance per string domain ID.
MixedDomainScaler Apply a different scaler type per string domain ID.

Links