{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\nTrajGRU (Hong Kong)\n========================================================\nThis example demonstrates how to use pre-trained\nTrajGRU model to predict the next two-hour forecast\nusing thirty-minute radar echo maps.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Definitions\n--------------------------------------------------------\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Import all required modules and methods:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Python package to allow system command line functions\nimport os\n# Python package to manage warning message\nimport warnings\n# Python package for datetime function\nimport datetime\n# Python package for xarrays to read and handle netcdf data\nimport xarray as xr\n# Python package for deep learning\nimport mxnet as mx\nimport mxnet.ndarray as nd\n# Python package for projection description\nfrom pyresample.geometry import AreaDefinition\nfrom pyresample import get_area_def\n# Python package for projection\nimport cartopy.crs as ccrs\n# Python package for land/sea features\nimport cartopy.feature as cfeature\n# Python package for reading map shape file\nimport cartopy.io.shapereader as shpreader\n# Python package for creating plots\nfrom matplotlib import pyplot as plt\n# Python package for colorbars\nfrom matplotlib.colors import BoundaryNorm, ListedColormap\n\n# swirlspy regrid function\nfrom swirlspy.core.resample import grid_resample\n# swirlspy standardize data function\nfrom swirlspy.utils import standardize_attr, FrameType\n# swirlspy cuda checking function\nfrom swirlspy.utils import locate_cuda\n# swirlspy deep learning function\nfrom swirlspy.qpf.dl.config import cfg, cfg_from_file\nfrom swirlspy.qpf.dl.utils import parse_ctx\nfrom swirlspy.qpf.dl.hko_benchmark import *\nfrom swirlspy.qpf.dl.encoder_forecaster import EncoderForecasterBaseFactory, encoder_forecaster_build_networks, EncoderForecasterStates, load_encoder_forecaster_params\nfrom swirlspy.qpf.dl.hko_evaluation import rainfall_to_pixel, pixel_to_dBZ\nfrom swirlspy.qpf.dl.operators import *\nfrom swirlspy.qpf.dl.ops import *\n# directory constants\nfrom swirlspy.tests.samples import DATA_DIR\nfrom swirlspy.tests.outputs import OUTPUT_DIR\nfrom swirlspy.tests.models import DL_MODEL_DIR\n\nwarnings.filterwarnings(\"ignore\")\n\n\ndef get_loss_weight_symbol(data, mask, seq_len):\n if cfg.MODEL.USE_BALANCED_LOSS:\n balancing_weights = cfg.HKO.EVALUATION.BALANCING_WEIGHTS\n weights = mx.sym.ones_like(data) * balancing_weights[0]\n thresholds = [rainfall_to_pixel(ele)\n for ele in cfg.HKO.EVALUATION.THRESHOLDS]\n for i, threshold in enumerate(thresholds):\n weights = weights + \\\n (balancing_weights[i + 1] -\n balancing_weights[i]) * (data >= threshold)\n weights = weights * mask\n else:\n weights = mask\n if cfg.MODEL.TEMPORAL_WEIGHT_TYPE == \"same\":\n return weights\n elif cfg.MODEL.TEMPORAL_WEIGHT_TYPE == \"linear\":\n upper = cfg.MODEL.TEMPORAL_WEIGHT_UPPER\n assert upper >= 1.0\n temporal_mult = 1 + \\\n mx.sym.arange(start=0, stop=seq_len) * \\\n (upper - 1.0) / (seq_len - 1.0)\n temporal_mult = mx.sym.reshape(\n temporal_mult, shape=(seq_len, 1, 1, 1, 1))\n weights = mx.sym.broadcast_mul(weights, temporal_mult)\n return weights\n elif cfg.MODEL.TEMPORAL_WEIGHT_TYPE == \"exponential\":\n upper = cfg.MODEL.TEMPORAL_WEIGHT_UPPER\n assert upper >= 1.0\n base_factor = np.log(upper) / (seq_len - 1.0)\n temporal_mult = mx.sym.exp(mx.sym.arange(\n start=0, stop=seq_len) * base_factor)\n temporal_mult = mx.sym.reshape(\n temporal_mult, shape=(seq_len, 1, 1, 1, 1))\n weights = mx.sym.broadcast_mul(weights, temporal_mult)\n return weights\n else:\n raise NotImplementedError\n\n\nclass HKONowcastingFactory(EncoderForecasterBaseFactory):\n def __init__(self,\n batch_size,\n in_seq_len,\n out_seq_len,\n ctx_num=1,\n name=\"hko_nowcasting\"):\n super(HKONowcastingFactory, self).__init__(batch_size=batch_size,\n in_seq_len=in_seq_len,\n out_seq_len=out_seq_len,\n ctx_num=ctx_num,\n height=cfg.HKO.ITERATOR.HEIGHT,\n width=cfg.HKO.ITERATOR.WIDTH,\n name=name)\n self._central_region = cfg.HKO.EVALUATION.CENTRAL_REGION\n\n def _slice_central(self, data):\n \"\"\"Slice the central region in the given symbol\n\n Parameters\n ----------\n data : mx.sym.Symbol\n\n Returns\n -------\n ret : mx.sym.Symbol\n \"\"\"\n x_begin, y_begin, x_end, y_end = self._central_region\n return mx.sym.slice(data,\n begin=(0, 0, 0, y_begin, x_begin),\n end=(None, None, None, y_end, x_end))\n\n def _concat_month_code(self):\n # TODO\n raise NotImplementedError\n\n def loss_sym(self,\n pred=mx.sym.Variable('pred'),\n mask=mx.sym.Variable('mask'),\n target=mx.sym.Variable('target')):\n \"\"\"Construct loss symbol.\n\n Optional args:\n pred: Shape (out_seq_len, batch_size, C, H, W)\n mask: Shape (out_seq_len, batch_size, C, H, W)\n target: Shape (out_seq_len, batch_size, C, H, W)\n \"\"\"\n self.reset_all()\n weights = get_loss_weight_symbol(\n data=target, mask=mask, seq_len=self._out_seq_len)\n mse = weighted_mse(pred=pred, gt=target, weight=weights)\n mae = weighted_mae(pred=pred, gt=target, weight=weights)\n gdl = masked_gdl_loss(pred=pred, gt=target, mask=mask)\n avg_mse = mx.sym.mean(mse)\n avg_mae = mx.sym.mean(mae)\n avg_gdl = mx.sym.mean(gdl)\n global_grad_scale = cfg.MODEL.NORMAL_LOSS_GLOBAL_SCALE\n if cfg.MODEL.L2_LAMBDA > 0:\n avg_mse = mx.sym.MakeLoss(avg_mse,\n grad_scale=global_grad_scale * cfg.MODEL.L2_LAMBDA,\n name=\"mse\")\n else:\n avg_mse = mx.sym.BlockGrad(avg_mse, name=\"mse\")\n if cfg.MODEL.L1_LAMBDA > 0:\n avg_mae = mx.sym.MakeLoss(avg_mae,\n grad_scale=global_grad_scale * cfg.MODEL.L1_LAMBDA,\n name=\"mae\")\n else:\n avg_mae = mx.sym.BlockGrad(avg_mae, name=\"mae\")\n if cfg.MODEL.GDL_LAMBDA > 0:\n avg_gdl = mx.sym.MakeLoss(avg_gdl,\n grad_scale=global_grad_scale * cfg.MODEL.GDL_LAMBDA,\n name=\"gdl\")\n else:\n avg_gdl = mx.sym.BlockGrad(avg_gdl, name=\"gdl\")\n loss = mx.sym.Group([avg_mse, avg_mae, avg_gdl])\n return loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initialising\n---------------------------------------------------\n\nThis section demonstrates the settings to run\nthe pre-trained model.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# set the directory containing the pre-trained model configuration\n# set the path to '[DL_MODEL_DIR]/hko/convgru' when using ConvGRU\nbase_dir = os.path.join(DL_MODEL_DIR, 'hko/trajgru')\n\ncfg_from_file(os.path.join(base_dir, 'cfg0.yml'), target=cfg.MODEL)\ncfg.MODEL.LOAD_DIR = base_dir\ncfg.MODEL.SAVE_DIR = base_dir\n\n# set the load_iter = 49999 when using ConvGRU\ncfg.MODEL.LOAD_ITER = 79999\n\n# no training on inference data: mode = 'fixed' and finetune = 0\n# train on inference data: mode = 'online' and finetune = 1\nmode = 'fixed'\nfinetune = 0\n\n# use GPU if CUDA installed; otherwise use CPU\nif locate_cuda():\n ctx = 'gpu'\nelse:\n ctx = 'cpu'\nctx = parse_ctx(ctx)\n\n# Initialize the model\nhko_nowcasting_online = HKONowcastingFactory(batch_size=1,\n in_seq_len=cfg.MODEL.IN_LEN,\n out_seq_len=cfg.MODEL.OUT_LEN)\n\n# Initialize the encoder network and forecaster network\nt_encoder_net, t_forecaster_net, t_loss_net =\\\n encoder_forecaster_build_networks(\n factory=hko_nowcasting_online,\n context=ctx,\n for_finetune=True)\nt_encoder_net.summary()\nt_forecaster_net.summary()\nt_loss_net.summary()\n\n# Load the pre-trained model params to the networks created\nload_encoder_forecaster_params(load_dir=cfg.MODEL.LOAD_DIR,\n load_iter=cfg.MODEL.LOAD_ITER,\n encoder_net=t_encoder_net,\n forecaster_net=t_forecaster_net)\n\n# 2019-04-20 10:30 to 2019-04-20 15:24\npd_path = cfg.HKO_PD.RAINY_EXAMPLE\n\n# Set the save directory\nsave_dir = os.path.join(base_dir, \"iter%d_%s_finetune%d\"\n % (cfg.MODEL.LOAD_ITER + 1, 'example',\n finetune))\n\n# Create the environment to run the model\nenv = HKOBenchmarkEnv(pd_path=pd_path, save_dir=save_dir, mode=mode)\nstates = EncoderForecasterStates(factory=hko_nowcasting_online, ctx=ctx[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Inference and Evaluation\n---------------------------------------------------\n\nThe pre-trained model loops through the radar echo maps\nand generate the inference over the next two hours.\nEvaluation of the inference is to compare the\ninference with the truth forecast data.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "stored_data = []\nstored_prediction = []\ncounter = 0\nfinetune_iter = 0\nwhile not env.done:\n states.reset_all()\n\n # Get radar data\n in_frame_dat, in_datetime_clips, out_datetime_clips, begin_new_episode, need_upload_prediction =\\\n env.get_observation(batch_size=1)\n\n in_frame_nd = nd.array(in_frame_dat, ctx=ctx[0])\n t_encoder_net.forward(is_train=False,\n data_batch=mx.io.DataBatch(data=[in_frame_nd] +\n states.get_encoder_states()))\n outputs = t_encoder_net.get_outputs()\n states.update(states_nd=outputs)\n if need_upload_prediction:\n counter += 1\n t_forecaster_net.forward(is_train=False,\n data_batch=mx.io.DataBatch(\n data=states.get_forecaster_state()))\n\n # Get inference from model\n pred_nd = t_forecaster_net.get_outputs()[0]\n pred_nd = nd.clip(pred_nd, a_min=0, a_max=1)\n\n # Generate prediction movies and inference evaluation\n env.upload_prediction(prediction=pred_nd.asnumpy())\n\nfingerprint = env._fingerprint\nprint(\"Saving videos to %s\" % os.path.join(save_dir, fingerprint))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Printing of Evaluation\n--------------------------------------------------------------------\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "env.save_eval(print_out=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generating radar reflectivity maps\n-----------------------------------\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Step 1: Retrieve observation and prediction for one-hour\nand two-hour forecasts.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Ground truth at one-hour forecast\ngt_10 = env._out_frame_dat[9:10, 0, 0]\n\n# Ground truth at two-hour forecast\ngt_20 = env._out_frame_dat[19:, 0, 0]\n\n# translate pixel to dBZ\ngt_10 = pixel_to_dBZ(gt_10)\ngt_20 = pixel_to_dBZ(gt_20)\npred_nd = pixel_to_dBZ(pred_nd.asnumpy())\n\n# Prediction at one-hour forecast\npd_10 = pred_nd[9:10, 0, 0]\n\n# Prediction at two-hour forecast\npd_20 = pred_nd[19:, 0, 0]\n\n# Save observation and forecast in list\nreflec_np = []\nreflec_np.append(gt_10)\nreflec_np.append(pd_10)\nreflec_np.append(gt_20)\nreflec_np.append(pd_20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Step 2: Generate xarrays for one-hour and two-hour forecast.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Get forecast base time\nbasetime = in_datetime_clips[0][0]\n\n# Create a datetime list\ndatetime_list = []\ncurr = basetime\nfor i in range(len(reflec_np)):\n datetime_list.append(curr)\n curr = curr + datetime.timedelta(hours=1)\n\ny = []\ny_init = 255466.666667\nfor i in range(480):\n y.append(y_init)\n y_init -= 1066.666667\n\nx = []\nx_init = -255466.666667\nfor i in range(480):\n x.append(x_init)\n x_init += 1066.666667\n\n# Define source grid\narea_id = 'aeqd'\ndescription = 'Azimuthal Equidistant Projection centered at the radar site extending up to 256000.000000m in the x direction and256000.000000m in the y directionwith a 480.000000x480.000000 grid resolution'\nproj_id = 'aeqd'\nprojection = {'datum': 'WGS84', 'ellps': 'WGS84', 'lat_0': '22.411667',\n 'lon_0': '114.123333', 'no_defs': 'True', 'proj': 'aeqd', 'units': 'm'}\nwidth = 480\nheight = 480\narea_extent = (-256000.00000000003, -256000.00000000003,\n 256000.00000000003, 256000.00000000003)\narea_def = AreaDefinition(area_id, description, proj_id,\n projection, width, height, area_extent)\n\nreflec_xr_list = []\nfor i in range(len(reflec_np)):\n reflec_xr = xr.DataArray(\n reflec_np[i],\n dims=('time', 'y', 'x'),\n coords={\n 'time': [datetime_list[i]],\n 'y': y,\n 'x': x\n },\n attrs={\n 'long_name': 'Reflectivity',\n 'units': 'dBZ',\n 'projection': 'Centered Azimuthal',\n 'site': (114.12333341315389, 22.41166664287448, 948),\n 'radar_height': 20,\n 'area_def': area_def\n }\n )\n reflec_xr_list.append(reflec_xr)\n\n# Defining target grid\narea_id = \"hk1980_250km\"\ndescription = (\"A 1km resolution rectangular grid \"\n \"centred at HKO and extending to 250 km \"\n \"in each direction in HK1980 easting/northing coordinates\")\nproj_id = 'hk1980'\nprojection = ('+proj=tmerc +lat_0=22.31213333333334 '\n '+lon_0=114.1785555555556 +k=1 +x_0=836694.05 '\n '+y_0=819069.8 +ellps=intl +towgs84=-162.619,-276.959,'\n '-161.764,0.067753,-2.24365,-1.15883,-1.09425 +units=m '\n '+no_defs')\nx_size = 500\ny_size = 500\narea_extent = (587000, 569000, 1087000, 1069000)\narea_def_tgt = get_area_def(\n area_id, description, proj_id, projection, x_size, y_size, area_extent\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Step 3: Reproject the radar data from read_iris_grid() from Centered\nAzimuthal (source) projection to HK 1980 (target) projection.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "reproj_xr_list = []\nfor i in range(len(reflec_xr_list)):\n reproj_xr = grid_resample(\n reflec_xr_list[i], area_def, area_def_tgt,\n coord_label=['easting', 'northing']\n )\n reproj_xr_list.append(reproj_xr)\n\n# Combining the DataArrays\nreflec_concat = xr.concat(reproj_xr_list, dim='time')\nstandardize_attr(reflec_concat, frame_type=FrameType.dBZ, zero_value=9999.)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Step 4: Generating radar reflectivity maps\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Defining colour scale and format\nlevels = [\n -32768,\n 10, 15, 20, 24, 28, 32,\n 34, 38, 41, 44, 47, 50,\n 53, 56, 58, 60, 62\n]\ncmap = ListedColormap([\n '#FFFFFF', '#08C5F5', '#0091F3', '#3898FF', '#008243', '#00A433',\n '#00D100', '#01F508', '#77FF00', '#E0D100', '#FFDC01', '#EEB200',\n '#F08100', '#F00101', '#E20200', '#B40466', '#ED02F0'\n])\n\nnorm = BoundaryNorm(levels, ncolors=cmap.N, clip=True)\n\n# Defining the crs\ncrs = area_def_tgt.to_cartopy_crs()\n\n# Get one-hour forecast date time\ndatetime_one_hr = out_datetime_clips[0][9]\n\n# Get two-hour forecast date time\ndatetime_two_hr = out_datetime_clips[0][-1]\n\ntimelist = [\n datetime_one_hr,\n datetime_one_hr,\n datetime_two_hr,\n datetime_two_hr\n]\n\n# Obtaining the slice of the xarray to be plotted\nda_plot = reflec_concat\n\n# Defining coastlines\nmap_shape_file = os.path.join(DATA_DIR, \"shape/rsmc\")\nocean_color = np.array([[[178, 208, 254]]], dtype=np.uint8)\nland_color = cfeature.COLORS['land']\ncoastline = cfeature.ShapelyFeature(\n list(shpreader.Reader(map_shape_file).geometries()),\n ccrs.PlateCarree()\n)\n\n# Plotting\np = da_plot.plot(\n col='time', col_wrap=2,\n subplot_kws={'projection': crs},\n cbar_kwargs={\n 'extend': 'max',\n 'ticks': levels[1:],\n 'format': '%.3g'\n },\n cmap=cmap,\n norm=norm\n)\nfor idx, ax in enumerate(p.axes.flat):\n # ocean\n ax.imshow(np.tile(ocean_color, [2, 2, 1]),\n origin='upper',\n transform=ccrs.PlateCarree(),\n extent=[-180, 180, -180, 180],\n zorder=-1)\n # coastline, color\n ax.add_feature(coastline,\n facecolor=land_color, edgecolor='none', zorder=0)\n # overlay coastline without color\n ax.add_feature(coastline, facecolor='none',\n edgecolor='gray', linewidth=0.5, zorder=3)\n ax.gridlines()\n if idx == 0 or idx == 2:\n ax.set_title(\n \"Ground Truth\\n\"\n f\"Based @ {timelist[idx].strftime('%H:%MH')}\",\n loc='left',\n fontsize=9\n )\n ax.set_title(\n ''\n )\n ax.set_title(\n f\"{basetime.strftime('%Y-%m-%d')} \\n\"\n f\"Valid @ {timelist[idx].strftime('%H:%MH')} \",\n loc='right',\n fontsize=9\n )\n else:\n ax.set_title(\n \"Reflectivity\\n\"\n f\"Based @ {basetime.strftime('%H:%MH')}\",\n loc='left',\n fontsize=9\n )\n ax.set_title(\n ''\n )\n ax.set_title(\n f\"{basetime.strftime('%Y-%m-%d')} \\n\"\n f\"Valid @ {timelist[idx].strftime('%H:%MH')} \",\n loc='right',\n fontsize=9\n )\n\nplt.savefig(\n os.path.join(OUTPUT_DIR, 'traj-output-map-hk.png'),\n dpi=300\n)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.15" } }, "nbformat": 4, "nbformat_minor": 0 }