TrajGRU (Hong Kong)

This example demonstrates how to use pre-trained TrajGRU model to predict the next two-hour forecast using thirty-minute radar echo maps.

Definitions

import os
import warnings
import datetime
import mxnet as mx
import mxnet.ndarray as nd
from pyresample.geometry import AreaDefinition
from pyresample import utils
import xarray as xr
import cartopy.feature as cfeature
import cartopy.crs as ccrs
import cartopy.io.shapereader as shpreader
import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm, ListedColormap

from swirlspy.core.resample import grid_resample
from swirlspy.utils import locate_cuda, standardize_attr, FrameType
from swirlspy.qpf.dl.config import cfg, save_cfg, cfg_from_file
from swirlspy.qpf.dl.utils import logging_config, parse_ctx
from swirlspy.qpf.dl.hko_benchmark import *
from swirlspy.qpf.dl.encoder_forecaster import EncoderForecasterBaseFactory, encoder_forecaster_build_networks, EncoderForecasterStates, load_encoder_forecaster_params
from swirlspy.qpf.dl.hko_evaluation import rainfall_to_pixel, pixel_to_dBZ
from swirlspy.qpf.dl.operators import *
from swirlspy.qpf.dl.ops import *

warnings.filterwarnings("ignore")
THIS_DIR = os.getcwd()
os.chdir(THIS_DIR)

def get_loss_weight_symbol(data, mask, seq_len):
    if cfg.MODEL.USE_BALANCED_LOSS:
        balancing_weights = cfg.HKO.EVALUATION.BALANCING_WEIGHTS
        weights = mx.sym.ones_like(data) * balancing_weights[0]
        thresholds = [rainfall_to_pixel(ele) for ele in cfg.HKO.EVALUATION.THRESHOLDS]
        for i, threshold in enumerate(thresholds):
            weights = weights + (balancing_weights[i + 1] - balancing_weights[i]) * (data >= threshold)
        weights = weights * mask
    else:
        weights = mask
    if cfg.MODEL.TEMPORAL_WEIGHT_TYPE == "same":
        return weights
    elif cfg.MODEL.TEMPORAL_WEIGHT_TYPE == "linear":
        upper = cfg.MODEL.TEMPORAL_WEIGHT_UPPER
        assert upper >= 1.0
        temporal_mult = 1 + \
                        mx.sym.arange(start=0, stop=seq_len) * (upper - 1.0) / (seq_len - 1.0)
        temporal_mult = mx.sym.reshape(temporal_mult, shape=(seq_len, 1, 1, 1, 1))
        weights = mx.sym.broadcast_mul(weights, temporal_mult)
        return weights
    elif cfg.MODEL.TEMPORAL_WEIGHT_TYPE == "exponential":
        upper = cfg.MODEL.TEMPORAL_WEIGHT_UPPER
        assert upper >= 1.0
        base_factor = np.log(upper) / (seq_len - 1.0)
        temporal_mult = mx.sym.exp(mx.sym.arange(start=0, stop=seq_len) * base_factor)
        temporal_mult = mx.sym.reshape(temporal_mult, shape=(seq_len, 1, 1, 1, 1))
        weights = mx.sym.broadcast_mul(weights, temporal_mult)
        return weights
    else:
        raise NotImplementedError

