🐍
Implementation of Mamba in PyTorch
This time, I introduce the implementation of Mamba. Rather than implementing mamba from scratch, we consider implementing mamba as a pytorch model.
1. Model
・Official Implementation
import torch
from mamba_ssm import Mamba
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
print(y)
# output
# tensor([[[ 0.0046, 0.0108, -0.0157, ..., -0.0179, -0.0226, 0.0497],
# [ 0.0063, 0.0288, -0.0023, ..., 0.0114, 0.0188, 0.0328],
# [ 0.0366, -0.0113, -0.1003, ..., 0.0279, -0.0380, 0.0105],
# ...,
# [-0.0461, 0.0007, 0.0366, ..., 0.0532, -0.0284, -0.0347],
# [-0.0073, 0.0311, -0.0177, ..., 0.0332, 0.0037, 0.0007],
# [-0.0008, -0.0456, 0.0199, ..., -0.0099, -0.0158, -0.0012]],
# [[ 0.0254, 0.0037, -0.0638, ..., 0.0042, 0.0123, -0.0212],
# [-0.0255, -0.0134, -0.0106, ..., -0.0006, -0.0259, -0.0111],
# [-0.0008, 0.0295, 0.0479, ..., -0.0059, 0.0114, -0.0043],
# ...,
# [ 0.0026, -0.0040, -0.0070, ..., -0.0498, -0.0085, 0.0358],
# [ 0.0534, -0.0341, -0.0274, ..., -0.0974, -0.0472, 0.0185],
# [-0.0236, -0.0280, 0.0078, ..., 0.0251, 0.0085, -0.0092]]],
# device='cuda:0', grad_fn=<UnsafeViewBackward0>)
This is an official implementation. we can use it as is, but it's not easy to use.
So, I made it as pytorch model.
2. Mamba with peripheral things
import torch
import torch.nn as nn
from mamba_ssm import Mamba
class MambaModel(nn.Module):
def __init__(self,
dim_model=384, # Model dimension d_model (embedding size)
d_state=64, # SSM state expansion factor
d_conv=8, # Local convolution width
expand=4, # Block expansion factor
output=3, # number of classes (or output number simply)
dropout_rate=0.1, # Dropout rate
is_test=False,
):
super().__init__()
self.model = Mamba(
d_model=dim_model,
d_state=d_state,
d_conv=d_conv,
expand=expand,
).to("cuda")
self.output = nn.Linear(dim_model, output)
self.dropout = nn.Dropout(dropout_rate)
self.sigmoid = nn.Sigmoid()
self.is_test = is_test
def forward(self, x):
# Add the length dimension if input has only 2 dimensions
if len(x.shape) == 2:
x = x.unsqueeze(1)
x = self.model(x)
x = self.dropout(x) # Apply dropout
x = self.output(x)
if self.is_test:
x = self.sigmoid(x)
x = x.squeeze()
return x
This is a changeable Mamba model, you can specify the parameters as you need.
This model expects the input to be [bs, length, input_size].
For example, in NLP, length is length of sequence and input_size is emmbed dim of a embedded input.
The model returns output that same shape to input.
please try to use it.
Reference
[1] github, mamba
Discussion