📖

【情報符号理論】Pythonで実装して理解する情報源、符号化、通信路、誤り訂正【実装】

2023/08/12に公開

概要

情報学の一分野である情報符号理論に登場する「情報源」「符号化」「通信路」「誤り訂正」について、それに相当するクラスを実際にPythonで実装することで理解を深めることを目的とした記事です。
情報源の定義など、情報符号理論の基礎的な内容について詳しく説明することはしませんので予めご了承ください。

想定する読者像

  • 計算機科学を学習している人
  • Pythonにある程度慣れている人

前提する知識

  • 情報符号理論の基礎的な理解
  • Pythonの初級的な内容(クラスの継承等まで)

この記事の内容

  • 情報符号理論の各種内容(ハフマン符号化、加法的通信路、ハミング符号等)をシミュレートするための実装方法
  • Pythonにおけるクラスを用いた実装のシンプルな具体例

注意

  • 本記事で紹介する実装はシンプルな分かりやすさを重視したもので、実用目的のものとは異なります。
  • 記事の内容には誤りが含まれている場合があります。予めご了承ください。誤りに気付いた方はご指摘いただけると大変助かります。

目次

  1. 情報源
    1.1. 情報源クラス
    1.2. 無記憶定常情報源
    1.3. マルコフ情報源
    1.4. 連続分布に対する情報源(応用)
  2. 符号化
    2.1. 符号クラス
    2.2. 汎用的な符号
    2.3 ハフマン符号化
    2.4. ブロックハフマン符号化
  3. 通信路
    3.1. 通信路クラス
    3.2. 無記憶定常通信路
    3.3. 2元対称通信路
    3.4. 加法的2元通信路
  4. 誤り検出と訂正
    4.1. 反復による冗長化
    4.2. パリティ検査符号
    4.3. ハミング符号
    4.4. 巡回符号
  5. まとめ
    5.1. 通信全体のシミュレーション

本編

情報源

情報源クラス

では最初に、情報源クラスの実装に取り組んでいきましょう。設計は極めてシンプルであり、コンストラクタ以外に外部からアクセス可能なメソッドは「step」のみです。このstepメソッドを使うと、情報源アルファベットから適当な情報源記号が1つ返されます。
抽象クラスは以下のようになります。

import abc
class Source(metaclass=abc.ABCMeta):
    @abc.abstractmethod
    def step(self):
        raise NotImplementedError()

Pythonに不慣れな方にはやや難しいかもしれませんが、これは情報源の具体的なクラスがどんなメソッドを持つべきか、というルールを定めているにすぎません。情報符号理論とは本質的に無関係な話なので無理に書かれている内容を理解しようとしなくても構いません。

このクラスを継承して、順にいくつかの具体的な情報源の役割を果たすクラスを作成していきます。

無記憶定常情報源

無記憶定常情報源とは要するに、毎回同じ確率分布に従って情報源記号を返すような情報源のことです。
実装は簡単ですね。次のようになります

from random import randint, random
from typing import List, Dict
class StaticSource(Source):
    def __init__(self, probability_dict:Dict[int,float]) -> None:
        self.probabilities = probability_dict

    def step(self):
        return self.__choose()

    def __choose(self):
        r = random()
        for out,p in self.probabilities.items():
            r -= p
            if r <= 0:
                return out
        assert False, ('error in choosing symbol')
	
# 使用例
static = StaticSource({"a":0.4, "b":0.6})
print(static.step())

インスタンスの生成時に、各情報源記号とそれに対応する確率の値を辞書の形でprobability_dictに渡します。stepメソッドが呼ばれた際には、内部的にはchooseメソッドを呼び出し、randomライブラリによって生成した疑似乱数をもとに情報源記号を選択して返しています。

使用例においては、"a"が確率0.4、"b"が確率0.6で返されることとなります。

マルコフ情報源

マルコフ情報源とは、非決定的な状態遷移図で表現されるいわゆるマルコフ過程により、次に生成される情報源記号が決まるような情報源のことです。
まずは実装を見てみましょう。

class MarkovSource(Source):
    def __init__(self, neighbor_matrix, output_matrix, initial_state : int = 0) -> None:
        self.neighbors = neighbor_matrix
        self.state = initial_state
        self.outputs = output_matrix
        assert len(self.outputs) == len(self.neighbors), ('output matrix must have the same number of rows as the neighbor matrix')
        self.__assert_output()
        assert self.state < len(self.neighbors) , ('initial_state must be less than the number of states')
        self.__assert_neighbors()
        
    def __assert_output(self) -> None:
        for row in self.outputs:
            assert len(row) == len(self.outputs[0]), ('output matrix must has the same number of columns for each row')

    def __assert_neighbors(self) -> None:
        for row in self.neighbors:
            assert len(row) == len(self.neighbors), ('neighbor matrix must be square')
        for row in self.neighbors:
            try :
                sum(row)
            except TypeError:
                raise TypeError('neighbor matrix must be a matrix of numbers')
            assert sum(row) == 1, ('neighbor matrix must be a stochastic matrix')

    def step(self):
        self.state, output = self.__choose_neighbor()
        return output

    def __choose_neighbor(self) -> tuple[int,int]:
        r = random()
        for i in range(len(self.neighbors)):
            r -= self.neighbors[self.state][i]
            if r <= 0:
                return i,self.outputs[self.state][i]
        assert False, ('error in choosing neighbor')

