🙄

RustでChainedHashTableを実装してみる

に公開

概要

https://www.amazon.co.jp/みんなのデータ構造-Pat-Morin/dp/4908686068/ref=sr_1_1?__mk_ja_JP=カタカナ&crid=3HFHM4XE0TZ37&keywords=みんなのデータ構造&qid=1705709264&sprefix=みんなのデータこうぞう,aps,205&sr=8-1

「みんなのデータ構造」という本を読んでいて、Rustの練習に良さそうなので、色々書いてみる。

この記事ではChainedHashTableを実装。

ChainedHashTableとは

ChainedHashTableはハッシュテーブルの一種で、ハッシュの衝突を「チェイン法」を使って解決する。
ハッシュテーブルをtとした時、ハッシュテーブルに格納されるデータは、各ハッシュ値のデータ格納領域(バケット)が持つリスト(リンクリスト)に格納される。

https://opendatastructures.org/ods-java/5_1_ChainedHashTable_Hashin.html
5.1 ChainedHashTable: Hashing with Chaining

実装

インターフェース

trait USet<T> {
    fn new() -> Self;
    fn add(&mut self, element: T) -> bool;
    fn find(&self, element: &T) -> Option<&T>;
    fn remove(&mut self, element: &T) -> Option<T>;
}

USetインタフェースを持つ。

データ構造

array<List> t;
int n;

リストを持つ配列t(ハッシュテーブル)とデータ数nを持つ。

#[derive(Debug)]
struct ChainedHashTable<T: std::fmt::Debug + Hash + Eq> {
    table: Vec<Vec<T>>,
    size: usize,
}

tableが持つのは、リンクリストが一般的のようだが、rustのVectorを使ってデータを持つようにする。

ハッシュ(乗算ハッシュ法)

impl<T: std::fmt::Debug + Hash + Eq> USet<T> for ChainedHashTable<T> {
    fn new() -> Self {
        let mut hash_table = Self {
            table: Vec::new(),
            size: 0,
        };
        // テーブルの初期化(ハッシュテーブルは2^D個のバケットを持つ)
        for _ in 0..(1 << (Self::D)) {
            hash_table.table.push(Vec::new());
        }
        hash_table
    }
...

impl<T: std::fmt::Debug + Hash + Eq> ChainedHashTable<T> {
    // zは{1, 3, ..., 2^W - 1}の奇数から選択した定数
    const Z: u64 = 4102541685;
    // Wはハッシュ値のビット数を表す(32ビット)
    const W: u64 = 32;
    // Dは2^D個のバケットを持つハッシュテーブルを表す(2^D = 8ビット)
    const D: u64 = 8;

    // 乗算ハッシュ法によるハッシュ値の計算
    // 計算式: (Z * hash_code) % 2^W >> (W - D)
    fn hash(&self, element: &T) -> u64 {
        // データのハッシュ値を計算
        let mut hasher = DefaultHasher::new();
        element.hash(&mut hasher);
        let hash_code = hasher.finish();

        let hash = Self::Z.wrapping_mul(hash_code);
        (hash % (1 << Self::W)) >> (Self::W - Self::D)
    }
}

ハッシュの計算は、乗算ハッシュ法を使う。
ある定数Zと格納する値のハッシュ値(hash)を乗算し、2^Wで剰余を出す。
その値を2^{W-D}で割ることで、dビット(D桁)のハッシュ値を計算できる。

\text{hash}(x) = \left( (z \cdot x) \mod 2^w \right) \div 2^{w-d}

テーブルは、D^2個のバケットを持っているため、求めたハッシュからバケットを特定することができる。

データの追加

fn add(&mut self, element: T) -> bool {
    if self.find(&element).is_some() {
        return false;
    }
    let hash = self.hash(&element);
    self.table[hash as usize].push(element);
    self.size += 1;
    true
}

データの追加は、データからhashを求めてテーブルのバケットを特定しデータを追加する。

データの検索

fn find(&self, element: &T) -> Option<&T> {
    let hash = self.hash(element);
    let list = &self.table[hash as usize];
    for i in 0..list.len() {
        let node = list.get(i).unwrap();
        if *node == *element {
            return Some(node);
        }
    }
    None
}

データの検索は、hashからバケットを求め、キーが衝突しているデータから対象を検索する。

データの削除

fn remove(&mut self, element: &T) -> Option<T> {
    let hash = self.hash(element);
    let list = &mut self.table[hash as usize];
    for i in 0..list.len() {
        let node = list.get(i).unwrap();
        if *node == *element {
            self.size -= 1;
            return Some(list.remove(i));
        }
    }
    None
}

データの検索とほぼ変わらない。

コード全体

コード
use std::{
    collections::hash_map::DefaultHasher,
    hash::{Hash, Hasher},
};

trait USet<T> {
    fn new() -> Self;
    fn add(&mut self, element: T) -> bool;
    fn find(&self, element: &T) -> Option<&T>;
    fn remove(&mut self, element: &T) -> Option<T>;
}

#[derive(Debug)]
struct ChainedHashTable<T: std::fmt::Debug + Hash + Eq> {
    table: Vec<Vec<T>>,
    size: usize,
}

impl<T: std::fmt::Debug + Hash + Eq> USet<T> for ChainedHashTable<T> {
    fn new() -> Self {
        let mut hash_table = Self {
            table: Vec::new(),
            size: 0,
        };
        for _ in 0..(1 << (Self::D)) {
            hash_table.table.push(Vec::new());
        }
        hash_table
    }

