[Python] 標準モジュールのみでキーボードイベントをハンドルする

2022/12/07に公開

概要

コンソールにおけるキーボードイベント (キー押下イベント) をハンドルする Python モジュールを標準モジュールのみで作成する. キー入力の処理は OS 依存だが, 本記事では Windows のみ考える.

なお, 実装については TUI フレームワークの Textual を参考にした. ソースの一部はここから複製・改変されているため, 注記のあるものには以下のライセンスが適用される.

LICENSE

Copyright (c) 2021 Riku Kanayama (k_kuroguro)
Released under the MIT license
https://github.com/k-kuroguro/zenn-docs/blob/master/MIT_LICENSE

The original copyright notice for Will McGugan:

Copyright (c) 2021 Will McGugan
Released under the MIT license
https://github.com/Textualize/textual/blob/main/LICENSE

サンプルコードは次のリポジトリに載せておく.
https://github.com/k-kuroguro/py-cli-keyboard

環境

  • Windows 11
  • Python 3.9.7

構成

大まかな流れ

  1. clikeyboard モジュールが KeyEvent クラス, Keys 列挙体と on_press 関数を提供する.
  2. 呼び出し側は, 引数に対象のキー (文字列または Keys 列挙体) とハンドラを渡して, on_press 関数を実行する.
  3. 対象のキー押下イベントが発生すると, ハンドラが起動される.
from clikeyboard import KeyEvent, Keys, on_press

def on_event(event: KeyEvent) -> None:
    print(f'Pressed "{event.key}"')

on_press('a', on_event)
on_press(Keys.CONTROL_D, on_event)

while 1:
    pass

ファイル構成

モジュールのファイル構成を以下に示す.

clikeyboard
   ├─ event.py
   ├─ keys.py
   ├─ _ansi_sequences.py
   ├─ _listener.py
   ├─ _parser.py
   ├─ _windows.py
   └─ __init__.py

それぞれの責務は

  • event
    キーボードイベントを表す KeyEvent クラスを定義する.
  • keys
    特殊キー (e.g. ctrl+c, esc) を表す Keys 列挙体を定義する.
    (なお, それ以外の文字・数字・記号キーは, 文字列 (e.g. 'a', '1') で対応する.)
  • _ansi_sequences
    ANSI エスケープシーケンスと Keys 列挙体の対応表を持つ.
  • _listener
    それぞれキー入力の読取とイベントハンドラの起動を行う2つのスレッドを管理する.
  • _parser
    string 型のキー入力のうち, 特殊キーを Keys 列挙体に変換する.
  • _windows
    Windows 依存のキー入力処理を行う.
  • __init__
    ユーザがハンドルの追加を行える on_press 関数を提供する.

となっている.

キー入力の処理

本プログラムで必要なキー入力に関する OS 依存の処理は, 以下の2つである.

  • 仮想ターミナルシーケンスの有効化.
  • コンソールからの入力読取.

どちらも Windows のコンソール API を使用する. 各定数・構造体・関数の詳細は Microsoft Docs を参照のこと.

仮想ターミナルシーケンスの有効化

仮想ターミナルシーケンス (ANSI エスケープシーケンス) は, 多くのプラットフォームでサポートされているコンソール制御の標準である. 有効化すれば, 入力がこのシーケンスに変換されるため, 使用しているターミナルの種類を問わずに特殊キーを処理できるようになる.

シーケンスの有効化は, SetConsoleMode 関数で行う. この関数はコンソールの設定を書き換えるため, GetConsoleMode 関数で元の設定を取得し, 保存しておく必要がある.

import msvcrt
import sys
from ctypes import WinDLL, byref
from ctypes.wintypes import DWORD
from typing import IO, Any, Callable

# Adapted from Textual.
# See above LICENSE.

_ENABLE_VIRTUAL_TERMINAL_INPUT = 0x0200

_kernel32 = WinDLL('kernel32', use_last_error=True)

_SetConsoleMode = _kernel32.SetConsoleMode
_GetConsoleMode = _kernel32.GetConsoleMode