class HKONowcastingFactory(EncoderForecasterBaseFactory):
    def __init__(self,
                 batch_size,
                 in_seq_len,
                 out_seq_len,
                 ctx_num=1,
                 name="hko_nowcasting"):
        super(HKONowcastingFactory, self).__init__(batch_size=batch_size,
                                                   in_seq_len=in_seq_len,
                                                   out_seq_len=out_seq_len,
                                                   ctx_num=ctx_num,
                                                   height=cfg.HKO.ITERATOR.HEIGHT,
                                                   width=cfg.HKO.ITERATOR.WIDTH,
                                                   name=name)
        self._central_region = cfg.HKO.EVALUATION.CENTRAL_REGION

    def _slice_central(self, data):
        """Slice the central region in the given symbol

        Parameters
        ----------
        data : mx.sym.Symbol

        Returns
        -------
        ret : mx.sym.Symbol
        """
        x_begin, y_begin, x_end, y_end = self._central_region
        return mx.sym.slice(data,
                            begin=(0, 0, 0, y_begin, x_begin),
                            end=(None, None, None, y_end, x_end))

    def _concat_month_code(self):
        # TODO
        raise NotImplementedError

    def loss_sym(self,
                 pred=mx.sym.Variable('pred'),
                 mask=mx.sym.Variable('mask'),
                 target=mx.sym.Variable('target')):
        """Construct loss symbol.

        Optional args:
            pred: Shape (out_seq_len, batch_size, C, H, W)
            mask: Shape (out_seq_len, batch_size, C, H, W)
            target: Shape (out_seq_len, batch_size, C, H, W)
        """
        self.reset_all()
        weights = get_loss_weight_symbol(data=target, mask=mask, seq_len=self._out_seq_len)
        mse = weighted_mse(pred=pred, gt=target, weight=weights)
        mae = weighted_mae(pred=pred, gt=target, weight=weights)
        gdl = masked_gdl_loss(pred=pred, gt=target, mask=mask)
        avg_mse = mx.sym.mean(mse)
        avg_mae = mx.sym.mean(mae)
        avg_gdl = mx.sym.mean(gdl)
        global_grad_scale = cfg.MODEL.NORMAL_LOSS_GLOBAL_SCALE
        if cfg.MODEL.L2_LAMBDA > 0:
            avg_mse = mx.sym.MakeLoss(avg_mse,
                                      grad_scale=global_grad_scale * cfg.MODEL.L2_LAMBDA,
                                      name="mse")
        else:
            avg_mse = mx.sym.BlockGrad(avg_mse, name="mse")
        if cfg.MODEL.L1_LAMBDA > 0:
            avg_mae = mx.sym.MakeLoss(avg_mae,
                                      grad_scale=global_grad_scale * cfg.MODEL.L1_LAMBDA,
                                      name="mae")
        else:
            avg_mae = mx.sym.BlockGrad(avg_mae, name="mae")
        if cfg.MODEL.GDL_LAMBDA > 0:
            avg_gdl = mx.sym.MakeLoss(avg_gdl,
                                      grad_scale=global_grad_scale * cfg.MODEL.GDL_LAMBDA,
                                      name="gdl")
        else:
            avg_gdl = mx.sym.BlockGrad(avg_gdl, name="gdl")
        loss = mx.sym.Group([avg_mse, avg_mae, avg_gdl])
        return loss

Initialising

This section demonstrates the settings to run the pre-trained model.

# set the directory containing the pre-trained model configuration
# set the path to '/../tests/models/hko/convgru' when using ConvGRU
base_dir = THIS_DIR + '/../tests/models/hko/trajgru'

cfg_from_file(os.path.join(base_dir, 'cfg0.yml'), target=cfg.MODEL)
cfg.MODEL.LOAD_DIR = base_dir
cfg.MODEL.SAVE_DIR = base_dir

# set the load_iter = 49999 when using ConvGRU
cfg.MODEL.LOAD_ITER = 79999

# no training on inference data: mode = 'fixed' and finetune = 0
# train on inference data: mode = 'online' and finetune = 1
mode = 'fixed'
finetune = 0

# use GPU if CUDA installed; otherwise use CPU
if locate_cuda():
    ctx = 'gpu'
else:
    ctx = 'cpu'
ctx = parse_ctx(ctx)

# Initialize the model
hko_nowcasting_online = HKONowcastingFactory(batch_size=1,
                                             in_seq_len=cfg.MODEL.IN_LEN,
                                             out_seq_len=cfg.MODEL.OUT_LEN)

# Initialize the encoder network and forecaster network
t_encoder_net, t_forecaster_net, t_loss_net =\
    encoder_forecaster_build_networks(
        factory=hko_nowcasting_online,
        context=ctx,
        for_finetune=True)
t_encoder_net.summary()
t_forecaster_net.summary()
t_loss_net.summary()

# Load the pre-trained model params to the networks created
load_encoder_forecaster_params(load_dir=cfg.MODEL.LOAD_DIR,
                                load_iter=cfg.MODEL.LOAD_ITER,
                                encoder_net=t_encoder_net,
                                forecaster_net=t_forecaster_net)

# 2019-04-20 10:30 to 2019-04-20 15:24
pd_path = cfg.HKO_PD.RAINY_EXAMPLE

# Set the save directory
save_dir=os.path.join(base_dir, "iter%d_%s_finetune%d"
                        % (cfg.MODEL.LOAD_ITER + 1, 'example',
                            finetune))

# Create the environment to run the model
env = HKOBenchmarkEnv(pd_path=pd_path, save_dir=save_dir, mode=mode)
states = EncoderForecasterStates(factory=hko_nowcasting_online, ctx=ctx[0])

Out:

Loading YAML config file from <_io.TextIOWrapper name='/tmp/build/docs/swirlspy/swirlspy/examples/../tests/models/hko/trajgru/cfg0.yml' mode='r' encoding='UTF-8'>
ebrnn1_0_ 96 96
ebrnn2_0_ 32 32
ebrnn3_0_ 16 16
fbrnn1_0_ 96 96
fbrnn2_0_ 32 32
fbrnn3_0_ 16 16
dbrnn1_0_ 96 96
dbrnn2_0_ 32 32
dbrnn3_0_ 16 16

Inference and Evaluation