# 使用例
neighbors = [[0.5, 0.5, 0],
	 [0, 0, 1],
	 [1, 0, 0]]
outputs = [[1,2,3],
       [4,5,6],
       [7,8,9]]
markov = MarkovSource(neighbors, outputs, 0)
for i in range(10):
    print(markov.step(),markov.state)

コンストラクタにおいては、状態遷移の確率行列neighbor_matrix、各状態から各状態へ遷移するときの出力を表す行列output_matrix、そして初期状態initial_stateをそれぞれ渡します。
neighbor_matrix[i][j]は状態iから状態jに遷移する確率を表し、output_matrix[i][j]は状態iから状態jに遷移する際に出力する情報源記号を表します。
stepメソッドが呼ばれた際には、内部的にはchoose_neighborメソッドを呼び出し、randomライブラリによって生成した疑似乱数をもとにして遷移先状態を決定し、状態の更新しながら出力記号を返しています。

連続分布に対する情報源(応用)

ここまで情報源からの出力が離散的であることを前提して説明してきましたが、例えば指数分布のような連続確率分布に従って出力を発生させる情報源を作成したい場合はどのようにすればよいのでしょうか。ここでは逆関数法と呼ばれる手法を紹介します。

例えば指数分布の確率密度関数は、以下のようにあらわすことができます。

f(x) = λe^{-λx}  (x\ge0)

指数分布の累積分布関数は

F(x) = 1-e^{-λx}  (x\ge0)

従ってこの逆関数は

F^{-1}(x) = -\frac{1}{λ} \log(1-x)  (x\ge0)

累積分布関数の値域が0から1までであることに注意してください。

ではここで、0から1までの範囲の連続一様分布を、randomライブラリによって生成し、これをuに代入したとします。
このとき、以下の等式が成り立ちます。

P(u \le F(x)) = F(x) (0\le x \le 1)

よって、r=-\frac{1}{λ} \log(1-u)すなわちr=F^{-1}(u)とおけば、

P(r \le x) = F(x)

となります。これは、rがx以下になる確率はF(x)である、ということを表しているため、累積分布関数の定義から明らかに、rは指数分布に従います。

よってクラスの定義としては、コンストラクタとしてF^{-1}、すなわち生成したい分布の分布関数の逆関数を受け取り、生成した0から1までの疑似乱数をこの関数に入れたときの値を返せばよいということになります。

実装は以下のようになります。

from typing import Callable
import math
class DistributionSource(Source):
    def __init__(self,Finv: Callable[[float],float]) -> None:
        self.Finv = Finv

    def step(self) -> float:
        return self.__generate()

    def __generate(self) -> float:
        r = random()
        return self.Finv(r)
#使用例
lam = 0.5
distribution = DistributionSource(lambda x: -math.log(1-x)/lam)
samples = []
for i in range(100):
    samples.append(distribution.step())

符号化

符号クラス

続いて符号クラスの実装に取り組んでいきましょう。コンストラクタ以外に必要なメソッドは主として符号化を行う「encoding」と復号を行う「decoding」の2つです。記号列をencodingメソッドによって符号化し、decodingメソッドによって元に戻します。今回は簡単のため、符号化前の記号列とそれに対応する符号語列が逐次的に復号できる瞬時符号である場合のみを考えます。
符号の抽象クラスの実装は以下のようになります。

class Coder(metaclass=abc.ABCMeta):
    @abc.abstractmethod
    def encode(self, series:str) -> str:
        raise NotImplementedError()
    @abc.abstractmethod
    def decode(self, series:str) -> str:
        raise NotImplementedError()

汎用的な符号

続いて、符号化前の記号と符号語の対応を列挙して与えることで、汎用的に符号をシミュレートすることができるクラスを考えます。

class GeneralCoder(Coder):
    def __init__(self, encoding:Dict[str,str]) -> None:
        self.encoding = encoding
        self.__assert_encoding()
        self.decoding = self.__invert_encoding()
        self.encodable = self.__prefix_condition(self.encoding, 'encoding')
        self.decodable = self.__prefix_condition(self.decoding, 'decoding')
        
    def __invert_encoding(self) -> Dict[str,str]:
        decoding = {}
        for symbol, code in self.encoding.items():
            if code not in decoding:
                decoding[code] = symbol
            else:
                print(f'Warning: code {code} is used by multiple symbols')
        return decoding
    
    def __assert_encoding(self) -> None:
        for key,value in self.encoding.items():
            assert isinstance(key, str), ('encoding keys must be strings')
            assert isinstance(value, str), ('encoding values must be strings')
        
    def __prefix_condition(self, encoding:Dict[str,str], method_name:str) -> bool:
        for i,symbol in enumerate(encoding.keys()):
            for j,prefix in enumerate(encoding.keys()):
                if i == j: continue
                if prefix == symbol[:len(prefix)]:
                    print(f'Warning: You cannot use {method_name} because it is not prefix-free')
                    return False
        return True

    def __translate(self, series:str, encoding:Dict[str,str]) -> str:
        result = ""
        i = 0
        while i < len(series):
            for symbol in encoding.keys():
                if series[i:i+len(symbol)] == symbol:
                    result += encoding[symbol]
                    i += len(symbol)
                    break
            else:
                #print('Warning: could not find symbol in encoding')
                i += 1
        return result
    
    def encode(self, series:str) -> str:
        assert self.encodable, ('cannot encode because encoding is not prefix-free')
        return self.__translate(series, self.encoding)
    
    def decode(self, series:str) -> str:
        assert self.decodable, ('cannot decode because decoding is not prefix-free')
        return self.__translate(series, self.decoding)
