import torch import torch.nn as nn # Define 1D convolution with kernel size 5 def conv1d_5(inplanes, outplanes, stride=1): return nn.Conv1d(inplanes, outplanes, kernel_size=5, stride=stride, padding=2, bias=False) # Transformer-based self-attention module class TransformerLayer(nn.Module): def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1): super(TransformerLayer, self).__init__() self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True) self.norm1 = nn.LayerNorm(embed_dim) self.ffn = nn.Sequential( nn.Linear(embed_dim, ff_dim), nn.ReLU(), nn.Linear(ff_dim, embed_dim) ) self.norm2 = nn.LayerNorm(embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): x = x.permute(0, 2, 1) # [B, C, L] -> [B, L, C] attn_output, _ = self.attention(x, x, x) x = self.norm1(x + self.dropout(attn_output)) ffn_output = self.ffn(x) x = self.norm2(x + self.dropout(ffn_output)) x = x.permute(0, 2, 1) # [B, L, C] -> [B, C, L] return x # Downsampling Block using ResNet-style skip connections class Block(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, bn=False): super(Block, self).__init__() self.bn = bn self.conv1 = conv1d_5(inplanes, planes, stride) self.bn1 = nn.BatchNorm1d(planes) self.relu = nn.ReLU(inplace=False) self.conv2 = conv1d_5(planes, planes) self.bn2 = nn.BatchNorm1d(planes) self.downsample = downsample def forward(self, x): residual = x out = self.conv1(x) if self.bn: out = self.bn1(out) out = self.relu(out) out = self.conv2(out) if self.bn: out = self.bn2(out) out = self.relu(out) if self.downsample is not None: residual = self.downsample(x) out = out + residual out = self.relu(out) return out # Upsampling Block class Decoder_block(nn.Module): def __init__(self, inplanes, outplanes, kernel_size=5, stride=5): super(Decoder_block, self).__init__() self.upsample = nn.ConvTranspose1d(inplanes, outplanes, kernel_size=kernel_size, stride=stride, bias=False) self.conv1 = conv1d_5(inplanes, outplanes) self.relu = nn.ReLU(inplace=False) self.conv2 = conv1d_5(outplanes, outplanes) def forward(self, x1, x2): x1 = self.upsample(x1) out = torch.cat((x1, x2), dim=1) out = self.conv1(out) out = self.relu(out) out = self.conv2(out) out = self.relu(out) return out # Main Model class Model(nn.Module): def __init__(self, inplanes=1, outplanes=2, layers=[2, 2, 2, 2]): super(Model, self).__init__() self.inplanes = inplanes self.outplanes = outplanes self.encoder1 = self._make_encoder(Block, 32, layers[0], 5) self.encoder2 = self._make_encoder(Block, 64, layers[1], 5) self.encoder3 = self._make_encoder(Block, 128, layers[2], 5) self.encoder4 = self._make_encoder(Block, 256, layers[3], 4) # Self-Attention Layer between Encoder and Decoder self.self_attention = TransformerLayer(embed_dim=256, num_heads=8, ff_dim=512) self.decoder3 = Decoder_block(256, 128, stride=4, kernel_size=4) self.decoder2 = Decoder_block(128, 64) self.decoder1 = Decoder_block(64, 32) self.conv1x1 = nn.ConvTranspose1d(32, outplanes, kernel_size=5, stride=5, bias=False) def _make_encoder(self, block, planes, blocks, stride=1): downsample = None if self.inplanes != planes or stride != 1: downsample = nn.Conv1d(self.inplanes, planes, kernel_size=1, stride=stride, bias=False) layers = [block(self.inplanes, planes, stride, downsample)] self.inplanes = planes for _ in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def forward(self, x): down1 = self.encoder1(x) down2 = self.encoder2(down1) down3 = self.encoder3(down2) down4 = self.encoder4(down3) # Apply self-attention layer attention_out = self.self_attention(down4) up3 = self.decoder3(attention_out, down3) up2 = self.decoder2(up3, down2) up1 = self.decoder1(up2, down1) out = self.conv1x1(up1) return out # Test function to verify input-output compatibility if __name__ == "__main__": model = Model(inplanes=1, outplanes=1, layers=[3, 3, 3, 3]) model.eval() image = torch.randn(1, 1, 1000) with torch.no_grad(): output = model(image) print(output.size())