メインコンテンツへスキップ

関数 pr_curve

pr_curve(
    y_true: 'Iterable[T] | None' = None,
    y_probas: 'Iterable[numbers.Number] | None' = None,
    labels: 'list[str] | None' = None,
    classes_to_plot: 'list[T] | None' = None,
    interp_size: 'int' = 21,
    title: 'str' = 'Precision-Recall Curve',
    split_table: 'bool' = False
) → CustomChart
Precision-Recall (PR) 曲線を作成します。 Precision-Recall 曲線は、特に不均衡なデータセットで分類器を評価する際に有用です。PR 曲線下面積が大きいほど、適合率が高く (偽陽性率が低い) 、再現率も高い (偽陰性率が低い) ことを示します。この曲線により、さまざまなしきい値における偽陽性と偽陰性のバランスを把握でき、モデルのパフォーマンス評価に役立ちます。 引数:
  • y_true: 真の二値ラベル。shape は (num_samples,) である必要があります。
  • y_probas: 各クラスの予測スコアまたは確率。確率推定値、信頼度スコア、またはしきい値を適用していない決定値を指定できます。shape は (num_samples, num_classes) である必要があります。
  • labels: プロットを解釈しやすくするために、y_true 内の数値を置き換える任意のクラス名のリスト。たとえば、labels = ['dog', 'cat', 'owl'] とすると、プロット内で 0 は ‘dog’、1 は ‘cat’、2 は ‘owl’ に置き換えられます。指定しない場合は、y_true の数値が使用されます。
  • classes_to_plot: プロットに含める y_true の一意なクラス値の任意のリスト。指定しない場合は、y_true 内のすべての一意なクラスがプロットされます。
  • interp_size: 再現率の値を補間する点の数。再現率の値は、[0, 1] の範囲に一様に分布する interp_size 個の点に固定され、それに応じて適合率が補間されます。
  • title: プロットのタイトル。デフォルトは “Precision-Recall Curve” です。
  • split_table: 表を W&B UI の別セクションに分けるかどうか。True の場合、表は “Custom Chart Tables” という名前のセクションに表示されます。デフォルトは False です。
戻り値:
  • CustomChart: W&B にログできるカスタムチャート object。チャートをログするには、wandb.log() に渡します。
例外:
  • wandb.Error: NumPy、pandas、または scikit-learn がインストールされていない場合。
例:
import wandb

# スパム検出の例(二値分類)
y_true = [0, 1, 1, 0, 1]  # 0 = スパムでない, 1 = スパム
y_probas = [
    [0.9, 0.1],  # 最初のサンプルの予測確率(スパムでない)
    [0.2, 0.8],  # 2番目のサンプル(スパム)、以下同様
    [0.1, 0.9],
    [0.8, 0.2],
    [0.3, 0.7],
]

labels = ["not spam", "spam"]  # 読みやすさのためのオプションのクラス名

with wandb.init(project="spam-detection") as run:
    pr_curve = wandb.plot.pr_curve(
         y_true=y_true,
         y_probas=y_probas,
         labels=labels,
         title="Precision-Recall Curve for Spam Detection",
    )
    run.log({"pr-curve": pr_curve})