def _set_console_mode(file: IO[Any], mode: int) -> bool:
    '''Set the console mode for a given file (stdout or stdin).'''

    windows_filehandle = msvcrt.get_osfhandle(file.fileno())
    success: bool = _SetConsoleMode(windows_filehandle, mode)
    return success

def _get_console_mode(file: IO[Any]) -> int:
    '''Get the console mode for a given file (stdout or stdin).'''

    windows_filehandle = msvcrt.get_osfhandle(file.fileno())
    mode = DWORD()
    _GetConsoleMode(windows_filehandle, byref(mode))
    return mode.value

def enable_virtual_terminal_sequences() -> Callable[[], None]:
    terminal_in = sys.stdin
    terminal_out = sys.stdout

    current_console_mode_in = _get_console_mode(terminal_in)
    current_console_mode_out = _get_console_mode(terminal_out)

    def restore() -> None:
        '''Restore console mode to previous settings.'''

        _set_console_mode(terminal_in, current_console_mode_in)
        _set_console_mode(terminal_out, current_console_mode_out)

    _set_console_mode(terminal_in, current_console_mode_in | _ENABLE_VIRTUAL_TERMINAL_INPUT)
    return restore

コンソールからの入力読取

コンソールからの入力は, ReadConsoleInput 関数で読み取れる. このとき, 標準入力がシグナル状態になるまで待機するために, WaitForMultipleObjects 関数を利用する.

from ctypes import Structure, Union, WinDLL, byref
from ctypes.wintypes import BOOL, CHAR, DWORD, HANDLE, SHORT, UINT, WCHAR, WORD
from typing import Callable, Optional

# Adapted from Textual.
# See above LICENSE.

class COORD(Structure):
    _fields_ = [
        ('X', SHORT),
        ('Y', SHORT)
    ]

class uChar(Union):
    _fields_ = [
        ('UnicodeChar', WCHAR),
        ('AsciiChar', CHAR)
    ]

class KEY_EVENT_RECORD(Structure):
    _fields_ = [
        ('bKeyDown', BOOL),
        ('wRepeatCount', WORD),
        ('wVirtualKeyCode', WORD),
        ('wVirtualScanCode', WORD),
        ('uChar', uChar),
        ('dwControlKeyState', DWORD)
    ]

class MOUSE_EVENT_RECORD(Structure):
    _fields_ = [
        ('dwMousePosition', COORD),
        ('dwButtonState', DWORD),
        ('dwControlKeyState', DWORD),
        ('dwEventFlags', DWORD)
    ]

class WINDOW_BUFFER_SIZE_RECORD(Structure):
    _fields_ = [
        ('dwSize', COORD)
    ]

class MENU_EVENT_RECORD(Structure):
    _fields_ = [
        ('dwCommandId', UINT)
    ]

class FOCUS_EVENT_RECORD(Structure):
    _fields_ = [
        ('bSetFocus', BOOL)
    ]

class InputEvent(Union):
    _fields_ = [
        ('KeyEvent', KEY_EVENT_RECORD),
        ('MouseEvent', MOUSE_EVENT_RECORD),
        ('WindowBufferSizeEvent', WINDOW_BUFFER_SIZE_RECORD),
        ('MenuEvent', MENU_EVENT_RECORD),
        ('FocusEvent', FOCUS_EVENT_RECORD)
    ]

class INPUT_RECORD(Structure):
    _fields_ = [
        ('EventType', WORD),
        ('Event', InputEvent)
    ]

_WAIT_TIMEOUT = 0x00000102

_MAX_EVENTS = 1024
_KEY_EVENT = 0x0001

_kernel32 = WinDLL('kernel32', use_last_error=True)

_GetStdHandle = _kernel32.GetStdHandle
_GetStdHandle.argtypes = [DWORD]
_GetStdHandle.restype = HANDLE
_ReadConsoleInputW = _kernel32.ReadConsoleInputW
_WaitForMultipleObjects = _kernel32.WaitForMultipleObjects

_hIn = _GetStdHandle(_STD_INPUT_HANDLE)
_arrtype = INPUT_RECORD * _MAX_EVENTS
_input_records = _arrtype()
_read_count = DWORD(0)