The pre-trained model loops through the radar echo maps and generate the inference over the next two hours. Evaluation of the inference is to compare the inference with the truth forecast data.

stored_data = []
stored_prediction = []
counter = 0
finetune_iter = 0
while not env.done:
    states.reset_all()

    # Get radar data
    in_frame_dat, in_datetime_clips, out_datetime_clips, begin_new_episode, need_upload_prediction =\
        env.get_observation(batch_size=1)

    in_frame_nd = nd.array(in_frame_dat, ctx=ctx[0])
    t_encoder_net.forward(is_train=False,
                        data_batch=mx.io.DataBatch(data=[in_frame_nd] +
                                                        states.get_encoder_states()))
    outputs = t_encoder_net.get_outputs()
    states.update(states_nd=outputs)
    if need_upload_prediction:
        counter += 1
        t_forecaster_net.forward(is_train=False,
                                data_batch=mx.io.DataBatch(
                                data=states.get_forecaster_state()))

        # Get inference from model
        pred_nd = t_forecaster_net.get_outputs()[0]
        pred_nd = nd.clip(pred_nd, a_min=0, a_max=1)

        # Generate prediction movies and inference evaluation
        env.upload_prediction(prediction=pred_nd.asnumpy())

fingerprint = env._fingerprint
print("Saving videos to %s" % os.path.join(save_dir, fingerprint))

Out:

Saving videos to /tmp/build/docs/swirlspy/swirlspy/examples/../tests/models/hko/trajgru/iter80000_example_finetune0/hko7_rainy_example_in5_out20_stride5_fixed

Printing of Evaluation

env.save_eval(print_out=True)

Out:

Saving evaluation result to /tmp/build/docs/swirlspy/swirlspy/examples/../tests/models/hko/trajgru/iter80000_example_finetune0/hko7_rainy_example_in5_out20_stride5_fixed
Total Sequence Num: 6, Out Seq Len: 20, Use Central: 0

Threshold = 0.5:

   POD: [0.9949990623241858, 0.9918333506062804, 0.9879483471286296, 0.9835357957698712, 0.9815608919382505, 0.9803458002719322, 0.9763011105374515, 0.9738750651119926, 0.9709277727410094, 0.9684372362233209, 0.9651381657989587, 0.9614454714596993, 0.9560960295113169, 0.9549910786063018, 0.9553564472146474, 0.9534676883380898, 0.9518710296291221, 0.9511550963547714, 0.9490770764569803, 0.9474938728434812]

   FAR: [0.046496071246717716, 0.0891856138113502, 0.1348375159430183, 0.1700057240984545, 0.19915031863051355, 0.22666186416936807, 0.25931046135649716, 0.287582896713443, 0.31619759778082057, 0.33715963878209315, 0.35671624859028517, 0.3777800672238606, 0.39334781715523626, 0.409948573193853, 0.42273719642447316, 0.4375580141713126, 0.4535575010395478, 0.4638890252081595, 0.4768551123631474, 0.48599792484605797]

   CSI: [0.9489561700732321, 0.9040345358536194, 0.8561270653903243, 0.8186204064873742, 0.7889799837506463, 0.7615313071442084, 0.7276074360212458, 0.6990573296038771, 0.6700825228206536, 0.6488238699861315, 0.6286757752633123, 0.6070726298397382, 0.5902104000872125, 0.5740865157944366, 0.5620999673384255, 0.5474159699231244, 0.5317505016978439, 0.5217467443563537, 0.5088614295574735, 0.4997668005028998]

   GSS: [0.943182998181027, 0.8931192648410953, 0.8395793763158353, 0.797794985421462, 0.7645093750142621, 0.7336281527411268, 0.6956896462756536, 0.6633676695748449, 0.6311404476579361, 0.6071824659777206, 0.5844684792499397, 0.5600591866095282, 0.5405972815568172, 0.5224201325828103, 0.5085016657873864, 0.49220862854570174, 0.47473985519984613, 0.46285356382837206, 0.44843375962538284, 0.4377804348867319]

   HSS: [0.9707608589246828, 0.9435425241589963, 0.9127949433715427, 0.8875260993504582, 0.8665404512323237, 0.8463500683017294, 0.8205388855250004, 0.7976200111481078, 0.7738640146704173, 0.7555862247518292, 0.7377470576462551, 0.7179973573011715, 0.7018022010405319, 0.6863021861074788, 0.6741811127161943, 0.6597048417089044, 0.6438286095353581, 0.6328091550285562, 0.6191981602822678, 0.608967022035211]

   POD stat: avg 0.967793/final 0.947494

   FAR stat: avg 0.317249/final 0.485998

   CSI stat: avg 0.669775/final 0.499767

   GSS stat: avg 0.630063/final 0.43778

   HSS stat: avg 0.762883/final 0.608967

