#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CNN pour prédire ΔM = ds_out - ds_in sur les ice shelves uniquement.
Pixels hors glace sont ignorés via masked loss.
"""

import xarray as xr
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, InputLayer
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

# === Files ===
infile  = "./data/param/fixedisf94/map_2D_quadratic_local_mixedcoeff_SummerPaper_tuning_fixedisf94.nc"
outfile = "./data/param/isf94/map_2D_quadratic_local_mixedcoeff_SummerPaper_tuning_isf94.nc"
var = "melt_m_we_per_y"

# === Load datasets ===
ds_in = xr.open_dataset(infile)[var].sel(time=slice("2014","2100"))
ds_out = xr.open_dataset(outfile)[var].sel(time=slice("2014","2100"))

X = ds_in.values  # (time, y, x)
Y = ds_out.values

# === Compute ΔM ===
DeltaM = Y - X

# === Mask for ice shelves ===
mask = ~np.isnan(DeltaM)  # True where ice shelf exists

# === Normalize only on ice shelves ===
valid_values = DeltaM[mask]
mean = valid_values.mean()
std = valid_values.std()
epsilon = 1e-8

DeltaM_n = np.zeros_like(DeltaM, dtype=np.float32)
DeltaM_n[mask] = (DeltaM[mask] - mean) / (std + epsilon)

# === Add channel dimension ===
Xn = np.nan_to_num(X, nan=0.0)[..., np.newaxis]  # NaN -> 0 for TensorFlow
DeltaM_n = DeltaM_n[..., np.newaxis]

# === Split training / validation (50/50 years) ===
train_idx, val_idx = train_test_split(np.arange(Xn.shape[0]), test_size=0.5, random_state=42)
X_train, X_val = Xn[train_idx], Xn[val_idx]
y_train, y_val = DeltaM_n[train_idx], DeltaM_n[val_idx]

print(f"Training samples: {X_train.shape[0]}, Validation samples: {X_val.shape[0]}")

# === Custom masked MSE ignoring zeros (pixels hors ice shelf) ===
def masked_mse_zero(y_true, y_pred):
    mask = tf.cast(tf.not_equal(y_true, 0.0), tf.float32)
    eps = 1e-8
    return tf.reduce_sum(mask * tf.square(y_true - y_pred)) / (tf.reduce_sum(mask) + eps)

# === CNN Model ===
model = Sequential([
    InputLayer(shape=(Xn.shape[1], Xn.shape[2], 1)),
    Conv2D(16, (3,3), activation='relu', padding='same'),
    Conv2D(32, (3,3), activation='relu', padding='same'),
    Conv2D(16, (3,3), activation='relu', padding='same'),
    Conv2D(1, (1,1), padding='same')
])

model.compile(optimizer="adam", loss=masked_mse_zero)

# === Train ===
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=20,
    batch_size=4,
    shuffle=True
)

# === Predict ΔM ===
DeltaM_pred_n = model.predict(Xn)

# === Un-normalize only on ice shelves ===
DeltaM_pred = np.zeros_like(DeltaM_pred_n, dtype=np.float32)
DeltaM_pred[mask[..., np.newaxis]] = DeltaM_pred_n[mask[..., np.newaxis]] * std + mean

# === Reconstruct Y_pred only on ice shelves ===
Y_pred = np.copy(X)
Y_pred[mask] += DeltaM_pred[mask[..., np.newaxis]].squeeze()

# === Save predicted field ===
pred = xr.DataArray(
    Y_pred,
    dims=("time","y","x"),
    coords=dict(time=ds_in.time, y=ds_in.y, x=ds_in.x),
    name="Melt_predicted"
)
pred.to_netcdf("pred_isf94local2coef_2014-2100.nc")
print("✅ Done — output saved to pred_isf94local2coef_2014-2100.nc")

# === Plot training vs validation loss ===
train_loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(train_loss) + 1)

plt.figure(figsize=(8,5))
plt.plot(epochs, train_loss, 'o-', label='Training Loss')
plt.plot(epochs, val_loss, 's-', label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss (MSE)')
plt.title('Training vs Validation Loss')
plt.legend()
plt.grid(True)
plt.show()
