🐍

Implementation of Mamba in PyTorch

2024/07/08に公開

This time, I introduce the implementation of Mamba. Rather than implementing mamba from scratch, we consider implementing mamba as a pytorch model.

1. Model

・Official Implementation

import torch
from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
print(y)

# output
# tensor([[[ 0.0046,  0.0108, -0.0157,  ..., -0.0179, -0.0226,  0.0497],
#          [ 0.0063,  0.0288, -0.0023,  ...,  0.0114,  0.0188,  0.0328],
#          [ 0.0366, -0.0113, -0.1003,  ...,  0.0279, -0.0380,  0.0105],
#          ...,
#          [-0.0461,  0.0007,  0.0366,  ...,  0.0532, -0.0284, -0.0347],
#          [-0.0073,  0.0311, -0.0177,  ...,  0.0332,  0.0037,  0.0007],
#          [-0.0008, -0.0456,  0.0199,  ..., -0.0099, -0.0158, -0.0012]],

#         [[ 0.0254,  0.0037, -0.0638,  ...,  0.0042,  0.0123, -0.0212],
#          [-0.0255, -0.0134, -0.0106,  ..., -0.0006, -0.0259, -0.0111],
#          [-0.0008,  0.0295,  0.0479,  ..., -0.0059,  0.0114, -0.0043],
#          ...,
#          [ 0.0026, -0.0040, -0.0070,  ..., -0.0498, -0.0085,  0.0358],
#          [ 0.0534, -0.0341, -0.0274,  ..., -0.0974, -0.0472,  0.0185],
#          [-0.0236, -0.0280,  0.0078,  ...,  0.0251,  0.0085, -0.0092]]],
#        device='cuda:0', grad_fn=<UnsafeViewBackward0>)

This is an official implementation. we can use it as is, but it's not easy to use.
So, I made it as pytorch model.

2. Mamba with peripheral things

import torch
import torch.nn as nn
from mamba_ssm import Mamba

class MambaModel(nn.Module):
    def __init__(self, 
                 dim_model=384, # Model dimension d_model (embedding size)
                 d_state=64, # SSM state expansion factor
                 d_conv=8, # Local convolution width
                 expand=4, # Block expansion factor
                 output=3, # number of classes (or output number simply)
                 dropout_rate=0.1, # Dropout rate
                 is_test=False,
                ):
        super().__init__()
        self.model = Mamba(
            d_model=dim_model,  
            d_state=d_state,  
            d_conv=d_conv,    
            expand=expand,    
        ).to("cuda")
        self.output = nn.Linear(dim_model, output)
        self.dropout = nn.Dropout(dropout_rate)
        self.sigmoid = nn.Sigmoid()
        self.is_test = is_test

    def forward(self, x):
        # Add the length dimension if input has only 2 dimensions
        if len(x.shape) == 2:
            x = x.unsqueeze(1)
            
        x = self.model(x)
        x = self.dropout(x)  # Apply dropout
        x = self.output(x)
        if self.is_test:
            x = self.sigmoid(x)
        x = x.squeeze()
        
        return x

This is a changeable Mamba model, you can specify the parameters as you need.
This model expects the input to be [bs, length, input_size].
For example, in NLP, length is length of sequence and input_size is emmbed dim of a embedded input.

The model returns output that same shape to input.

please try to use it.

Reference

[1] github, mamba

Discussion