Threshold = 2:

   POD: [0.9852889083369798, 0.9801533837703389, 0.9740197134270798, 0.9700417257610326, 0.9662766478804214, 0.9624266899364815, 0.9592937913161438, 0.9548222943462756, 0.9502367281152413, 0.9449535753044737, 0.9402282739287144, 0.9343528311460894, 0.9282525729198348, 0.9234230701358797, 0.9242890533696922, 0.9219379178232143, 0.9228487808697865, 0.9212618992338054, 0.9190034168013705, 0.9154884905569706]

   FAR: [0.017536204782979077, 0.04274378023846637, 0.07551523057602633, 0.10435170424352842, 0.12914534947918585, 0.15433925928567607, 0.18342022491274937, 0.20794070233168802, 0.2340173286019375, 0.25689269599437364, 0.27999693322088476, 0.30362466074421224, 0.32165159847143254, 0.3397055730809674, 0.35321938719877877, 0.36967344020420845, 0.38615144215118175, 0.39719706033156726, 0.4101549395524521, 0.4183332620895595]

   CSI: [0.9682604689259752, 0.9390545411218015, 0.9022364945750281, 0.8715408267851043, 0.845167397458452, 0.8186337548008458, 0.7892325887428581, 0.7634479223400507, 0.7364410683196839, 0.7122741904942246, 0.6884897360703812, 0.6638930520600628, 0.6445534436950746, 0.6260160440246975, 0.6142385606109637, 0.5983899821109123, 0.5838845124851367, 0.573268077728816, 0.5606966114306636, 0.5520254998551145]

   GSS: [0.9647095082312246, 0.9322123859498771, 0.8911644447912733, 0.8570114906524168, 0.8275104320050067, 0.797813166628722, 0.7649794848564586, 0.7359488590529116, 0.7059869485631705, 0.6789478278566798, 0.6522955135234061, 0.6247045047258717, 0.6026429343115548, 0.5819447182225841, 0.5683267745936315, 0.550749142348133, 0.5344983834646916, 0.5220620384952848, 0.5080226858850864, 0.4980031333854099]

   HSS: [0.9820378067999749, 0.9649170999300895, 0.9424505068776613, 0.9230007406699742, 0.9056150022597953, 0.8875373497511865, 0.8668423530358172, 0.8478923272594864, 0.8276580886597898, 0.8087777554391486, 0.7895627727420635, 0.7690069214540338, 0.7520613873612857, 0.7357333180092869, 0.7247555596197616, 0.7103007537559463, 0.6966424848983755, 0.6859931136728131, 0.6737600045942524, 0.664889308021604]

   POD stat: avg 0.94493/final 0.915488

   FAR stat: avg 0.249281/final 0.418333

   CSI stat: avg 0.722587/final 0.552025

   GSS stat: avg 0.689977/final 0.498003

   HSS stat: avg 0.807972/final 0.664889

Threshold = 5:

   POD: [0.9204261130838569, 0.9123557911660547, 0.9132637148726112, 0.9172258786079924, 0.9174234311415721, 0.9138993623361781, 0.9105009521175307, 0.9039197293394745, 0.8963952529034714, 0.8900604958118284, 0.8847186403416057, 0.8787300121710664, 0.871170845179633, 0.8640429153556507, 0.8584284189788212, 0.8601423854725547, 0.8611418123973515, 0.8609210660675993, 0.8632869058058201, 0.8648629456043175]

   FAR: [0.04149552276166528, 0.05882221101585701, 0.0825329256566224, 0.10195421904942405, 0.12197396740801476, 0.14529453484771296, 0.16247027624741905, 0.18338840063605116, 0.19980549155871133, 0.21883731075900606, 0.24013689427912635, 0.2612749292147022, 0.2809379960063489, 0.29966386273558876, 0.31812142897896306, 0.33362908963014476, 0.34851954019395626, 0.35710418557125534, 0.36617757149273844, 0.37231818114885656]

   CSI: [0.885155297776657, 0.8631388549171146, 0.8439307507304565, 0.8307214305889686, 0.8137175748710593, 0.7910102803047112, 0.7738237315106145, 0.7513908480011214, 0.7324529182475822, 0.7124221183800623, 0.6914054433568744, 0.6703807762809275, 0.6499494781984928, 0.6308210253590715, 0.6129493134827015, 0.6012274805636891, 0.5895481224807297, 0.5824082453146912, 0.5760062045881297, 0.571619184272812]

   GSS: [0.873810091293894, 0.8495994789775003, 0.8282223533958314, 0.8135058492532931, 0.7945724989755995, 0.7694482850387966, 0.7503913955363465, 0.7254491098110649, 0.7047207217697593, 0.6826409884804459, 0.6594526797338668, 0.6360795807695043, 0.613112562936445, 0.5919862380023466, 0.5720597607969308, 0.5591386930419584, 0.5461380194770304, 0.537709106448511, 0.5303740256169345, 0.525163341804218]

   HSS: [0.9326559776295312, 0.9186848165065205, 0.906041162725584, 0.8971637445649842, 0.8855284469467434, 0.8697041801613612, 0.8573984052365787, 0.8408814907215675, 0.8267873004302453, 0.8113923209453291, 0.7947833497001263, 0.7775655759609619, 0.7601609175002141, 0.7437077329828953, 0.7277837332429833, 0.7172404809620246, 0.7064544207531454, 0.6993638838367845, 0.6931299365240162, 0.6886650464375403]

   POD stat: avg 0.888146/final 0.864863

   FAR stat: avg 0.224723/final 0.372318

   CSI stat: avg 0.708704/final 0.571619

   GSS stat: avg 0.678179/final 0.525163

   HSS stat: avg 0.802755/final 0.688665

