############################################
# ENANDES PROJECT 
# November/2024
# Functions and packages - modelo produtividade
############################################
import pandas as pd
import numpy as np
import pickle  # Use 'import pickle5 as pickle' if your files need it
import geopandas as gpd
import rioxarray
import xarray as xa # rioxarray relies on xarray
from shapely.geometry import mapping # To convert GeoDataFrame geometry for clipping
import os
import os.path as osp
import math # Importar math para usar ceil
#from sklearn.model_selection import train_test_split
#from sklearn.preprocessing import MinMaxScaler
#from tensorflow.keras.models import Sequential
#from tensorflow.keras.layers import LSTM, Dense, Input, Dropout
#from tensorflow.keras.callbacks import EarlyStopping
import matplotlib.pyplot as plt



# --- 1. Data Loading Functions ---

def carregar_produtividade(filepath):
    """Loads annual productivity data."""
    try:
        df_prod = pd.read_csv(filepath, sep='\t')
        if 'Ano' not in df_prod.columns or 'Produtividade' not in df_prod.columns:
            raise ValueError("CSV must have 'Ano' and 'Produtividade' columns.")
        df_prod = df_prod.set_index('Ano').sort_index()
        print(f"Productivity data loaded. Years: {df_prod.index.min()}-{df_prod.index.max()}")
        return df_prod['Produtividade']
    except FileNotFoundError:
        print(f"Error: Productivity file '{filepath}' not found.")
        return None
    except Exception as e:
        print(f"Error loading productivity data: {e}")
        return None

def carregar_mascara_mt(shapefile_path):
    """Loads the Mato Grosso shapefile."""
    try:
        gdf_mt = gpd.read_file(shapefile_path)
        # Ensure it's a single polygon/multipolygon for simplicity
        if len(gdf_mt) > 1:
            print("Warning: Shapefile has multiple features. Dissolving into one.")
            gdf_mt = gdf_mt.dissolve()
        print(f"Mato Grosso mask loaded. Original CRS: {gdf_mt.crs}")
        return gdf_mt
    except Exception as e:
        print(f"Error loading shapefile '{shapefile_path}': {e}")
        return None