#使用例
encoding = {
    '0':'0',
    '1':'10',
    '2':'110',
    '3':'111'
}
coder = GeneralCoder(encoding)
series = "012322102"
code = coder.encode(series)
decoded = coder.decode(code)
print(code, decoded)

インスタンスを生成する際には記号とそれに対する符号語をまず辞書形式でencodingに与えます。assert_encodingメソッドにおいて記号と符号語がそれぞれ文字列データであることを確認した後、invert_encodingにおいて与えられた符号化に対して復号を行うための辞書decodingを作成します。最後にencoding、decodingのそれぞれについて瞬時符号となるための語頭条件が維持されていることをprefix_conditionメソッドで確認しています。
encode、decodeメソッドが呼ばれた際にはそれぞれ与えられた記号列をtranslateメソッドに渡し、事前に定義された符号に基づいた置き換える処理を行っています。

ハフマン符号化

では次に、各情報記号の確率分布を与えることによってハフマン符号を自動的に構成するようなプログラムを考えてみましょう。
ハフマン符号化を行うには、ハフマン木を構成する必要があります。そこでまずはハフマン木にあたるクラスを実装していきます。

class Tree():
    def __init__(self, prob: float, char : str = None):
        self.char = char
        self.isleaf = (char != None)
        self.prob = prob
        self.left = None
        self.right = None

    def __lt__(self, other):
        return self.prob < other.prob

情報源記号と確率を情報として持っているシンプルなノードのクラスです。葉ノード以外については対応する情報源記号は存在しないため、単純に確率の値のみをもちます。後でヒープを使って簡単にソートを行うために、不等号「<」に対する定義も実装しておきましょう。

ではこれを用いてハフマン符号のクラスを実装します。復号のための辞書の構成など、基本的な部分はすべてGeneralCoderと同じですので、先程作ったクラスを継承してしまいます。

import heapq
class Huffman(GeneralCoder):
    def __init__(self, distribution:Dict[str,float], symbol0:str = '0', symbol1:str = '1'):
        self.distribution = distribution
        self.symbol0 = symbol0
        self.symbol1 = symbol1
        self.tree = self.__build_huffman_tree()
        self.encoding = self.__build_encoding(self.tree)
        # print(self.average_length())
        super().__init__(self.encoding)
    
    def __build_huffman_tree(self):
        priority_queue = [Tree(prob, char) for char, prob in self.distribution.items()]
        heapq.heapify(priority_queue)
        while len(priority_queue) > 1:
            left = heapq.heappop(priority_queue)
            right = heapq.heappop(priority_queue)
            parent = Tree(left.prob + right.prob)
            parent.left = left
            parent.right = right
            heapq.heappush(priority_queue, parent)
        return heapq.heappop(priority_queue)
    
    def __build_encoding(self, tree:Tree):
        encoding = {}
        def recurse(tree:Tree, code:str):
            if tree.isleaf:
                encoding[tree.char] = code
            else:
                recurse(tree.left, code + self.symbol0)
                recurse(tree.right, code + self.symbol1)
        recurse(tree, '')
        return encoding

    def average_length(self):
        return sum([len(code)*self.distribution[symbol] for symbol, code in self.encoding.items()])
#使用例
distribution = {'a':0.5, 'b':0.25, 'c':0.125, 'd':0.125}
coder = Huffman(distribution)
print(coder.encoding)
series = 'abacabad'
code = coder.encode(series)
decoded = coder.decode(code)
print(code, decoded)

コンストラクタには各符号に対する確率分布distributionの他、必要であれば「0」「1」に対応する記号を渡します。これをもとに、build_huffman_treeでハフマン木を構成します。具体的には、まず各情報源記号とそれに対する確率をもつノードのリストをdistributionをもとに作成し、priority_queueに格納します。次に、heapqライブラリを用いてpriority_queueを優先度付きキューとして扱えるようにします。ノードの大小関係は確率の大小関係をもとに決定すると先に定義しましたので、この優先度付きキューでは確率の値が最も小さいものから順に取り出されることとなります。
確率の小さい方から2つを選択し、これを統合して新たなノードとすることを、未統合のノードが1つになるまで繰り返します。これによって得られたハフマン木について、今度はbuild_encodingにおいて再帰的にたどり、葉ノードまで到達したらそのノードに対応する情報源記号にそのノードに対応する符号語を割り当てることで符号を構成します。
後は構成した辞書型データのencodingをスーパークラスのコンストラクタに渡せば完成です。
情報源記号の確率分布が情報として与えられているため、ついでにaverage_lengthという情報源記号1つあたりの符号長の期待値を求める関数も実装しておきましょう。