Threshold = 10:

   POD: [0.8797434737276966, 0.8623155925399022, 0.8467330816826867, 0.8365058516790496, 0.8319829676503637, 0.8227749100578758, 0.8187597809076682, 0.8155337892434126, 0.8100978753879208, 0.8045177411294353, 0.8002547821413643, 0.8084692235882972, 0.8052185672805378, 0.8048080160003168, 0.8043121352981901, 0.8016070933776669, 0.8075337965636012, 0.8012230588422772, 0.8012173191312768, 0.7959853005097404]

   FAR: [0.2281404259725914, 0.25932373619015736, 0.29862196792650497, 0.3310916729370454, 0.36039039767216297, 0.3810763505861242, 0.3909608137013809, 0.4146462217444196, 0.4307004054242975, 0.44239854801391026, 0.45772862152684113, 0.4679911122729055, 0.4877335931272695, 0.5020400161731012, 0.5153804898396996, 0.5255877150855657, 0.5348623853211009, 0.5484265714638555, 0.5500210872122705, 0.5545383177983436]

   CSI: [0.6981936269535214, 0.6623456328119153, 0.6223649209118732, 0.5915682360117562, 0.5664434214234904, 0.546117607361167, 0.536685131045802, 0.5169135833312514, 0.5022694755538014, 0.4910683781541315, 0.4776239976239976, 0.47246178310679837, 0.4557868311738019, 0.4443011128845355, 0.43350604490500866, 0.42456261727307987, 0.4187184490074019, 0.40607963493922383, 0.40478829085172874, 0.39981343098435]

   GSS: [0.6826728570275701, 0.6449739655564909, 0.6030081781669192, 0.5707881963747381, 0.5444019313141256, 0.5228770528406542, 0.512971314394788, 0.4922912653057141, 0.47710156116471986, 0.4654397384838655, 0.45115436217901367, 0.44544060861785756, 0.42791364201846643, 0.41559700528456817, 0.4041143638109575, 0.39468033534166413, 0.38853463057562554, 0.3753452114872746, 0.3737072857551623, 0.3685510230943937]

   HSS: [0.811414832272872, 0.784175286735675, 0.7523457289612514, 0.7267538649603742, 0.7050003244309545, 0.6866963447447326, 0.6780978720670385, 0.6597790615692738, 0.6459969628473172, 0.6352219422757109, 0.621786866976126, 0.6163388602231006, 0.5993550722207225, 0.5871685285192073, 0.5756145998166932, 0.5659796375417836, 0.5596326112724407, 0.5458196361935674, 0.5440857592157645, 0.5386003398851333]

   POD stat: avg 0.81798/final 0.795985

   FAR stat: avg 0.434083/final 0.554538

   CSI stat: avg 0.503581/final 0.399813

   GSS stat: avg 0.478078/final 0.368551

   HSS stat: avg 0.641993/final 0.5386

