{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n# TrajGRU (Hong Kong)\nThis example demonstrates how to use pre-trained\nTrajGRU model to predict the next two-hour forecast\nusing thirty-minute radar echo maps.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Definitions\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import os\nimport warnings\nimport datetime\nimport mxnet as mx\nimport mxnet.ndarray as nd\nfrom pyresample.geometry import AreaDefinition\nfrom pyresample import utils\nimport xarray as xr\nimport cartopy.feature as cfeature\nimport cartopy.crs as ccrs\nimport cartopy.io.shapereader as shpreader\nimport matplotlib.pyplot as plt\nfrom matplotlib.colors import BoundaryNorm, ListedColormap\n\nfrom swirlspy.core.resample import grid_resample\nfrom swirlspy.utils import locate_cuda, standardize_attr, FrameType\nfrom swirlspy.qpf.dl.config import cfg, save_cfg, cfg_from_file\nfrom swirlspy.qpf.dl.utils import logging_config, parse_ctx\nfrom swirlspy.qpf.dl.hko_benchmark import *\nfrom swirlspy.qpf.dl.encoder_forecaster import EncoderForecasterBaseFactory, encoder_forecaster_build_networks, EncoderForecasterStates, load_encoder_forecaster_params\nfrom swirlspy.qpf.dl.hko_evaluation import rainfall_to_pixel, pixel_to_dBZ\nfrom swirlspy.qpf.dl.operators import *\nfrom swirlspy.qpf.dl.ops import *\n\nwarnings.filterwarnings(\"ignore\")\nTHIS_DIR = os.getcwd()\nos.chdir(THIS_DIR)\n\ndef get_loss_weight_symbol(data, mask, seq_len):\n if cfg.MODEL.USE_BALANCED_LOSS:\n balancing_weights = cfg.HKO.EVALUATION.BALANCING_WEIGHTS\n weights = mx.sym.ones_like(data) * balancing_weights[0]\n thresholds = [rainfall_to_pixel(ele) for ele in cfg.HKO.EVALUATION.THRESHOLDS]\n for i, threshold in enumerate(thresholds):\n weights = weights + (balancing_weights[i + 1] - balancing_weights[i]) * (data >= threshold)\n weights = weights * mask\n else:\n weights = mask\n if cfg.MODEL.TEMPORAL_WEIGHT_TYPE == \"same\":\n return weights\n elif cfg.MODEL.TEMPORAL_WEIGHT_TYPE == \"linear\":\n upper = cfg.MODEL.TEMPORAL_WEIGHT_UPPER\n assert upper >= 1.0\n temporal_mult = 1 + \\\n mx.sym.arange(start=0, stop=seq_len) * (upper - 1.0) / (seq_len - 1.0)\n temporal_mult = mx.sym.reshape(temporal_mult, shape=(seq_len, 1, 1, 1, 1))\n weights = mx.sym.broadcast_mul(weights, temporal_mult)\n return weights\n elif cfg.MODEL.TEMPORAL_WEIGHT_TYPE == \"exponential\":\n upper = cfg.MODEL.TEMPORAL_WEIGHT_UPPER\n assert upper >= 1.0\n base_factor = np.log(upper) / (seq_len - 1.0)\n temporal_mult = mx.sym.exp(mx.sym.arange(start=0, stop=seq_len) * base_factor)\n temporal_mult = mx.sym.reshape(temporal_mult, shape=(seq_len, 1, 1, 1, 1))\n weights = mx.sym.broadcast_mul(weights, temporal_mult)\n return weights\n else:\n raise NotImplementedError\n\nclass HKONowcastingFactory(EncoderForecasterBaseFactory):\n def __init__(self,\n batch_size,\n in_seq_len,\n out_seq_len,\n ctx_num=1,\n name=\"hko_nowcasting\"):\n super(HKONowcastingFactory, self).__init__(batch_size=batch_size,\n in_seq_len=in_seq_len,\n out_seq_len=out_seq_len,\n ctx_num=ctx_num,\n height=cfg.HKO.ITERATOR.HEIGHT,\n width=cfg.HKO.ITERATOR.WIDTH,\n name=name)\n self._central_region = cfg.HKO.EVALUATION.CENTRAL_REGION\n\n def _slice_central(self, data):\n \"\"\"Slice the central region in the given symbol\n\n Parameters\n ----------\n data : mx.sym.Symbol\n\n Returns\n -------\n ret : mx.sym.Symbol\n \"\"\"\n x_begin, y_begin, x_end, y_end = self._central_region\n return mx.sym.slice(data,\n begin=(0, 0, 0, y_begin, x_begin),\n end=(None, None, None, y_end, x_end))\n\n def _concat_month_code(self):\n # TODO\n raise NotImplementedError\n\n def loss_sym(self,\n pred=mx.sym.Variable('pred'),\n mask=mx.sym.Variable('mask'),\n target=mx.sym.Variable('target')):\n \"\"\"Construct loss symbol.\n\n Optional args:\n pred: Shape (out_seq_len, batch_size, C, H, W)\n mask: Shape (out_seq_len, batch_size, C, H, W)\n target: Shape (out_seq_len, batch_size, C, H, W)\n \"\"\"\n self.reset_all()\n weights = get_loss_weight_symbol(data=target, mask=mask, seq_len=self._out_seq_len)\n mse = weighted_mse(pred=pred, gt=target, weight=weights)\n mae = weighted_mae(pred=pred, gt=target, weight=weights)\n gdl = masked_gdl_loss(pred=pred, gt=target, mask=mask)\n avg_mse = mx.sym.mean(mse)\n avg_mae = mx.sym.mean(mae)\n avg_gdl = mx.sym.mean(gdl)\n global_grad_scale = cfg.MODEL.NORMAL_LOSS_GLOBAL_SCALE\n if cfg.MODEL.L2_LAMBDA > 0:\n avg_mse = mx.sym.MakeLoss(avg_mse,\n grad_scale=global_grad_scale * cfg.MODEL.L2_LAMBDA,\n name=\"mse\")\n else:\n avg_mse = mx.sym.BlockGrad(avg_mse, name=\"mse\")\n if cfg.MODEL.L1_LAMBDA > 0:\n avg_mae = mx.sym.MakeLoss(avg_mae,\n grad_scale=global_grad_scale * cfg.MODEL.L1_LAMBDA,\n name=\"mae\")\n else:\n avg_mae = mx.sym.BlockGrad(avg_mae, name=\"mae\")\n if cfg.MODEL.GDL_LAMBDA > 0:\n avg_gdl = mx.sym.MakeLoss(avg_gdl,\n grad_scale=global_grad_scale * cfg.MODEL.GDL_LAMBDA,\n name=\"gdl\")\n else:\n avg_gdl = mx.sym.BlockGrad(avg_gdl, name=\"gdl\")\n loss = mx.sym.Group([avg_mse, avg_mae, avg_gdl])\n return loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initialising\n\nThis section demonstrates the settings to run \nthe pre-trained model.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# set the directory containing the pre-trained model configuration\n# set the path to '/../tests/models/hko/convgru' when using ConvGRU\nbase_dir = THIS_DIR + '/../tests/models/hko/trajgru'\n\ncfg_from_file(os.path.join(base_dir, 'cfg0.yml'), target=cfg.MODEL)\ncfg.MODEL.LOAD_DIR = base_dir\ncfg.MODEL.SAVE_DIR = base_dir\n\n# set the load_iter = 49999 when using ConvGRU\ncfg.MODEL.LOAD_ITER = 79999\n\n# no training on inference data: mode = 'fixed' and finetune = 0\n# train on inference data: mode = 'online' and finetune = 1\nmode = 'fixed'\nfinetune = 0\n\n# use GPU if CUDA installed; otherwise use CPU\nif locate_cuda():\n ctx = 'gpu'\nelse:\n ctx = 'cpu'\nctx = parse_ctx(ctx)\n\n# Initialize the model\nhko_nowcasting_online = HKONowcastingFactory(batch_size=1,\n in_seq_len=cfg.MODEL.IN_LEN,\n out_seq_len=cfg.MODEL.OUT_LEN)\n\n# Initialize the encoder network and forecaster network\nt_encoder_net, t_forecaster_net, t_loss_net =\\\n encoder_forecaster_build_networks(\n factory=hko_nowcasting_online,\n context=ctx,\n for_finetune=True)\nt_encoder_net.summary()\nt_forecaster_net.summary()\nt_loss_net.summary()\n\n# Load the pre-trained model params to the networks created\nload_encoder_forecaster_params(load_dir=cfg.MODEL.LOAD_DIR,\n load_iter=cfg.MODEL.LOAD_ITER,\n encoder_net=t_encoder_net,\n forecaster_net=t_forecaster_net)\n\n# 2019-04-20 10:30 to 2019-04-20 15:24\npd_path = cfg.HKO_PD.RAINY_EXAMPLE \n\n# Set the save directory\nsave_dir=os.path.join(base_dir, \"iter%d_%s_finetune%d\"\n % (cfg.MODEL.LOAD_ITER + 1, 'example',\n finetune)) \n\n# Create the environment to run the model\nenv = HKOBenchmarkEnv(pd_path=pd_path, save_dir=save_dir, mode=mode)\nstates = EncoderForecasterStates(factory=hko_nowcasting_online, ctx=ctx[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference and Evaluation\n\nThe pre-trained model loops through the radar echo maps\nand generate the inference over the next two hours.\nEvaluation of the inference is to compare the \ninference with the truth forecast data.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "stored_data = []\nstored_prediction = []\ncounter = 0\nfinetune_iter = 0\nwhile not env.done: \n states.reset_all()\n\n # Get radar data \n in_frame_dat, in_datetime_clips, out_datetime_clips, begin_new_episode, need_upload_prediction =\\\n env.get_observation(batch_size=1)\n \n in_frame_nd = nd.array(in_frame_dat, ctx=ctx[0])\n t_encoder_net.forward(is_train=False,\n data_batch=mx.io.DataBatch(data=[in_frame_nd] +\n states.get_encoder_states()))\n outputs = t_encoder_net.get_outputs()\n states.update(states_nd=outputs)\n if need_upload_prediction:\n counter += 1\n t_forecaster_net.forward(is_train=False,\n data_batch=mx.io.DataBatch(\n data=states.get_forecaster_state()))\n \n # Get inference from model\n pred_nd = t_forecaster_net.get_outputs()[0] \n pred_nd = nd.clip(pred_nd, a_min=0, a_max=1)\n\n # Generate prediction movies and inference evaluation\n env.upload_prediction(prediction=pred_nd.asnumpy())\n\nfingerprint = env._fingerprint\nprint(\"Saving videos to %s\" % os.path.join(save_dir, fingerprint))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Printing of Evaluation\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "env.save_eval(print_out=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generating radar reflectivity maps\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Step 1: Retrieve observation and prediction for one-hour \nand two-hour forecasts.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Ground truth at one-hour forecast\ngt_10 = env._out_frame_dat[9:10, 0, 0]\n\n# Ground truth at two-hour forecast\ngt_20 = env._out_frame_dat[19:, 0, 0] \n\n# translate pixel to dBZ\ngt_10 = pixel_to_dBZ(gt_10)\ngt_20 = pixel_to_dBZ(gt_20)\npred_nd = pixel_to_dBZ(pred_nd.asnumpy())\n\n# Prediction at one-hour forecast\npd_10 = pred_nd[9:10, 0, 0]\n\n# Prediction at two-hour forecast\npd_20 = pred_nd[19:, 0, 0]\n\n# Save observation and forecast in list\nreflec_np = []\nreflec_np.append(gt_10)\nreflec_np.append(pd_10)\nreflec_np.append(gt_20)\nreflec_np.append(pd_20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Step 2: Generate xarrays for one-hour and two-hour forecast.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Get forecast base time\nbasetime = in_datetime_clips[0][0]\n\n# Create a datetime list\ndatetime_list = []\ncurr = basetime\nfor i in range(len(reflec_np)):\n datetime_list.append(curr)\n curr = curr + datetime.timedelta(hours=1)\n\ny = []\ny_init = 255466.666667\nfor i in range(480):\n y.append(y_init)\n y_init -= 1066.666667\n\nx = []\nx_init = -255466.666667\nfor i in range(480):\n x.append(x_init)\n x_init += 1066.666667\n\n# Define source grid\narea_id = 'aeqd'\ndescription = 'Azimuthal Equidistant Projection centered at the radar site extending up to 256000.000000m in the x direction and256000.000000m in the y directionwith a 480.000000x480.000000 grid resolution'\nproj_id = 'aeqd'\nprojection = {'datum': 'WGS84', 'ellps': 'WGS84', 'lat_0': '22.411667', 'lon_0': '114.123333', 'no_defs': 'True', 'proj': 'aeqd', 'units': 'm'}\nwidth = 480\nheight = 480\narea_extent = (-256000.00000000003, -256000.00000000003, 256000.00000000003, 256000.00000000003)\narea_def = AreaDefinition(area_id, description, proj_id, projection, width, height, area_extent)\n\nreflec_xr_list = []\nfor i in range(len(reflec_np)):\n reflec_xr = xr.DataArray(\n reflec_np[i],\n dims=('time', 'y', 'x'),\n coords={\n 'time': [datetime_list[i]],\n 'y': y,\n 'x': x\n },\n attrs={\n 'long_name': 'Reflectivity',\n 'units': 'dBZ',\n 'projection': 'Centered Azimuthal',\n 'site': (114.12333341315389, 22.41166664287448, 948),\n 'radar_height': 20,\n 'area_def': area_def\n }\n )\n reflec_xr_list.append(reflec_xr) \n\n# Defining target grid\narea_id = \"hk1980_250km\"\ndescription = (\"A 1km resolution rectangular grid \"\n \"centred at HKO and extending to 250 km \"\n \"in each direction in HK1980 easting/northing coordinates\")\nproj_id = 'hk1980'\nprojection = ('+proj=tmerc +lat_0=22.31213333333334 '\n '+lon_0=114.1785555555556 +k=1 +x_0=836694.05 '\n '+y_0=819069.8 +ellps=intl +towgs84=-162.619,-276.959,'\n '-161.764,0.067753,-2.24365,-1.15883,-1.09425 +units=m '\n '+no_defs')\nx_size = 500\ny_size = 500\narea_extent = (587000, 569000, 1087000, 1069000)\narea_def_tgt = utils.get_area_def(\n area_id, description, proj_id, projection, x_size, y_size, area_extent\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Step 3: Reproject the radar data from read_iris_grid() from Centered\nAzimuthal (source) projection to HK 1980 (target) projection.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "reproj_xr_list = []\nfor i in range(len(reflec_xr_list)):\n reproj_xr = grid_resample(\n reflec_xr_list[i], area_def, area_def_tgt,\n coord_label=['easting', 'northing']\n ) \n reproj_xr_list.append(reproj_xr)\n\n# Combining the DataArrays\nreflec_concat = xr.concat(reproj_xr_list, dim='time')\nstandardize_attr(reflec_concat, frame_type=FrameType.dBZ, zero_value=9999.)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Step 4: Generating radar reflectivity maps\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Defining colour scale and format\nlevels = [\n -32768,\n 10, 15, 20, 24, 28, 32,\n 34, 38, 41, 44, 47, 50,\n 53, 56, 58, 60, 62\n]\ncmap = ListedColormap([\n '#FFFFFF', '#08C5F5', '#0091F3', '#3898FF', '#008243', '#00A433',\n '#00D100', '#01F508', '#77FF00', '#E0D100', '#FFDC01', '#EEB200',\n '#F08100', '#F00101', '#E20200', '#B40466', '#ED02F0'\n])\n\nnorm = BoundaryNorm(levels, ncolors=cmap.N, clip=True)\n\n# Defining the crs\ncrs = area_def_tgt.to_cartopy_crs()\n\n# Get one-hour forecast date time\ndatetime_one_hr = out_datetime_clips[0][9]\n\n# Get two-hour forecast date time\ndatetime_two_hr = out_datetime_clips[0][-1]\n\ntimelist = [\n datetime_one_hr,\n datetime_one_hr,\n datetime_two_hr,\n datetime_two_hr\n]\n\n# Obtaining the slice of the xarray to be plotted\nda_plot = reflec_concat\n\n# Defining coastlines\nmap_shape_file = os.path.join(THIS_DIR, \"./../tests/samples/shape/rsmc\")\nocean_color = np.array([[[178, 208, 254]]], dtype=np.uint8)\nland_color = cfeature.COLORS['land']\ncoastline = cfeature.ShapelyFeature(\n list(shpreader.Reader(map_shape_file).geometries()),\n ccrs.PlateCarree()\n)\n\n# Plotting\np = da_plot.plot(\n col='time', col_wrap=2,\n subplot_kws={'projection': crs},\n cbar_kwargs={\n 'extend': 'max',\n 'ticks': levels[1:],\n 'format': '%.3g'\n },\n cmap=cmap,\n norm=norm\n)\nfor idx, ax in enumerate(p.axes.flat):\n # ocean\n ax.imshow(np.tile(ocean_color, [2, 2, 1]),\n origin='upper',\n transform=ccrs.PlateCarree(),\n extent=[-180, 180, -180, 180],\n zorder=-1)\n # coastline, color\n ax.add_feature(coastline,\n facecolor=land_color, edgecolor='none', zorder=0)\n # overlay coastline without color\n ax.add_feature(coastline, facecolor='none',\n edgecolor='gray', linewidth=0.5, zorder=3)\n ax.gridlines()\n if idx == 0 or idx == 2:\n ax.set_title(\n \"Ground Truth\\n\"\n f\"Based @ {timelist[idx].strftime('%H:%MH')}\",\n loc='left',\n fontsize=9\n )\n ax.set_title(\n ''\n )\n ax.set_title(\n f\"{basetime.strftime('%Y-%m-%d')} \\n\"\n f\"Valid @ {timelist[idx].strftime('%H:%MH')} \",\n loc='right',\n fontsize=9\n )\n else:\n ax.set_title(\n \"Reflectivity\\n\"\n f\"Based @ {basetime.strftime('%H:%MH')}\",\n loc='left',\n fontsize=9\n )\n ax.set_title(\n ''\n )\n ax.set_title(\n f\"{basetime.strftime('%Y-%m-%d')} \\n\"\n f\"Valid @ {timelist[idx].strftime('%H:%MH')} \",\n loc='right',\n fontsize=9\n )\n\nplt.savefig(\n THIS_DIR +\n f\"/../tests/outputs/traj-output-map-hk.png\",\n dpi=300\n)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.15" } }, "nbformat": 4, "nbformat_minor": 0 }