pyyamlで引数を保存する方法と読み込み方に関して

argparseで読み込んだ引数をpyyamlで保存

import argparse
import yaml
from datetime import datetime
from dateutil import tz

parser = argparse.ArgumentParser(description='there are arguments for time series forecasting')

parser.add_argument('--n_steps', type=int, help='how many steps ahead do we predict?')
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--architecture', type=str, default='bidirectional', help='model architecture. "bidirectional": BLSTM, "unidirectional": LSTM.')

args = parser.parse_args()

jst = tz.gettz('Asia/Tokyo')
time_datetime = datetime.now(tz=jst)
time_str = time_datetime.strftime('%Y%m%d_%H%M%S')
dir_config = '../../reports/config'
os.makedirs(dir_config, exist_ok=True)
with open('{}/{}_config.yaml'.format(dir_config, time_str), 'w') as f:
    yaml.dump(args, f, allow_unicode=True)

yamlの読み込み

pyyamlで保存することに関しては特に問題は発生しないが、読み込むときにエラーが頻発する。下記に筆者がはまったポイントとその対処方法について述べる

はまったポイント

yaml.safe_loadとyaml.loadを使って読み込んだら・・・

yaml.constructor.ConstructorError: could not determine a constructor for the tag 'tag:yaml.org,2002:python/object:argparse.Namespace’

のエラーが発生

obj['x’]で読み込んだら・・・

TypeError: 'Namespace’ object is not subscriptable

のエラーが発生

対処法

yaml.unsafe_loadを使う

obj.xの形で読み込む

最終的なyaml読み込みコード

# Fetch names of model and args from config
yaml_path = '../../reports/config/{}'.format(yaml_name)
with open(yaml_path) as f:
    obj = yaml.unsafe_load(f)
epoch = obj.epoch

Python

Posted by vastee