""" TrajGRU (Hong Kong) ======================================================== This example demonstrates how to use pre-trained TrajGRU model to predict the next two-hour forecast using thirty-minute radar echo maps. """ ########################################################### # Definitions # -------------------------------------------------------- # ######################################################################## # Import all required modules and methods: # Python package to allow system command line functions import os # Python package to manage warning message import warnings # Python package for datetime function import datetime # Python package for xarrays to read and handle netcdf data import xarray as xr # Python package for deep learning import mxnet as mx import mxnet.ndarray as nd # Python package for projection description from pyresample.geometry import AreaDefinition from pyresample import get_area_def # Python package for projection import cartopy.crs as ccrs # Python package for land/sea features import cartopy.feature as cfeature # Python package for reading map shape file import cartopy.io.shapereader as shpreader # Python package for creating plots from matplotlib import pyplot as plt # Python package for colorbars from matplotlib.colors import BoundaryNorm, ListedColormap # swirlspy regrid function from swirlspy.core.resample import grid_resample # swirlspy standardize data function from swirlspy.utils import standardize_attr, FrameType # swirlspy cuda checking function from swirlspy.utils import locate_cuda # swirlspy deep learning function from swirlspy.qpf.dl.config import cfg, cfg_from_file from swirlspy.qpf.dl.utils import parse_ctx from swirlspy.qpf.dl.hko_benchmark import * from swirlspy.qpf.dl.encoder_forecaster import EncoderForecasterBaseFactory, encoder_forecaster_build_networks, EncoderForecasterStates, load_encoder_forecaster_params from swirlspy.qpf.dl.hko_evaluation import rainfall_to_pixel, pixel_to_dBZ from swirlspy.qpf.dl.operators import * from swirlspy.qpf.dl.ops import * # directory constants from swirlspy.tests.samples import DATA_DIR from swirlspy.tests.outputs import OUTPUT_DIR from swirlspy.tests.models import DL_MODEL_DIR warnings.filterwarnings("ignore") def get_loss_weight_symbol(data, mask, seq_len): if cfg.MODEL.USE_BALANCED_LOSS: balancing_weights = cfg.HKO.EVALUATION.BALANCING_WEIGHTS weights = mx.sym.ones_like(data) * balancing_weights[0] thresholds = [rainfall_to_pixel(ele) for ele in cfg.HKO.EVALUATION.THRESHOLDS] for i, threshold in enumerate(thresholds): weights = weights + \ (balancing_weights[i + 1] - balancing_weights[i]) * (data >= threshold) weights = weights * mask else: weights = mask if cfg.MODEL.TEMPORAL_WEIGHT_TYPE == "same": return weights elif cfg.MODEL.TEMPORAL_WEIGHT_TYPE == "linear": upper = cfg.MODEL.TEMPORAL_WEIGHT_UPPER assert upper >= 1.0 temporal_mult = 1 + \ mx.sym.arange(start=0, stop=seq_len) * \ (upper - 1.0) / (seq_len - 1.0) temporal_mult = mx.sym.reshape( temporal_mult, shape=(seq_len, 1, 1, 1, 1)) weights = mx.sym.broadcast_mul(weights, temporal_mult) return weights elif cfg.MODEL.TEMPORAL_WEIGHT_TYPE == "exponential": upper = cfg.MODEL.TEMPORAL_WEIGHT_UPPER assert upper >= 1.0 base_factor = np.log(upper) / (seq_len - 1.0) temporal_mult = mx.sym.exp(mx.sym.arange( start=0, stop=seq_len) * base_factor) temporal_mult = mx.sym.reshape( temporal_mult, shape=(seq_len, 1, 1, 1, 1)) weights = mx.sym.broadcast_mul(weights, temporal_mult) return weights else: raise NotImplementedError class HKONowcastingFactory(EncoderForecasterBaseFactory): def __init__(self, batch_size, in_seq_len, out_seq_len, ctx_num=1, name="hko_nowcasting"): super(HKONowcastingFactory, self).__init__(batch_size=batch_size, in_seq_len=in_seq_len, out_seq_len=out_seq_len, ctx_num=ctx_num, height=cfg.HKO.ITERATOR.HEIGHT, width=cfg.HKO.ITERATOR.WIDTH, name=name) self._central_region = cfg.HKO.EVALUATION.CENTRAL_REGION def _slice_central(self, data): """Slice the central region in the given symbol Parameters ---------- data : mx.sym.Symbol Returns ------- ret : mx.sym.Symbol """ x_begin, y_begin, x_end, y_end = self._central_region return mx.sym.slice(data, begin=(0, 0, 0, y_begin, x_begin), end=(None, None, None, y_end, x_end)) def _concat_month_code(self): # TODO raise NotImplementedError def loss_sym(self, pred=mx.sym.Variable('pred'), mask=mx.sym.Variable('mask'), target=mx.sym.Variable('target')): """Construct loss symbol. Optional args: pred: Shape (out_seq_len, batch_size, C, H, W) mask: Shape (out_seq_len, batch_size, C, H, W) target: Shape (out_seq_len, batch_size, C, H, W) """ self.reset_all() weights = get_loss_weight_symbol( data=target, mask=mask, seq_len=self._out_seq_len) mse = weighted_mse(pred=pred, gt=target, weight=weights) mae = weighted_mae(pred=pred, gt=target, weight=weights) gdl = masked_gdl_loss(pred=pred, gt=target, mask=mask) avg_mse = mx.sym.mean(mse) avg_mae = mx.sym.mean(mae) avg_gdl = mx.sym.mean(gdl) global_grad_scale = cfg.MODEL.NORMAL_LOSS_GLOBAL_SCALE if cfg.MODEL.L2_LAMBDA > 0: avg_mse = mx.sym.MakeLoss(avg_mse, grad_scale=global_grad_scale * cfg.MODEL.L2_LAMBDA, name="mse") else: avg_mse = mx.sym.BlockGrad(avg_mse, name="mse") if cfg.MODEL.L1_LAMBDA > 0: avg_mae = mx.sym.MakeLoss(avg_mae, grad_scale=global_grad_scale * cfg.MODEL.L1_LAMBDA, name="mae") else: avg_mae = mx.sym.BlockGrad(avg_mae, name="mae") if cfg.MODEL.GDL_LAMBDA > 0: avg_gdl = mx.sym.MakeLoss(avg_gdl, grad_scale=global_grad_scale * cfg.MODEL.GDL_LAMBDA, name="gdl") else: avg_gdl = mx.sym.BlockGrad(avg_gdl, name="gdl") loss = mx.sym.Group([avg_mse, avg_mae, avg_gdl]) return loss ############################################################# # Initialising # --------------------------------------------------- # # This section demonstrates the settings to run # the pre-trained model. # # set the directory containing the pre-trained model configuration # set the path to '[DL_MODEL_DIR]/hko/convgru' when using ConvGRU base_dir = os.path.join(DL_MODEL_DIR, 'hko/trajgru') cfg_from_file(os.path.join(base_dir, 'cfg0.yml'), target=cfg.MODEL) cfg.MODEL.LOAD_DIR = base_dir cfg.MODEL.SAVE_DIR = base_dir # set the load_iter = 49999 when using ConvGRU cfg.MODEL.LOAD_ITER = 79999 # no training on inference data: mode = 'fixed' and finetune = 0 # train on inference data: mode = 'online' and finetune = 1 mode = 'fixed' finetune = 0 # use GPU if CUDA installed; otherwise use CPU if locate_cuda(): ctx = 'gpu' else: ctx = 'cpu' ctx = parse_ctx(ctx) # Initialize the model hko_nowcasting_online = HKONowcastingFactory(batch_size=1, in_seq_len=cfg.MODEL.IN_LEN, out_seq_len=cfg.MODEL.OUT_LEN) # Initialize the encoder network and forecaster network t_encoder_net, t_forecaster_net, t_loss_net =\ encoder_forecaster_build_networks( factory=hko_nowcasting_online, context=ctx, for_finetune=True) t_encoder_net.summary() t_forecaster_net.summary() t_loss_net.summary() # Load the pre-trained model params to the networks created load_encoder_forecaster_params(load_dir=cfg.MODEL.LOAD_DIR, load_iter=cfg.MODEL.LOAD_ITER, encoder_net=t_encoder_net, forecaster_net=t_forecaster_net) # 2019-04-20 10:30 to 2019-04-20 15:24 pd_path = cfg.HKO_PD.RAINY_EXAMPLE # Set the save directory save_dir = os.path.join(base_dir, "iter%d_%s_finetune%d" % (cfg.MODEL.LOAD_ITER + 1, 'example', finetune)) # Create the environment to run the model env = HKOBenchmarkEnv(pd_path=pd_path, save_dir=save_dir, mode=mode) states = EncoderForecasterStates(factory=hko_nowcasting_online, ctx=ctx[0]) ############################################################# # Inference and Evaluation # --------------------------------------------------- # # The pre-trained model loops through the radar echo maps # and generate the inference over the next two hours. # Evaluation of the inference is to compare the # inference with the truth forecast data. # stored_data = [] stored_prediction = [] counter = 0 finetune_iter = 0 while not env.done: states.reset_all() # Get radar data in_frame_dat, in_datetime_clips, out_datetime_clips, begin_new_episode, need_upload_prediction =\ env.get_observation(batch_size=1) in_frame_nd = nd.array(in_frame_dat, ctx=ctx[0]) t_encoder_net.forward(is_train=False, data_batch=mx.io.DataBatch(data=[in_frame_nd] + states.get_encoder_states())) outputs = t_encoder_net.get_outputs() states.update(states_nd=outputs) if need_upload_prediction: counter += 1 t_forecaster_net.forward(is_train=False, data_batch=mx.io.DataBatch( data=states.get_forecaster_state())) # Get inference from model pred_nd = t_forecaster_net.get_outputs()[0] pred_nd = nd.clip(pred_nd, a_min=0, a_max=1) # Generate prediction movies and inference evaluation env.upload_prediction(prediction=pred_nd.asnumpy()) fingerprint = env._fingerprint print("Saving videos to %s" % os.path.join(save_dir, fingerprint)) ###################################################################### # Printing of Evaluation # -------------------------------------------------------------------- # env.save_eval(print_out=True) ############################################# # Generating radar reflectivity maps # ----------------------------------- # ###################################################################### # Step 1: Retrieve observation and prediction for one-hour # and two-hour forecasts. # Ground truth at one-hour forecast gt_10 = env._out_frame_dat[9:10, 0, 0] # Ground truth at two-hour forecast gt_20 = env._out_frame_dat[19:, 0, 0] # translate pixel to dBZ gt_10 = pixel_to_dBZ(gt_10) gt_20 = pixel_to_dBZ(gt_20) pred_nd = pixel_to_dBZ(pred_nd.asnumpy()) # Prediction at one-hour forecast pd_10 = pred_nd[9:10, 0, 0] # Prediction at two-hour forecast pd_20 = pred_nd[19:, 0, 0] # Save observation and forecast in list reflec_np = [] reflec_np.append(gt_10) reflec_np.append(pd_10) reflec_np.append(gt_20) reflec_np.append(pd_20) ########################################################################### # Step 2: Generate xarrays for one-hour and two-hour forecast. # Get forecast base time basetime = in_datetime_clips[0][0] # Create a datetime list datetime_list = [] curr = basetime for i in range(len(reflec_np)): datetime_list.append(curr) curr = curr + datetime.timedelta(hours=1) y = [] y_init = 255466.666667 for i in range(480): y.append(y_init) y_init -= 1066.666667 x = [] x_init = -255466.666667 for i in range(480): x.append(x_init) x_init += 1066.666667 # Define source grid area_id = 'aeqd' description = 'Azimuthal Equidistant Projection centered at the radar site extending up to 256000.000000m in the x direction and256000.000000m in the y directionwith a 480.000000x480.000000 grid resolution' proj_id = 'aeqd' projection = {'datum': 'WGS84', 'ellps': 'WGS84', 'lat_0': '22.411667', 'lon_0': '114.123333', 'no_defs': 'True', 'proj': 'aeqd', 'units': 'm'} width = 480 height = 480 area_extent = (-256000.00000000003, -256000.00000000003, 256000.00000000003, 256000.00000000003) area_def = AreaDefinition(area_id, description, proj_id, projection, width, height, area_extent) reflec_xr_list = [] for i in range(len(reflec_np)): reflec_xr = xr.DataArray( reflec_np[i], dims=('time', 'y', 'x'), coords={ 'time': [datetime_list[i]], 'y': y, 'x': x }, attrs={ 'long_name': 'Reflectivity', 'units': 'dBZ', 'projection': 'Centered Azimuthal', 'site': (114.12333341315389, 22.41166664287448, 948), 'radar_height': 20, 'area_def': area_def } ) reflec_xr_list.append(reflec_xr) # Defining target grid area_id = "hk1980_250km" description = ("A 1km resolution rectangular grid " "centred at HKO and extending to 250 km " "in each direction in HK1980 easting/northing coordinates") proj_id = 'hk1980' projection = ('+proj=tmerc +lat_0=22.31213333333334 ' '+lon_0=114.1785555555556 +k=1 +x_0=836694.05 ' '+y_0=819069.8 +ellps=intl +towgs84=-162.619,-276.959,' '-161.764,0.067753,-2.24365,-1.15883,-1.09425 +units=m ' '+no_defs') x_size = 500 y_size = 500 area_extent = (587000, 569000, 1087000, 1069000) area_def_tgt = get_area_def( area_id, description, proj_id, projection, x_size, y_size, area_extent ) ############################################################################## # Step 3: Reproject the radar data from read_iris_grid() from Centered # Azimuthal (source) projection to HK 1980 (target) projection. reproj_xr_list = [] for i in range(len(reflec_xr_list)): reproj_xr = grid_resample( reflec_xr_list[i], area_def, area_def_tgt, coord_label=['easting', 'northing'] ) reproj_xr_list.append(reproj_xr) # Combining the DataArrays reflec_concat = xr.concat(reproj_xr_list, dim='time') standardize_attr(reflec_concat, frame_type=FrameType.dBZ, zero_value=9999.) ############################################# # Step 4: Generating radar reflectivity maps # Defining colour scale and format levels = [ -32768, 10, 15, 20, 24, 28, 32, 34, 38, 41, 44, 47, 50, 53, 56, 58, 60, 62 ] cmap = ListedColormap([ '#FFFFFF', '#08C5F5', '#0091F3', '#3898FF', '#008243', '#00A433', '#00D100', '#01F508', '#77FF00', '#E0D100', '#FFDC01', '#EEB200', '#F08100', '#F00101', '#E20200', '#B40466', '#ED02F0' ]) norm = BoundaryNorm(levels, ncolors=cmap.N, clip=True) # Defining the crs crs = area_def_tgt.to_cartopy_crs() # Get one-hour forecast date time datetime_one_hr = out_datetime_clips[0][9] # Get two-hour forecast date time datetime_two_hr = out_datetime_clips[0][-1] timelist = [ datetime_one_hr, datetime_one_hr, datetime_two_hr, datetime_two_hr ] # Obtaining the slice of the xarray to be plotted da_plot = reflec_concat # Defining coastlines map_shape_file = os.path.join(DATA_DIR, "shape/rsmc") ocean_color = np.array([[[178, 208, 254]]], dtype=np.uint8) land_color = cfeature.COLORS['land'] coastline = cfeature.ShapelyFeature( list(shpreader.Reader(map_shape_file).geometries()), ccrs.PlateCarree() ) # Plotting p = da_plot.plot( col='time', col_wrap=2, subplot_kws={'projection': crs}, cbar_kwargs={ 'extend': 'max', 'ticks': levels[1:], 'format': '%.3g' }, cmap=cmap, norm=norm ) for idx, ax in enumerate(p.axes.flat): # ocean ax.imshow(np.tile(ocean_color, [2, 2, 1]), origin='upper', transform=ccrs.PlateCarree(), extent=[-180, 180, -180, 180], zorder=-1) # coastline, color ax.add_feature(coastline, facecolor=land_color, edgecolor='none', zorder=0) # overlay coastline without color ax.add_feature(coastline, facecolor='none', edgecolor='gray', linewidth=0.5, zorder=3) ax.gridlines() if idx == 0 or idx == 2: ax.set_title( "Ground Truth\n" f"Based @ {timelist[idx].strftime('%H:%MH')}", loc='left', fontsize=9 ) ax.set_title( '' ) ax.set_title( f"{basetime.strftime('%Y-%m-%d')} \n" f"Valid @ {timelist[idx].strftime('%H:%MH')} ", loc='right', fontsize=9 ) else: ax.set_title( "Reflectivity\n" f"Based @ {basetime.strftime('%H:%MH')}", loc='left', fontsize=9 ) ax.set_title( '' ) ax.set_title( f"{basetime.strftime('%Y-%m-%d')} \n" f"Valid @ {timelist[idx].strftime('%H:%MH')} ", loc='right', fontsize=9 ) plt.savefig( os.path.join(OUTPUT_DIR, 'traj-output-map-hk.png'), dpi=300 )