PythonのSynthetic Data Vault (SDV)ライブラリで時系列データを生成してみる

PythonのSynthetic Data Vault (SDV)ライブラリで時系列データを生成してみる

Mediumの新着記事を眺めていたら気になるタイトルがありました。


Synthetic Data Vault(SDV)という、統計モデルや機械学習モデルを使ってデータセットをモデリングし、合成データを生成できるPythonライブラリがあるとのことです。

ということで、今回はこのSDVライブラリを試してみます。
公式のチュートリアルを参考に時系列データセットでモデリングし、合成データを生成します。
なお、実装したコードはGitHubに上げています。

※20201218:弊社ブログにて『合成データがモデル構築をよりオープンにする〜MLタスクでのSDVによる合成データの有効性を検証する』という記事を掲載しています。

SDVとは

Synthetic Data Vault(SDV)は、リレーショナル・データベースの生成モデルを構築するシステムのことです。
※詳しい説明はSDVシステムを提案した原論文をご参照ください。
ここで言うリレーショナルデータベースは、実体関連モデル(entity-relationship Model)やER図のイメージで良いと思います。
要は、データ間の関係性も考慮した上で合成データを自動的に生成してくれるシステムということです。

SDVで合成データを生成するメリットは、手元に実データが少量しかない場合でも、本番相当のデータをいくつでも合成できるということです。
例えばシステムテストで負荷試験をする際に、テキトーに水増しした非現実的なデータでテストするよりも、本番と同等のデータセットで検証する方が良いです。

SDVによるデータセットのモデリングは以下の4ステップとなります。

※原論文のfig1より引用

  1. Organize
    DBのデータをテーブルごとに別ファイルにフォーマットする。
  2. Specify Structure
    DBのメタデータを指定する。
  3. Learn Model
    テーブル間の関係を考慮してモデリングする。
  4. Synthesize Data
    fitしたモデルをもとに合成データを得る。

PythonのSDVパッケージでモデリングできるデータセットの種類はSingle Table Data(テーブルデータ)、Relational Data(リレーショナル・データベース)、Timeseries Data(時系列データ)の3つですが、今回は時系列データを対象にモデリングをしてみます。

SDVで合成データを生成してみる

PythonのSDVパッケージを使って、時系列データセットをもとにモデリングし、合成データを生成していきます。

データの読み込み

まず、SDVライブラリがデフォルトで保持している2019年のNASDAQ100の株価データセットをロードし、データを確認します。
INPUT:

from sdv.demo import load_timeseries_demo

# sdvライブラリの時系列データのデモからデータセットをロードする
data = load_timeseries_demo()
data.head()

OUTPUT:

SymbolDateOpenCloseVolumeMarketCapSectorIndustry
0AAPL2018/12/3139.632539.4350011400140007.38E+11TechnologyComputer Manufacturing
1AAPL2019/1/238.722539.481481588007.38E+11TechnologyComputer Manufacturing
2AAPL2019/1/335.99499935.5475013652488007.38E+11TechnologyComputer Manufacturing
3AAPL2019/1/436.132537.0649992344284007.38E+11TechnologyComputer Manufacturing
4AAPL2019/1/737.17499936.9824982191112007.38E+11TechnologyComputer Manufacturing

ここで、各カラムは以下を意味しています。

カラム名説明
Symbol銘柄コード
Date日付
Open始値
Close終値
Volume取引量
MarketCap企業の時価総額
Sector事業を展開している業種
Industry事業を展開している業界

次に、データセットのサイズと、銘柄数を確認します。
INPUT:

# データ量の確認
print(f'data size: {data.shape}')
print(f'銘柄数:{len(set(data.Symbol))}')

OUTPUT:

data size: (25784, 8)
銘柄数:103

103の銘柄を一気に可視化しても見にくいので、ランダムに10銘柄をサンプリングしたうえで、Open(始値)のチャートを可視化します。
INPUT:

import random
random.seed(0)

# 銘柄をサンプリングして可視化する
N = 10
symbol_sample = random.sample(list(set(data.Symbol)), N)

# 開始値の可視化
fig = plt.figure(figsize=(20,10))
ax = fig.add_subplot(111)
for i, v in data.query('Symbol in @symbol_sample').groupby('Symbol'):
    ax.plot(v.Date, v.Open, label=i)

ax.legend()
ax.set_title('stock chart [open]')

OUTPUT:

どの銘柄についても短期的には増減はあるものの、全体的には始値は増加トレンドにあるように見えます。

PARモデルでモデリング