    fn add(&mut self, element: T) -> bool {
        if self.find(&element).is_some() {
            return false;
        }
        let hash = self.hash(&element);
        self.table[hash as usize].push(element);
        self.size += 1;
        true
    }

    fn find(&self, element: &T) -> Option<&T> {
        let hash = self.hash(element);
        let list = &self.table[hash as usize];
        for i in 0..list.len() {
            let node = list.get(i).unwrap();
            if *node == *element {
                return Some(node);
            }
        }
        None
    }

    fn remove(&mut self, element: &T) -> Option<T> {
        let hash = self.hash(element);
        let list = &mut self.table[hash as usize];
        for i in 0..list.len() {
            let node = list.get(i).unwrap();
            if *node == *element {
                self.size -= 1;
                return Some(list.remove(i));
            }
        }
        None
    }
}

impl<T: std::fmt::Debug + Hash + Eq> ChainedHashTable<T> {
    // zは{1, 3, ..., 2^W - 1}の奇数から選択した定数
    const Z: u64 = 4102541685;
    // Wはハッシュ値のビット数を表す(32ビット)
    const W: u64 = 32;
    // Dは2^D個のバケットを持つハッシュテーブルを表す(2^D = 8ビット)
    const D: u64 = 8;

    // 乗算ハッシュ法によるハッシュ値の計算
    // 計算式: (Z * hash_code) % 2^W >> (W - D)
    fn hash(&self, element: &T) -> u64 {
        // データのハッシュ値を計算
        let mut hasher = DefaultHasher::new();
        element.hash(&mut hasher);
        let hash_code = hasher.finish();

        let hash = Self::Z.wrapping_mul(hash_code);
        (hash % (1 << Self::W)) >> (Self::W - Self::D)
    }
}

fn main() {
    let mut hash_table = ChainedHashTable::<i32>::new();
    // 数値の格納
    for i in 1..=1000 {
        hash_table.add(i);
    }

    for i in 1..=500 {
        hash_table.remove(&i);
    }

    println!("{:?}", hash_table.find(&501));
    println!("{:?}", hash_table.find(&1));

    let mut hash_table = ChainedHashTable::<String>::new();
    hash_table.add("hello".to_string());
    hash_table.add("world".to_string());
    hash_table.add("rust".to_string());
    hash_table.add("go".to_string());
    hash_table.add("c".to_string());
    hash_table.add("c++".to_string());
    hash_table.add("java".to_string());
    hash_table.add("kotlin".to_string());
    hash_table.add("swift".to_string());
    hash_table.add("python".to_string());

    hash_table.remove(&"python".to_string());
    println!("{:?}", hash_table.find(&"hello".to_string()));
    println!("{:?}", hash_table.find(&"python".to_string()));
}

Discussion