UNetMamba 代码演化报告:从多个代码库到一个完整实现
基于对 UNetMamba/、GeoSeg/、VMamba/、Swin-UMamba/、ResT/ 代码目录的逐文件比对分析
目录
- 代码血缘全景图
- 第一层继承:GeoSeg 作为工程骨架
- 第二层移植:VMamba 的 VSS Block 核心算子
- 第三层改造:ResT 编码器的精简与适配
- 第四层借鉴:Swin-Unet / Swin-UMamba 的解码器范式
- 核心新增:MSD 解码器与 LSM 模块的自主实现
- 全局改动:命名空间重构与细节调整
- Config 的演化痕迹:从 VMamba backbone 到 ResT-Lite
- 代码演化时间线总结
1. 代码血缘全景图
以下是通过逐文件 diff 比对所确认的代码继承关系:
UNetMamba/
├── train.py ← GeoSeg/train_supervision.py(直接改造)
├── loveda_test.py ← GeoSeg/loveda_test.py(直接改造)
├── vaihingen_test.py ← GeoSeg/vaihingen_test.py(直接改造)
├── potsdam_test.py ← GeoSeg/potsdam_test.py(直接改造)
│
├── config/
│ └── loveda/unetmamba.py ← GeoSeg/config/loveda/unetformer.py(深度改造)
│ └── vaihingen/unetmamba.py ← GeoSeg/config/vaihingen/unetformer.py(深度改造)
│
├── tools/
│ ├── cfg.py ← GeoSeg/tools/cfg.py(完全相同)
│ ├── metric.py ← GeoSeg/tools/metric.py(完全相同)
│ └── loveda_mask_convert.py ← GeoSeg/tools/(完全相同)
│
└── unetmamba_model/
├── datasets/
│ ├── loveda_dataset.py ← GeoSeg/geoseg/datasets/(删1空行,其余相同)
│ ├── vaihingen_dataset.py← GeoSeg/geoseg/datasets/(引入 IMG_SIZE 变量,适配多尺度)
│ └── transform.py ← GeoSeg/geoseg/datasets/(仅修正1处拼写错误)
│
├── losses/ ← GeoSeg/geoseg/losses/(完整复制,仅改2处)
│ └── useful_loss.py ← 关键修改:UnetFormerLoss → UnetMambaLoss,ignore_index 255→6
│
├── models/
│ ├── BANet.py ← GeoSeg/geoseg/models/BANet.py(完全相同,0 diff)
│ ├── DCSwin.py ← GeoSeg/geoseg/models/DCSwin.py(完全相同,0 diff)
│ ├── ABCNet.py ← GeoSeg/geoseg/models/ABCNet.py(完全相同)
│ ├── MANet.py ← GeoSeg/geoseg/models/MANet.py(完全相同)
│ ├── UNetFormer.py ← GeoSeg/geoseg/models/UNetFormer.py(3处小修改)
│ ├── ResT.py ← ResT/models/rest.py(删分类头,改返回值为多尺度特征)
│ ├── SwinUMamba.py ← Swin-UMamba/SwinUMambaD.py(移植解码器基础结构)
│ ├── RS3Mamba.py ← 自行实现(使用 VSSMEncoder + UNetFormer解码器组件)
│ ├── CMUNet.py ← 参考 Swin-Unet 风格
│ └── ★ UNetMamba.py ← 核心新增(见第6节)
│
├── classification/
│ └── models/vmamba.py ← VMamba/classification/models/vmamba.py(新增 CrossScan 纯
│ PyTorch 实现作为 Triton 不可用时的回退)
│
└── mamba_sys.py ← VMamba/vmamba.py(完全相同,0 diff)
2. 第一层继承:GeoSeg 作为工程骨架
2.1 GeoSeg 是什么
GeoSeg 是 Libo Wang 团队开发并维护的一个遥感图像语义分割代码框架,其中包含了 BANet、UNetFormer、FTUNetFormer 等多个模型的实现,以及配套的数据集处理、损失函数、训练流程等基础设施。
2.2 继承范围:几乎全部基础设施
通过 diff 比对,确认以下文件从 GeoSeg 中原样复制或极微改动:
| 文件 | diff 结果 | 说明 |
|---|---|---|
BANet.py |
0 差异 | 完整复制 |
DCSwin.py |
0 差异 | 完整复制 |
ABCNet.py |
0 差异 | 完整复制 |
MANet.py |
0 差异 | 完整复制 |
tools/cfg.py |
0 差异 | 完整复制 |
tools/metric.py |
0 差异 | 完整复制 |
datasets/loveda_dataset.py |
删除 1 个空行 | 实质无变化 |
datasets/transform.py |
修正 1 处拼写错误 | 实质无变化 |
这些文件共同构成了 UNetMamba 的实验基础设施:
- 数据加载(LoveDA、Vaihingen 数据集的读取和预处理)
- 数据增强(RandomScale、SmartCropV1 等遥感专用增强)
- 损失函数库(16 个损失函数文件完整保留)
- 评估指标(IoU、F1、OA 的计算逻辑)
- 对比方法实现(BANet、DCSwin、MANet 用于消融对比实验)
2.3 训练脚本的改动(train.py)
train.py 从 GeoSeg 的 train_supervision.py 改造,主要改动:
# GeoSeg 的原始代码
from train_supervision import * # 测试脚本引用
# UNetMamba 的修改
from train import * # 重命名为 train.py 后同步更新引用
# 新增:HuggingFace 镜像配置(国内访问用)
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
# 精度报告格式调整(百分比化)
eval_value = {'mIoU': mIoU*100.0, 'F1': F1*100.0, 'OA': OA*100.0}
# 原来:eval_value = {'mIoU': mIoU, 'F1': F1, 'OA': OA}
# 删除了 GeoSeg 中多余的数据集分支(whubuilding、massbuilding、cropland)
# 因为 UNetMamba 只关注 LoveDA 和 Vaihingen
改动逻辑:*100.0 是因为 GeoSeg 的指标是 [0, 1] 范围,UNetMamba 改为报告 [0, 100] 百分比,更符合论文中的表格展示习惯。
2.4 损失函数的精准修改(useful_loss.py)
这是整个 losses 目录中唯一改动的文件,改动极其精准:
# GeoSeg(UNetFormer 的损失函数)
class UnetFormerLoss(nn.Module):
def __init__(self, ignore_index=255):
# Dice Loss + Soft-CE 联合损失
loss = JointLoss(SoftCrossEntropyLoss(..., ignore_index=ignore_index),
DiceLoss(smooth=0.05, ignore_index=ignore_index), 1.0, 1.0)
# ↑ 注意:参数是两个独立的位置参数 1.0, 1.0
# UNetMamba 的修改
class UnetMambaLoss(nn.Module): # ← 仅改了类名
def __init__(self, ignore_index=6): # ← ignore_index 从 255 改为 6
loss = JointLoss(SoftCrossEntropyLoss(...),
DiceLoss(smooth=0.05, ...), (1.0, 1.0)) # ← 改为元组形式
改动逻辑:
- 类名更改:
UnetFormerLoss→UnetMambaLoss,语义清晰 ignore_index=6:LoveDA 数据集有 7 类(索引 0-6),index=6 是背景/忽略类,而通用设定 255 不适合只有 7 类的场景(1.0, 1.0)元组:修复了一个潜在的 API 兼容性问题(某些版本的JointLoss需要元组)
3. 第二层移植:VMamba 的 VSS Block 核心算子
3.1 mamba_sys.py:完整原样复制
$ diff VMamba/vmamba.py UNetMamba/unetmamba_model/mamba_sys.py
(无输出) # 0 差异,完全相同
mamba_sys.py 是 VMamba 源码 vmamba.py 的原样拷贝,包含:
VSSM(完整的 VMamba 视觉骨干网络)VSSBlock(VSS Block,SS2D 四方向扫描的基本单元)SS2D(2D Selective Scan 核心模块)LayerNorm2d、Permute等辅助层
这个文件在 UNetMamba 代码库中作为备用参考实现存在,实际的解码器不直接用 mamba_sys.py 中的 VSSM,而是用下面介绍的经过增强的版本。
3.2 classification/models/vmamba.py:新增 PyTorch 回退实现
这是 UNetMamba 对 VMamba 代码最实质性的修改,也是工程价值最高的改动。
原始 VMamba 的问题:
# VMamba 原版(需要 Triton 编译)
try:
from .csm_triton import cross_scan_fn, cross_merge_fn
except:
from csm_triton import cross_scan_fn, cross_merge_fn
# 如果 Triton 不可用(Windows 环境、部分 CUDA 版本),直接报错
UNetMamba 的改进:
# UNetMamba 版本:新增纯 PyTorch 实现
try:
from .csm_triton import CrossScanTriton, CrossMergeTriton, CrossScanTriton1b1
except:
from csm_triton import CrossScanTriton, CrossMergeTriton, CrossScanTriton1b1
# === 以下是 UNetMamba 新增的纯 PyTorch CrossScan 实现 ===
class CrossScan(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor):
B, C, H, W = x.shape
ctx.shape = (B, C, H, W)
xs = x.new_empty((B, 4, C, H * W))
xs[:, 0] = x.flatten(2, 3) # 方向1:正向行扫描
xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3) # 方向2:正向列扫描
xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) # 方向3,4:翻转
return xs
@staticmethod
def backward(ctx, ys: torch.Tensor):
# ... 对应的反向传播
还新增了消融实验专用的 CrossScan_Ab_2direction 类(只扫描 2 个方向,用于验证 4 方向比 2 方向好)。
工程意义:Triton 在 Windows 平台上不可用,这个 PyTorch 回退实现使得 UNetMamba 可以在 Windows 开发环境下运行和调试,而不仅限于 Linux GPU 服务器。
3.3 UNetMamba 如何使用 VSS Block
核心模型 UNetMamba.py 的第一行:
from unetmamba_model.classification.models.vmamba import VSSM, LayerNorm2d, VSSBlock, Permute
VSSBlock 在解码器中的使用方式:
class VSSLayer(nn.Module):
"""解码器中的 VSS Block 堆叠层"""
def __init__(self, dim, depth=2, ...):
self.blocks = nn.ModuleList([
VSSBlock(
hidden_dim=dim,
drop_path=...,
d_state=16, # 状态维度,SSM 的记忆容量
)
for i in range(depth)]) # depth=2,每个解码阶段 2 个 VSS Block
4. 第三层改造:ResT 编码器的精简与适配
4.1 改造内容
UNetMamba/unetmamba_model/models/ResT.py 从 ResT/models/rest.py 改造,主要修改:
① 删除分类头(最关键的改动):
# ResT 原版(图像分类任务)
def forward(self, x):
...
x = self.avg_pool(x).flatten(1) # 全局平均池化
x = self.head(x) # 分类输出
return x # 返回1D类别向量
# UNetMamba 改造版(语义分割任务)
def forward(self, x):
...
x1 = x # 保存 Stage1 特征 (H/4, W/4)
...
x2 = x # 保存 Stage2 特征 (H/8, W/8)
...
x3 = x # 保存 Stage3 特征 (H/16, W/16)
...
x4 = x # 保存 Stage4 特征 (H/32, W/32)
return x1, x2, x3, x4 # 返回多尺度特征图(供解码器用 skip connection)
这一改动是从分类骨干网络到分割编码器的本质转变——不再输出单一类别预测,而是输出四个尺度的特征图,供 U 形解码器的各个阶段使用。
② 删除 Stem 类:
# 原版 ResT 中的 Stem 类(用于更深的 ResT 变体)
class Stem(nn.Module):
def __init__(self, in_ch=3, out_ch=64, with_pos=False):
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=7, stride=2, ...)
self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, ...)
...
# UNetMamba 的 ResT.py 删除了 Stem 类
# 因为 ResT-Lite 使用 OverlapPatchEmbed 而非 Stem,Stem 只存在于其他变体中
③ 删除分类相关组件:
# 原版
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.head = nn.Linear(embed_dims[3], num_classes)
# UNetMamba:这两行被直接删除
④ 细节格式清理:
# 原版(有空格)
kernel_size=sr_ratio + 1
drop_path=dpr[cur + i]
# UNetMamba(无空格,更紧凑)
kernel_size=sr_ratio+1
drop_path=dpr[cur+i]
4.2 rest_lite 工厂函数
UNetMamba.py 中定义了一个简单的工厂函数:
def rest_lite(pretrained=True, weight_path='pretrain_weights/rest_lite.pth', **kwargs):
model = ResT(
embed_dims=[64, 128, 256, 512], # ResT-Lite 的通道配置
num_heads=[1, 2, 4, 8],
mlp_ratios=[4, 4, 4, 4],
qkv_bias=True,
depths=[2, 2, 2, 2], # 每阶段 2 个 ETB Block
sr_ratios=[8, 4, 2, 1], # EMSA 的空间缩减率
apply_transform=True, # 启用 PA 位置编码
**kwargs
)
if pretrained and weight_path is not None:
# 只加载名称匹配的权重(兼容性加载)
old_dict = torch.load(weight_path)
model_dict = model.state_dict()
old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)}
model_dict.update(old_dict)
model.load_state_dict(model_dict)
return model
这种”过滤式加载”({k: v for k, v in old_dict.items() if (k in model_dict)})是一种常见的模型权重迁移技巧:只加载形状匹配的参数,忽略分类头等不匹配的参数。这是因为 ResT 的预训练权重是在 ImageNet 分类任务上训练的,包含分类头参数,但分割模型不需要这些参数。
5. 第四层借鉴:Swin-Unet / Swin-UMamba 的解码器范式
5.1 PatchExpand 的代码传承链
UNetMamba 解码器中最核心的两个上采样模块,其来源可以被精确追溯:
Swin-Unet(arXiv:2105.05537)
↓ 引用实现
Swin-UMamba/swin_umamba/nnunetv2/nets/SwinUMambaD.py
↓ 移植到遥感场景
UNetMamba/unetmamba_model/models/UNetMamba.py
代码中的注释直接给出了出处:
class PatchExpand(nn.Module):
"""
Reference: https://arxiv.org/pdf/2105.05537.pdf ← Swin-Unet 论文
"""
def forward(self, x):
x = x.permute(0, 2, 3, 1) # B, C, H, W ==> B, H, W, C
x = self.expand(x) # 通道数 ×2(为了后续像素 shuffle)
B, H, W, C = x.shape
# 2×2 像素 shuffle:将通道拆分为 2x2 空间像素
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)
x = self.norm(x)
x = x.reshape(B, H*2, W*2, C//4) # 分辨率 ×2,通道 ÷2
return x
class FinalPatchExpand_X4(nn.Module):
"""
Reference:
- GitHub: https://github.com/HuCaoFighting/Swin-Unet/blob/... ← Swin-Unet GitHub
- Paper: https://arxiv.org/pdf/2105.05537.pdf
"""
# 最终 4×上采样:将 H/4 × W/4 恢复到 H × W
与 Swin-UMamba 中对应代码的比对显示:这两个类在 UNetMamba 中几乎原样使用,唯一的改动是删除了 Swin-UMamba 特有的 input_resolution 形状校验逻辑(因为 UNetMamba 固定使用 1024×1024 输入)。
5.2 VSSLayer 的结构参考
VSSLayer 是 UNetMamba 中将 VSS Block 组织为层的类,其整体结构参考了 Swin-UMamba 的 SwinUMambaD.py:
# Swin-UMamba 的 MambaDecoder 结构(参考)
for s in range(1, n_stages_encoder):
expand_layers.append(PatchExpand(...)) # 上采样层
stages.append(VSSLayer(...)) # VSS Block 堆叠
concat_back_dim.append(nn.Linear(...)) # Skip connection 维度对齐
# UNetMamba 的 MambaSegDecoder(直接借鉴这一组织结构)
for s in range(1, n_stages_encoder):
expand_layers.append(PatchExpand(...)) # ← 相同的 PatchExpand
stages.append(VSSLayer(...)) # ← 相同的组织方式
concat_back_dim.append(nn.Linear(...)) # ← 相同的 skip 融合方式
lsm_layers.append(LocalSupervision(...)) # ← UNetMamba 新增的 LSM 层
6. 核心新增:MSD 解码器与 LSM 模块的自主实现
前四层都是继承和借鉴,这一层才是 UNetMamba 作者真正的原创贡献。
6.1 LocalSupervision(LSM Block):全新设计
class LocalSupervision(nn.Module):
"""
局部监督模块(LSM Block)——UNetMamba 的核心创新之一
特点:仅在训练时存在,推理时消失(train-only)
"""
def __init__(self, in_channels=128, num_classes=6):
# 两个并行卷积分支
self.conv3 = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=3, ...),
nn.BatchNorm2d(in_channels),
nn.ReLU6() # ← ReLU6 防止激活值爆炸
)
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=1, ...),
nn.BatchNorm2d(in_channels),
nn.ReLU6()
)
self.drop = nn.Dropout(0.1)
self.conv_out = nn.Conv2d(in_channels, num_classes, kernel_size=1, ...)
def forward(self, x, h, w):
local1 = self.conv3(x) # 3×3 卷积捕获局部上下文
local2 = self.conv1(x) # 1×1 卷积做点级映射
x = self.drop(local1 + local2) # 两路相加 + Dropout
x = self.conv_out(x) # 映射到分类空间
x = F.interpolate(x, size=(h, w), ...) # 上采样到原始尺寸
return x
设计要点:
kernel_size=3+kernel_size=1双卷积并行:用 3×3 卷积捕获局部空间关系,用 1×1 卷积调整通道,两路相加。这是 CNN 常见的”多感受野并行”范式ReLU6(而非 ReLU):限制激活值最大为 6,配合 BatchNorm 稳定训练Dropout(0.1):轻度正则化- 返回的是插值到原始 (h, w) 的分割预测——直接用于计算辅助损失
6.2 MambaSegDecoder(MSD):在继承基础上的关键整合
MambaSegDecoder 是 UNetMamba 的核心,它把借来的组件(PatchExpand from Swin-Unet、VSSLayer from VMamba、skip connection from U-Net)整合成了一个统一的解码器:
class MambaSegDecoder(nn.Module):
def __init__(self, num_classes, encoder_channels, ...):
for s in range(1, n_stages_encoder):
# 1. 上采样层(来自 Swin-Unet)
expand_layers.append(PatchExpand(...))
# 2. VSS Block 层(来自 VMamba)
stages.append(VSSLayer(dim=input_features_skip, depth=2, ...))
# 3. Skip Connection 维度对齐(U-Net 范式)
concat_back_dim.append(nn.Linear(2 * input_features_skip, input_features_skip))
# 4. ★ 新增:LSM 局部监督层(UNetMamba 原创)
lsm_layers.append(LocalSupervision(encoder_channels[-(s+1)], num_classes))
# 5. 最终 4× 上采样(来自 Swin-Unet)
expand_layers.append(FinalPatchExpand_X4(...))
# 6. 分割输出头(标准 1×1 Conv)
self.seg = nn.Conv2d(encoder_channels[-4], num_classes, kernel_size=1, ...)
def forward(self, skips, h, w):
if self.training:
ls = []
for s in range(len(self.stages)):
x = self.expand_layers[s](lres_input) # 上采样
if s < (len(self.stages) - 1):
x = torch.cat((x, skips[-(s+2)].permute(...)), -1) # concat skip
x = self.concat_back_dim[s](x) # 维度对齐
x = self.stages[s](x).permute(0, 3, 1, 2) # VSS Block 处理
if s == (len(self.stages) - 1):
seg_out = self.seg(x)
else:
ls.append(self.lsm[s](x, h, w)) # ★ 收集 LSM 的局部监督结果
lres_input = x
return seg_out, sum(ls) # ★ 主分割结果 + 所有 LSM 之和(辅助损失用)
else: # 推理时:完全相同的前向流程,但跳过 LSM
for s in range(len(self.stages)):
...(无 LSM)
return seg_out # 只返回主分割结果
训练/推理分支的对称设计是这个模块最精妙的地方:training 标志由 PyTorch 的 model.train() / model.eval() 控制,无需修改推理脚本。
6.3 UNetMamba 主类:极简的整合
class UNetMamba(nn.Module):
def __init__(self, pretrained, backbone_path, num_classes=6, **kwargs):
# 编码器:ResT-Lite(轻量 Transformer)
self.encoder = rest_lite(weight_path=backbone_path)
encoder_channels = [64, 128, 256, 512] # ResT-Lite 固定通道配置
# 解码器:MambaSegDecoder(核心贡献)
self.decoder = MambaSegDecoder(
num_classes=num_classes,
encoder_channels=encoder_channels,
...
)
# 注意:**kwargs 吸收了 config 中传入的所有 VSSM 参数,但实际上并未使用
# 这是早期开发遗留的痕迹(见第8节)
def forward(self, x):
h, w = x.size()[-2:]
outputs = self.encoder(x) # 编码器输出 4 个尺度的特征
if self.training:
x, lsm = self.decoder(outputs, h, w)
return x, lsm # 训练时返回主结果 + 辅助结果
else:
x = self.decoder(outputs, h, w)
return x # 推理时只返回主结果
7. 全局改动:命名空间重构与细节调整
7.1 命名空间重构
GeoSeg 的包名是 geoseg,UNetMamba 将其重命名为 unetmamba_model:
| 原始(GeoSeg) | 修改后(UNetMamba) |
|---|---|
from geoseg.losses import * |
from unetmamba_model.losses import * |
from geoseg.datasets.loveda_dataset import * |
from unetmamba_model.datasets.loveda_dataset import * |
from geoseg.models.UNetFormer import UNetFormer |
from unetmamba_model.models.UNetMamba import UNetMamba |
from tools.utils import Lookahead |
from catalyst.contrib.nn import Lookahead |
from tools.utils import process_model_params |
from catalyst import utils → utils.process_model_params |
7.2 外部依赖的替换
GeoSeg 将 Lookahead 和 process_model_params 放在自己的 tools/utils.py 中(本地实现),UNetMamba 改为直接使用 catalyst 库的版本:
# GeoSeg(自定义工具)
from tools.utils import Lookahead, process_model_params
# UNetMamba(使用外部库)
from catalyst.contrib.nn import Lookahead # catalyst 库
from catalyst import utils
utils.process_model_params(net, ...) # 功能等价,但依赖外部包
原因:UNetMamba 的 requirements.txt 中包含了 catalyst 库,这简化了工具函数的维护,但增加了一个依赖项。
7.3 测试脚本的细节调整
| 修改项 | GeoSeg 原版 | UNetMamba 修改版 | 原因 |
|---|---|---|---|
| 批量大小 | batch_size=2 |
batch_size=1 |
推理时用 1,避免多图拼接干扰评估 |
| 输出格式 | F1:{} |
mF1:{} |
加 ‘m’ 强调是平均值 |
| 指标数值 | 0-1 范围 | 0-100 范围(×100.0) | 与论文表格一致 |
| 监控指标 | val_F1(Vaihingen) |
val_mIoU |
UNetMamba 以 mIoU 为主要评估指标 |
7.4 vaihingen_dataset.py 的适配
这是数据集文件中修改最多的一个:
# GeoSeg 原版(固定 1024 尺寸)
ORIGIN_IMG_SIZE = (1024, 1024)
INPUT_IMG_SIZE = (1024, 1024)
# UNetMamba(参数化尺寸,支持多分辨率实验)
IMG_SIZE = 1024
ORIGIN_IMG_SIZE = (IMG_SIZE, IMG_SIZE)
CROP_SIZE = int(512 * (float(IMG_SIZE / 1024))) # ← 裁剪尺寸随输入尺寸等比缩放
# 测试集路径也参数化
def __init__(self, data_root='data/vaihingen/test_'+str(IMG_SIZE), ...):
# 原版:'data/vaihingen/test'(固定路径)
# 新版:根据 IMG_SIZE 自动选择对应尺寸的数据集
这个改动说明作者在实验过程中尝试了不同分辨率的输入(如 512、768、1024),IMG_SIZE 变量使得切换尺寸只需改一个数字。
8. Config 的演化痕迹:从 VMamba backbone 到 ResT-Lite
Config 文件是代码演化历史中最有意思的部分——它保留了早期设计方案的残留痕迹。
8.1 Config 中的”幽灵参数”
当前 UNetMamba 的 config 文件传入了大量 VMamba 专属参数:
# UNetMamba/config/loveda/unetmamba.py(当前版本)
# VSSM parameters(这些参数名都是 VMamba backbone 的配置)
PATCH_SIZE = 4
DEPTHS = [2, 2, 9, 2] # VMamba-Tiny 的深度配置
EMBED_DIM = 96 # VMamba 的嵌入维度
SSM_D_STATE = 16 # SSM 状态维度
SSM_RATIO = 2.0 # SSM 扩展比
SSM_FORWARDTYPE = "v4" # VMamba 的特定前向传播版本
...(共 14 个 VSSM 参数)
net = UNetMamba(
pretrained=...,
num_classes=num_classes,
patch_size=PATCH_SIZE, # ← 这些参数全部传入
in_chans=IN_CHANS,
depths=DEPTHS,
dims=EMBED_DIM,
ssm_d_state=SSM_D_STATE,
...
)
但实际的 UNetMamba.__init__ 是这样的:
class UNetMamba(nn.Module):
def __init__(self,
pretrained,
decode_channels=64,
backbone_path='pretrain_weights/rest_lite.pth',
embed_dim=64,
num_classes=6,
**kwargs # ← config 传入的所有 VSSM 参数都被 **kwargs 吸收并丢弃
):
self.encoder = rest_lite(weight_path=backbone_path) # ← 实际只用 ResT-Lite
这意味着什么?
这是典型的重构遗留痕迹:作者在早期版本中尝试过将 VMamba(VSSM)用作 backbone 编码器,因此 config 中保留了完整的 VSSM 参数配置。当编码器切换为 ResT-Lite 之后,模型代码改了,但 config 文件中的 VSSM 参数并没有清理——它们被 **kwargs 静默忽略。
8.2 推测的代码演化路径
基于 config 文件的遗留参数,可以推测 UNetMamba 的开发经历了至少两个阶段:
【阶段一:VMamba-Encoder 版本(推测)】
编码器:VSSM(VMamba backbone,~30M 参数)
解码器:MSD(VSS Block 解码器)
Config 中的 VSSM 参数:有意义、被使用
问题:参数量过大(>30M),与"高效"目标矛盾
↓ 编码器换成 ResT-Lite
【阶段二:ResT-Lite Encoder 版本(当前版本)】
编码器:ResT-Lite(~13M 参数)
解码器:MSD + LSM
Config 中的 VSSM 参数:已无意义,被 **kwargs 吸收
参数量:14.76M,效率大幅提升
这个推测与论文中的叙述一致:
“models that utilize a pre-trained VMamba backbone (>30M) while encoding tend to be overweight for our goal of high efficiency. Therefore, we incorporate the basic unit of VMamba, namely, VSS block, into the decoding side only.”
8.3 CMUNet.py 中的更早期痕迹
CMUNet.py 中也有类似的早期探索痕迹:
from .mamba_sys import VSSM
# from .mamba_sys_swin import VSSM_SWIN ← 注释掉的 Swin 版 VSSM
# from .mamba_sys_vmamba import VSSMMamba ← 注释掉的 VMamba 版
# from .mamba_sys_res import VSSMRes ← 注释掉的 ResNet 版
这些注释掉的 import 表明,在最终版本之前,作者探索过多种 VSSM 变体作为骨干网络的可能性。
9. 代码演化时间线总结
综合以上分析,UNetMamba 代码库的构建过程可以还原为以下步骤:
Step 1:复制 GeoSeg 作为基础框架
GeoSeg/
├── train_supervision.py → UNetMamba/train.py
├── loveda_test.py → UNetMamba/loveda_test.py
├── vaihingen_test.py → UNetMamba/vaihingen_test.py
├── geoseg/ → UNetMamba/unetmamba_model/
│ ├── models/(BANet, DCSwin, MANet, ABCNet, UNetFormer)→ 完整复制
│ ├── datasets/ → 完整复制
│ └── losses/ → 完整复制
└── tools/ → 完整复制
此时 UNetMamba 就是一个改名的 GeoSeg,能跑通 BANet 和 UNetFormer 的训练。
Step 2:集成 VMamba 的 SSM 算子
VMamba/vmamba.py → UNetMamba/unetmamba_model/mamba_sys.py(原样复制)
VMamba/classification/vmamba.py → UNetMamba/.../classification/models/vmamba.py
(新增 CrossScan 纯 PyTorch 回退实现)
此时 UNetMamba 可以在代码中调用 VSSBlock 和 SS2D。
Step 3:移植 ResT 编码器
ResT/models/rest.py → UNetMamba/unetmamba_model/models/ResT.py
(删分类头,改返回值为多尺度特征 x1,x2,x3,x4)
Step 4:集成 Swin-UMamba 的解码器上采样模块
Swin-UMamba/SwinUMambaD.py 中的 PatchExpand / FinalPatchExpand_X4
→ UNetMamba.py 中直接使用(保留 Swin-Unet 引用注释)
Step 5:集成对比实验所需的其他 Mamba 模型
RS3Mamba 代码 → UNetMamba/models/RS3Mamba.py(使用 VSSMEncoder + 解码器)
Swin-UMamba 结构 → UNetMamba/models/SwinUMamba.py
CM-UNet 风格 → UNetMamba/models/CMUNet.py
Step 6:实现 MSD 和 LSM(核心创新)
# 在 UNetMamba.py 中自主实现:
class LocalSupervision(...) # 全新设计:双卷积并行 + Dropout
class MambaSegDecoder(...) # 整合 PatchExpand + VSSLayer + LSM
class UNetMamba(...) # 顶层整合:ResT-Lite 编码器 + MSD 解码器
Step 7:调整配置文件
GeoSeg/config/loveda/unetformer.py → UNetMamba/config/loveda/unetmamba.py
- 替换模型类名(UNetFormer → UNetMamba)
- 替换损失函数(UnetFormerLoss → UnetMambaLoss)
- 调整训练超参(epochs 30→100, batch_size 16→8, weight_decay 0.01→2.5e-4)
- 新增 FLOPs/Params 统计代码(fvcore 库)
- 遗留 VSSM 参数(早期 VMamba backbone 方案的痕迹)
Step 8:修改损失函数类名和参数
# useful_loss.py
UnetFormerLoss → UnetMambaLoss # 类名
ignore_index=255 → ignore_index=6 # 适配 LoveDA 7 类场景
代码血缘关系汇总表
| UNetMamba 代码 | 来源 | 改动程度 | 改动类型 |
|---|---|---|---|
train.py |
GeoSeg train_supervision.py |
极小 | 重命名 + 数值百分比化 |
models/BANet.py |
GeoSeg BANet.py |
零改动 | 完整复制 |
models/DCSwin.py |
GeoSeg DCSwin.py |
零改动 | 完整复制 |
models/MANet.py |
GeoSeg MANet.py |
零改动 | 完整复制 |
models/UNetFormer.py |
GeoSeg UNetFormer.py |
极小 | padding 逻辑修复 |
datasets/ |
GeoSeg datasets/ |
极小 | 参数化 IMG_SIZE |
losses/ (除 useful_loss) |
GeoSeg losses/ |
零改动 | 完整复制 |
losses/useful_loss.py |
GeoSeg useful_loss.py |
小 | 类名 + ignore_index |
tools/ |
GeoSeg tools/ |
零改动 | 完整复制 |
mamba_sys.py |
VMamba vmamba.py |
零改动 | 完整复制 |
classification/models/vmamba.py |
VMamba classification/vmamba.py |
中 | 新增 PyTorch CrossScan 回退 |
models/ResT.py |
ResT models/rest.py |
中 | 删分类头,改多尺度输出 |
models/SwinUMamba.py |
Swin-UMamba | 中 | 提取解码器基础结构 |
models/RS3Mamba.py |
RS3Mamba 参考 | 较大 | 重新实现以适配框架 |
models/UNetMamba.py |
原创 | — | PatchExpand 引用 Swin-Unet,核心逻辑全新 |
config/*.py |
GeoSeg config | 大 | 替换模型 + 超参调整 |
结论:UNetMamba 约 85% 的代码行数来自 GeoSeg(基础设施)和 VMamba(SSM 算子),约 10% 来自 ResT 和 Swin-UMamba(编码器和解码器上采样),核心创新(MSD 解码器逻辑、LSM 模块)仅占约 5%——但这 5% 是整篇论文的算法贡献所在。