def processar_ndvi_anual(ano, pasta_ndvi, gdf_mt_mask):
    """
    Loads NDVI data for a year from a pickle file, masks it using the MT geometry,
    and calculates the mean NDVI for each of the 23 observations.

    **ASSUMES**:
    - Pickle file named 'ndvi_{ano}.pkl' exists in 'pasta_ndvi'.
    - Pickle contains a list/iterable of 23 items.
    - Each item is loadable by rioxarray (e.g., xarray.DataArray with spatial info)
      OR is a NumPy array that can be georeferenced (requires more complex handling).
    """
    filepath = os.path.join(pasta_ndvi, f'ndvi_{ano}.tif')
    if not os.path.exists(filepath):
        print(f"Warning: NDVI file not found for year {ano}: {filepath}")
        return None # Indicate missing data for the year

    try:
        with open(filepath, 'rb') as f:
            ndvi_data_anual = pickle.load(f)

        if len(ndvi_data_anual) != NDVI_POR_ANO:
            print(f"Warning: Year {ano} has {len(ndvi_data_anual)} NDVI observations, expected {NDVI_POR_ANO}.")
            # Decide how to handle: return None, pad, use available, etc.
            # For now, return None if count doesn't match strictly.
            return None

        mean_ndvi_values = []
        target_crs = None # Determine CRS from first valid NDVI raster

        for i, ndvi_item in enumerate(ndvi_data_anual):
            try:
                # --- Adapt this section based on your pickle content ---
                # Scenario A: Item is already an xarray.DataArray with CRS
                if isinstance(ndvi_item, xa.DataArray) and hasattr(ndvi_item, 'rio'):
                   rds = ndvi_item # Already suitable for rioxarray
                # Scenario B: Item is NumPy array + needs metadata (MORE COMPLEX)
                elif isinstance(ndvi_item, np.ndarray):
                   # You would need to get transform, crs from somewhere else
                   # and create the DataArray manually. Example placeholder:
                   # height, width = ndvi_item.shape
                   # transform = ??? # Get affine transform
                   # crs = ??? # Get CRS string (e.g., 'EPSG:4326')
                   # rds = xa.DataArray(ndvi_item, dims=('y', 'x'),
                   #                    coords={'y': np.arange(height), 'x': np.arange(width)})
                   # rds.rio.write_crs(crs, inplace=True)
                   # rds.rio.write_transform(transform, inplace=True)
                   print(f"Error: NDVI item {i} in year {ano} is NumPy array - requires manual georeferencing.")
                   mean_ndvi_values.append(np.nan) # Cannot process without geo info
                #    continue
                else:
                     # Attempt to open with rioxarray directly if it's path-like or dataset-like
                     # This might work if pickle contains file paths or certain objects
                    try:
                        # Ensure it's written to a temporary file if needed or handled directly
                        # This part is highly dependent on pickle content
                        rds = rioxarray.open_rasterio(ndvi_item, masked=True)
                        print(f"Info: Opened NDVI item {i} for year {ano} using rioxarray directly.")
                    except Exception as rio_ex:
                        print(f"Error: Cannot interpret NDVI item {i} for year {ano} with rioxarray. Type: {type(ndvi_item)}. Details: {rio_ex}")
                        mean_ndvi_values.append(np.nan)
                        continue

                # --- Geospatial Operations ---
                if rds.rio.crs is None:
                    print(f"Warning: NDVI item {i} for year {ano} has no CRS defined. Skipping.")
                    mean_ndvi_values.append(np.nan)
                    continue

                if target_crs is None:
                    target_crs = rds.rio.crs # Use CRS of the first valid raster
                    print(f"Determined target CRS from NDVI data: {target_crs}")

                # Reproject mask if necessary
                if gdf_mt_mask.crs != target_crs:
                    print(f"Reprojecting mask from {gdf_mt_mask.crs} to {target_crs}")
                    gdf_mt_mask = gdf_mt_mask.to_crs(target_crs)

                # Clip (Mask) the NDVI raster
                # Use the geometry of the (potentially reprojected) mask
                clipped_rds = rds.rio.clip(gdf_mt_mask.geometry.apply(mapping), gdf_mt_mask.crs, drop=True, invert=False)

                # Calculate mean of valid pixels
                # Fill value (nodata) should be handled by `clip` or inherently by raster
                # We calculate mean only on non-NaN values after clipping
                mean_val = clipped_rds.where(clipped_rds != rds.rio.nodata).mean().item()

                if np.isnan(mean_val):
                     print(f"Warning: Mean NDVI is NaN after clipping for item {i}, year {ano}. Possibly no overlapping data.")
                     mean_ndvi_values.append(np.nan) # Use NaN for this observation
                else:
                    mean_ndvi_values.append(mean_val)

            except Exception as item_e:
                print(f"Error processing NDVI item {i} for year {ano}: {item_e}")
                mean_ndvi_values.append(np.nan) # Append NaN if processing fails for an item

        # Check if we got enough valid values
        if len(mean_ndvi_values) != NDVI_POR_ANO:
             print(f"Error: Could not process all {NDVI_POR_ANO} NDVI items for year {ano}. Got {len(mean_ndvi_values)}.")
             return None # Or handle as needed

        # Replace any remaining NaNs if desired (e.g., with yearly mean, ffill, bfill)
        mean_ndvi_values_arr = np.array(mean_ndvi_values)
        if np.isnan(mean_ndvi_values_arr).any():
            # Example: Fill with the mean of the *valid* values for that year
            valid_mean = np.nanmean(mean_ndvi_values_arr)
            if not np.isnan(valid_mean): # Ensure mean is valid
                print(f"Warning: Filling {np.isnan(mean_ndvi_values_arr).sum()} NaN NDVI values in year {ano} with yearly mean ({valid_mean:.4f})")
                mean_ndvi_values_arr[np.isnan(mean_ndvi_values_arr)] = valid_mean
            else:
                print(f"Error: Cannot fill NaNs for year {ano} as yearly mean is also NaN.")
                return None # Cannot proceed if all are NaN

        return mean_ndvi_values_arr.tolist() # Return list of 23 mean values

    except FileNotFoundError: # Handles the initial check again
        print(f"Warning: NDVI file not found for year {ano}: {filepath}")
        return None
    except Exception as e:
        print(f"Error processing pickle file for year {ano}: {e}")
        return None

def concatenate_pkl_area(data_p1, data_p2, directory_path, year):
    combined_df = []
    with open(data_p1.format(year), 'rb') as f1:
        data_1 = pickle.load(f1)
    with open(data_p2.format(year), 'rb') as f2:
        data_2 = pickle.load(f2)

    for i in range(0,23):
        data_1_image = data_1[i]
        data_2_image = data_2[i]
        result = np.concatenate((data_1_image, data_2_image), axis=1)
        combined_df.append(result)
        print(i)

    result_final = np.stack(combined_df)
    
    with open(osp.join(directory_path, 'filtered_{}.pkl'.format(year)), 'wb') as f:  # Python 3: open(..., 'wb')
        pickle.dump(result_final, f)
        
    return print("concatenated")