SDVで時系列データをモデリングするにはPARクラスを使用します。
PARはProbabilistic AutoRegressive model:確率的自己回帰モデルの略称です。
PARモデルの詳細については『Adversarial Attacks on Probabilistic Autoregressive Forecasting Models』という論文をご参照ください。

まず、Entityカラム、Contextカラム、Sequence Indexを定義します。

  • Entityカラム
    行間に依存関係が存在するグループです。このデータセットでは、銘柄ごとに行間(日付)に依存関係があるので、銘柄(Symbol)をEntityカラムとします。
  • Contextカラム
    Entityに関する属性情報を保持する変数です。このデータセットでは、時価総額(MarketCap)、業種(Sector)、業界(Industry)がContextに該当します。
  • Sequence Index
    行間の依存関係において、順序が意味をもつような変数です。このデータセットでは、日付(Date)がSequence Indexに該当します。

INPUT:

# Entityカラムの定義
entity_columns = ['Symbol']
# Contextカラムの定義
context_columns = ['MarketCap', 'Sector', 'Industry']
# Sequence Indexの定義
sequence_index = 'Date'

次に、定義したEntityカラム、Contextカラム、Sequence Indexをもとに時系列データセットをPARモデルでモデリングします。
なお、ローカルのCPUで全量データをモデリングすると時間がかかりすぎるため、記事冒頭でサンプリングした10銘柄を対象にモデリングしています。

# ローカルのCPU環境で検証するため、データ量を削減
from sdv.timeseries import PAR

model = PAR(
    entity_columns=entity_columns,
    context_columns=context_columns,
    sequence_index=sequence_index,
)
model.fit(data.query('Symbol in @symbol_sample'))

合成データを生成

モデリング済みのモデルをもとに合成データを生成します。
試しに5シーケンス(5銘柄)を生成します。
INPUT:

# 5つシーケンス(銘柄)を生成する
seq_sample = model.sample(num_sequences=5)

print(f"""
サンプリング銘柄: {symbol_sample}
生成した銘柄: {list(set(seq_sample.Symbol))}
""")

OUTPUT:

サンプリング銘柄: ['FB', 'INTC', 'ADBE', 'XEL', 'TCOM', 'PDD', 'KLAC', 'ADSK', 'CSCO', 'MSFT']
生成した銘柄: ['rQNFBQUE', 'nYCBW', 'yCVUONVUD', 'oZZRLZTB', 'iZQPFTOZBYM']

どうやら、生成した銘柄名はランダムな文字列のようです。

では、生成した5銘柄と、モデリングに使用した10銘柄の分布を比較してみましょう。

import numpy as np

# 可視化で始値の軸を揃えるため
open_max = np.max([data.query('Symbol in @symbol_sample').Open.max(), seq_sample.Open.max()])
open_min = np.min([data.query('Symbol in @symbol_sample').Open.min(), seq_sample.Open.min()])
# 擬似生成したシーケンス
fig = plt.figure(figsize=(20,10))

ax1 = fig.add_subplot(211)
for i, v in data.query('Symbol in @symbol_sample').groupby('Symbol'):
    ax1.plot(v.Date, v.Open, label=i)
ax1.set_ylim(open_min, open_max)
ax1.legend()
ax1.set_title('stock chart [original]')

ax2 = fig.add_subplot(212)
for i, v in seq_sample.groupby('Symbol'):
    ax2.plot(v.Date, v.Open, label=i)
ax2.set_ylim(open_min, open_max)
ax2.legend()
ax2.set_title('stock chart [synthetic data]')

OUTPUT:

上がモデリングに使用した10銘柄の分布で、下が生成した5銘柄の分布になります。
このグラフの情報だけでは生成した5銘柄の分布の良し悪しは一概に言えませんが、一応それっぽい(実際にありそうな)分布にはなっていそうです。
仮に独自に時系列データセットを生成したとしても、このように自然な分布にはならないと思います。

最後に

今回はSDVライブラリの時系列データセットに対するモデリングを紹介しましたが、これ以外にもテーブルデータやリレーショナル・データベースのデータセットのモデリングに対応してるので、今度はそちらも試してみようと思います。
また、SDVライブラリでのモデリングはかなり処理負荷がかかります。
今回はローカルのCPU環境で検証しましたが、CPU環境ではデモ用の時系列データセットでさえサンプリングしないとモデリングに時間がかかり過ぎてしまいます。
SDVライブラリはGPU環境でのモデリングにも対応しているので、可能ならGPU環境でモデリングした方が良さそうです。

参考