#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 29 12:45:09 2025

@author: chrisk
"""

import numpy as np
import xarray as xr


file_path = "MARcst-AN35km-176x148.cdf2"

# Open the dataset
ds = xr.open_dataset(file_path, engine="netcdf4")

# Print a summary of the dataset
print(ds)

# ICE MASK (pixel covered with more than 30% of ice)

ice_msk = ds['ICE'].where(ds['ICE']>= 30)

# ICE shelf pixel =  (ice fraction minus grounded ICE) where ICE > 30% and no grounded ice and no rock pixel
ice_shelf_fraction = (ds['ICE']-ds['GROUND']).where( (ds['ICE']>= 30) &  (ds['GROUND']<100) & (ds['ROCK'] ==0) )

# Longitudes 
lon = ds["LON"] # Lat 
lat = ds["LAT"]

########## Opening the SMB file
file_path_SMB = "mon-SMB-MAR_ERA5-1979-2024.nc"

mon_SMB = xr.open_dataset(file_path_SMB, engine="netcdf4")
mon_SMB = mon_SMB.sel(SECTOR=1, drop=True)

# We give the LON and LAT values to the dataset by creating a dataset with 1 of same dimension 
# than SMB multiplied by lon and lat values
mon_SMB["LON"]= xr.ones_like(mon_SMB['SMB'])*lon.values
mon_SMB["LAT"] = xr.ones_like(mon_SMB['SMB'])*lat.values



lat_min = -80
lat_max = -60.005972
lon_min = -14.878589
lon_max = 65.439792

# We cut the grid using the box coordinates

# Mask the dataset
ds_box = mon_SMB.where(
    (mon_SMB["LAT"] >= lat_min) & (mon_SMB["LAT"] <= lat_max) &
    (mon_SMB["LON"] >= lon_min) & (mon_SMB["LON"] <= lon_max), drop=True
)

ice_msk_cut = ice_msk.where(
    (ds["LAT"] >= lat_min) & (ds["LAT"] <= lat_max) &
    (ds["LON"] >= lon_min) & (ds["LON"] <= lon_max), drop=True
)




import geopandas as gpd

# -----------------------------

shapefile_path = "./antaws-dataset-cmd9p6/Shp/267AWS.shp"
stations_gdf = gpd.read_file(shapefile_path)

# Check columns name => Zhandian is the name, lat and lon for coordinates
print(stations_gdf.columns) 

stations_list_name = []
stations_list_lon = []
stations_list_lat= []

for idx, row in stations_gdf.iterrows():
    lat = row["lat"]
    lon = row["lon"]
    name = row["zhandian"]
    if lon < 0:
        lon = lon + 360
    #print (name,lon,lat)
    
    if lat_min <= lat <= lat_max and lon_min <= lon <= lon_max:
        print (name,lon,lat)
        stations_list_name.append(name)
        stations_list_lon.append(lon)
        stations_list_lat.append(lat)



print("Stations inside the box :")
for s in stations_list_name:
    print(s)


######
from pyproj import Transformer
transformer = Transformer.from_crs(
    "EPSG:4326",   # lon/lat
    "EPSG:3031",   # Antarctic Polar Stereographic
    always_xy=True
)
x_station = []
y_station = []

for lon, lat in zip(stations_list_lon, stations_list_lat):
    x, y = transformer.transform(lon, lat)
    x_station.append(x)
    y_station.append(y)
    print(x,y)
    
#Now we also have the projected x and y coordinates for the stations

x_station = np.array(x_station)
y_station = np.array(y_station)



######### A quick map with MAR and the stations

SMB_cut = ds_box["SMB"].sel(TIME=slice("2000-01-01","2000-12-31")).sum("TIME")  # Annual SMB for 2000
SMB_cut = SMB_cut.where(ice_msk_cut.values >= 30)


import matplotlib.colors as colors

vmax = abs(SMB_cut).max().item()

#Coulour centered around 0
norm = colors.TwoSlopeNorm(
    vmin=-vmax,
    vcenter=0.0,
    vmax=vmax
)

import matplotlib.pyplot as plt

X = ds_box["X"].values * 1000
Y = ds_box["Y"].values * 1000

plt.figure(figsize=(8, 8))


pcm = plt.pcolormesh(
    X, Y,
    SMB_cut.squeeze(),
    shading="auto",
    cmap="RdBu_r",
    norm=norm
)

plt.colorbar(pcm, label="SMB (mm w.e. / mois)")


# Stations : croix noires
plt.scatter(
    x_station, y_station,
    marker="x",
    color="black",
    s=60,
    zorder=5
)

# Noms des stations
for name, x, y in zip(stations_list_name, x_station, y_station):
    plt.text(
        x + 10_000, y + 10_000, name,
        fontsize=8,
        color="black"
    )

plt.xlabel("X (m, EPSG:3031)")
plt.ylabel("Y (m, EPSG:3031)")
plt.title("SMB MAR – Stations AWS")

plt.axis("equal")
plt.tight_layout()
plt.show()