Threshold = 20:

   POD: [0.847693200833248, 0.827469052744887, 0.8135791728014992, 0.8020088984764729, 0.778738933567615, 0.7748988404910061, 0.7686053572633426, 0.7588077695914267, 0.7512246207898382, 0.7267514843087363, 0.7134427124194154, 0.7064948278785823, 0.7019424653851324, 0.6901295041266859, 0.6805022554366121, 0.6754052770806485, 0.66625678860539, 0.6494918458993146, 0.6403772827613887, 0.6252173331550087]

   FAR: [0.27861086893344955, 0.3236272649784157, 0.368755841728113, 0.39816875758802106, 0.42820989951618904, 0.45098653304102726, 0.46439570256482593, 0.4782868984572876, 0.486881879225604, 0.5042582735477898, 0.5171522230943257, 0.5253383767032767, 0.5289992333017635, 0.5397798460712367, 0.551276359600444, 0.5595293389069516, 0.5706739446229696, 0.577527892471229, 0.5759937991363083, 0.5764186204553177]

   CSI: [0.6386158991510162, 0.5927755554484554, 0.551477759883866, 0.5239815018718343, 0.4918788551426803, 0.4734983066342538, 0.4612316786784355, 0.44750355506399114, 0.4385909548135145, 0.4178565855181024, 0.4044161720064193, 0.39647887323943665, 0.392502114065583, 0.38140621523399837, 0.3706386021525091, 0.36351854556074764, 0.3533375600036229, 0.344034482141899, 0.3424615879943478, 0.3378075658488998]

   GSS: [0.6283683233911987, 0.5811386413816828, 0.5385984144320107, 0.5104097828146703, 0.47755553634259684, 0.45867847128078054, 0.4459448528913043, 0.43175107519283673, 0.4227397333870319, 0.4016801544266556, 0.3880386148354192, 0.3798340281179515, 0.3756414119179976, 0.36428825770304335, 0.3533324680904233, 0.34614719143839545, 0.3358669253746832, 0.3263262264167268, 0.3246830898193664, 0.32013460589787246]

   HSS: [0.771776648273991, 0.7350887849706248, 0.7001156499057485, 0.6758560340671448, 0.6464129768342831, 0.6288959223180167, 0.6168213843005079, 0.6031091335268401, 0.5942615131449805, 0.5731409596663073, 0.559117895839561, 0.5505503131214016, 0.5461327474785118, 0.534034146590648, 0.5221665428436544, 0.5142783696165166, 0.5028448852126186, 0.49207535810906344, 0.49020492873301513, 0.48500297540512854]

   POD stat: avg 0.729952/final 0.625217

   FAR stat: avg 0.485244/final 0.576419

   CSI stat: avg 0.436201/final 0.337808

   GSS stat: avg 0.420558/final 0.320135

   HSS stat: avg 0.587094/final 0.485003

Threshold = 30:

   POD: [0.8079593740283967, 0.7816115283133753, 0.7634045555331587, 0.748710879665084, 0.7305992393873985, 0.7230745147150908, 0.7058912922344359, 0.6756948311648864, 0.6609632769804382, 0.6214223542061338, 0.5952331172904035, 0.5930957570841994, 0.5780462237308441, 0.5638950331791676, 0.5385965792245574, 0.5241141232612448, 0.5171144278606965, 0.49129574678536103, 0.4766483516483517, 0.45601523958384216]

   FAR: [0.3426644182124789, 0.3981872167526176, 0.4473386596621794, 0.48342667934763467, 0.5111584883600977, 0.5370506481357744, 0.5439073086112657, 0.5657817684950225, 0.5811160743186379, 0.6018567117176223, 0.6215102366391917, 0.6106748302941568, 0.6237816764132553, 0.6361188607020049, 0.648706120980142, 0.6559010765590108, 0.6479712795502269, 0.658531555066685, 0.6543086885362556, 0.6504810752124592]

   CSI: [0.5685116312987676, 0.5151839464882944, 0.4718432691708715, 0.4402449641259644, 0.41418290941933983, 0.39322909276653706, 0.38326136582414255, 0.35933125134611243, 0.34479910476393477, 0.320423464808068, 0.3010150137449778, 0.3072559678562988, 0.2951596869981903, 0.28396749449381026, 0.27002967359050445, 0.2621835081392691, 0.2649367862969005, 0.25227284270404793, 0.25058028575849794, 0.2466708941027267]

   GSS: [0.5607126857067353, 0.5063021181523711, 0.4620997694274454, 0.43003435479207536, 0.40359421342590696, 0.3823774915111035, 0.3720010191427285, 0.34779242605426863, 0.33313522363364567, 0.30859673892925893, 0.2892277733360422, 0.29529991078692164, 0.28288534221918565, 0.27164217440832705, 0.25762087891414104, 0.2498922576171006, 0.25259318984640045, 0.23989758258652077, 0.23831065830950024, 0.2346185053837437]

   HSS: [0.7185341553789301, 0.6722451121205364, 0.6321042914990713, 0.6014322010531022, 0.5750867445382382, 0.5532171839590925, 0.5422751352986123, 0.5160919728157974, 0.49977709346789173, 0.4716452819250702, 0.44868374591035726, 0.45595604280945373, 0.4410142245913254, 0.4272305210932744, 0.4096956137315037, 0.3998620778618409, 0.40331241123444833, 0.3869635459504262, 0.3848964017395019, 0.3800664000428532]

   POD stat: avg 0.627669/final 0.456015

   FAR stat: avg 0.571024/final 0.650481

   CSI stat: avg 0.347254/final 0.246671

   GSS stat: avg 0.335932/final 0.234619

   HSS stat: avg 0.496005/final 0.380066