ブロックハフマン符号化

続いて、いくつかの情報源記号のまとまりを1つの記号と捉えることによって平均符号長を改善することができるブロックハフマン符号も実装していきましょう。

class BlockHuffman(Huffman):
    def __init__(self, distribution:Dict[str,float], block_length:int, symbol0:str = '0', symbol1:str = '1'):
        self.length = block_length
        self.block_distribution = self.__compose_block_probabilities(distribution, block_length)
        self.symbol0 = symbol0
        self.symbol1 = symbol1
        super().__init__(self.block_distribution, symbol0, symbol1)
        print(self.average_length())

    def __compose_block_probabilities(self, dist, length):
        if length == 1:
            return dist
        prev_dist = self.__compose_block_probabilities(dist, length - 1) # length-1までの結果を再帰的に計算
        result = {}
        for k1, p1 in dist.items():
            for k2, p2 in prev_dist.items():
                key = k1 + k2
                prob = p1 * p2
                result[key] = prob
        return result

    def average_length(self):
        return sum([len(code)*self.distribution[symbol] for symbol, code in self.encoding.items()])/self.length
    
    def encode(self, series:str) -> str:
        assert len(series)%self.length==0, (f"Error: Encoding and Input is not match. Only accept input whose length is a multiple of {self.length}")
        return super().encode(series)

#使用例
#シンプルな確率分布を定義してブロックハフマン符号を生成
distribution = {'a':0.6,'b':0.3,'c':0.1}
length = 3
coder = BlockHuffman(distribution, block_length = length)
print(coder.encoding)
#符号化と復号のテスト
static = StaticSource(distribution)
series = ""
for i in range(10*length):
    series += str(static.step())
code = coder.encode(series)
decoded = coder.decode(code)
print(code, decoded)

Huffman符号のクラスを継承しているので、ブロック化したときの確率分布さえ得られればそれでOKです。
記号一つずつの確率分布distributionと、いくつの記号を1つにまとめるかを表すblock_length、「0」「1」に対応する記号をコンストラクタに渡し、compose_block_probabilitiesメソッドで確率分布をブロックごとにまとめたものを得ます。この処理は一見すると難しそうに感じますが、長さlength-1の時のブロック化した確率分布から長さlengthのときの確率分布を得ることは簡単にできますので、再帰的に処理を行うことで簡単に構成できます。
また、ブロックハフマン符号化は単純には入力系列がブロック長の整数倍でなければ処理できないため、そのとことを確認する処理をencodeメソッドに追加しています。

通信路

通信路クラス

次に、通信路クラスの実装を行います。通信路とは、簡単に言えば入力にノイズを付加して出力するものといえます。コンストラクタ以外に必要なメソッドとしては、入力信号から通信路を通過した出力信号を得るpassageメソッドのみです。
通信路の抽象クラスの実装は次のようになります。

class Channel(metaclass=abc.ABCMeta):
    @abc.abstractmethod
    def passage(self):
        raise NotImplementedError()

無記憶定常通信路

具体的な通信路を表現するクラスとして、まずは通信路行列によって各入力に対する確率分布が指定される無記憶定常通信路について実装していきます。


class StaticChannel(Channel):
    def __init__(self, channel_matrix:Dict[str,Dict[str,float]]):
        self.channel_matrix = channel_matrix
        self.__assert_matrix()
    
    def __assert_matrix(self):
        for dict in self.channel_matrix.values():
            prob_sum = 0
            for prob in dict.values():
                assert 0<=prob and prob<=1, ("probability must be between 0 and 1")
                prob_sum += prob
            assert prob_sum==1, ("Sum of probabilities for each input signal must be 1")
        
    
    def passage(self,input_signal: str):
        output_signal = ""
        for input in input_signal:
            assert (input in self.channel_matrix), ("input char is not in channel_matrix")
            output_signal += self.__choose(input)
        return output_signal

    def __choose(self, input):
        r = random()
        for out,p in self.channel_matrix[input].items():
            r -= p
            if r <= 0:
                return out
        assert False, ('error in choosing symbol')

#使用例
input_signal = "0000011111"
static = StaticChannel({"0":{"0":0.8,"1":0.2},"1":{"0":0.4,"1":0.6}})
print(static.passage(input_signal))

インスタンスを生成する際には、通信路行列を表す辞書型データのchannel_matrixをコンストラクタに渡します。このchannel_matrixにおいては、channel_matrix[i][j]が入力記号がiのとき出力記号がjである確率を表しています。assert_matrixメソッドではchannel_matrixが正しく確率行列となっているかを確認しています。
passageメソッドが呼ばれると、入力記号を1つずつ順に取り出して生成した疑似乱数と通信路行列をもとに出力記号を選択し、これらを繋げたものを出力として返しています。

