每日Attention学习20——Group Shuffle Attention

news/2025/2/6 23:04:38 标签: 划水
模块出处

[MICCAI 24] [link] LB-UNet: A Lightweight Boundary-Assisted UNet for Skin Lesion Segmentation


模块名称

Group Shuffle Attention (GSA)


模块作用

轻量特征学习


模块结构

在这里插入图片描述


模块特点
  • 使用分组(Group)卷积降低计算量
  • 引入External Attention机制更好的学习特征
  • Shuffle操作促进不同Group之间信息的交互

模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x
        

class GSA(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        c_dim = dim_in // 4
        self.share_space1 = nn.Parameter(torch.Tensor(1, c_dim, 8, 8), requires_grad=True)
        nn.init.ones_(self.share_space1)
        self.conv1 = nn.Sequential(
            nn.Conv2d(c_dim, c_dim, kernel_size=3, padding=1, groups=c_dim),
            nn.GELU(),
            nn.Conv2d(c_dim, c_dim, 1)
        )
        self.share_space2 = nn.Parameter(torch.Tensor(1, c_dim, 8, 8), requires_grad=True)
        nn.init.ones_(self.share_space2)
        self.conv2 = nn.Sequential(
            nn.Conv2d(c_dim, c_dim, kernel_size=3, padding=1, groups=c_dim),
            nn.GELU(),
            nn.Conv2d(c_dim, c_dim, 1)
        )
        self.share_space3 = nn.Parameter(torch.Tensor(1, c_dim, 8, 8), requires_grad=True)
        nn.init.ones_(self.share_space3)
        self.conv3 = nn.Sequential(
            nn.Conv2d(c_dim, c_dim, kernel_size=3, padding=1, groups=c_dim),
            nn.GELU(),
            nn.Conv2d(c_dim, c_dim, 1)
        )
        self.share_space4 = nn.Parameter(torch.Tensor(1, c_dim, 8, 8), requires_grad=True)
        nn.init.ones_(self.share_space4)
        self.conv4 = nn.Sequential(
            nn.Conv2d(c_dim, c_dim, kernel_size=3, padding=1, groups=c_dim),
            nn.GELU(),
            nn.Conv2d(c_dim, c_dim, 1)
        )
        self.norm1 = LayerNorm(dim_in, eps=1e-6, data_format='channels_first')
        self.norm2 = LayerNorm(dim_in, eps=1e-6, data_format='channels_first')
        self.ldw = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, kernel_size=3, padding=1, groups=dim_in),
            nn.GELU(),
            nn.Conv2d(dim_in, dim_out, 1),
        )

    def forward(self, x):
        x = self.norm1(x)
        x1, x2, x3, x4 = torch.chunk(x, 4, dim=1)
        B, C, H, W = x1.size()
        x1 = x1 * self.conv1(F.interpolate(self.share_space1, size=x1.shape[2:4],mode='bilinear', align_corners=True))
        x2 = x2 * self.conv2(F.interpolate(self.share_space2, size=x1.shape[2:4],mode='bilinear', align_corners=True))
        x3 = x3 * self.conv3(F.interpolate(self.share_space3, size=x1.shape[2:4],mode='bilinear', align_corners=True))
        x4 = x4 * self.conv4(F.interpolate(self.share_space4, size=x1.shape[2:4],mode='bilinear', align_corners=True))
        x = torch.cat([x2,x4,x1,x3], dim=1)
        x = self.norm2(x)
        x = self.ldw(x)
        return x
    

if __name__ == '__main__':
    x = torch.randn([1, 64, 44, 44])
    gsa = GSA(dim_in=64, dim_out=64)
    out = gsa(x)
    print(out.shape)  # [1, 64, 44, 44]


http://www.niftyadmin.cn/n/5843379.html

相关文章

代码讲解系列-CV(二)——卷积神经网络

文章目录 一、系列大纲二、卷积神经网络(图像分类为例)2.1 pytorch简介训练框架张量自动微分动态计算图更深入学习 2.2 数据输入和增强Dataset—— torch.utils.data.DatasetDataLoader——torch.utils.data.Dataloader数据增强 2.3 CNN设计与训练nn.Mod…

Django框架丨从零开始的Django入门学习

Django 是一个用于构建 Web 应用程序的高级 Python Web 框架,Django是一个高度模块化的框架,使用 Django,只要很少的代码,Python 的程序开发人员就可以轻松地完成一个正式网站所需要的大部分内容,并进一步开发出全功能…

【kafka的零拷贝原理】

kafka的零拷贝原理 一、零拷贝技术概述二、Kafka中的零拷贝原理三、零拷贝技术的优势四、零拷贝技术的实现细节五、注意事项一、零拷贝技术概述 零拷贝(Zero-Copy)是一种减少数据拷贝次数,提高数据传输效率的技术。 在传统的数据传输过程中,数据需要在用户态和内核态之间…

edu小程序挖掘严重支付逻辑漏洞

edu小程序挖掘严重支付逻辑漏洞 一、敏感信息泄露 打开购电小程序 这里需要输入姓名和学号,直接搜索引擎搜索即可得到,这就不用多说了,但是这里的手机号可以任意输入,只要用户没有绑定手机号这里我们输入自己的手机号抓包直接进…

5-Scene层级关系

Fiber里有个scene是只读属性,能从fiber中获取它属于哪个场景,scene实体中又声明了fiber,fiber与scene是互相引用的关系。 scene层级关系 举例 在unity.core中的EntityHelper中,可以通过entity获取对应的scene root fiber等属性…

孟加拉国_行政边界省市边界arcgis数据shp格式wgs84坐标

这篇内容将深入探讨孟加拉国的行政边界省市边界数据,该数据是以arcgis的shp格式提供的,并采用WGS84坐标系统。ArcGIS是一款广泛应用于地理信息系统(GIS)的专业软件,它允许用户处理、分析和展示地理空间数据。在GIS领域…

PyTorch Geometric(PyG)机器学习实战

PyTorch Geometric(PyG)机器学习实战 在图神经网络(GNN)的研究和应用中,PyTorch Geometric(PyG)作为一个基于PyTorch的库,提供了高效的图数据处理和模型构建功能。 本文将通过一个节…

不同数据库与 WebGL 集成

一、引言 在当今数字化时代,数据可视化是一个关键需求,而 WebGL(Web Graphics Library)为在网页上实现高性能 3D 图形渲染提供了强大的工具。然而,WebGL 本身无法直接与数据库进行交互,为了将数据库中的数…