MSE: [314.61374, 561.646, 862.4934, 1132.2228, 1381.6515, 1651.9943, 1965.568, 2278.7437, 2560.9277, 2812.3176, 3080.173, 3402.1716, 3704.8875, 3981.3147, 4229.9194, 4482.1465, 4771.235, 5021.6626, 5250.041, 5443.3755]

MAE: [1497.3224, 2071.152, 2699.64, 3233.034, 3736.8633, 4283.821, 4910.3013, 5534.777, 6078.5493, 6570.014, 7081.8594, 7663.645, 8218.234, 8702.816, 9153.38, 9616.606, 10146.479, 10650.319, 11085.315, 11487.362]

Balanced MSE: [1070.8331, 1634.8705, 2263.3179, 2824.0442, 3341.3748, 3771.3533, 4360.8774, 4870.9004, 5281.554, 5947.5796, 6573.6094, 7365.2114, 8027.2095, 8434.544, 8809.261, 9131.528, 9834.229, 10723.3, 10968.574, 11710.198]

Balanced MAE: [10147.587, 12380.348, 14265.581, 15578.719, 16676.086, 17684.6, 19159.518, 20519.938, 21459.555, 22900.639, 24149.828, 25751.506, 27186.436, 28011.447, 28775.604, 29444.352, 30696.69, 32317.549, 33036.81, 34442.59]

GDL: [1246.7345, 1350.3164, 1375.5074, 1386.1107, 1412.33, 1445.2308, 1488.2333, 1488.806, 1480.857, 1481.2399, 1491.5039, 1517.3276, 1509.305, 1503.5643, 1508.6593, 1511.4779, 1517.622, 1521.2721, 1523.7344, 1523.9825]

MSE stat: avg 2944.46/final 5443.38

MAE stat: avg 6721.07/final 11487.4

Balanced MSE stat: avg 6347.22/final 11710.2

Balanced MAE stat: avg 23229.3/final 34442.6

GDL stat: avg 1464.19/final 1523.98

Generating radar reflectivity maps

Step 1: Retrieve observation and prediction for one-hour and two-hour forecasts.

# Ground truth at one-hour forecast
gt_10 = env._out_frame_dat[9:10, 0, 0]

# Ground truth at two-hour forecast
gt_20 = env._out_frame_dat[19:, 0, 0]

# translate pixel to dBZ
gt_10 = pixel_to_dBZ(gt_10)
gt_20 = pixel_to_dBZ(gt_20)
pred_nd = pixel_to_dBZ(pred_nd.asnumpy())

# Prediction at one-hour forecast
pd_10 = pred_nd[9:10, 0, 0]

# Prediction at two-hour forecast
pd_20 = pred_nd[19:, 0, 0]

# Save observation and forecast in list
reflec_np = []
reflec_np.append(gt_10)
reflec_np.append(pd_10)
reflec_np.append(gt_20)
reflec_np.append(pd_20)

Step 2: Generate xarrays for one-hour and two-hour forecast.

# Get forecast base time
basetime = in_datetime_clips[0][0]

# Create a datetime list
datetime_list = []
curr = basetime
for i in range(len(reflec_np)):
    datetime_list.append(curr)
    curr = curr + datetime.timedelta(hours=1)

y = []
y_init = 255466.666667
for i in range(480):
    y.append(y_init)
    y_init -= 1066.666667

x = []
x_init = -255466.666667
for i in range(480):
    x.append(x_init)
    x_init += 1066.666667

# Define source grid
area_id = 'aeqd'
description = 'Azimuthal Equidistant Projection centered at the radar site extending up to 256000.000000m in the x direction and256000.000000m in the y directionwith a 480.000000x480.000000 grid resolution'
proj_id = 'aeqd'
projection = {'datum': 'WGS84', 'ellps': 'WGS84', 'lat_0': '22.411667', 'lon_0': '114.123333', 'no_defs': 'True', 'proj': 'aeqd', 'units': 'm'}
width = 480
height = 480
area_extent = (-256000.00000000003, -256000.00000000003, 256000.00000000003, 256000.00000000003)
area_def = AreaDefinition(area_id, description, proj_id, projection, width, height, area_extent)

reflec_xr_list = []
for i in range(len(reflec_np)):
    reflec_xr = xr.DataArray(
        reflec_np[i],
        dims=('time', 'y', 'x'),
        coords={
            'time': [datetime_list[i]],
            'y': y,
            'x': x
        },
        attrs={
            'long_name': 'Reflectivity',
            'units': 'dBZ',
            'projection': 'Centered Azimuthal',
            'site': (114.12333341315389, 22.41166664287448, 948),
            'radar_height': 20,
            'area_def': area_def
        }
    )
    reflec_xr_list.append(reflec_xr)

# Defining target grid
area_id = "hk1980_250km"
description = ("A 1km resolution rectangular grid "
            "centred at HKO and extending to 250 km "
            "in each direction in HK1980 easting/northing coordinates")
