#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Created on Thu Nov 28 14:17:30 2024

@author: mbaudewyn
"""


#%%

def pression_vapeur_saturante(temperature): 
    
    pression_vapeur_saturante_data = np.zeros((50,80)) 

    for i in range(0,50):
        for j in range(0,80):
            pression_vapeur_saturante_data[i,j] = 0.6108 * math.exp((17.27 * temperature[i,j]) / (temperature[i,j] + 237.3))


    return pression_vapeur_saturante_data  # en kPa


#%%

def derivee_vapeur_saturante(temperature):
    
    derivee = np.zeros((50,80))
    
    for i in range(0,50):
        for j in range(0,80):
            derivee[i,j] = (4098*(0.6108*math.exp(17.27*temperature[i,j]/(temperature[i,j] + 237.3))))/((temperature[i,j] + 237.3)**2)
    
    return derivee


#%%    

def compute_evapotranspiration(rn,temp,tmin,tmax,vent,pression_surface,rh_min,rh_max):
    len_time = len(temp[:,0,0]) 
    print(len_time, 'pas de temps :')
    
    evapotranspiration = np.zeros((len_time,50,80)) 
    
    # facteur de correction pour l'humidite relative
    rh_max = rh_max * 0.9147
    rh_min = rh_min * 0.9147
    
    for t in range(0,len_time):
        print(t)
        press_vap_sat_tmin_t = pression_vapeur_saturante(tmin[t,:,:])
        press_vap_sat_tmax_t = pression_vapeur_saturante(tmax[t,:,:])
        
        pression_vapeur_saturante_t = (press_vap_sat_tmin_t + press_vap_sat_tmax_t) / 2
        pression_vapeur_actuelle_t = (press_vap_sat_tmin_t*rh_max[t,:,:]/100 + 
                                      press_vap_sat_tmax_t*rh_min[t,:,:]/100) / 2
        
        delta_t = derivee_vapeur_saturante(temp[t,:,:])
        
        psychro = 0.665 * 10**(-3) * pression_surface[t,:,:]/10 #hPa vers kPa
        
        rn_t = rn[t,:,:]
        vent_t = vent[t,:,:]
        temp_t = temp[t,:,:] 
        
        denom = delta_t + psychro*(1 + 0.34*vent_t)
        num1 = 0.408*delta_t*rn_t
        num2 = psychro*(900/(temp_t + 273.15))*vent_t*(pression_vapeur_saturante_t - pression_vapeur_actuelle_t)
        
        evapotranspiration[t,:,:] = (num1 + num2)/denom
        
    return evapotranspiration


#%%    

def compute_deficit_hydrique_cumule(bilan_hydrique):
    bilan_cumule = bilan_hydrique.cumsum(dim='TIME')
    
    # Initialisation du tableau de sortie
    daily_dhc = xr.zeros_like(bilan_cumule)
    
    # Calculer le maximum cumulé sur l'axe temporel
    max_values = np.maximum.accumulate(bilan_cumule.values, axis=0)
    daily_dhc.values = bilan_cumule.values - max_values
    
    return daily_dhc

#%%

import xarray as xr 
import numpy as np
import math
import pandas as pd
import os
from datetime import datetime
from dateutil.relativedelta import relativedelta


# Directory containing the NetCDF files
directory = "/srv6_tmp5/mbaudewyn/PREVISIONS_DHC/MAR/"

# Filter and extract NetCDF files with a date pattern at the end
nc_files = [f for f in os.listdir(directory) if f.endswith(".nc")]

# Function to extract the date from the filename (format : MAR-ER5-GFS-20240101-YYYYMMDD.nc)
def extract_date(filename):
    try:
        # Extract the 8-character date from the filename before ".nc"
        return filename.split(".nc")[0][-8:]
    except IndexError:
        return None

# Find the file with the most recent date
most_recent_file = max(nc_files, key=lambda f: extract_date(f))

# Open the most recent file using xarray
file_path = os.path.join(directory, most_recent_file)
dataset = xr.open_dataset(file_path)

print(f"Opened file: {most_recent_file}")


#%%


lon = np.array(dataset["LON"].load())
lat = np.array(dataset["LAT"].load())
dates = np.array(dataset["TIME"].load())
dates = pd.DatetimeIndex(dates).normalize().unique()


# ARCHIVE
archive = False
file_path_archive = '/srv6_tmp5/mbaudewyn/PREVISIONS_DHC/DHC/DHC-archive.nc'
if os.path.exists(file_path):
    archive = True
    dataset_archive = xr.open_dataset(file_path_archive)
    dataset_archive = dataset_archive.rename({"time": "TIME", "x": "X21_100", "y": "Y21_70"})
    last_date_archive = pd.to_datetime(dataset_archive['TIME'].max().item())
    DHC_last_date_archive = dataset_archive['DHC'].sel(TIME=last_date_archive)
    first_date_new = last_date_archive + relativedelta(days=1)
    
    dataset = dataset.sel(TIME=slice(first_date_new, None))


try:
    uwind = dataset["U2Z"].sel(ZUVLEV=2).load() # m/s
except:
    uwind = dataset["U2Z"].load() # m/s
    uwind = uwind.squeeze(dim='ZUVLEV1_1')
try:
    vwind = dataset["V2Z"].sel(ZUVLEV=2).load() # m/s
except:
    vwind = dataset["V2Z"].load() # m/s
    vwind = vwind.squeeze(dim='ZUVLEV1_1')
    
try:
    temp_2m = dataset["TTZ"].sel(ZTQLEV=2).load() # °C
except:
    temp_2m = dataset["TTZ"].load() # °C
    temp_2m = temp_2m.squeeze(dim='ZTQLEV1_1')
try:
    rela_humi = dataset["RHZ"].load() # %
except:
    rela_humi = rela_humi.squeeze(dim='ZTQLEV1_1')


rain = dataset["MBRR"].load() # mmWE
snow = dataset["MBSF"].load() # mmWe

swd = dataset["SWD"].load() # W/m²
swd_data = np.array(swd)

albedo = np.array(dataset["SAL"].load()) # no units
nombre_h = len(swd_data[:,0,0])
swu_data = np.zeros((nombre_h,50,80))

for heure in range(0,nombre_h):
    swu_data[heure,:,:] = swd_data[heure,:,:] * albedo[:,:]
    
swu = xr.DataArray(swu_data,dims=("TIME","Y21_70","X21_100"),coords={"TIME":swd["TIME"].data,
                                                                         "Y21_70":swd["Y21_70"].data,
                                                                         "X21_100":swd["X21_100"].data})
lwd = dataset["LWD"].load() # W/m²
lwd_data = np.array(lwd) 

lwu= dataset["LWU"].load() # W/m²
lwu_data = np.array(lwu)

pressure = dataset["SP"].load() # hPa




#traitement des données brutes (conversion basique de hourly en daily)

rela_humi_min = rela_humi.resample(TIME="1D").min(dim="TIME")
rela_humi_max = rela_humi.resample(TIME="1D").max(dim="TIME")
rela_humi_mean = rela_humi.resample(TIME="1D").mean(dim="TIME")
daily_tmin = temp_2m.resample(TIME="1D").min(dim="TIME")
daily_tmax = temp_2m.resample(TIME="1D").max(dim="TIME")
temp_daily = np.array(temp_2m.resample(TIME="1D").mean(dim="TIME"))
precip = snow.resample(TIME="1D").sum(dim="TIME") + rain.resample(TIME="1D").sum(dim="TIME")
vent = np.sqrt(uwind**2 + vwind**2)


vent_daily = np.array(vent.resample(TIME="1D").mean(dim="TIME"))
pressure_data_daily = np.array(pressure.resample(TIME="1D").mean(dim="TIME"))
swd_daily = np.array(swd.resample(TIME="1D").mean(dim="TIME"))
swu_daily = np.array(swu.resample(TIME="1D").mean(dim="TIME"))
lwd_daily = np.array(lwd.resample(TIME="1D").mean(dim="TIME"))
lwu_daily = np.array(lwu.resample(TIME="1D").mean(dim="TIME"))

swd_daily_conv = swd_daily * 0.0864  # conversion de W/m² en MJ/(m²jour)
swu_daily_conv = swu_daily * 0.0864
lwd_daily_conv = lwd_daily * 0.0864
lwu_daily_conv = lwu_daily * 0.0864

rn_daily_conv = (swd_daily_conv - swu_daily_conv) + (lwd_daily_conv - lwu_daily_conv)

precip_daily = np.array(precip)


print("Calcul de l'ETP")
evapot_daily = compute_evapotranspiration(rn_daily_conv,temp_daily,daily_tmin,daily_tmax,
                                          vent_daily,pressure_data_daily,rela_humi_min,rela_humi_max)

print("Calcul de l'index DHC")
bilan_hydrique_daily = precip - evapot_daily

if archive:
    bilan_hydrique_daily_archive = dataset_archive['BILAN']
    
    
    bilan_hydrique_daily = xr.concat([bilan_hydrique_daily_archive,
                                           bilan_hydrique_daily],dim="TIME")

deficit_daily = compute_deficit_hydrique_cumule(bilan_hydrique_daily)


print("Save NetCDF")
"""
da_evapotranspiration = xr.DataArray(evapot_daily,
                                     dims=("time","y","x"),
                                     coords={"time":dates,
                                             "y":dataset["Y21_70"].data,
                                             "x":dataset["X21_100"].data},
                                     attrs={"long_name": 'Potential Evapotranspiration (Penman-Monteith)',
                                            "units": 'mm'})


