🙆‍♀️

2023/04/20に公開

より論文実装に近い形で実装していますが、大枠は参考文献のqiita記事を参考にしました。

## 使い方

1. ライブラリのインポート
``````import os
import random
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from tradercompany.activation_funcs import identity, ReLU, sign, tanh

%matplotlib inline

SEED = 2021
def fix_all_seeds(seed):
np.random.seed(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)

fix_all_seeds(SEED)
``````
1. データ準備

pandas.DataFrame形式の時系列`df_y`を準備します。訓練用のデータと検証用に分割します。

``````def create_dataset(T, sigma_t):
def simulation(y_t, sigma):
y_t1 = np.zeros(2)
y_t1[0] = 1.0*tanh(y_t[0]) + 0.8*y_t[0]*y_t[1] + 1.0*y_t[1] - 1.0*ReLU(min(y_t[0], y_t[1])) + sigma*np.random.randn()
y_t1[1] = +0.6*sign(y_t[1]) + 0.5*y_t[0]*y_t[1] - 1.0*max(y_t[0], y_t[1]) + sigma*np.random.randn()
return y_t1

y = np.zeros((2, T))
y_without_noise = np.zeros((2, T))
y[:,0] = np.array([0.1, 0.1])
y_without_noise[:,0] = np.array([0.1, 0.1])

for t in range(1, T):
y[:,t] = simulation(y[:,t-1], sigma_t)
y_without_noise[:,t] = simulation(y[:,t-1], 0.0)

plt.plot(y[0], color = "#cc0000", label = "stock0")
plt.plot(y[1], color = "#083090", label = "stock1")
plt.plot(y_without_noise[0], color = "#cc0000", linestyle = "--", label = "stock0" + "(w/o noise)")
plt.plot(y_without_noise[1], color = "#083090", linestyle = "--", label = "stock1" + "(w/o noise)")
plt.xlabel("time", fontsize = 18)
plt.ylabel("y", fontsize = 18)
plt.xlim([T-100, T])
plt.legend()
plt.show()
plt.close()

return y, y_without_noise

sigma = 0.1
T_total = 500
y, y_without_noise = create_dataset(T_total, sigma)

T_train = 800
df_y_train = df_y.iloc[:T_train, :]
df_y_test = df_y.iloc[T_train:, :]
``````
``````activation_funcs = [identity, ReLU, sign, tanh]
binary_operators = [max, min, add, diff, multiple, get_x, get_y, x_is_greater_than_y]
stock_names = ["stock0", "stock1"]
time_window = 200
delay_time_max = 2
num_factors_max = 4
``````
1. モデルを構築する
``````model = Company(stock_names,
num_factors_max,
delay_time_max,
activation_funcs,
binary_operators,
Q=0.2,
time_window=time_window,
how_recruit="random")
``````
1. 学習する
``````model.fit(df_y_train)
``````
1. モデルの保存
``````with open("model.pkl", "wb") as f:
pickle.dump(model, f)
``````
1. 次の時刻の予測
``````# 時刻t+1の予測
model.aggregate()
``````

8-1. 検証用データに対する予測(tuningなし)

``````with open("model.pkl", "rb") as f:

errors_test_notuning = []
for i, row in df_y_test.iterrows():
prediction_test = model.aggregate()
errors_test_notuning.append(np.abs(row.values - prediction_test))

# tuning==Falseの場合、データが追加されても重みの更新などパラメータは変わらない
model.fit_new_data(row.to_dict(), tuning=False)
``````

8-2. 検証用データに対する予測(tuningあり)

``````with open("model.pkl", "rb") as f:

errors_test_tuning = []
for i, row in df_y_test.iterrows():
prediction_test = model.aggregate()
errors_test_tuning.append(np.abs(row.values - prediction_test))

# tuning==Trueの場合、データが追加された際に重みの更新などパラメータが調整される
model.fit_new_data(row.to_dict(), tuning=True)
``````
1. 精度比較

また、オンライン学習をしたモデル(tuningあり)の方が僅かに精度が良いこともわかります。

``````days_ma = 5

errors_test_notuning = np.array(errors_test_notuning)
errors_test_notuning_ma = pd.DataFrame(errors_test_notuning).rolling(days_ma).mean()

errors_test_tuning = np.array(errors_test_tuning)
errors_test_tuning_ma = pd.DataFrame(errors_test_tuning).rolling(days_ma).mean()

# baseline method
errors_baseline = np.abs(y[:,T_train+1:] - y[:,T_train:-1])
errors_baseline_ma = pd.DataFrame(errors_baseline.T).rolling(days_ma).mean()

# lower bound
errors_lower_bound = np.abs(y[:,T_train+1:] - y_without_noise[:,T_train+1:])
errors_lower_bound_ma = pd.DataFrame(errors_lower_bound.T).rolling(days_ma).mean()

for i_stock, name in enumerate(stock_names):
print(name)
plt.plot(errors_baseline_ma[i_stock], label="baseline")
plt.plot(errors_lower_bound_ma[i_stock], label="lower-bound")
plt.xlabel("time")
plt.ylabel("mean average error")
plt.legend()
plt.show()

for i_stock, name in enumerate(stock_names):
print(name)
print("baseline", errors_baseline[i_stock].mean())
print("lower bound", errors_lower_bound[i_stock].mean())
``````

[f:id:yamayou_1:20211120105831p:plain]

[f:id:yamayou_1:20211120105846p:plain]

1. モデルの解釈

``````num_stock = len(stock_names)

print(stock_names[0])

print(stock_names[1])
``````

[f:id:yamayou_1:20211120110134p:plain]

# 最後に

## 追記

この実装で現在Kaggkeで行われている仮想通貨コンペに適用してみたのですが、データ量や変数の数が多い（100とか）とかなり速度が遅いです。

ログインするとコメントできます