proj_id = 'hk1980'
projection = ('+proj=tmerc +lat_0=22.31213333333334 '
            '+lon_0=114.1785555555556 +k=1 +x_0=836694.05 '
            '+y_0=819069.8 +ellps=intl +towgs84=-162.619,-276.959,'
            '-161.764,0.067753,-2.24365,-1.15883,-1.09425 +units=m '
            '+no_defs')
x_size = 500
y_size = 500
area_extent = (587000, 569000, 1087000, 1069000)
area_def_tgt = utils.get_area_def(
    area_id, description, proj_id, projection, x_size, y_size, area_extent
)

Step 3: Reproject the radar data from read_iris_grid() from Centered Azimuthal (source) projection to HK 1980 (target) projection.

reproj_xr_list = []
for i in range(len(reflec_xr_list)):
    reproj_xr = grid_resample(
        reflec_xr_list[i], area_def, area_def_tgt,
        coord_label=['easting', 'northing']
    )
    reproj_xr_list.append(reproj_xr)

# Combining the DataArrays
reflec_concat = xr.concat(reproj_xr_list, dim='time')
standardize_attr(reflec_concat, frame_type=FrameType.dBZ, zero_value=9999.)

Step 4: Generating radar reflectivity maps

# Defining colour scale and format
levels = [
    -32768,
    10, 15, 20, 24, 28, 32,
    34, 38, 41, 44, 47, 50,
    53, 56, 58, 60, 62
]
cmap = ListedColormap([
    '#FFFFFF', '#08C5F5', '#0091F3', '#3898FF', '#008243', '#00A433',
    '#00D100', '#01F508', '#77FF00', '#E0D100', '#FFDC01', '#EEB200',
    '#F08100', '#F00101', '#E20200', '#B40466', '#ED02F0'
])

norm = BoundaryNorm(levels, ncolors=cmap.N, clip=True)

# Defining the crs
crs = area_def_tgt.to_cartopy_crs()

# Get one-hour forecast date time
datetime_one_hr = out_datetime_clips[0][9]

# Get two-hour forecast date time
datetime_two_hr = out_datetime_clips[0][-1]

timelist = [
    datetime_one_hr,
    datetime_one_hr,
    datetime_two_hr,
    datetime_two_hr
]

# Obtaining the slice of the xarray to be plotted
da_plot = reflec_concat

# Defining coastlines
map_shape_file = os.path.join(THIS_DIR, "./../tests/samples/shape/rsmc")
ocean_color = np.array([[[178, 208, 254]]], dtype=np.uint8)
land_color = cfeature.COLORS['land']
coastline = cfeature.ShapelyFeature(
    list(shpreader.Reader(map_shape_file).geometries()),
    ccrs.PlateCarree()
)

# Plotting
p = da_plot.plot(
    col='time', col_wrap=2,
    subplot_kws={'projection': crs},
    cbar_kwargs={
        'extend': 'max',
        'ticks': levels[1:],
        'format': '%.3g'
    },
    cmap=cmap,
    norm=norm
)
for idx, ax in enumerate(p.axes.flat):
    # ocean
    ax.imshow(np.tile(ocean_color, [2, 2, 1]),
            origin='upper',
            transform=ccrs.PlateCarree(),
            extent=[-180, 180, -180, 180],
            zorder=-1)
    # coastline, color
    ax.add_feature(coastline,
                facecolor=land_color, edgecolor='none', zorder=0)
    # overlay coastline without color
    ax.add_feature(coastline, facecolor='none',
                edgecolor='gray', linewidth=0.5, zorder=3)
    ax.gridlines()
    if idx == 0 or idx == 2:
        ax.set_title(
            "Ground Truth\n"
            f"Based @ {timelist[idx].strftime('%H:%MH')}",
            loc='left',
            fontsize=9
        )
        ax.set_title(
            ''
        )
        ax.set_title(
            f"{basetime.strftime('%Y-%m-%d')} \n"
            f"Valid @ {timelist[idx].strftime('%H:%MH')} ",
            loc='right',
            fontsize=9
        )
    else:
        ax.set_title(
            "Reflectivity\n"
            f"Based @ {basetime.strftime('%H:%MH')}",
            loc='left',
            fontsize=9
        )
        ax.set_title(
            ''
        )
        ax.set_title(
            f"{basetime.strftime('%Y-%m-%d')} \n"
            f"Valid @ {timelist[idx].strftime('%H:%MH')} ",
            loc='right',
            fontsize=9
        )

plt.savefig(
    THIS_DIR +
    f"/../tests/outputs/traj-output-map-hk.png",
    dpi=300
)
../_images/sphx_glr_traj_hk_001.png

Total running time of the script: ( 2 minutes 24.304 seconds)

Gallery generated by Sphinx-Gallery