Pytorch implementation of FusionCount in Google Colab

standfsk
4 min readJan 22, 2024

--

Guide to FusionCount

Intro

In this guide, we’ll talk about FusionCount and implement the model using Pytorch.

Why FusionCount?

FusionCount excels in providing precise crowd counts while maintaining efficiency, making it an ideal choice for various applications — from event planning to public safety.

FusionCount exploits the adaptive fusion of a large majority of encoded features instead of relying on additional extraction components to obtain multiscale features. Thus, it can cover a more extensive scope of receptive field sizes and lower the computational cost. We also introduce a new channel reduction block, which can extract saliency information during decoding and further enhance the model’s performance.

FusionCount paper review

Implementation

Preprocess

  • Image’s max size for width and height is set to 1920 keeping the ratio between width and height.
  • If there are no people to count, I have removed the data from training process.
  • Each image has annotation file named “.npy”
for image_path in image_paths:
name = f'{dataset}_{count:05d}.jpg'
image_save_path = os.path.join(save_dir, name)
label_path = image_path.replace('.jpg', '_ann.mat')
points = loadmat(label_path)['annPoints'].astype(np.float32)
if not points.any():
continue
image = Image.open(image_path).convert('RGB')
image, points = reform_data(image, points, max_size)
image.save(image_save_path)
label_save_path = image_save_path.replace('jpg', 'npy')
np.save(label_save_path, points)
count += 1

Model

from torch import Tensor, nn
import torch.nn.functional as F
from helpers import _initialize_weights, ConvNormActivation, FeatureFuser, ChannelReducer
from vgg import VGG

class FusionCount(nn.Module):
"""
The official PyTorch implementation of the model proposed in FusionCount: Efficient Crowd Counting via Multiscale Feature Fusion.
"""
def __init__(self, batch_norm: bool = True) -> None:
super(FusionCount, self).__init__()
if batch_norm:
self.encoder = VGG(name="vgg16_bn", pretrained=True, start_idx=2)
else:
self.encoder = VGG(name="vgg16", pretrained=True, start_idx=2)

self.fuser_1 = FeatureFuser([64, 128, 128], batch_norm=batch_norm)
self.fuser_2 = FeatureFuser([128, 256, 256, 256], batch_norm=batch_norm)
self.fuser_3 = FeatureFuser([256, 512, 512, 512], batch_norm=batch_norm)
self.fuser_4 = FeatureFuser([512, 512, 512, 512], batch_norm=batch_norm)

self.reducer_1 = ChannelReducer(in_channels=64, out_channels=32, dilation=2, batch_norm=batch_norm)
self.reducer_2 = ChannelReducer(in_channels=128, out_channels=64, dilation=2, batch_norm=batch_norm)
self.reducer_3 = ChannelReducer(in_channels=256, out_channels=128, dilation=2, batch_norm=batch_norm)
self.reducer_4 = ChannelReducer(in_channels=512, out_channels=256, dilation=2, batch_norm=batch_norm)

output_layer = ConvNormActivation(
in_channels=32,
out_channels=1,
kernel_size=1,
stride=1,
dilation=1,
norm_layer=None,
activation_layer=nn.ReLU(inplace=True)
)

self.output_layer = _initialize_weights(output_layer)
self.maxPool = nn.MaxPool2d(kernel_size=8)
self.density_layer = nn.Sequential(nn.Conv2d(128, 1, 1), nn.ReLU())

def forward(self, x: Tensor) -> Tensor:
feats = self.encoder(x)

feat_1, feat_2, feat_3, feat_4 = feats[0: 3], feats[3: 7], feats[7: 11], feats[11:]

feat_1 = self.fuser_1(feat_1)
feat_2 = self.fuser_2(feat_2)
feat_3 = self.fuser_3(feat_3)
feat_4 = self.fuser_4(feat_4)

feat_4 = self.reducer_4(feat_4)
feat_4 = F.interpolate(feat_4, size=feat_3.shape[-2:], mode="bilinear", align_corners=False)

feat_3 = feat_3 + feat_4
feat_3 = self.reducer_3(feat_3)
feat_3 = F.interpolate(feat_3, size=feat_2.shape[-2:], mode="bilinear", align_corners=False)

