############################################
# ENANDES PROJECT
# November/2024
# Functions and packages
############################################

from functions_produtividade import *

# --- Configuration ---
ANOS_NDVI_ANTERIORES = 2
NDVI_POR_ANO = 23
ANO_INICIO = 2000
ANO_FIM = 2024 # Inclusive, adjust if needed
ARQUIVO_PRODUTIVIDADE = 'produtividade_soja_mt.csv' # CSV: 'Ano', 'Produtividade'
PASTA_DADOS_NDVI = 'dados_ndvi_pkl/' # Folder containing ndvi_YYYY.pkl files
ARQUIVO_SHAPEFILE_MT = 'limite_mt.shp' # Path to Mato Grosso shapefile

# --- 2. Main Execution ---

# Load initial data
produtividade_series = carregar_produtividade(ARQUIVO_PRODUTIVIDADE)
gdf_mt = carregar_mascara_mt(ARQUIVO_SHAPEFILE_MT)

if produtividade_series is None or gdf_mt is None:
    print("Exiting due to errors loading initial data.")
    exit()

# Process NDVI data year by year
all_ndvi_data = {}
anos_processar = range(ANO_INICIO, ANO_FIM + 1)

print("\n--- Processing NDVI Data ---")
for ano in anos_processar:
    print(f"Processing year: {ano}")
    # Pass a copy of the mask geometry to avoid issues if reprojection happens multiple times
    ndvi_means = processar_ndvi_anual(ano, PASTA_DADOS_NDVI, gdf_mt.copy())
    if ndvi_means is not None:
        all_ndvi_data[ano] = ndvi_means
    else:
        print(f"Failed to process NDVI for year {ano}. This year will be excluded.")

# Convert processed NDVI data to DataFrame
ndvi_df = pd.DataFrame.from_dict(all_ndvi_data, orient='index', columns=[f'NDVI_{i+1}' for i in range(NDVI_POR_ANO)])
ndvi_df.index.name = 'Ano'
ndvi_df = ndvi_df.sort_index()

if ndvi_df.empty:
    print("Error: No valid NDVI data could be processed. Exiting.")
    exit()

print("\n--- Processed Mean NDVI Data (first 5 rows) ---")
print(ndvi_df.head())

# --- 3. Data Preparation for LSTM ---

# Align productivity and NDVI data by available years
common_years = produtividade_series.index.intersection(ndvi_df.index)
produtividade_aligned = produtividade_series.loc[common_years]
ndvi_aligned = ndvi_df.loc[common_years]

print(f"\nYears with aligned Productivity and processed NDVI data: {len(common_years)}")
if len(common_years) < ANOS_NDVI_ANTERIORES + 1:
    print(f"Error: Need at least {ANOS_NDVI_ANTERIORES + 1} years of aligned data. Found {len(common_years)}. Exiting.")
    exit()

# Create sequences
X = []
y = []
target_years_list = [] # Keep track of the year being predicted

for i in range(len(common_years) - ANOS_NDVI_ANTERIORES):
    target_year_index = i + ANOS_NDVI_ANTERIORES
    target_year = common_years[target_year_index]

    input_years_indices = common_years[i : i + ANOS_NDVI_ANTERIORES]

    # Get NDVI data for the two preceding years and flatten
    ndvi_sequence = ndvi_aligned.loc[input_years_indices].values.flatten()

    # Check for NaNs in the sequence (might occur if filling failed)
    if np.isnan(ndvi_sequence).any():
        print(f"Warning: Skipping year {target_year} due to NaN values in input NDVI sequence.")
        continue

    X.append(ndvi_sequence)
    y.append(produtividade_aligned.loc[target_year])
    target_years_list.append(target_year)

if not X:
    print("Error: No valid sequences could be created. Check data processing and alignment. Exiting.")
    exit()

X = np.array(X)
y = np.array(y)
target_years_list = np.array(target_years_list)

print(f"\nInput sequences shape (X): {X.shape}") # (samples, timesteps*features)
print(f"Target values shape (y): {y.shape}")   # (samples,)
print(f"Years being predicted: {target_years_list}")

# Scaling
scaler_X = MinMaxScaler(feature_range=(0, 1))
X_scaled = scaler_X.fit_transform(X)

scaler_y = MinMaxScaler(feature_range=(0, 1))
y_scaled = scaler_y.fit_transform(y.reshape(-1, 1))

# Reshape for LSTM: [samples, timesteps, features]
# Here, timesteps = sequence length (46), features = 1
n_timesteps = X.shape[1]
n_features = 1
X_scaled_lstm = X_scaled.reshape((X_scaled.shape[0], n_timesteps, n_features))
print(f"Reshaped X for LSTM: {X_scaled_lstm.shape}")

# Train/Test Split (Temporal)
test_ratio = 0.2
n_test_samples = max(1, int(len(X_scaled_lstm) * test_ratio))
split_index = len(X_scaled_lstm) - n_test_samples

X_train, X_test = X_scaled_lstm[:split_index], X_scaled_lstm[split_index:]
y_train, y_test = y_scaled[:split_index], y_scaled[split_index:]
y_test_original = y[split_index:] # Original scale for evaluation
test_years = target_years_list[split_index:] # Years corresponding to the test set

print(f"Training samples: {len(X_train)}, Test samples: {len(X_test)}")
print(f"Test years: {test_years}")

# --- 4. LSTM Model Building and Training ---

model = Sequential()
model.add(Input(shape=(n_timesteps, n_features)))
model.add(LSTM(units=64, activation='relu', return_sequences=True)) # More units, return sequences for next LSTM
model.add(Dropout(0.2)) # Add dropout for regularization
model.add(LSTM(units=32, activation='relu', return_sequences=False))
model.add(Dropout(0.2))
model.add(Dense(units=1)) # Output layer for regression

model.compile(optimizer='adam', loss='mean_squared_error')
model.summary()

early_stopping = EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=True)

history = model.fit(
    X_train, y_train,
    epochs=150,
    batch_size=8, # Smaller batch size can sometimes help
    validation_split=0.15, # Use 15% of training data for validation
    callbacks=[early_stopping],
    verbose=1
)

# --- 5. Evaluation ---

y_pred_scaled = model.predict(X_test)
y_pred = scaler_y.inverse_transform(y_pred_scaled)

from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

mse = mean_squared_error(y_test_original, y_pred)
mae = mean_absolute_error(y_test_original, y_pred)
r2 = r2_score(y_test_original, y_pred)

print("\n--- Test Set Evaluation ---")
print(f"Test Years: {test_years}")
print(f"Mean Squared Error (MSE): {mse:.4f}")
print(f"Mean Absolute Error (MAE): {mae:.4f} (unit: productivity)")
print(f"R-squared (R²): {r2:.4f}")

# --- 6. Visualization ---

plt.figure(figsize=(14, 7))
plt.plot(test_years, y_test_original, marker='o', linestyle='-', label='Real Productivity')
plt.plot(test_years, y_pred.flatten(), marker='x', linestyle='--', label='Predicted Productivity (LSTM)')
plt.title('Mato Grosso Soybean Productivity: Real vs. Predicted (Test Set)')
plt.xlabel('Year')
plt.ylabel('Productivity')
plt.xticks(test_years, rotation=45)
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

plt.figure(figsize=(10, 5))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss During Training')
plt.xlabel('Epoch')
plt.ylabel('Mean Squared Error (Scaled)')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

print("\nScript finished.")