#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Nov 29 13:29:00 2024

@author: mbaudewyn
"""

import pandas as pd
import xarray as xr
import numpy as np
from datetime import datetime
from dateutil.relativedelta import relativedelta
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.dates as mdates
from mpl_toolkits.basemap import Basemap

# Définition des chemins de fichiers
today_date = datetime.today()
directory = '/srv6_tmp5/mbaudewyn/PREVISIONS_DHC/DHC/'
directory_fig = '/srv6_tmp5/mbaudewyn/PREVISIONS_DHC/FIG/'
file_name = directory + f"DHC-{today_date.strftime('%Y%m%d')}.nc"

# Chargement des données NetCDF
dataset = xr.open_mfdataset(file_name, decode_times=True, engine="netcdf4")
lons = np.array(dataset["LON"])
lats = np.array(dataset["LAT"])

# Sélection des périodes temporelles
last_date = pd.to_datetime(dataset['time'].max().item())
date_J0 = last_date - pd.Timedelta(days=9)
one_year_ago = last_date - relativedelta(years=1) - pd.Timedelta(days=9)

# Extraction des données
dhc_data_year = dataset['DHC'].sel(time=slice(one_year_ago, date_J0))
dhc_prev_today = np.array(dataset['DHC'].isel(time=-8))
dhc_prev_10_days = np.array(dataset['DHC'].isel(time=-1))

# Définition des couleurs des cartes
colors = ['#b35806', '#f0a23f', '#fddfb4', '#f7f6f6']  
cmap = mcolors.ListedColormap(colors)
bounds = [-400, -300, -200, -100, 0]  # Intervalles de couleur
norm = mcolors.BoundaryNorm(bounds, cmap.N)

# Définition des zones bioclimatiques
zones_bioclim = ['Plaines et Vallées Scaldisiennes', 'Hesbigno-Brabançon', 
                 'Sambre-et-Meuse et Condroz', 'Fagne, Famenne et Calestienne', 
                 'Thiérache', 'Ardenne centro-orientale et Haute Ardenne', 
                 'Basse et Moyenne Ardenne', 'Basse et Haute Lorraine']
num_bioclim = ['1', '2', '3', '4', '5', '6_7', '8', '9_10']
acronymes = ['PVS','HBR','SMC','FFC','THI', 'AHA','BMA','BHL']
color_list = ['#e31a1c','#fdbf6f','#ffff39','#71c4f0','#1f78b4','#c2d364','#4a8b00','#bf1695']



# GENERATION DES CARTES J+0 ET J+10
for data_fig, nom_file, date_fig in zip([dhc_prev_today, dhc_prev_10_days], 
                                        ['00', '10'], [date_J0, last_date]):
    
    # Application du masque (Belgique)
    fn_mask = '/srv6_tmp5/mbaudewyn/PREVISIONS_DHC/MASKS/Belgium_pix.nc4'
    mask_data = xr.open_dataset(fn_mask)
    mask = np.array(mask_data["MSK"]) == 0
    data_fig[mask] = float('nan')
    
    # Création de la carte
    fig, ax = plt.subplots(figsize=(15, 6))
    m = Basemap(width=34000, height=30000, rsphere=(649328.00,665262.0),
                area_thresh=1000., projection = 'lcc', lat_1 = 49.83, lat_2 = 51.17, 
                lat_0 = np.mean(lats), lon_0 = np.mean(lons)-0.2, resolution='i')
    m.drawcountries()
    m.drawcoastlines()
    
    # Ajouter des villes sur la carte
    villes = {
        "Bruges": (51.2093, 3.2242),
        "Courtrai": (50.828, 3.2649),
        "Gand": (51.05, 3.733),
        "Anvers": (51.2194, 4.4025),
        "Bruxelles": (50.8503, 4.3517),
        "Louvain": (50.8796, 4.7009),
        "Hasselt": (50.9311, 5.3378),
        "Tournai": (50.605, 3.387),
        "Mons": (50.4541, 3.9523),
        "Charleroi": (50.4114, 4.4448),
        "Namur": (50.4669, 4.8675),
        "Liège": (50.633, 5.567),
        "Philippeville": (50.1954, 4.5438),
        "Verviers": (50.5836, 5.8624),
        "Marche-en-Famenne": (50.2277, 5.3426),
        "Neufchâteau": (49.8406, 5.4369),
        "Arlon": (49.6833, 5.8167)
    }
    
    # Ajout des points et noms des villes sur la carte
    color_villes = '#4F4F4F'
    for ville, (lat, lon) in villes.items():
        x_ville, y_ville = m(lon, lat)  # Conversion des coordonnées
        m.plot(x_ville, y_ville, marker='o', color=color_villes, markersize=4, alpha=0.7)  # Ajout du point
        if ville in ["Bruxelles", "Verviers", "Charleroi", "Philippeville"]:
            plt.text(x_ville + 500, y_ville - 800, ville, fontsize=10, color=color_villes)  # Ajout du nom
        else:
            plt.text(x_ville + 500, y_ville + 100, ville, fontsize=10, color=color_villes)
    
    # Affichage des données sur la carte
    x, y = m(lons, lats)
    cmesh = m.pcolormesh(x, y, data_fig, cmap = cmap, norm = norm)
    cb = m.colorbar(cmesh, location='right', pad="5%")
    cb.set_label('Déficit hydrique cumulé (en mm)')
    
    # Ajouter un deuxieme legende 
    cbar_ax = plt.gcf().add_axes([0.78, 0.125, 0.017, 0.75])  # [left, bottom, width, height]
    cbar = mpl.colorbar.ColorbarBase(
        cbar_ax, cmap=cmap, norm=norm, orientation='vertical', ticks=[-400, -300, -200, -100, 0]
    )
    
    # Ajoutes les noms des categories
    plt.text(
        0.82, 0.80,  # x, y coordinates 
        'Déficit faible', 
        ha='left', va='center', transform=plt.gcf().transFigure, fontsize=12, color='black'
    )
    plt.text(
        0.82, 0.60,  # x, y coordinates 
        'Déficit modéré', 
        ha='left', va='center', transform=plt.gcf().transFigure, fontsize=12, color='black'
    )
    plt.text(
        0.82, 0.40,  # x, y coordinates 
        'Déficit important', 
        ha='left', va='center', transform=plt.gcf().transFigure, fontsize=12, color='black'
    )
    plt.text(
        0.82, 0.20,  # x, y coordinates 
        'Déficit grave',
        ha='left', va='center', transform=plt.gcf().transFigure, fontsize=12, color='black'
    )
    
    # Ajouter la date
    plt.text(
        0.34, 0.17,  # x, y coordinates 
        date_fig.strftime('%d-%m-%Y'), 
        ha='left', va='center', transform=plt.gcf().transFigure, fontsize=16, color='black'
    )
        
    plt.savefig(directory_fig + "ARCHIVE/carto_J"+nom_file+f"_{today_date.strftime('%Y%m%d')}.png", format="png", dpi=300)
    plt.savefig(directory_fig + "carto_secheresse_belgique_T"+nom_file+".png", format="png", dpi=300)
    

# GENERATION DES SERIES TEMPORELLES PAR ZONE BIOCLIMATIQUE
for zone_bc, num, acronyme, color_bc in zip(zones_bioclim, num_bioclim, acronymes, color_list):
    plt.figure(figsize=(12, 6))

    mask_zone_bioclim = 'zonebioclim_'+num+'_MAR_mask.nc'
    
    for mask, zone, color, transp in zip(['Belgium_pix.nc4', mask_zone_bioclim],
                                         ['Belgique',zone_bc],['gray',color_bc],[False, True]):
        
        # Chargement du masque
        fn_mask = '/srv6_tmp5/mbaudewyn/PREVISIONS_DHC/MASKS/' + mask
        mask_data = xr.open_dataset(fn_mask)
        mask = np.array(mask_data["MSK"])==0
        mask = np.transpose(mask) if transp else mask
        
        # Ajustement manuel du masque pour la Thiérache
        if zone == 'Thiérache' and transp :
            for x in range(50):
                for y in range(80):
                    mask[x, y] = not (7 < x < 13 and 37 < y < 46)
            
        # Application du masque 
        mask_3d = np.expand_dims(mask, axis=0)
        dhc_data_mask = dataset['DHC'].isel(time=slice(-41, None)).where(mask_3d == 0, 
                                                                                   float('nan'))
        # Calcul de la moyenne spatiale
        dhc_data_mask_ERA = dhc_data_mask.isel(time=slice(0, 32)).mean(dim=["x", "y"]).values
        dhc_data_mask_GFS = dhc_data_mask.isel(time=slice(31, None)).mean(dim=["x", "y"]).values
        
        # Tracé des courbes temporelles
        time_values = pd.to_datetime(dhc_data_mask['time'].values)
        time_values_ERA = time_values[:32]
        time_values_GFS = time_values[31:]
        plt.plot(time_values_ERA, dhc_data_mask_ERA, label=zone, color=color, linestyle="-")
        plt.plot(time_values_GFS, dhc_data_mask_GFS, color=color, linestyle="--")
    
    # Formatage de l'axe des abscisses
    plt.gca().xaxis.set_major_locator(mdates.DayLocator(interval=3))  
    plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%d-%m-%Y'))  
    plt.xticks(rotation=45)
    
    # Labels et mise en forme du graphique
    plt.ylabel("Déficit hydrique cumulé (en mm)")
    plt.ylim(-410,10)
    plt.legend()
    plt.grid()
    plt.tight_layout()
    
    # Enregistrement des figures
    output_directory = '/srv6_tmp5/mbaudewyn/PREVISIONS_DHC/FIG/'
    plt.savefig(output_directory + "ARCHIVE/evol_" + acronyme + f"_{today_date.strftime('%Y%m%d')}.png", 
                format="png", dpi=300)
    plt.savefig(output_directory + "evol_secheresse_zonebioclim_" + acronyme + ".png", 
                format="png", dpi=300)


# GENERATION DE LA SERIE TEMPORELLE POUR LA BELGIQUE
plt.figure(figsize=(12, 6))

# Chargement du masque
fn_mask = '/srv6_tmp5/mbaudewyn/PREVISIONS_DHC/MASKS/Belgium_pix.nc4'
mask_data = xr.open_dataset(fn_mask)
mask = np.array(mask_data["MSK"]) == 0
mask_3d = np.expand_dims(mask, axis = 0)
dhc_data_mask = dataset['DHC'].isel(time=slice(-41, None)).where(mask_3d == 0, float('nan'))

# Calcul de la moyenne spatiale
dhc_data_mask_ERA = dhc_data_mask.isel(time=slice(0, 32)).mean(dim=["x", "y"]).values
dhc_data_mask_GFS = dhc_data_mask.isel(time=slice(31, None)).mean(dim=["x", "y"]).values

# Tracé des courbes temporelles
time_values = pd.to_datetime(dhc_data_mask['time'].values)
time_values_ERA = time_values[:32]
time_values_GFS = time_values[31:]
plt.plot(time_values_ERA, dhc_data_mask_ERA, label='Observations (MAR-ERA5)', 
         color = "black", linestyle = "-")
plt.plot(time_values_GFS, dhc_data_mask_GFS, label='Prévisions (MAR-GFS)', 
         color = "black", linestyle = "--")

# Formatage de l'axe des abscisses
plt.gca().xaxis.set_major_locator(mdates.DayLocator(interval=3))  
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%d-%m-%Y'))  
plt.xticks(rotation=45)

# Labels et mise en forme du graphique
plt.ylabel("Déficit hydrique cumulé (en mm)")
plt.ylim(-410,10)
plt.legend()
plt.grid()
plt.tight_layout()

# Enregistrement des figures
output_directory = '/srv6_tmp5/mbaudewyn/PREVISIONS_DHC/FIG/'
plt.savefig(output_directory + f"ARCHIVE/evol_Belgique_{today_date.strftime('%Y%m%d')}.png", 
            format="png", dpi=300)
plt.savefig(output_directory + "evol_secheresse_zonebioclim_BEL.png", format="png", dpi=300)


# GENERATION DE LA SERIE TEMPORELLE SUR 1 ANNEE GLISSANTE
plt.figure(figsize=(12, 6))

for zone_bc, num, acronyme, color_bc in zip(zones_bioclim, num_bioclim, acronymes, color_list):
    
    # Chargement du masque
    mask = 'zonebioclim_' + num + '_MAR_mask.nc'
    fn_mask = '/srv6_tmp5/mbaudewyn/PREVISIONS_DHC/MASKS/' + mask
    mask_data = xr.open_dataset(fn_mask)
    mask = np.array(mask_data["MSK"])==0
    mask = np.transpose(mask)
    
    # Ajustement manuel du masque pour la Thiérache
    if zone_bc == 'Thiérache' :
        for x in range(50):
            for y in range(80):
                mask[x, y] = not (7 < x < 13 and 37 < y < 46)
                
    # Application du masque
    mask_3d = np.expand_dims(mask, axis=0)
    dhc_data_year_mask = dhc_data_year.where(mask_3d == 0, float('nan'))
    
    # Calcul de la moyenne spatiale
    data_values_year = dhc_data_year_mask.mean(dim=["x", "y"]).values
    
    # Tracé des courbes temporelles
    time_values_year = pd.to_datetime(dhc_data_year['time'].values)
    plt.plot(time_values_year, data_values_year, label=zone_bc, color=color_bc)

# Formatage de l'axe des abscisses
plt.gca().xaxis.set_major_locator(mdates.MonthLocator(interval=1))  
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%d-%m-%Y'))  
plt.xticks(rotation=45)

# Labels et mise en forme du graphique
plt.ylabel("Déficit hydrique cumulé (en mm)")
plt.ylim(-410,10)
plt.legend()
plt.grid()
plt.tight_layout()

# Enregistrement des figures
output_directory = '/srv6_tmp5/mbaudewyn/PREVISIONS_DHC/FIG/'
plt.savefig(output_directory + "ARCHIVE/evol_annee" + f"_{today_date.strftime('%Y%m%d')}.png", 
            format="png", dpi=300)
plt.savefig(output_directory + "evol_secheresse_annee.png", format="png", dpi=300)