2元対称通信路

先に実装した無記憶定常通信路を利用して、代表的な通信路である2元対象通信路を実装してみましょう。

class BinarySymmetricChannel(StaticChannel):
    def __init__(self, error_rate:float):
        self.error_rate = error_rate
        assert 0<=self.error_rate and self.error_rate<=1, ("error_rate must be between 0 and 1")
        channel_matrix = {"0":{"0":1-self.error_rate,"1":self.error_rate},"1":{"0":self.error_rate,"1":1-self.error_rate}}
        super().__init__(channel_matrix)

実装はこのようになります。このように、先に作成したクラスを継承することで簡単に様々な通信路のクラスを実装することができます。

加法的2元通信路

続いて、誤り源と呼ばれる情報源を用いた通信路の実装について説明します。この通信路においては入力記号は2種類のみであり、誤り源の出力が0のときは入力がそのまま出力され、1のときには入力が反転して出力されます。
実装はこのようになります。

class AbbitiveBinaryChannel(Channel):
    def __init__(self, symbol0:str, symbol1:str, error_source:Source):
        self.symbol0 = symbol0
        self.symbol1 = symbol1
        self.error_source = error_source
    def passage(self, input_signal:str) -> str:
        output_signal = ""
        for input in input_signal:
            error = self.error_source.step()
            if (input == self.symbol0 and int(error) == 0) or (input == self.symbol1 and int(error) == 1):
                output_signal += self.symbol0
            elif (input == self.symbol0 and int(error) == 1) or (input == self.symbol1 and int(error) == 0):
                output_signal += self.symbol1
            else:
                if input != self.symbol0 and input != self.symbol1:
                    assert False, ("all of input must be symbol0 or symbol1")
                elif int(error) != 0 and int(error) != 1:
                    assert False, ("output of error source must be 0 or 1")
                else:
                    assert False, ("unexpected error")
        return output_signal
    
#使用例
channel = AbbitiveBinaryChannel("a","b",StaticSource({0:0.3, 1:0.7}))
input_signal = "aaaaabbbbb"
print(channel.passage(input_signal))

インスタンスの生成時には2つの記号symbol0,symbol1と誤り源error_sourceをそれぞれコンストラクタに渡します。
passageメソッドでは入力記号を1つずつ順に取り出し、同時に誤り源から出力を1つ得て、これらから出力記号を1つずつ順に決定していきます。
この加法的2元通信路のerror_sourceにマルコフ情報源などを渡すことによって、ギルバートモデルのような記憶のあるより複雑な通信路を表現することができます。

誤り検出と訂正

最後に誤り検出、誤り訂正を行うための符号化、復号についてその実装を説明していきます。通信路の章で実装したように、実際の通信路には確率的なノイズがつきものです。このようなノイズに対し、できる限り正しく情報を送信をするための技術が誤り訂正です。
誤り訂正においては、符号化の章で示した符号の抽象クラスCoderを利用して、幾つかの誤り訂正符号を実装します。

反復による冗長化

もっとも単純な誤り訂正の手法は、同じメッセージを何度も繰り返して送信することです。例えば3回連続で同じメッセージを送信すれば、そのうちの1つがノイズによって書き換えられてしまっても多数決を取ることによって正しいメッセージを判断することが可能になります。
では実装を見てみましょう。