feat_2 = feat_2 + feat_3
feat_2 = self.reducer_2(feat_2)
feat_2 = F.interpolate(feat_2, size=feat_1.shape[-2:], mode="bilinear", align_corners=False)

feat_1 = feat_1 + feat_2
feat_1 = self.reducer_1(feat_1)
feat_1 = F.interpolate(feat_1, size=x.shape[-2:], mode="bilinear", align_corners=False)

output = self.output_layer(feat_1)
mu = self.maxPool(output)
B, C, H, W = mu.size()
mu_sum = mu.view([B, -1]).sum(1).unsqueeze(1).unsqueeze(2).unsqueeze(3)
mu_normed = mu / (mu_sum + 1e-6)
return mu, mu_normed

Train

def train_eopch(self):
epoch_ot_loss = AverageMeter()
epoch_ot_obj_value = AverageMeter()
epoch_wd = AverageMeter()
epoch_count_loss = AverageMeter()
epoch_tv_loss = AverageMeter()
epoch_loss = AverageMeter()
epoch_mae = AverageMeter()
epoch_mse = AverageMeter()
epoch_start = time.time()
self.model.train() # Set model to training mode

for step, (inputs, points, gt_discrete) in enumerate(self.dataloaders['train']):
inputs = inputs.to(self.device)
gd_count = np.array([len(p) for p in points], dtype=np.float32)
points = [p.to(self.device) for p in points]
gt_discrete = gt_discrete.to(self.device)
N = inputs.size(0)

with torch.set_grad_enabled(True):
outputs, outputs_normed = self.model(inputs)
# Compute OT loss.
ot_loss, wd, ot_obj_value = self.ot_loss(outputs_normed, outputs, points)
ot_loss = ot_loss * self.args.wot
ot_obj_value = ot_obj_value * self.args.wot
epoch_ot_loss.update(ot_loss.item(), N)
epoch_ot_obj_value.update(ot_obj_value.item(), N)
epoch_wd.update(wd, N)

# Compute counting loss.
count_loss = self.mae(outputs.sum(1).sum(1).sum(1),
torch.from_numpy(gd_count).float().to(self.device))
epoch_count_loss.update(count_loss.item(), N)

# Compute TV loss.
gd_count_tensor = torch.from_numpy(gd_count).float().to(self.device).unsqueeze(1).unsqueeze(
2).unsqueeze(3)
gt_discrete_normed = gt_discrete / (gd_count_tensor + 1e-6)
tv_loss = (self.tv_loss(outputs_normed, gt_discrete_normed).sum(1).sum(1).sum(
1) * torch.from_numpy(gd_count).float().to(self.device)).mean(0) * self.args.wtv
epoch_tv_loss.update(tv_loss.item(), N)

loss = ot_loss + count_loss + tv_loss

self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

pred_count = torch.sum(outputs.view(N, -1), dim=1).detach().cpu().numpy()
pred_err = pred_count - gd_count
epoch_loss.update(loss.item(), N)
epoch_mse.update(np.mean(pred_err * pred_err), N)
epoch_mae.update(np.mean(abs(pred_err)), N)

self.logger.info(
'Epoch {} Train, Loss: {:.2f}, OT Loss: {:.2e}, Wass Distance: {:.2f}, OT obj value: {:.2f}, '
'Count Loss: {:.2f}, TV Loss: {:.2f}, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'
.format(self.epoch, epoch_loss.get_avg(), epoch_ot_loss.get_avg(), epoch_wd.get_avg(),
epoch_ot_obj_value.get_avg(), epoch_count_loss.get_avg(), epoch_tv_loss.get_avg(),
np.sqrt(epoch_mse.get_avg()), epoch_mae.get_avg(),
time.time() - epoch_start))
model_state_dic = self.model.state_dict()
save_path = os.path.join(self.save_dir, '{}_ckpt.tar'.format(self.epoch))
torch.save({
'epoch': self.epoch,
'optimizer_state_dict': self.optimizer.state_dict(),
'model_state_dict': model_state_dic
}, save_path)
self.save_list.append(save_path)

If you would like to see full process, please refer to the link below.

Result

Reference

--

--