PyTorch Native¶
This page documents the official RankSEG integration path for standard PyTorch semantic-segmentation inference.
If you already have a PyTorch semantic-segmentation model and want to replace
argmax with a better inference-time post-processing step, this is the
recommended starting point.
Where RankSEG fits¶
Most PyTorch segmentation code has three stages:
images -> model(images) -> logits -> argmax/threshold -> masks
RankSEG keeps the model and logits unchanged. The only change is the final prediction step:
images -> model(images) -> logits -> probabilities -> RankSEG -> masks
The default official configuration is:
rankseg = RankSEG(metric="dice", solver="RMA", output_mode="multiclass")
This is the most practical setup for standard semantic segmentation and should be the first configuration users try unless they already know they need a different metric or output mode.
Minimal integration¶
If your model already returns a logits tensor, the integration is only a few lines:
import torch
from rankseg import RankSEG
model.eval()
with torch.inference_mode():
logits = model(images) # (batch_size, num_classes, *image_shape)
probs = torch.softmax(logits, dim=1)
rankseg = RankSEG(metric="dice", solver="RMA", output_mode="multiclass")
preds = rankseg.predict(probs) # (batch_size, *image_shape)
Usual PyTorch inference vs RankSEG inference¶
This example is self-contained and uses a tiny convolutional model only to show
the tensor contract. Replace model with your trained network.
import torch
from rankseg import RankSEG
model = torch.nn.Conv2d(in_channels=3, out_channels=21, kernel_size=1)
images = torch.randn(2, 3, 256, 256)
model.eval()
with torch.inference_mode():
logits = model(images)
probs = torch.softmax(logits, dim=1)
baseline_preds = probs.argmax(dim=1)
rankseg = RankSEG(metric="dice", solver="RMA", output_mode="multiclass")
rankseg_preds = rankseg.predict(probs)
print("logits: ", tuple(logits.shape)) # (2, 21, 256, 256)
print("baseline:", tuple(baseline_preds.shape)) # (2, 256, 256)
print("RankSEG: ", tuple(rankseg_preds.shape)) # (2, 256, 256)
Choosing options¶
Option |
Recommended starting value |
When to change it |
|---|---|---|
|
|
Use |
|
|
Start here for Dice/IoU. Use other solvers only when you need their specific trade-offs. |
|
|
Use |
Activation before RankSEG |
|
Use |
Notebook and script¶
For a notebook-style walkthrough with a pretrained model, see:
The maintained script version is:
For a more complete explanation of:
probability conventions
softmaxvssigmoidmulticlassvsmultilabelinput / output tensor shapes
common integration mistakes
see Getting Started.