🍅

Tokioのreadとread_bufの違い

2023/05/01に公開

問題が発生した最小限のコードを以下に記載します。

main.rs
use tokio::{  
    io::{AsyncReadExt, AsyncWriteExt},
    net::TcpStream,
};

#[tokio::main]
async fn main() -> std::io::Result<()> {  
    let mut stream = TcpStream::connect("93.184.216.34:80").await?;
    stream
        .write_all(
            "GET / HTTP/1.1\\r\\nHost: example.com\\r\\nUser-Agent: curl/8.0.1\\r\\nAccept: */*\\r\\n\\r\\n"
                .as_bytes(),
        )
        .await?;

    let mut buf = vec![0; 2048];

    let n = stream.read_buf(&mut buf).await?;

    let bad = String::from_utf8_lossy(&buf);
    assert!(!bad.starts_with("HTTP/1.1 200 OK\\r\\n"));

    let ok = String::from_utf8_lossy(&buf[buf.len() - n..]);
    assert!(ok.starts_with("HTTP/1.1 200 OK\\r\\n"));

    Ok(())
}

なぜこのデータはインデックス0からではなく、bufの末尾から入力されるのでしょうか。以下、その理由を調べてみましょう。

stream.read_buf を呼び出すと、ReadBuf データ構造が作成されます。ReadBuf は、tokio が定義した、低レベルバッファーに対するラッパーの一種で、2つのカーソルを使用してデータの境界を追跡します。詳しい使い方は、tokio::io::ReadBuf を参照してください。ここでは、ReadBuf が提供する poll 関数を見てみましょう。この関数は、データをラップされたバッファーに読み込むために使用されます。

read_buf.rs
// https://github.com/tokio-rs/tokio/blob/52bc6b6f2d773def6bfaabf6925fef4e789782b7/tokio/src/io/util/read_buf.rs#L35
impl<R, B> Future for ReadBuf<'_, R, B>  
where  
    R: AsyncRead + Unpin,
    B: BufMut,
{
    type Output = io::Result<usize>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
        use crate::io::ReadBuf;
        use std::mem::MaybeUninit;

        let me = self.project();

        if !me.buf.has_remaining_mut() {
            return Poll::Ready(Ok(0));
        }

        let n = {
            let dst = me.buf.chunk_mut();
            let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit<u8>]) };
            let mut buf = ReadBuf::uninit(dst);
            let ptr = buf.filled().as_ptr();
            ready!(Pin::new(me.reader).poll_read(cx, &mut buf)?);

            // Ensure the pointer does not change from under us
            assert_eq!(ptr, buf.filled().as_ptr());
            buf.filled().len()
        };

        // Safety: This is guaranteed to be the number of initialized (and read)
        // bytes due to the invariants provided by `ReadBuf::filled`.
        unsafe {
            me.buf.advance_mut(n);
        }

        Poll::Ready(Ok(n))
    }

}

chunk_mut は、異なるタイプに対して異なる定義がありますが、Vec<u8> の場合は以下のように定義されています:

buf_mut.rs
// https://github.com/tokio-rs/bytes/blob/b29112ce4484424a0137173310ec8b9f84db27ae/src/buf/buf_mut.rs#L1480-L1490
fn chunk_mut(&mut self) -> &mut UninitSlice {  
    if self.capacity() == self.len() {
        self.reserve(64); // Grow the vec
    }

    let cap = self.capacity();
    let len = self.len();

    let ptr = self.as_mut_ptr();
    unsafe { &mut UninitSlice::from_raw_parts_mut(ptr, cap)[len..] }
}

chunk_mut は、Vec<u8> に対しては、len() から capacity() までのスライスを返します。もし capacity() の値と len() が等しい場合、追加の少なくとも64バイトのスペースを割り当てるために、reserve関数が呼び出されます。返されたスライスの実際のアドレスは len() から始まることに特に注意する必要があります。つまり、vec![0; 2048] を渡すと、返されるのは vec[2048..cap] の領域のスライスになります。これが stream.read_buf(&mut buf).await?; を呼び出した後に、インデックス0からではなく、バッファの末尾から埋められる理由です。

次に、poll関数の中で、このスライスを新しいReadBufに初期化し、ready!(Pin::new(me.reader).poll_read(cx, &mut buf)?); を呼び出してストリームからデータを読み込みます。buf.filled().len()は、すでに埋められたデータを返します。そして、advance_mut関数を使用して、元のバッファの長さをnだけ増やします。

要約すると、read_buf 関数は、現在の len() から capacity() までの範囲にデータを書き込みます。一方で、もう一つの関数 read は、バッファの先頭に書き込みます。