def _wait_for_handles(handles: list[HANDLE], timeout: int = 0) -> Optional[HANDLE]:
    arrtype = HANDLE * len(handles)
    handle_array = arrtype(*handles)

    ret: int = _WaitForMultipleObjects(
        len(handle_array), handle_array, BOOL(False), DWORD(timeout)
    )

    if ret == _WAIT_TIMEOUT:
        return None
    else:
        return handles[ret]

def listen(on_press: Callable[[list[str]], None]) -> None:
    '''Listen key press events. '''

    if _wait_for_handles([_hIn], 200) is None:
        return

    _ReadConsoleInputW(
        _hIn, byref(_input_records), _MAX_EVENTS, byref(_read_count)
    )
    read_input_records = _input_records[:_read_count.value]

    del _keys[:]

    for input_record in read_input_records:
        event_type = input_record.EventType

        if event_type == _KEY_EVENT:
            key_event = input_record.Event.KeyEvent
            key = key_event.uChar.UnicodeChar
            if key_event.bKeyDown or key == '\x1b':
                _keys.append(key)

    handler(_keys[:])

キー入力を読み取る listen 関数は handler 引数を持ち, 読取後にこれを実行する. この引数にqueue.putなどを渡すことで, listen 関数が他スレッドへデータを送ることができるようにしている.

データクラス・列挙体の定義

キーボードイベントを表す KeyEvent クラス, 特殊キーを表す Keys 列挙体, ANSI シーケンスと Keys 列挙体の対応表である ANSI_SEQUENCES クラスを定義する.

event.py
class Event:
    __slots__ = ()

class KeyEvent(Event):
    '''Event with key press.'''

    __slots__ = ('_key',)

    def __init__(self, key: str) -> None:
        self._key: str = key

    @property
    def key(self) -> str:
        return self._key

KeyEvent クラスは大量に生成されることが考えられるため, slotsを利用して高速化を図っている.

keys.py
from enum import Enum

# Adapted from Textual.
# See above LICENSE.

class Keys(str, Enum):
    '''List of keys for use in key bindings.'''

    value: str

    ESCAPE = 'escape'
    SHIFT_ESCAPE = 'shift+escape'
   ()
    CONTROL_Z = 'ctrl+z'
_ansi_sequences.py
from .keys import Keys

# Adapted from Textual.
# See above LICENSE.

# Mapping of vt100 escape codes to Keys.
# https://docs.microsoft.com/en-us/windows/console/console-virtual-terminal-sequences
ANSI_SEQUENCES: dict[str, tuple[Keys, ...]] = {
    # Control keys.
    '\r': (Keys.ENTER,),
    '\x00': (Keys.CONTROL_AT,),
    ()
    '\x0c': (Keys.CONTROL_L,)
}

入力のパース処理

string 型の特殊キー入力を ANSI_SEQUENCES クラスを使用して, Keys 列挙体にパースする Parser クラスを定義する.

_parser.py
from typing import Generator

from .event import KeyEvent
from ._ansi_sequences import ANSI_SEQUENCES

_ESC = '\x1b'
_get_ansi_sequence = ANSI_SEQUENCES.get

class Parser():

    def parse(self, data: str) -> Generator[KeyEvent, None, None]:

        pos = 0
        data_size = len(data)

        def read1() -> str:
            nonlocal pos
            character = data[pos]
            pos += 1
            return character

        while pos < data_size:
            character = read1()

            if character == _ESC and pos != data_size:
                sequence: str = character
                while True:
                    sequence += read1()
                    keys = _get_ansi_sequence(sequence, None)
                    if keys is not None:
                        for key in keys:
                            yield KeyEvent(key=key)
                        break
            else:
                keys = _get_ansi_sequence(character, None)
                if keys is not None:
                    for key in keys:
                        yield KeyEvent(key=key)
                else:
                    yield KeyEvent(key=character)

特殊キーの入力の1文字目はエスケープ (\x1b) となる. したがって, 1文字ずつ確認し, エスケープがあれば対応表から変換するという処理を行っている.

