🕸️

k-means 法を図で理解する

2023/08/15に公開

アルゴリズム

  1. 任意の数のクラスターを適当に配置する
  2. 各クラスターに近い点と結びつける
  3. グループの中心にクラスターを移動する
  4. 収束するまで 2〜3 を繰り返す

視覚的に理解する

最初の状態

最終的にクラスタリングがうまくいくように狙って配置しているわけではなく何の目論見もなくランダムに配置している。

任意の数のクラスターを適当に配置する

クラスターも同様にランダムに配置している。

各クラスターに近い点と結びつける

グループの中心にクラスターを移動する

再度、各クラスターに近い点と結びつける

ここがいちばん興味深いところでクラスターの中心が移動したことでより近いクラスターに切り替わった点がいくつか確認できる。

再度、グループの中心にクラスターを移動する

収束したら終える

近いクラスターと結びつけてクラスターをグループの中心に移動するを繰り返し、何も動かなくなったら終える。

途中経過を飛ばしてこの結果を見ると点がクラスターに寄って行っているように感じるが、クラスターが移動しているのであって、点はまったく移動していない。

コードの要点

各クラスターに近い点と結びつける

points.each do |point|
  最も近いクラスター = clusters.min_by do |cluster|
    point と cluster の距離を返す
  end
  point.cluster = 最も近いクラスター
end

距離は平方根まで求める必要はない。

クラスターを中心に移動する

cluster.each do |cluster|
  if 所属する点が1つ以上あれば
    cluster.位置ベクトル = 所属する点.sum(0ベクトル) / 所属する点の個数
  end
end

クラスターの初期配置のバランスがいまいちだとクラスターに所属する点がなくなる場合もあるので0除算に注意する。

コード
require "active_support/core_ext/module/delegation"
require "#{__dir__}/../../物理/ベクトル/vec2"

class Cluster
  attr_accessor :v
  attr_accessor :color
  attr_accessor :points

  def initialize(v, color)
    @v = v
    @color = color
    @points = []
  end
end

class Point
  attr_accessor :v
  attr_accessor :cluster

  def initialize(v)
    @v = v
  end
end

class App
  attr_accessor :points
  attr_accessor :clusters

  delegate :write, :animation_write, to: :image_formatter

  def call
    @points = points_count.times.collect do
      Point.new(vec_rand)
    end

    @clusters = []

    snapshot :init1

    @clusters = [
      "blue",
      "green",
      "purple",
      "orange",
      "red",
    ].collect do |color|
      Cluster.new(vec_rand, color)
    end

    snapshot :init2

    @generation = 0
    loop do
      if @generation >= 50
        break
      end

      @clusters.each { |e| e.points.clear }
      @points.each do |point|
        cluster = @clusters.min_by do |cluster|
          point.v.distance_squared_to(cluster.v)
        end
        point.cluster = cluster.dup
        cluster.points << point
      end

      snapshot

      updated = false
      @clusters.each do |e|
        unless e.points.empty?
          v = e.points.sum(V.zero, &:v) / e.points.size
          if e.v != v
            e.v = v
            updated = true
          end
        end
      end
      unless updated
        break
      end

      snapshot
      @generation += 1
    end

    snapshot :done
    puts @generation

    animation_write("images/animation.gif")
  end

  def field_wh
    V[field_w, field_w / 1.618033988749895]
  end

  private

  def snapshot(name = nil)
    name ||= "snapshot#{snapshot_couner}"
    write("images/#{name}.png")
  end

  def snapshot_couner
    @snapshot_couner ||= 0
    @snapshot_couner += 1
  end

  def vec_rand
    x = rand(field_wh.x)
    y = rand(field_wh.y)
    V[x, y]
  end

  def field_w
    1200
  end

  def points_count
    100
  end

  def image_formatter
    @image_formatter ||= ImageFormatter.new(self)
  end
end

class ImageFormatter
  require "rmagick"
  include Magick

  attr_accessor :base

  def initialize(base)
    @base = base
    @image_list = ImageList.new
  end

  def write(path)
    render
    @layer.write(path)
    @image_list << @layer
    open(path)
  end

  def animation_write(path)
    av = @image_list.optimize_layers(Magick::OptimizeLayer)
    av.delay = 50
    av.write(path)
    open(path)
  end

  private

  def open(path)
    system "open -a 'Google Chrome' #{path}"
  end

  def render
    @layer = Image.new(*base.field_wh) do |e|
      e.background_color = bg_color
    end

    base.points.each do |e|
      if cluster = e.cluster
        line_draw(e.v, cluster.v, cluster.color)
      end
    end

    base.points.each do |e|
      color = object_default_color
      if cluster = e.cluster
        color = cluster.color
      end
      point_draw(e.v, color)
    end

    base.clusters.each do |e|
      cluster_draw(e.v, e.color)
    end
  end

  def object_wh
    V[12, 12]
  end

  def object_default_color
    "gray80"
  end

  def bg_color
    "white"
  end

  def line_draw(v0, v1, color)
    g = Draw.new
    g.fill(color)
    g.line(*v0, *v1)
    g.draw(@layer)
  end

  def point_draw(v, color)
    g = Draw.new
    g.fill(color)
    g.stroke_width(1)
    g.stroke(bg_color)
    g.ellipse(*v, *(object_wh / 2), 0, 360)
    g.draw(@layer)
  end

  def cluster_draw(v, color)
    r = object_wh / 2 * 2
    g = Draw.new
    g.fill(color)
    g.stroke_width(1)
    g.stroke(bg_color)
    g.rectangle(*(v - r), *(v + r))
    g.draw(@layer)
  end
end

if $0 == __FILE__
  App.new.call
end

Discussion