class RepeatChecker(Coder):
    def __init__(self, repeat:int = 3, threshold:int = -1, fail_str:str = "FAIL") -> None:
        self.threshold = threshold if threshold != -1 else repeat//2+1
        self.repeat = repeat
        self.fail_str = fail_str
    def encode(self, series:str) -> str:
        return series*self.repeat
    def decode(self, series:str) -> str:
        # 同じ文字列のrepeat回の繰り返しであるかを判定
        if len(series)%self.repeat != 0:
            return self.fail_str
        #文字列をrepeat個に分けて各文字列の出現回数を数える
        counts = {}
        for i in range(0,len(series),len(series)//self.repeat):
            if series[i:i+len(series)//self.repeat] not in counts:
                counts[series[i:i+len(series)//self.repeat]] = 1
            else:
                counts[series[i:i+len(series)//self.repeat]] += 1
        max_count = 0
        max_str = ""
        for key,value in counts.items():
            if value >= self.threshold and value > max_count:
                max_count = value
                max_str = key
        return max_str if max_str != "" else self.fail_str
# 使用例
channel = StaticChannel({'0':{'0':0.9,'1':0.1},'1':{'0':0.1,'1':0.9}})
check = RepeatChecker(repeat = 3)
series = "00011111000"
encoded_series = check.encode(series)
print("符号化:", encoded_series)
output_series = channel.passage(encoded_series)
print("雑音を含む系列:", output_series)
decoded_series = check.decode(output_series)
print("復号:", decoded_series)

繰り返し回数repeat、判断の閾値threshold、受信失敗に出力する文字列fail_strをインスタンス生成時に指定します。ただしthresholdを指定しなかった場合には過半数を閾値としています。閾値というのは受信語において同一のメッセージがいくつ以上あればそのメッセージを送ったものと判断してよいかという基準です。例えば7回繰り返してメッセージを送ったとき、うち3回はA、2回はB、2回はCというメッセージが受信されたときに、閾値が3回であればAだと解釈しますが、閾値が4回ならば「受信失敗」と判断します。閾値を高くすると「受信失敗」の確率が高くなりますが、一方閾値を低くしすぎると誤って復号してしまう確率が高くなってしまいます。この実装では閾値を超えるものが複数ある場合はその中で多数決を採り、同率首位がある場合にはシステムがその中で適当に1つ選んでしまう仕様となっています。

パリティ検査符号

続いては1ビットのパリティビットを末尾に付加するパリティ検査符号を実装します。これは、符号語に含まれる1の個数が偶数個になるように末尾に1か0のいずれかを追加するものです。受信語に対して1の個数を数え、もし奇数個になってしまっていれば受信に失敗したと判断します。
この符号化が適用可能なのは、記号が2種類の場合のみであることに注意してください。

class ParityChecker(Coder):
    def __init__(self, symbol0:str='0', symbol1:str='1', fail_str:str = "FAIL") -> None:
        self.fail_str = fail_str
        self.symbol0 = symbol0
        self.symbol1 = symbol1
    def encode(self, series:str) -> str:
        return series + self.__parity(series)
    def decode(self, series:str) -> str:
        if series[-1] != self.__parity(series[:-1]):
            return self.fail_str
        return series[:-1]
    def __parity(self, series:str) -> str:
        return self.__num2symbol(series.count(self.symbol1)%2)
    def __num2symbol(self, num:int) -> str:
        return self.symbol0 if num == 0 else self.symbol1

# 使用例
channel = StaticChannel({'0':{'0':0.9,'1':0.1},'1':{'0':0.1,'1':0.9}})
check = ParityChecker()
series = "00011111000"
encoded_series = check.encode(series)
print("符号化:", encoded_series)
output_series = channel.passage(encoded_series)
print("雑音を含む系列:", output_series)
decoded_series = check.decode(output_series)
print("復号:", decoded_series)

「0」「1」以外に、「a」「b」などからなる入力系列にも対応できるように、コンストラクタには「0」「1」それぞれに対応する記号を渡せるようになっています。また、受信失敗時にはfail_strに渡した文字列を返します。
encode時にはparityメソッドによって計算したパリティビットを末尾に付加し、decode時にはparity検査を行った上で、誤りが検出されなければ末尾のパリティビットを取り除いた記号列を返しています。

ハミング符号

パリティ検査符号においては1ビットの誤りを検出することしかできませんでした。ここでは、1ビットの誤りを検出するだけでなく訂正することもできるハミング符号の実装を行います。
ハミング符号の実装にはさまざまなものが考えられるかもしれませんが、ここでは符号語のうち「2の累乗」番目の記号を検査ビットとする実装を示します。

class HammingCoder(Coder):
    def __init__(self, symbol0:str='0', symbol1:str='1', fail_str:str='FAIL') -> None:
        self.symbol0 = symbol0
        self.symbol1 = symbol1
        self.fail_str = fail_str
    def encode(self, series:str) -> str:
        return self.__add_check_bits(series)
    def decode(self, series:str) -> str:
        self.__calc_syndrome(series)
        corrected_series = self.__correct_error(series)
        return self.__remove_check_bits(corrected_series)
    def __is_power_of_2(self, num:int) -> bool:
        return num >= 1 and num & (num-1) == 0
    def __symbol2num(self, symbol:str) -> int:
        return 0 if symbol == self.symbol0 else 1
    def __num2symbol(self, num:int) -> str:
        return self.symbol0 if num == 0 else self.symbol1
    def __parity(self, series:str, place:int) -> str:
        assert self.__is_power_of_2(place), ("place must be power of 2")
        parity = 0
        for i in range(place,len(series)):
            if (i+1)&place != 0:
                parity += self.__symbol2num(series[i])
        return self.__num2symbol(parity%2)
    def __add_check_bits(self, series:str) -> str:
        # 2の累乗の位置に0を挿入
        result = ""
        i = 0
        while i < len(series):
            if self.__is_power_of_2(len(result)+1):
                result += self.symbol0
            else:
                result += series[i]
                i += 1
        # 2の累乗の位置をパリティビットに置き換え
        i = 1
        while i < len(result):
            result = result[:i-1] + self.__parity(result, i) + result[i:]            
            i *= 2
        return result
    def __remove_check_bits(self, series:str) -> str:
        # パリティビットを削除
        result = ""
        for i in range(len(series)):
            if not self.__is_power_of_2(i+1):
                result += series[i]
        return result
    def __calc_syndrome(self, series:str) -> bool:
        syndrome = 0
        i = 1
        while i < len(series):
            if series[i-1] != self.__parity(series, i):
                syndrome += i
            i *= 2
        return syndrome
    def __correct_error(self, series:str) -> str:
        syndrome = self.__calc_syndrome(series)
        if syndrome == 0:
            return series
        else:
            if 0 <= syndrome-1 and syndrome-1 < len(series):
                return self.__flip_bit(series, syndrome-1)
            else:
                return self.fail_str
    def __flip_bit(self, series:str, place:int) -> str:
        result = ""
        for i in range(len(series)):
            if i == place:
                result += self.__num2symbol(1-self.__symbol2num(series[i]))
            else:
                result += series[i]
        return result
    
# 使用例・テスト
if __name__ == "__main__":
    series = "11110111011"
    print("元の系列:", series)
    hamming = HammingCoder()
    encoded_series = hamming.encode(series)
    print("符号化:", encoded_series)
    for i in range(len(encoded_series)):
        flip_char = str(1-int(encoded_series[i]))
        error_series = encoded_series[:i] + flip_char + encoded_series[i+1:]
        print("誤りを含む系列:", error_series)
        decoded_series = hamming.decode(error_series)
        print("復号:", decoded_series)

少し複雑ですが、方針は単純です。インスタンス生成時には「0」「1」に対応する記号および受信失敗時の出力を必要に応じて指定します。encode時にはadd_check_bitsメソッドで「2の累乗」番目の位置に一旦0を挿入します。そして、番号を2進数表記したときの各桁について、その桁に1が立つすべての位置を見たときに、1の個数が偶数になるようにパリティビットを変更します。
例えば7までの2進数表記は下のようになります。このうち1,2,4が検査ビット、残りは情報ビットです。例えば2番目の位置にある検査ビットの値は、2,3,6,7番目の中での1の個数が偶数になるように決めることになります。

1 2 3 4 5 6 7
0 0 0 1 1 1 1
0 1 1 0 0 1 1
1 0 1 0 1 0 1

decode時には、まずcalc_syndromeメソッドによりシンドロームを計算します。例えば1番目と4番目の検査ビットがともに誤った値になっていた場合、シンドロームの値は1+4で5となります。シンドロームは0から7までの値を取り、0であれば誤りなし、1~7のときは単一誤りを仮定するならばそのシンドロームの値の位置が誤っているということになります。
このことを利用してcorrect_errorメソッドで誤りを訂正し、最後にremove_check_bitsメソッドで検査ビットを取り除いて元の記号列を得ます。

巡回符号

最後に巡回符号による誤り検出について説明します。

class CyclicCoder(Coder):
    def __init__(self, cyclic_polynomial:int, cycle:int = None, error_correct:bool = True, symbol0:str='0', symbol1:str='1', fail_str:str = "FAIL") -> None:
        self.fail_str = fail_str
        self.symbol0 = symbol0
        self.symbol1 = symbol1
        self.cyclic_polynomial = cyclic_polynomial
        self.error_correct = error_correct
        self.parity_length = self.cyclic_polynomial.bit_length() - 1
        if cycle is None:
            self.cycle = self.__calc_cycle()
        else:
            self.cycle = cycle
        self.valid_info_length = self.cycle - self.parity_length
        if self.error_correct:
                self.remainder_dict = self.__build_remainder_dict()
        print("周期:", self.cycle)
    def __calc_cycle(self) -> int:  
        cycle = 1
        while True:
            gx = (1 << cycle) + 1
            if self.__calc_remainder(gx) == 0:
                break
            cycle += 1
        return cycle
    def __build_remainder_dict(self) -> List[int]:
        remainder_dict = {}
        for place in range(self.cycle):
            num = (1 << place)
            remainder_dict[self.__calc_remainder(num)]=place
        return remainder_dict
    def encode(self, series:str) -> str:
        if len(series) > self.valid_info_length:
            print("Warning: length of series is longer than valid_info_length")
        dividend = self.__series2bin(series) << self.parity_length
        return series + self.__bin2series(self.__calc_remainder(dividend), self.parity_length)
    def decode(self, series:str) -> str:
        dividend = self.__series2bin(series)
        remainder = self.__calc_remainder(dividend)
        if remainder != 0:
            if not self.error_correct:
                return self.fail_str
            if remainder not in self.remainder_dict:
                return self.fail_str
            error_place = self.remainder_dict[remainder]
            series = self.__flip_bit(series, len(series) - error_place - 1)
        return series[:-self.parity_length]
    def __symbol2num(self, symbol:str) -> int:
        return 0 if symbol == self.symbol0 else 1
    def __num2symbol(self, num:int) -> str:
        return self.symbol0 if num == 0 else self.symbol1
    def __series2bin(self, series:str) -> int:
        num = 0
        for i in range(len(series)):
            num += self.__symbol2num(series[i])*(2**(len(series)-i-1))
        return num
    def __bin2series(self, num:int, max_len:int) -> str:
        series = ""
        while num > 0:
            series = self.__num2symbol(num%2) + series
            num //= 2
        while len(series) < max_len:
            series = self.symbol0 + series
        assert len(series) == max_len, ("length of series must be max_len")
        return series
    def __calc_remainder(self, dividend:int) -> int:
        divisor = self.cyclic_polynomial
        quotient, remainder = 0, dividend
        shift = remainder.bit_length() - divisor.bit_length()
        while shift >= 0:
            quotient ^= (1 << shift)
            remainder ^= (divisor << shift)
            shift = remainder.bit_length() - divisor.bit_length()
        return remainder
    def __flip_bit(self, series:str, place:int) -> str:
        assert 0 <= place and place < len(series), ("place must be between 0 and len(series)")
        result = ""
        for i in range(len(series)):
            if i == place:
                result += self.__num2symbol(1-self.__symbol2num(series[i]))
            else:
                result += series[i]
        return result


# 使用例・テスト
if __name__ == "__main__":
    series = "1100"
    print("元の系列:", series)
    cyclic = CyclicCoder(0b1101)
    encoded_series = cyclic.encode(series)
    print("符号化:", encoded_series)
    for i in range(len(encoded_series)):
        flip_char = str(1-int(encoded_series[i]))
        error_series = encoded_series[:i] + flip_char + encoded_series[i+1:]
        print("誤りを含む系列:", error_series)
        decoded_series = cyclic.decode(error_series)
        print("復号:", decoded_series)

コンストラクタには「0」「1」に対応する記号の他に、巡回多項式と単一誤り訂正を行うか否かを指定します。単一誤り訂正を敢えて行わないことにより、誤りを検出しやすくすることが可能になります。なお、巡回多項式は2進数の数値として与えることとしています。
巡回多項式に原始多項式を与えれば、巡回ハミング符号とすることも可能です。
検査ビットの長さは巡回多項式の次数となりencode時には入力系列を検査ビット長だけ左シフトしたものを巡回多項式で割った余りを検査ビット長として入力系列の末尾に加えています。またdecode時には受信系列を再び巡回多項式で割り、この余りが0となるかどうかで誤りがあるかを判定しています。
また予め各iに対してx^iを巡回多項式で割った余りとiの対応を辞書型で記録しておくことにより、クエリごとの単一誤りの訂正を高速に行うことができます。
ここではシンプルな実装を行うことに重きを置き最適化等を行っていないため、周期の計算や対応関係の辞書の構築などにはO(n^2)の時間がかかります。このため、例えば周期が32767となるCRC-16-CCITT(x^{16}+x^{12}+x^5+1)などを使った誤り訂正を行おうとすると、最初のインスタンス生成時にかなり長い時間がかかることに注意してください。

まとめ

通信全体のシミュレーション

最後に、この記事内で作った情報源、符号化、通信路等を用いて、通信の流れを疑似的に再現してみましょう。

class Communication():
    def __init__(self, source:Source, encoding:Coder, channel:Channel, checker:Coder, fail_str:str = 'FAIL') -> None:
        self.source = source
        self.encoding = encoding
        self.channel = channel
        self.checker = checker
        self.fail_str = fail_str
    def steps(self, length:int) -> bool:
        input_series = ""
        for i in range(length):
            input_series += self.source.step()
        encoding_series = self.encoding.encode(input_series)
        check_encoding_series = self.checker.encode(encoding_series)
        received_series = self.channel.passage(check_encoding_series)
        checked_series = self.checker.decode(received_series)

        if checked_series == self.fail_str:
            return "Error Detected"
        decoded_series = self.encoding.decode(checked_series)
        if input_series == decoded_series:
            return "Success"
        else:
            return "Fail"
#使用例
distribution = {'a':0.6,'b':0.3,'c':0.1}
source = StaticSource(distribution)
encoding = Huffman(distribution)
channel = BinarySymmetricChannel(0.01)
checker = HammingCoder()
communication = Communication(source, encoding, channel, checker)

error_detect = 0
success = 0
fail = 0
length = 20
times = 100
for i in range(times):
    result = communication.steps(length)
    if result == "Error Detected":
        error_detect += 1
    elif result == "Success":
        success += 1
    elif result == "Fail":
        fail += 1
print(f"success rate:{success/times}")
print(f"error detect rate:{error_detect/times}")
print(f"fail rate:{fail/times}")

コンストラクタには情報源、符号化、通信路、誤り検査のそれぞれに対応するクラスを指定します。stepメソッドを長さを指定して呼ぶと、その長さの系列を情報源に出力させ、これを符号化し、誤り検査のための冗長化を行った上で通信路に通し、復号によって元の系列を得ています。
これによって各符号化の平均符号長や通信路誤りの発見率などを様々に状況を変えながらシミュレートすることができます。ここまでで紹介した事柄を踏まえながら様々にアレンジしてみてください。

拙い文章とコードでしたが、ここまで読んでいただきましてありがとうございました。
バグなどを見つけた際にはご指摘いただけると幸いです。

参考文献

「基礎から学ぶ情報理論第2版」ムイスリ出版(2020) 著 中村篤祥・喜田拓也・湊真一・廣瀬善大

Discussion