da_precip = xr.DataArray(precip_daily,
                         dims=("time","y","x"),
                         coords={"time":dates,
                                 "y":dataset["Y21_70"].data,
                                 "x":dataset["X21_100"].data},
                         attrs={"long_name": 'Precipitation',
                                "units": 'mm'})
"""

da_bilan = xr.DataArray(bilan_hydrique_daily,
                        dims=("time","y","x"),
                        coords={"time":dates,
                                "y":dataset["Y21_70"].data,
                                "x":dataset["X21_100"].data},
                        attrs={"long_name": 'Bilan hydrique journalier (P-ETP)',
                               "units": 'mm'})

da_deficit = xr.DataArray(deficit_daily,
                        dims=("time","y","x"),
                        coords={"time":dates,
                                "y":dataset["Y21_70"].data,
                                "x":dataset["X21_100"].data},
                        attrs={"long_name": 'Deficit hydrique cumule (P-ETP)',
                               "units": 'mm'})

da_jour = xr.DataArray(dates[:].day,
                       dims=("time"),
                       coords={"time":dates},
                       attrs={"long_name": 'Date(YYYY MM DD HH)',
                              "units": 'YYYYMMDDHH'})

da_month = xr.DataArray(dates[:].month,
                        dims=("time"),
                        coords={"time":dates},
                        attrs={"long_name": 'Date(YYYY MM DD HH)',
                               "units": 'YYYYMMDDHH'})

da_year = xr.DataArray(dates[:].year,
                       dims=("time"),
                       coords={"time":dates},
                       attrs={"long_name": 'Date(YYYY MM DD HH)',
                              "units": 'YYYYMMDDHH'})

da_lon = xr.DataArray(lon,
                      dims=("y","x"),
                      coords={
                          "y":dataset["Y21_70"].data,
                          "x":dataset["X21_100"].data},
                      attrs={"long_name": 'Longitude',
                             "units": 'degrees'})
da_lat = xr.DataArray(lat,
                      dims=("y","x"),
                      coords={
                          "y":dataset["Y21_70"].data,
                          "x":dataset["X21_100"].data},
                      attrs={"long_name": 'Latitude',
                             "units": 'degrees'})


ds_to_save = xr.Dataset({"YYYY":da_year,"MM":da_month,"DD":da_jour,
                         "LON":da_lon,"LAT":da_lat,"BILAN":da_bilan,"DHC":da_deficit})

today_date = datetime.today()

ds_to_save.to_netcdf(f"/srv6_tmp5/mbaudewyn/PREVISIONS_DHC/DHC/DHC-{today_date.strftime('%Y%m%d')}.nc",unlimited_dims="time",engine="netcdf4",format="NETCDF4")

