""" TrajGRU (Hong Kong) ======================================================== This example demonstrates how to use pre-trained TrajGRU model to predict the next two-hour forecast using thirty-minute radar echo maps. """ ########################################################### # Definitions # -------------------------------------------------------- # import os import warnings import datetime import mxnet as mx import mxnet.ndarray as nd from pyresample.geometry import AreaDefinition from pyresample import utils import xarray as xr import cartopy.feature as cfeature import cartopy.crs as ccrs import cartopy.io.shapereader as shpreader import matplotlib.pyplot as plt from matplotlib.colors import BoundaryNorm, ListedColormap from swirlspy.core.resample import grid_resample from swirlspy.utils import locate_cuda, standardize_attr, FrameType from swirlspy.qpf.dl.config import cfg, save_cfg, cfg_from_file from swirlspy.qpf.dl.utils import logging_config, parse_ctx from swirlspy.qpf.dl.hko_benchmark import * from swirlspy.qpf.dl.encoder_forecaster import EncoderForecasterBaseFactory, encoder_forecaster_build_networks, EncoderForecasterStates, load_encoder_forecaster_params from swirlspy.qpf.dl.hko_evaluation import rainfall_to_pixel, pixel_to_dBZ from swirlspy.qpf.dl.operators import * from swirlspy.qpf.dl.ops import * warnings.filterwarnings("ignore") THIS_DIR = os.getcwd() os.chdir(THIS_DIR) def get_loss_weight_symbol(data, mask, seq_len): if cfg.MODEL.USE_BALANCED_LOSS: balancing_weights = cfg.HKO.EVALUATION.BALANCING_WEIGHTS weights = mx.sym.ones_like(data) * balancing_weights[0] thresholds = [rainfall_to_pixel(ele) for ele in cfg.HKO.EVALUATION.THRESHOLDS] for i, threshold in enumerate(thresholds): weights = weights + (balancing_weights[i + 1] - balancing_weights[i]) * (data >= threshold) weights = weights * mask else: weights = mask if cfg.MODEL.TEMPORAL_WEIGHT_TYPE == "same": return weights elif cfg.MODEL.TEMPORAL_WEIGHT_TYPE == "linear": upper = cfg.MODEL.TEMPORAL_WEIGHT_UPPER assert upper >= 1.0 temporal_mult = 1 + \ mx.sym.arange(start=0, stop=seq_len) * (upper - 1.0) / (seq_len - 1.0) temporal_mult = mx.sym.reshape(temporal_mult, shape=(seq_len, 1, 1, 1, 1)) weights = mx.sym.broadcast_mul(weights, temporal_mult) return weights elif cfg.MODEL.TEMPORAL_WEIGHT_TYPE == "exponential": upper = cfg.MODEL.TEMPORAL_WEIGHT_UPPER assert upper >= 1.0 base_factor = np.log(upper) / (seq_len - 1.0) temporal_mult = mx.sym.exp(mx.sym.arange(start=0, stop=seq_len) * base_factor) temporal_mult = mx.sym.reshape(temporal_mult, shape=(seq_len, 1, 1, 1, 1)) weights = mx.sym.broadcast_mul(weights, temporal_mult) return weights else: raise NotImplementedError class HKONowcastingFactory(EncoderForecasterBaseFactory): def __init__(self, batch_size, in_seq_len, out_seq_len, ctx_num=1, name="hko_nowcasting"): super(HKONowcastingFactory, self).__init__(batch_size=batch_size, in_seq_len=in_seq_len, out_seq_len=out_seq_len, ctx_num=ctx_num, height=cfg.HKO.ITERATOR.HEIGHT, width=cfg.HKO.ITERATOR.WIDTH, name=name) self._central_region = cfg.HKO.EVALUATION.CENTRAL_REGION def _slice_central(self, data): """Slice the central region in the given symbol Parameters ---------- data : mx.sym.Symbol Returns ------- ret : mx.sym.Symbol """ x_begin, y_begin, x_end, y_end = self._central_region return mx.sym.slice(data, begin=(0, 0, 0, y_begin, x_begin), end=(None, None, None, y_end, x_end)) def _concat_month_code(self): # TODO raise NotImplementedError def loss_sym(self, pred=mx.sym.Variable('pred'), mask=mx.sym.Variable('mask'), target=mx.sym.Variable('target')): """Construct loss symbol. Optional args: pred: Shape (out_seq_len, batch_size, C, H, W) mask: Shape (out_seq_len, batch_size, C, H, W) target: Shape (out_seq_len, batch_size, C, H, W) """ self.reset_all() weights = get_loss_weight_symbol(data=target, mask=mask, seq_len=self._out_seq_len) mse = weighted_mse(pred=pred, gt=target, weight=weights) mae = weighted_mae(pred=pred, gt=target, weight=weights) gdl = masked_gdl_loss(pred=pred, gt=target, mask=mask) avg_mse = mx.sym.mean(mse) avg_mae = mx.sym.mean(mae) avg_gdl = mx.sym.mean(gdl) global_grad_scale = cfg.MODEL.NORMAL_LOSS_GLOBAL_SCALE if cfg.MODEL.L2_LAMBDA > 0: avg_mse = mx.sym.MakeLoss(avg_mse, grad_scale=global_grad_scale * cfg.MODEL.L2_LAMBDA, name="mse") else: avg_mse = mx.sym.BlockGrad(avg_mse, name="mse") if cfg.MODEL.L1_LAMBDA > 0: avg_mae = mx.sym.MakeLoss(avg_mae, grad_scale=global_grad_scale * cfg.MODEL.L1_LAMBDA, name="mae") else: avg_mae = mx.sym.BlockGrad(avg_mae, name="mae") if cfg.MODEL.GDL_LAMBDA > 0: avg_gdl = mx.sym.MakeLoss(avg_gdl, grad_scale=global_grad_scale * cfg.MODEL.GDL_LAMBDA, name="gdl") else: avg_gdl = mx.sym.BlockGrad(avg_gdl, name="gdl") loss = mx.sym.Group([avg_mse, avg_mae, avg_gdl]) return loss ############################################################# # Initialising # --------------------------------------------------- # # This section demonstrates the settings to run # the pre-trained model. # # set the directory containing the pre-trained model configuration # set the path to '/../tests/models/hko/convgru' when using ConvGRU base_dir = THIS_DIR + '/../tests/models/hko/trajgru' cfg_from_file(os.path.join(base_dir, 'cfg0.yml'), target=cfg.MODEL) cfg.MODEL.LOAD_DIR = base_dir cfg.MODEL.SAVE_DIR = base_dir # set the load_iter = 49999 when using ConvGRU cfg.MODEL.LOAD_ITER = 79999 # no training on inference data: mode = 'fixed' and finetune = 0 # train on inference data: mode = 'online' and finetune = 1 mode = 'fixed' finetune = 0 # use GPU if CUDA installed; otherwise use CPU if locate_cuda(): ctx = 'gpu' else: ctx = 'cpu' ctx = parse_ctx(ctx) # Initialize the model hko_nowcasting_online = HKONowcastingFactory(batch_size=1, in_seq_len=cfg.MODEL.IN_LEN, out_seq_len=cfg.MODEL.OUT_LEN) # Initialize the encoder network and forecaster network t_encoder_net, t_forecaster_net, t_loss_net =\ encoder_forecaster_build_networks( factory=hko_nowcasting_online, context=ctx, for_finetune=True) t_encoder_net.summary() t_forecaster_net.summary() t_loss_net.summary() # Load the pre-trained model params to the networks created load_encoder_forecaster_params(load_dir=cfg.MODEL.LOAD_DIR, load_iter=cfg.MODEL.LOAD_ITER, encoder_net=t_encoder_net, forecaster_net=t_forecaster_net) # 2019-04-20 10:30 to 2019-04-20 15:24 pd_path = cfg.HKO_PD.RAINY_EXAMPLE # Set the save directory save_dir=os.path.join(base_dir, "iter%d_%s_finetune%d" % (cfg.MODEL.LOAD_ITER + 1, 'example', finetune)) # Create the environment to run the model env = HKOBenchmarkEnv(pd_path=pd_path, save_dir=save_dir, mode=mode) states = EncoderForecasterStates(factory=hko_nowcasting_online, ctx=ctx[0]) ############################################################# # Inference and Evaluation # --------------------------------------------------- # # The pre-trained model loops through the radar echo maps # and generate the inference over the next two hours. # Evaluation of the inference is to compare the # inference with the truth forecast data. # stored_data = [] stored_prediction = [] counter = 0 finetune_iter = 0 while not env.done: states.reset_all() # Get radar data in_frame_dat, in_datetime_clips, out_datetime_clips, begin_new_episode, need_upload_prediction =\ env.get_observation(batch_size=1) in_frame_nd = nd.array(in_frame_dat, ctx=ctx[0]) t_encoder_net.forward(is_train=False, data_batch=mx.io.DataBatch(data=[in_frame_nd] + states.get_encoder_states())) outputs = t_encoder_net.get_outputs() states.update(states_nd=outputs) if need_upload_prediction: counter += 1 t_forecaster_net.forward(is_train=False, data_batch=mx.io.DataBatch( data=states.get_forecaster_state())) # Get inference from model pred_nd = t_forecaster_net.get_outputs()[0] pred_nd = nd.clip(pred_nd, a_min=0, a_max=1) # Generate prediction movies and inference evaluation env.upload_prediction(prediction=pred_nd.asnumpy()) fingerprint = env._fingerprint print("Saving videos to %s" % os.path.join(save_dir, fingerprint)) ###################################################################### # Printing of Evaluation # -------------------------------------------------------------------- # env.save_eval(print_out=True) ############################################# # Generating radar reflectivity maps # ----------------------------------- # ###################################################################### # Step 1: Retrieve observation and prediction for one-hour # and two-hour forecasts. # Ground truth at one-hour forecast gt_10 = env._out_frame_dat[9:10, 0, 0] # Ground truth at two-hour forecast gt_20 = env._out_frame_dat[19:, 0, 0] # translate pixel to dBZ gt_10 = pixel_to_dBZ(gt_10) gt_20 = pixel_to_dBZ(gt_20) pred_nd = pixel_to_dBZ(pred_nd.asnumpy()) # Prediction at one-hour forecast pd_10 = pred_nd[9:10, 0, 0] # Prediction at two-hour forecast pd_20 = pred_nd[19:, 0, 0] # Save observation and forecast in list reflec_np = [] reflec_np.append(gt_10) reflec_np.append(pd_10) reflec_np.append(gt_20) reflec_np.append(pd_20) ########################################################################### # Step 2: Generate xarrays for one-hour and two-hour forecast. # Get forecast base time basetime = in_datetime_clips[0][0] # Create a datetime list datetime_list = [] curr = basetime for i in range(len(reflec_np)): datetime_list.append(curr) curr = curr + datetime.timedelta(hours=1) y = [] y_init = 255466.666667 for i in range(480): y.append(y_init) y_init -= 1066.666667 x = [] x_init = -255466.666667 for i in range(480): x.append(x_init) x_init += 1066.666667 # Define source grid area_id = 'aeqd' description = 'Azimuthal Equidistant Projection centered at the radar site extending up to 256000.000000m in the x direction and256000.000000m in the y directionwith a 480.000000x480.000000 grid resolution' proj_id = 'aeqd' projection = {'datum': 'WGS84', 'ellps': 'WGS84', 'lat_0': '22.411667', 'lon_0': '114.123333', 'no_defs': 'True', 'proj': 'aeqd', 'units': 'm'} width = 480 height = 480 area_extent = (-256000.00000000003, -256000.00000000003, 256000.00000000003, 256000.00000000003) area_def = AreaDefinition(area_id, description, proj_id, projection, width, height, area_extent) reflec_xr_list = [] for i in range(len(reflec_np)): reflec_xr = xr.DataArray( reflec_np[i], dims=('time', 'y', 'x'), coords={ 'time': [datetime_list[i]], 'y': y, 'x': x }, attrs={ 'long_name': 'Reflectivity', 'units': 'dBZ', 'projection': 'Centered Azimuthal', 'site': (114.12333341315389, 22.41166664287448, 948), 'radar_height': 20, 'area_def': area_def } ) reflec_xr_list.append(reflec_xr) # Defining target grid area_id = "hk1980_250km" description = ("A 1km resolution rectangular grid " "centred at HKO and extending to 250 km " "in each direction in HK1980 easting/northing coordinates") proj_id = 'hk1980' projection = ('+proj=tmerc +lat_0=22.31213333333334 ' '+lon_0=114.1785555555556 +k=1 +x_0=836694.05 ' '+y_0=819069.8 +ellps=intl +towgs84=-162.619,-276.959,' '-161.764,0.067753,-2.24365,-1.15883,-1.09425 +units=m ' '+no_defs') x_size = 500 y_size = 500 area_extent = (587000, 569000, 1087000, 1069000) area_def_tgt = utils.get_area_def( area_id, description, proj_id, projection, x_size, y_size, area_extent ) ############################################################################## # Step 3: Reproject the radar data from read_iris_grid() from Centered # Azimuthal (source) projection to HK 1980 (target) projection. reproj_xr_list = [] for i in range(len(reflec_xr_list)): reproj_xr = grid_resample( reflec_xr_list[i], area_def, area_def_tgt, coord_label=['easting', 'northing'] ) reproj_xr_list.append(reproj_xr) # Combining the DataArrays reflec_concat = xr.concat(reproj_xr_list, dim='time') standardize_attr(reflec_concat, frame_type=FrameType.dBZ, zero_value=9999.) ############################################# # Step 4: Generating radar reflectivity maps # Defining colour scale and format levels = [ -32768, 10, 15, 20, 24, 28, 32, 34, 38, 41, 44, 47, 50, 53, 56, 58, 60, 62 ] cmap = ListedColormap([ '#FFFFFF', '#08C5F5', '#0091F3', '#3898FF', '#008243', '#00A433', '#00D100', '#01F508', '#77FF00', '#E0D100', '#FFDC01', '#EEB200', '#F08100', '#F00101', '#E20200', '#B40466', '#ED02F0' ]) norm = BoundaryNorm(levels, ncolors=cmap.N, clip=True) # Defining the crs crs = area_def_tgt.to_cartopy_crs() # Get one-hour forecast date time datetime_one_hr = out_datetime_clips[0][9] # Get two-hour forecast date time datetime_two_hr = out_datetime_clips[0][-1] timelist = [ datetime_one_hr, datetime_one_hr, datetime_two_hr, datetime_two_hr ] # Obtaining the slice of the xarray to be plotted da_plot = reflec_concat # Defining coastlines map_shape_file = os.path.join(THIS_DIR, "./../tests/samples/shape/rsmc") ocean_color = np.array([[[178, 208, 254]]], dtype=np.uint8) land_color = cfeature.COLORS['land'] coastline = cfeature.ShapelyFeature( list(shpreader.Reader(map_shape_file).geometries()), ccrs.PlateCarree() ) # Plotting p = da_plot.plot( col='time', col_wrap=2, subplot_kws={'projection': crs}, cbar_kwargs={ 'extend': 'max', 'ticks': levels[1:], 'format': '%.3g' }, cmap=cmap, norm=norm ) for idx, ax in enumerate(p.axes.flat): # ocean ax.imshow(np.tile(ocean_color, [2, 2, 1]), origin='upper', transform=ccrs.PlateCarree(), extent=[-180, 180, -180, 180], zorder=-1) # coastline, color ax.add_feature(coastline, facecolor=land_color, edgecolor='none', zorder=0) # overlay coastline without color ax.add_feature(coastline, facecolor='none', edgecolor='gray', linewidth=0.5, zorder=3) ax.gridlines() if idx == 0 or idx == 2: ax.set_title( "Ground Truth\n" f"Based @ {timelist[idx].strftime('%H:%MH')}", loc='left', fontsize=9 ) ax.set_title( '' ) ax.set_title( f"{basetime.strftime('%Y-%m-%d')} \n" f"Valid @ {timelist[idx].strftime('%H:%MH')} ", loc='right', fontsize=9 ) else: ax.set_title( "Reflectivity\n" f"Based @ {basetime.strftime('%H:%MH')}", loc='left', fontsize=9 ) ax.set_title( '' ) ax.set_title( f"{basetime.strftime('%Y-%m-%d')} \n" f"Valid @ {timelist[idx].strftime('%H:%MH')} ", loc='right', fontsize=9 ) plt.savefig( THIS_DIR + f"/../tests/outputs/traj-output-map-hk.png", dpi=300 )