0.13.4 以降で wandb.integration.keras モジュールから利用できます。
W&B Keras インテグレーションでは、次のコールバックを提供しています。
WandbMetricsLogger: 実験管理 にはこのコールバックを使用します。トレーニングおよび検証のメトリクスに加えて、システム メトリクスも W&B にログします。WandbModelCheckpoint: モデル チェックポイントを W&B Artifacts にログするには、このコールバックを使用します。WandbEvalCallback: このベース コールバックは、モデルの予測を W&B Tables にログし、インタラクティブに可視化できるようにします.
Keras インテグレーションのインストールとインポート
wandb.integration.keras から必須クラスをインポートします:
WandbMetricsLogger で実験を管理する
wandb.integration.keras.WandbMetricsLogger() は、on_epoch_end、on_batch_end などのコールバックが引数として受け取る Keras の logs 辞書を自動的にログします。
以下の抜粋例では、Keras の workflow で WandbMetricsLogger() を使用する方法を示します。まず、使用したい optimizer、損失関数、メトリクスを指定してモデルをコンパイルします。次に、wandb.init() を使用して W&B run を初期化します。最後に、WandbMetricsLogger() コールバックを model.fit() に渡します。
loss、accuracy、top@5_accuracy などのトレーニングおよび検証メトリクスを W&B にログします。さらに、以下の内容もログします。
WandbMetricsLogger リファレンス
| パラメーター | 説明 |
|---|---|
log_freq | (epoch、batch、または int): epoch の場合は各エポックの終了時にメトリクスをログします。batch の場合は各バッチの終了時にメトリクスをログします。int の場合は指定したバッチ数ごとにメトリクスをログします。デフォルトは epoch です。 |
initial_global_step | (int): initial_epoch からトレーニングを再開し、学習率スケジューラを使用している場合に、学習率を正しくログするためにこの引数を使用します。これは step_size * initial_step として計算できます。デフォルトは 0 です。 |
WandbModelCheckpoint を使用してモデルをチェックポイントする
WandbModelCheckpoint コールバックを使用すると、Keras モデル (SavedModel 形式) またはモデルの重みを定期的に保存し、モデルのバージョン管理のために wandb.Artifact として W&B にアップロードできます。
このコールバックは tf.keras.callbacks.ModelCheckpoint() のサブクラスであるため、チェックポイントのロジックは親コールバックによって処理されます。
このコールバックでは、次を保存できます。
- monitor に基づいて最高のパフォーマンスを達成したモデル
- パフォーマンスに関係なく、各エポックの終了時のモデル
- エポックの終了時、または一定数のトレーニングバッチごとのモデル
- モデルの重みのみ、またはモデル全体
SavedModel形式または.h5形式のモデル
WandbMetricsLogger() と併用してください。
WandbModelCheckpoint リファレンス
| パラメーター | 説明 | |
|---|---|---|
filepath | (str): モデルファイルを保存するパス。 | |
monitor | (str): 監視するメトリクスの名前。 | |
verbose | (int): 詳細表示モード。0 または 1。モード 0 ではメッセージを表示せず、モード 1 ではコールバックがアクションを実行した際にメッセージを表示します。 | |
save_best_only | (Boolean): save_best_only=True の場合、monitor 属性と mode 属性で定義された内容に基づき、最新のモデル、または最良と判断されたモデルのみを保存します。 | |
save_weights_only | (Boolean): True の場合、モデルの重みのみを保存します。 | |
mode | (auto, min, or max): val_acc の場合は max、val_loss の場合は min に設定します。 | |
save_freq | (“epoch” or int): ‘epoch’ を使用すると、コールバックは各エポックの終了後にモデルを保存します。整数を使用すると、その数のバッチの終了時にコールバックがモデルを保存します。val_acc や val_loss などの検証メトリクスを監視する場合、これらのメトリクスはエポックの終了時にのみ利用できるため、save_freq は “epoch” に設定する必要があります。 | |
options | (str): save_weights_only が true の場合は省略可能な tf.train.CheckpointOptions オブジェクト、save_weights_only が false の場合は省略可能な tf.saved_model.SaveOptions オブジェクト。 | |
initial_value_threshold | (float): 監視するメトリクスの初期「最良」値となる浮動小数点数。 |
N エポックごとにチェックポイントをログする
save_freq="epoch") 、コールバックは各エポックの後にチェックポイントを作成し、それを artifact としてアップロードします。特定のバッチ数ごとにチェックポイントを作成するには、save_freq を整数に設定します。N エポックごとにチェックポイントを作成するには、train データローダーの要素数を計算して、それを save_freq に渡します。
TPU アーキテクチャでチェックポイントを効率的にログする
UnimplementedError: File system scheme '[local]' not implemented というエラーメッセージが表示されることがあります。これは、モデルディレクトリ (filepath) にはクラウドストレージのバケットパス (gs://bucket-name/...) を使用する必要があり、さらにそのバケットに TPU サーバーからアクセスできなければならないためです。一方、W&B ではチェックポイント作成にローカルパスを使用し、その後 artifact としてアップロードします。
WandbEvalCallback を使用してモデルの予測を可視化する
WandbEvalCallback() は、主にモデル予測、次いでデータセットの可視化を目的とした Keras コールバックを構築するための抽象基底クラスです。
この抽象コールバックは、データセットやタスクに依存しません。これを使用するには、基底コールバッククラス WandbEvalCallback() を継承し、add_ground_truth メソッドと add_model_prediction メソッドを実装します。
WandbEvalCallback() は、次の機能を提供するユーティリティクラスです。
- データと予測用の
wandb.Table()インスタンスを作成する。 - データと予測のテーブルを
wandb.Artifact()としてログする。 on_train_beginでデータテーブルをログする。on_epoch_endで予測テーブルをログする。
WandbClfEvalCallback を使用します。このコールバックは、検証データ (data_table) を W&B にログし、推論を実行して、各エポックの終了時に予測結果 (pred_table) を W&B にログします。
WandbEvalCallback リファレンス
| パラメーター | 説明 |
|---|---|
data_table_columns | (list) data_table の列名の一覧 |
pred_table_columns | (list) pred_table の列名の一覧 |
メモリ使用量の詳細
on_train_begin method が呼び出されると、data_table を W&B にログします。これが W&B Artifact としてアップロードされると、この表への参照を取得でき、data_table_ref クラス変数を使ってアクセスできます。data_table_ref は 2 次元リストで、self.data_table_ref[idx][n] のようにインデックス指定できます。ここで、idx は行番号、n は列番号です。以下の例で使い方を見てみましょう。
コールバックをカスタマイズする
on_train_begin または on_epoch_end の method をオーバーライドします。N バッチごとにサンプルをログしたい場合は、on_train_batch_end method を実装できます。
WandbEvalCallback を継承してモデルの予測可視化用コールバックを実装していて、不明な点や修正が必要な点があれば、issue を作成してお知らせください。WandbCallback [レガシー]
WandbCallback() クラスを使用すると、model.fit() でトラッキングされるすべてのメトリクスと損失値を自動的に保存できます。
スクリプトについては、example repoを参照してください。これには、Fashion MNIST の例と、それによって生成される W&B ダッシュボード が含まれます。
WandbCallback クラスは、さまざまなログ設定オプションをサポートします。たとえば、監視するメトリクスの指定、重みと勾配のトラッキング、training_data と validation_data に対する予測のログなどです。
詳しくは、keras.WandbCallback のリファレンスドキュメントを参照してください。
WandbCallback は次のことを行います
- Keras が収集したすべてのメトリクスの履歴データを自動的にログします。これには、損失や
keras_model.compile()に渡されたすべての項目が含まれます。 monitor属性とmode属性で定義される「最良」のトレーニング step に関連付けられた run の summary メトリクスを設定します。デフォルトでは、これはval_lossが最小のエポックです。WandbCallbackはデフォルトで、最良のepochに対応するモデルを保存します。- 必要に応じて、勾配とパラメーターのヒストグラムをログします。
- 必要に応じて、wandb が可視化できるようにトレーニングデータと検証データを保存します。
WandbCallback リファレンス
| 引数 | |
|---|---|
monitor | (str) 監視するメトリクスの名前。デフォルトは val_loss です。 |
mode | (str) {auto, min, max} のいずれかです。min - monitor が最小になるときにモデルを保存します max - monitor が最大になるときにモデルを保存します auto - モデルを保存するタイミングを自動的に推定します (デフォルト) 。 |
save_model | True - monitor がそれまでのすべてのエポックを上回った場合にモデルを保存します False - モデルを保存しません |
save_graph | (boolean) True の場合、モデルのグラフを wandb に保存します (デフォルトは True) 。 |
save_weights_only | (boolean) True の場合は、モデルの重みのみを保存します (model.save_weights(filepath)) 。それ以外の場合は、モデル全体を保存します) 。 |
log_weights | (boolean) True の場合、モデルの各レイヤーの重みのヒストグラムを保存します。 |
log_gradients | (boolean) True の場合、トレーニング中の勾配のヒストグラムをログします |
training_data | (tuple) model.fit に渡す (X,y) と同じ形式です。勾配の計算に必要で、log_gradients が True の場合は必須です。 |
validation_data | (tuple) model.fit に渡した (X,y) と同じ形式。wandb が可視化するためのデータです。このフィールドを設定すると、各エポックで wandb は少数の予測を実行し、後で可視化できるようにその結果を保存します。 |
generator | (generator) wandb が可視化するための検証データを返すジェネレーターです。このジェネレーターは (X,y) のタプルを返す必要があります。wandb で特定のデータ例を可視化するには、validate_data または generator のいずれかを設定する必要があります。 |
validation_steps | (int) validation_data がジェネレーターである場合、検証セット全体に対してジェネレーターを何 step 実行するか。 |
labels | (list) wandb でデータを可視化する場合、このラベルのリストを指定すると、複数クラスの分類器を構築しているときに、数値出力が分かりやすい文字列に変換されます。二値分類器の場合は、2 つのラベル [label for false, label for true] からなるリストを渡せます。validate_data と generator が両方とも false の場合、これは何もしません。 |
predictions | (int) 各エポックで可視化用に行う予測数です。最大値は 100 です。 |
input_type | (string) 可視化しやすくするためのモデル入力のタイプ。次のいずれかを指定できます: (image, images, segmentation_mask)。 |
output_type | (string) 視覚化に役立つモデル出力のタイプ。次のいずれかを指定できます: (image, images, segmentation_mask)。 |
log_evaluation | (boolean) True の場合、各エポックで、検証データとモデルの予測を含む表を保存します。詳しくは validation_indexes、validation_row_processor、output_row_processor を参照してください。 |
class_colors | ([float, float, float]) 入力または出力がセグメンテーションマスクの場合、各クラスに対応する RGB タプル (範囲 0~1) を含む配列。 |
log_batch_frequency | (integer) None の場合、callback は各エポックでログします。整数を設定すると、callback は log_batch_frequency バッチごとにトレーニングメトリクスをログします。 |
log_best_prefix | (string) None の場合、追加の summary メトリクスは保存されません。文字列を設定すると、監視対象のメトリクスとエポックの先頭にそのプレフィックスを付け、結果を summary メトリクスとして保存します。 |
validation_indexes | ([wandb.data_types._TableLinkMixin]) 各検証例に関連付けるインデックスキーの順序付きリスト。log_evaluation が True で validation_indexes を指定すると、検証データのTableは作成されません。代わりに、各予測が TableLinkMixin で表される行に関連付けられます。行キーのリストを取得するには、Table.get_index() を使用します。 |
validation_row_processor | (Callable) 検証データに適用する関数で、通常はデータの可視化に使用します。この関数は ndx (int) と row (dict) を受け取ります。モデルの入力が 1 つだけの場合、row["input"] にはその行の入力データが含まれます。それ以外の場合は、入力スロットの名が含まれます。fit 関数が単一のターゲットを受け取る場合、row["target"] にはその行のターゲットデータが含まれます。それ以外の場合は、出力スロットの名が含まれます。たとえば、入力データが単一の配列で、データを Image として可視化する場合は、プロセッサとして lambda ndx, row: {"img": wandb.Image(row["input"])} を指定します。log_evaluation が False の場合、または validation_indexes が指定されている場合は無視されます。 |
output_row_processor | (Callable) validation_row_processor と同様ですが、モデルの出力に対して適用されます。row["output"] にはモデルの出力結果が含まれます。 |
infer_missing_processors | (Boolean) validation_row_processor と output_row_processor が存在しない場合に、それらを推論するかどうかを指定します。デフォルトは True です。labels を指定すると、W&B は必要に応じて分類用のプロセッサを推論します。 |
log_evaluation_frequency | (int) 評価結果をどの頻度でログするかを指定します。デフォルトは 0 で、この場合はトレーニング終了時にのみログします。1 に設定すると毎エポック、2 に設定すると 1 エポックおき、というようにログします。log_evaluation が False の場合は効果はありません。 |
よくある質問
wandb で Keras のマルチプロセシングを使用するにはどうすればよいですか?
use_multiprocessing=True を設定すると、次のエラーが発生することがあります。
Sequenceクラスの構築時に、wandb.init(group='...')を追加します。mainでは、if __name__ == "__main__":を使用していることを確認し、スクリプトの残りのロジックはその中に記述します。