リスナーを用意する

キー入力読み取りとハンドラの起動をマルチスレッドで行う Listener クラスを作る. 処理の内容は以下のようになる.

_listener.py
import atexit
import platform
from queue import Queue
from threading import Event, Lock, Thread
from typing import Callable, Optional

from .event import KeyEvent
from ._parser import Parser

if platform.system() == 'Windows':
    from ._windows import listen, enable_virtual_terminal_sequences
    _restore = enable_virtual_terminal_sequences()
    atexit.register(_restore)
else:
    raise NotImplementedError()

Handler = Callable[[KeyEvent], None]

class Listener:

    def __init__(self) -> None:
        self._running: bool = False
        self._handlers: list[Handler] = []
        self._queue: Queue[KeyEvent] = Queue()
        self._lock = Lock()
        self._stop_event = Event()
        self._listening_thread: Optional[Thread] = None
        self._processing_thread: Optional[Thread] = None

    def start(self) -> None:
        '''Start listener if it hasn't started.'''

        with self._lock:
            if self._running:
                return

            self._running = True
            self._stop_event.clear()

            self._listening_thread = Thread(target=self._listen)
            self._listening_thread.daemon = True
            self._listening_thread.start()

            self._processing_thread = Thread(target=self._process)
            self._processing_thread.daemon = True
            self._processing_thread.start()

    def stop(self) -> None:
        '''Stop running listener.'''

        with self._lock:
            if not self._running:
                return

            self._running = False
            self._stop_event.set()

            if self._listening_thread is not None:
                self._listening_thread.join()
            if self._processing_thread is not None:
                self._processing_thread.join()

            self._listening_thread = None
            self._processing_thread = None

    def add_handler(self, handler: Handler) -> None:
        with self._lock:
            self._handlers.append(handler)

    def remove_handler(self, handler: Handler) -> None:
        '''Remove all handlers from list.'''

        with self._lock:
            while handler in self._handlers:
                self._handlers.remove(handler)

    def _listen(self) -> None:
        parser = Parser()

        def put_queue(keys: list[str]) -> None:
            for event in parser.parse(''.join(keys)):
                self._queue.put(event)

        while not self._stop_requested():
            listen(put_queue)

    def _process(self) -> None:
        while not self._stop_requested():
            event = self._queue.get()
            self._invoke_handlers(event)
            self._queue.task_done()

    def _invoke_handlers(self, event: KeyEvent) -> None:
        with self._lock:
            for handler in self._handlers:
                handler(event)

    def _stop_requested(self) -> bool:
        return self._stop_event.is_set()

以下の2点に注意する.

  • atexit.register によって, プログラム終了時にコンソール設定がリストアされるようにすること.
  • Lockを使用して, スレッドセーフを保つこと.

on_press 関数を定義する

最後にユーザがハンドルを追加するための on_press 関数を定義する.

__init__.py
from typing import Callable

from ._listener import Handler, Listener
from .keys import Keys

_listener = Listener()

def on_press(key: str, handler: Handler) -> Callable[[], None]:
    '''
    Invoke handler when pressed key.

    Args:
        key (str): String indicating key.
        handler (Handler): Handler to invoke.

    Returns:
        Callable[[], None]: A callable for removing handler.
    '''

    _listener.start()

    handler_checking_key: Handler = lambda event: (key in [Keys.ANY, event.key] or None) and handler(event)
    _listener.add_handler(handler_checking_key)

    def remove() -> None:
        _listener.remove_handler(handler_checking_key)

    return remove

追加したハンドラを削除するための remove 関数を返すようにしている.

動作確認

main.py
from clikeyboard import KeyEvent, Keys, on_press

def on_event(event: KeyEvent) -> None:
    print(f'Pressed "{event.key}"')

on_press(Keys.ANY, on_event)

while 1:
    pass
$ pip install git+https://github.com/k-kuroguro/py-cli-keyboard.git
$ python main.py

この通り実行すれば, 押したキーに対応する文字列が表示されるはずである.

GitHubで編集を提案

Discussion