main.rs
use tokio::{  
    io::{AsyncReadExt, AsyncWriteExt},
    net::TcpStream,
};

#[tokio::main]
async fn main() -> std::io::Result<()> {  
    let mut stream = TcpStream::connect("93.184.216.34:80").await?;
    stream
        .write_all(
            "GET / HTTP/1.1\\r\\nHost: example.com\\r\\nUser-Agent: curl/8.0.1\\r\\nAccept: */*\\r\\n\\r\\n"
                .as_bytes(),
        )
        .await?;

    let mut buf = bytes::BytesMut::with_capacity(4096);

    let n = stream.read(&mut buf).await?;
    assert_eq!(n, 0);

    Ok(())
}

なぜこの例でnは常に0になるのでしょうか? &mut bufを実行すると、次のコードが実行されます

#[inline]
pub fn with_capacity(capacity: usize) -> BytesMut {  
    BytesMut::from_vec(Vec::with_capacity(capacity))
}

#[inline]
fn as_slice_mut(&mut self) -> &mut [u8] {  
    unsafe { slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
}

impl AsMut<[u8]> for BytesMut {  
    #[inline]
    fn as_mut(&mut self) -> &mut [u8] {
        self.as_slice_mut()
    }
}

as_slice_mut 関数は、現在のlenを長さの値として使用します。何もデータをBytesMutに追加していないので、lenの値は0です。つまり、0長の配列を指定したことになります。そのため、readはデータを配列に書き込めません。正しい使い方は、長さを初期化することで、0で埋めることができます。

let mut buf = bytes::BytesMut::zeroed(16);  
stream.read(&mut buf).await?;  

または、unsafe コードを使用して長さを変更することもできますが、この場合は len()capacity() より小さい必要があります。

let mut buf = bytes::BytesMut::with_capacity(4096);  
unsafe { buf.set_len(4096) };  
stream.read(&mut buf).await?;  

先に言及した read_buf の使用もあります。

let mut buf = bytes::BytesMut::with_capacity(4096);  
stream.read_buf(&mut buf).await?;  

上記まで、私たちは readread_buf 関数のパラメータ要件上の違いを理解しました。以下にいくつかのコード例がありますが、正しいかどうかを判断してみてください。

Example 01 ✅
let mut buf = [0u8; 4096];
stream.read_buf(&mut buf.as_mut()).await?;
assert!(&buf[0..4] == b"HTTP");
Example 02 ✅
let mut buf = vec![0u8; 4096];
stream.read(&mut buf[..]).await?;
assert!(&buf[0..4] == b"HTTP");
Example 03 ✅
let mut buf = [0u8; 4096];
stream.read_buf(&mut buf.as_mut_slice()).await?;
assert!(&buf[0..4] == b"HTTP");
Example 04 ✅
let mut buf = vec![0u8; 2048];

stream.read_buf(&mut buf.as_mut_slice()).await?;
assert!(&buf[0..4] == b"HTTP");

この例は少し理解が難しいかもしれません。前の例と似ていますが、 chunk_mut&mut [u8] タイプで定義されている点が異なります。

fn chunk_mut(&mut self) -> &mut UninitSlice {
        // UninitSlice is repr(transparent), so safe to transmute
        unsafe { &mut *(*self as *mut [u8] as *mut _) }
}

これは、整个のスライスを返します。

Example 05 ❌
let mut buf = Vec::with_capacity(2048);
stream.read(&mut buf[..]).await?;
assert!(&buf[0..4] == b"HTTP");
Example 06 ✅
use std::mem::MaybeUninit;
let mut buf: [u8; 2048] = unsafe { MaybeUninit::uninit().assume_init() };
stream.read(&mut buf[..]).await?;
assert!(&buf[0..4] == b"HTTP");
Example 07 ❌ core dump
use std::mem::MaybeUninit;
let mut buf: Box<[u8; 2048]> = unsafe { MaybeUninit::uninit().assume_init() };
stream.read(&mut buf[..]).await?;
assert!(&buf[0..4] == b"HTTP");
Example 08 ❌ core dump
use std::mem::MaybeUninit;
let mut buf: Vec<u8> = unsafe { MaybeUninit::uninit().assume_init() };
buf.reserve(2048);
stream.read_buf(&mut buf).await?;
assert!(&buf[0..4] == b"HTTP");
Example 09 ✅
let mut buf: Box<[u8; 2048]> = Box::new([0; 2048]);
stream.read(&mut buf[..]).await?;
assert!(&buf[0..4] == b"HTTP");

Discussion