{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\nVerification of Probabilistic Forecasts\n====================================================\n\nThis example demonstrated how to perform verification\non probabilistic forecasts.\n\nReliability diagrams, sharpness histograms,\nReceiver Operating Characteristic (ROC) Curves and\nPrecision / Recall Curves are included in this\nexample.\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Definitions\n--------------------------------------------------------\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Import all required modules and methods:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Python package to allow system command line functions\nimport os\n# Python package to manage warning message\nimport warnings\n# Python package for time calculations\nimport pandas as pd\n# Python package for numerical calculations\nimport numpy as np\n# Python package for xarrays to read and handle netcdf data\nimport xarray as xr\n# Python package for creating plots\nfrom matplotlib import pyplot as plt\n\n# swirlspy verification metric package\nimport swirlspy.ver.metric as mt\n# directory constants\nfrom swirlspy.tests.samples import DATA_DIR\nfrom swirlspy.tests.outputs import OUTPUT_DIR\n\nwarnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initialising\n-----------------------------------------------------------\n\nIn a probabilistic forecast, the probability of an event\nis forecast, instead of a simple Yes or No.\n\nAs always, the forecast and observed values are stored as xarray.DataArrays.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Extracting a set of probabilities and binary values\n# for verification\n\n# This particular forecast xarray.DataArray also includes\n# observation data at basetime, so data at the first\n# timestep (basetime) is removed\n# The array also contains some data above 1, so clipping\n# is required\nforecast = xr.open_dataarray(\n os.path.join(DATA_DIR, 'ltg/ltgvf_201904201300.nc')\n).isel(time=slice(1, None)).clip(min=0, max=1)\n\n# Observation\nobservation = xr.open_dataarray(\n os.path.join(DATA_DIR, 'ltg/lobs_201904201300.nc')\n)\n\ntimelist = [pd.Timestamp(t) for t in forecast.time.values]\n\n# Define basetime\nbasetime = pd.Timestamp('201904201300')\n\nprint(forecast)\nprint(observation)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Reliability diagram and Sharpness Histogram\n-----------------------------------------------------------------\n\nThis section demonstrates how to obtain data and plot\nthe reliability diagram along with a Sharpness Histogram.\n\nThe reliability diagram plots observed frequency against\nthe forecast probability, where the range of forecast\nprobabilities is divided into K bins (in this example, K=10).\n\nThe sharpness histogram displays the number of forecasts in each\nforecast probability bin.\n\nIn this example, multiple curves will be plotted for different\nlead times from basetime, so multiple sets of reliability\ndiagram data will be plotted.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Obtaining data to plot reliability diagram\n# Data is stored as an xarray.DataSet\n# For loop to generate data for different lead times\nreliabilityDataList = []\nfor time in timelist:\n reliabilityData = mt.reliability(forecast.sel(time=time),\n observation.sel(time=time),\n n_bins=10)\n reliabilityDataList.append(reliabilityData)\n\n# Concatenate reliability diagram along the time dimension\nreliability_data = xr.concat(reliabilityDataList,\n dim=xr.IndexVariable('time', timelist))\n\n# Plotting reliability diagram\n\nplt.figure(figsize=(20, 30))\nax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)\nax2 = plt.subplot2grid((3, 1), (2, 0))\n\nfor time in reliability_data.time.values:\n # Extracting DataArrays from DataSet\n observed_rf = reliability_data.observed_rf.sel(time=time)\n nforecast = reliability_data.nforecast.sel(time=time)\n time = pd.Timestamp(time)\n\n # Minutes from basetime\n timeDiff = time - basetime\n timeDiffMins = int(timeDiff.total_seconds() // 60)\n\n # Format axes\n ax1.set_ylim(bottom=0, top=1)\n ax1.set_xlim(left=0, right=1)\n ax1.set_aspect('equal', adjustable='datalim')\n\n # Plot reliability data\n observed_rf.plot(\n ax=ax1, marker='s', markersize=10,\n label=f\"t + {timeDiffMins}\"\n )\n\n # Plot perfect reliability and sample climatology\n ax1.plot([0, 1], [0, 1], ls='--', c='red', dashes=(9, 2))\n c = reliability_data.attrs['climatology']\n ax1.plot([0, 1], [c, c], ls='--', c='black', dashes=(15, 3))\n\n # Plot skill area\n x0 = np.array([c, 1])\n y01 = np.array([c, 0.5*(1+c)])\n y02 = np.array([1, 1])\n ax1.fill_between(x0, y01, y02, where=y01 <= y02, facecolor='lightgreen')\n x1 = np.array([0, c])\n y11 = np.array([0.5*c, c])\n y12 = np.array([0, 0])\n ax1.fill_between(x1, y11, y12, where=y11 >= y12, facecolor='lightgreen')\n\n # Labelling\n ax1.text(0.7, 0.85, 'Perfect Reliability', rotation=45, fontsize=20)\n ax1.text(0.6, c + 0.01, 'Climatology', fontsize=20)\n ax1.text(c + 0.4, 0.95, 'Skill', fontsize=20)\n ax1.grid(True, ls='--', dashes=(2, 0.1))\n ax1.set_xlabel('')\n ax1.set_ylabel('Observed Relative Frequency', fontsize=20)\n ax1.set_title('Reliability diagram', fontsize=32)\n ax1.xaxis.set_tick_params(labelsize=17)\n ax1.yaxis.set_tick_params(labelsize=17)\n lgd1 = ax1.legend(loc=\"upper left\",\n title='Minutes from basetime',\n fontsize=20)\n plt.setp(lgd1.get_title(), fontsize=20)\n\n # Plot and label sharpness histogram\n coords = nforecast.coords['forecast_probability'].values\n data = nforecast.values\n ax2.step(coords, data, where='mid', label=f\"t + {timeDiffMins}\")\n ax2.set_xlim(left=0, right=1)\n ax2.set_xlabel('Forecast Probability', fontsize=20)\n ax2.set_ylabel('Count', fontsize=20)\n ax2.xaxis.set_tick_params(labelsize=17)\n ax2.yaxis.set_tick_params(labelsize=17)\n ax2.set_title('Sharpness Histogram', fontsize=25)\n lgd2 = ax2.legend(\n loc=\"upper center\",\n title='Minutes from basetime',\n fontsize=20)\n plt.setp(lgd2.get_title(), fontsize=20)\n\n# Saving\nplt.savefig(os.path.join(OUTPUT_DIR, 'reliability.png'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Receiver Operating Characteristic Curve (ROC Curve)\n----------------------------------------------------------------------------\n\nThis section demonstrates how to obtain data and plot\nthe ROC Curve.\n\nThe ROC Curve plots the Probability of Detection against the\nProbability of False Detection.\n\nSimilarly, in this section, multiple curves will be plotted for\ndifferent lead times.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Obtaining data to plot the ROC Curve\n# Data is stored as a dictionary\nrocDataList = []\nfor time in timelist:\n rocData = mt.roc(forecast.sel(time=time), observation.sel(time=time))\n rocDataList.append(rocData)\n\n# Plotting ROC Curve\n\n# Intialising figure and axes\nplt.figure(figsize=(20, 20))\nax = plt.axes()\nax.set_ylim(bottom=0, top=1)\nax.set_xlim(left=0, right=1)\nax.xaxis.set_tick_params(labelsize=25)\nax.yaxis.set_tick_params(labelsize=25)\nax.set_aspect('equal', adjustable='box')\n\nfor roc_data, time in zip(rocDataList, timelist):\n # Time from basetime\n time_diff = time - basetime\n time_diff_min = int(time_diff.total_seconds() // 60)\n # Plot ROC data\n pod = roc_data['pod']\n pofd = roc_data['pofd']\n label = f\"t + {time_diff_min} : {roc_data['auc']:.3f}\"\n ax.plot(pofd, pod, linewidth=2, label=label)\n\n# Plot no discrimination line\nax.plot([0, 1], [0, 1], ls='--', c='red', dashes=(9, 2))\nax.text(0.65, 0.65 + 0.22, 'No Discrimination', rotation=45, fontsize=25)\n\n# Plotting grid\nax.grid(True, ls='--', dashes=(2, 0.1))\n# Plotting labels and titles\nax.set_xlabel('Probability of False Detection', fontsize=30)\nax.set_ylabel('Probability of Detection', fontsize=30)\nlgd = ax.legend(\n loc=\"upper left\",\n title='Minutes from basetime : Area under curve',\n fontsize=24\n)\nplt.setp(lgd.get_title(), fontsize=24)\n\nplt.title('ROC Curve', fontsize=40)\n# Saving\nplt.savefig(os.path.join(OUTPUT_DIR, 'roc.png'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Precision-Recall Curve\n---------------------------------------------------------------------------\n\nThis section demonstrates how to obtain data and plot\nthe Precision-Recall Curve.\n\nThe Precision-Recall Curve plots precision, which\nis equivalent to 1 - FAR, against recall, which is\nequivalent to Probability of Detection.\n\nSimilarly, in this section, multiple curves will be plotted for\ndifferent lead times.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Obtaining data to plot the ROC Curve\n# Data is stored as a dictionary\nprDataList = []\nfor time in timelist:\n prData = mt.precision_recall(forecast.sel(time=time),\n observation.sel(time=time))\n prDataList.append(prData)\n\n# Plotting Precision Recall Curve\n\n# Initialising figure and axes\nplt.figure(figsize=(20, 20))\nax = plt.axes()\nax.set_ylim(bottom=0, top=1)\nax.set_xlim(left=0, right=1)\nax.xaxis.set_tick_params(labelsize=25)\nax.yaxis.set_tick_params(labelsize=25)\nax.set_aspect('equal', adjustable='box')\n\nfor time, pr_data in zip(timelist, prDataList):\n # Time from basetime\n time_diff = time - basetime\n time_diff_min = int(time_diff.total_seconds() // 60)\n # Plot data\n p = pr_data['precision']\n r = pr_data['recall']\n label = (f\"t + {time_diff_min} : {pr_data['ap']:.3f}\"\n f\" : {pr_data['auc']:.3f}\")\n ax.plot(r, p, linewidth=2, label=label)\n\n# Drawing grid and labelling\nax.grid(True, ls='--', dashes=(2, 0.1))\nax.set_xlabel('Recall', fontsize=30)\nax.set_ylabel('Precision', fontsize=30)\nlgd = ax.legend(\n loc=\"upper left\",\n title='Minutes from basetime : AP: AUC',\n fontsize=20\n)\nplt.setp(lgd.get_title(), fontsize=24)\n\n# Title\nplt.title('Precision-Recall Curve', fontsize=40)\n\n# Saving\nplt.savefig(os.path.join(OUTPUT_DIR, 'precision_recall.png'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Brier Skill Score\n------------------------------------------------------------------------\n\nThis section demonstrates how to compute the Brier Skill Score.\n\nA Brier Skill Score of 1 means a perfect forecast, 0 indicates\nthat the forecast is no better than climatology, and a negative\nscore indicates that the forecast is worse than climatology.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Calculate the Brier Skill Score\nfor time in timelist:\n bss = mt.brier_skill_score(forecast.sel(time=time),\n observation.sel(time=time))\n # Time from basetime\n time_diff = time - basetime\n time_diff_min = int(time_diff.total_seconds() // 60)\n\n print(f\"For t + {time_diff_min:3} min, Brier Skill Score: {bss:8.5f}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.15" } }, "nbformat": 4, "nbformat_minor": 0 }