############################################
# ENANDES PROJECT
# November/2024
# Functions and packages
############################################


import os
import os.path as osp
import pickle
import json
import time
import pystac_client
import osgeo 
import osgeo_utils 
from osgeo import gdal
from osgeo import ogr, osr
import pandas as pd 
import netCDF4
import numpy as np
import shapely
import rpy2
from matplotlib import pyplot as plt
from pyproj import Transformer
from pyproj.crs import CRS
import rasterio
from rasterio.crs import CRS
from rasterio.warp import transform
from rasterio.windows import Window
from rasterio.windows import from_bounds
from rasterio.enums import Resampling
from rasterio.io import MemoryFile
from rasterio.transform import Affine
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import scipy
from joblib import Parallel, delayed
from scipy.ndimage import generic_filter
from scipy.signal import savgol_filter
import matplotlib.pyplot as plt
#%matplotlib inline

# Prepare and run read and mask function
def run_read_and_mask(set_data, set_tile, n_processors, set_period):
    #### Using STAC
    service = pystac_client.Client.open('https://data.inpe.br/bdc/stac/v1/')
    #### listing data and retrieving items
    collection = service.get_collection(set_data)
    item_search = service.search(query={'bdc:tiles': {'eq': set_tile}},
                                datetime=set_period,
                                collections=[set_data])
    item_search.matched()
    items = list(item_search.items())
    # read and mask data
    uris = [(item.assets['NDVI'].href, item.assets['pixel_reliability'].href) for item in items]
    ndvi_3d = Parallel(n_jobs=n_processors)(delayed(read_and_mask)([uris[i]]) for i in range(len(uris)))
    ndvi_3d = np.array(ndvi_3d)
    t, extra, x, y = ndvi_3d.shape
    ndvi_3d = ndvi_3d.reshape(t, x, y)
    
    return ndvi_3d

def run_read_no_mask(set_data, set_tile, n_processors, set_period):
    #### Using STAC
    service = pystac_client.Client.open('https://data.inpe.br/bdc/stac/v1/')
    #### listing data and retrieving items
    collection = service.get_collection(set_data)
    item_search = service.search(query={'bdc:tiles': {'eq': set_tile}},
                                datetime=set_period,
                                collections=[set_data])
    item_search.matched()
    items = list(item_search.items())
    # read and mask data
    uris = [(item.assets['NDVI'].href, item.assets['pixel_reliability'].href) for item in items]
    ndvi_3d = Parallel(n_jobs=n_processors)(delayed(read_no_mask)([uris[i]]) for i in range(len(uris)))
    ndvi_3d = np.array(ndvi_3d)
    t, extra, x, y = ndvi_3d.shape
    ndvi_3d = ndvi_3d.reshape(t, x, y)
    
    return ndvi_3d

# read and mask data using quality reliability parameter
def read_and_mask(uris, crs=None):
    """Lê cada raster como um numpy.ma.masked_array, aplicando máscara de confiabilidade."""
    from rasterio.crs import CRS
    source_crs = CRS.from_string('EPSG:4326')
    if crs:
        source_crs = CRS.from_string(crs)
    if not uris:
        raise ValueError("A lista de URIs está vazia. Verifique se há itens retornados pela busca.")
    num_images = len(uris)

    # Todas as imagens têm a mesma dimensão, então pegamos a altura e largura da primeira
    with rasterio.open(uris[0][0]) as first_dataset:
        height, width = first_dataset.height, first_dataset.width
    ndvi_masked_data = np.zeros((num_images, height, width))
    
    for i, (ndvi_uri, pixR_uri) in enumerate(uris):
        try:
            with rasterio.open(ndvi_uri) as ndvi_dataset, \
                 rasterio.open(pixR_uri) as pixel_reliability_dataset:

                # Ler os dados das bandas
                ndvi_data = ndvi_dataset.read(1)
                pixR_data = pixel_reliability_dataset.read(1)

                # Aplicar a máscara e armazenar os resultados nas matrizes
                #ndvi_masked_data[i] = np.ma.masked_array(ndvi_data, mask=mask) if masked else ndvi_data
                ndvi_masked_data[i] = np.where(pixR_data != 0, float('nan'), ndvi_data)

        except Exception as e:
            print(f"Erro ao processar os dados da imagem {i}: {e}")

    return ndvi_masked_data/10000

# read and mask data using quality reliability parameter
def read_no_mask(uris, crs=None):
    """Lê cada raster como um numpy.ma.masked_array, aplicando máscara de confiabilidade."""
    from rasterio.crs import CRS
    source_crs = CRS.from_string('EPSG:4326')
    if crs:
        source_crs = CRS.from_string(crs)
    if not uris:
        raise ValueError("A lista de URIs está vazia. Verifique se há itens retornados pela busca.")
    num_images = len(uris)

    # Todas as imagens têm a mesma dimensão, então pegamos a altura e largura da primeira
    with rasterio.open(uris[0][0]) as first_dataset:
        height, width = first_dataset.height, first_dataset.width
    ndvi_masked_data = np.zeros((num_images, height, width))
    
    for i, (ndvi_uri, pixR_uri) in enumerate(uris):
        try:
            with rasterio.open(ndvi_uri) as ndvi_dataset, \
                 rasterio.open(pixR_uri) as pixel_reliability_dataset:

                # Ler os dados das bandas
                ndvi_data = ndvi_dataset.read(1)
                pixR_data = pixel_reliability_dataset.read(1)

                # Aplicar a máscara e armazenar os resultados nas matrizes
                #ndvi_masked_data[i] = np.ma.masked_array(ndvi_data, mask=mask) if masked else ndvi_data
                ndvi_masked_data[i] = ndvi_data

        except Exception as e:
            print(f"Erro ao processar os dados da imagem {i}: {e}")

    return ndvi_masked_data/10000


# funtion to interpolate and apply savgol filter
def savitzky_golay_filtering(timeseries, wnds=[10, 7], orders=[2, 4], debug=False):                                     
    interp_ts = pd.Series(timeseries)
    #interp_ts = interp_ts.interpolate(method='linear', limit=14)
    interp_ts = interp_ts.interpolate(method='linear', limit_direction='both')
    smooth_ts = interp_ts                                                                                              
    wnd, order = wnds[0], orders[0]
    F = 1e8
    W = None
    it = 0                                                                                                             
    while True:
        smoother_ts = savgol_filter(smooth_ts, window_length=wnd, polyorder=order)                                     
        diff = smoother_ts - interp_ts
        sign = diff > 0                                                                                                                       
        if W is None:
            W = 1 - np.abs(diff) / np.max(np.abs(diff)) * sign                                                         
            wnd, order = wnds[1], orders[1]                                                                            
        fitting_score = np.sum(np.abs(diff) * W)                                                                       
        #print(it, ' : ', fitting_score)
        if fitting_score > F:
            break
        else:
            F = fitting_score
            it += 1        
        smooth_ts = smoother_ts * sign + interp_ts * (1 - sign)
    if debug:
        return smooth_ts, interp_ts
    return smooth_ts

def apply_filter_savgol_2d(ndvi_3d_teste):
    new_data = []
    data_3d_array = []
    t, y = ndvi_3d_teste.shape
    for ydim in range(y):
        #print("ydim", ydim)
        if len(np.isnan(ndvi_3d_teste[:, ydim])) - np.sum(np.isnan(ndvi_3d_teste[:, ydim])) < 5: 
            new_data.append(ndvi_3d_teste[:, ydim])
        else:
            new_data.append(savitzky_golay_filtering(ndvi_3d_teste[:, ydim]))
    data_3d_array = np.array(new_data)
    ndvi_interp = data_3d_array.reshape(y, t) 
    ndvi_interp_t = np.transpose(ndvi_interp, (1,0))
    return ndvi_interp_t



def run_filter(ndvi_3d, n_processors):
    t, x, y = ndvi_3d.shape
    resultado = Parallel(n_jobs=n_processors)(delayed(apply_filter_savgol_2d)(ndvi_3d[:,i,:]) for i in range(x))
    resultado_array = np.array(resultado) # x, t, y e ndvi (t,x,y)
    resultado_array = np.transpose(resultado_array, (1,0, 2))

    return resultado_array


def calculate_stats(data, stats_used):
    int_img=0
    stats_images = []
    for images in range(0,23):
        img1 = img = []
        for years in range(int(data.shape[0]/23) + 1):
            if (years > int(data.shape[0]/23)) and (images > (data.shape[0]%23)): 
                StopIteration
            if (int_img >= data.shape[0]):
                StopIteration
            else:
                img1.append(data[int_img,:,:])
                int_img=int_img+23
            #print(years)
            #print(int_img)
        img = np.array(img1)
        if stats_used == "average": img_stats = np.average(img, axis=0)
        if stats_used == "maximum": img_stats = np.max(img, axis=0)
        if stats_used == "minimum": img_stats = np.min(img, axis=0)
        if stats_used == "variance": img_stats = (np.var(img, axis=0))*img.shape[0]
        int_img=images+1
        stats_images.append(img_stats)
    stats_images = np.array(stats_images)

    return stats_images

def calculate_via(data, average_data, var_data):
    via = stats_images = []
    for j in range(0,(data.shape[0])): 
        i = j % 23
        via.append((data[j,:,:] - average_data[i,:,:]) / np.sqrt(var_data[i,:,:]))
    stats_images = np.array(via)

    return stats_images


def calculate_vci(data, max_data, min_data, fill=np.nan):
    vci = stats_images = []
    for j in range(0,(data.shape[0])): 
        i = j % 23
        via_calculated = (data[j,:,:] - min_data[i,:,:]) / (max_data[i,:,:] - min_data[i,:,:])
        if np.isscalar(via_calculated):
            return via_calculated if np.isfinite(via_calculated) \
                else fill
        else:
            via_calculated[ ~ np.isfinite(via_calculated)] = fill
        vci.append(via_calculated)
    stats_images = np.array(vci)

    return stats_images

def scatter_plot_with_correlation_line(x, y, graph_filepath):
    plt.scatter(x, y)
    axes = plt.gca()
    m, b = np.polyfit(x, y, 1)
    X_plot = np.linspace(axes.get_xlim()[0],axes.get_xlim()[1],100)
    plt.plot(X_plot, m*X_plot + b, '-')
    plt.savefig(graph_filepath, dpi=300, format='png', bbox_inches='tight')

def run_processing(masked_data_s1, name_obj, type_obj):
    # filter masked data
    filtered_data = run_filter(masked_data_s1, n_processors = 10)
    #print("filtered")

    # statistics - climatology
    average_data = calculate_stats(data = filtered_data, stats_used = "average")
    max_data = calculate_stats(data = filtered_data, stats_used = "maximum")
    min_data = calculate_stats(data = filtered_data, stats_used = "minimum")
    var_data = calculate_stats(data = filtered_data, stats_used = "variance")

    # statistics - anomaly
    via_data = calculate_via(data = filtered_data, average_data = average_data, var_data = var_data)
    vci_data = calculate_vci(data = filtered_data, max_data = max_data, min_data = min_data)

    # save objects (statistics climatology and anomaly)
    with open(name_obj, 'wb') as f:  # Python 3: open(..., 'wb')
        if type_obj == "filtered_data": 
            return pickle.dump([filtered_data], f)
        if type_obj == "average_data": 
            return pickle.dump([average_data], f)
        if type_obj == "max_data": 
            return pickle.dump([max_data], f)
        if type_obj == "min_data": 
            return pickle.dump([min_data], f)
        if type_obj == "var_data": 
            return pickle.dump([var_data], f)
        if type_obj == "via_data": 
            return pickle.dump([via_data], f)
        if type_obj == "vci_data": 
            return pickle.dump([vci_data], f)

    return print("Objs saved")

def concatenate_pkl(directory_path, name_obj, i, j, name_var):
    combined_df = []
    for i in range(i,j): #replace 10 for 48
        with open(name_obj.format(i), 'rb') as f:
            data = pickle.load(f)
            data_ave = data[0]
        combined_df.append(data_ave)
        print(i)
    result = np.concatenate(combined_df, axis=2)

    with open(osp.join(directory_path, 'result{}.pkl'.format(name_var)), 'wb') as f:  # Python 3: open(..., 'wb')
        pickle.dump(result, f)
        
    return print("concatenated")


def concatenate_parts(directory_path, name_obj1, name_obj2, name_var):

    with open(name_obj1, 'rb') as f:
        data_p1 = pickle.load(f)

    with open(name_obj2, 'rb') as f:
        data_p2 = pickle.load(f)
        
    result = np.concatenate((data_p1, data_p2), axis=2)
    print("concatenated")

    with open(osp.join(directory_path, 'result{}.pkl'.format(name_var)), 'wb') as f:  # Python 3: open(..., 'wb')
        pickle.dump(result, f)
        
    return print("results")




def run_processing_stats(masked_data_s1, name_obj):
    # filter masked data
    filtered_data = run_filter(masked_data_s1, n_processors = 8)
    average_data = calculate_stats(data = filtered_data, stats_used = "average")
    max_data = calculate_stats(data = filtered_data, stats_used = "maximum")
    min_data = calculate_stats(data = filtered_data, stats_used = "minimum")
    var_data = calculate_stats(data = filtered_data, stats_used = "variance")
    via = calculate_via(data = filtered_data, average_data = average_data, var_data = var_data)
    vci = calculate_vci(data = filtered_data, max_data = max_data, min_data = min_data)

    # save objects (statistics climatology and anomaly)
    with open(name_obj, 'wb') as f:  # Python 3: open(..., 'wb')
        pickle.dump([average_data, max_data, min_data, var_data], f)
    return print("Objs saved")

def concatenate_pkl_stats(directory_path, variable_index):
    combined_df = []
    name_var = []
    i=0
    for i in range(i,48): #replace 10 for 48
        with open(osp.join(directory_path, 'objs_slice{}.pkl'.format(i)), 'rb') as f:
            data = pickle.load(f)
            data_ave = data[variable_index]
        combined_df.append(data_ave)
        print(i)
    result = np.concatenate(combined_df, axis=2)
    if variable_index == 0: name_var = 'Ave'
    if variable_index == 1: name_var = 'Max'
    if variable_index == 2: name_var = 'Min'
    if variable_index == 3: name_var = 'Var'

    with open(osp.join(directory_path, 'result{}.pkl'.format(name_var)), 'wb') as f:  # Python 3: open(..., 'wb')
        pickle.dump(result, f)



def get_spatial_reference_from_stac(tile_id, collection_id, datetime_range, asset_key='NDVI'):
    """
    Busca um item STAC e extrai informações de referência espacial de um asset específico.

    Args:
        tile_id (str): O ID do tile (ex: '013012').
        collection_id (str): O ID da coleção STAC (ex: 'mod13q1-6.1').
        datetime_range (str): O intervalo de data/hora para a busca (ex: '2001-01-01/2001-02-10').
        asset_key (str, optional): A chave do asset do qual extrair os metadados
                                   (ex: 'NDVI', 'EVI'). Padrão 'NDVI'.

    Returns:
        tuple: (crs, transform, width, height) ou (None, None, None, None) se não encontrado.
               crs (rasterio.crs.CRS): Sistema de Referência de Coordenadas.
               transform (rasterio.Affine): Transformação afim.
               width (int): Largura da imagem do asset.
               height (int): Altura da imagem do asset.
    """
    print(f"Buscando item STAC para tile '{tile_id}', coleção '{collection_id}', data '{datetime_range}'...")
    try:
        service = pystac_client.Client.open('https://data.inpe.br/bdc/stac/v1/')
        item_search = service.search(
            query={'bdc:tiles': {'eq': tile_id}},
            datetime=datetime_range,
            collections=[collection_id]
        )

        items = list(item_search.items())
        if not items:
            print("Nenhum item STAC encontrado para os critérios fornecidos.")
            return None, None, None, None

        # Usando o primeiro item encontrado. No seu exemplo era items[1],
        # mas items[0] é geralmente o mais relevante se houver apenas um resultado esperado.
        # Se múltiplos itens são esperados e você precisa de um específico, ajuste a lógica.
        stac_item = items[0] # Ou items[1] conforme seu código original, se souber o porquê.
        print(f"Item STAC ID: {stac_item.id}")

        if asset_key not in stac_item.assets:
            print(f"Asset '{asset_key}' não encontrado no item STAC. Assets disponíveis: {list(stac_item.assets.keys())}")
            # Tentar com uma chave comum se a preferida não existir
            available_data_assets = [k for k in stac_item.assets.keys() if k.upper() in ['NDVI', 'EVI', 'RED', 'NIR']]
            if not available_data_assets:
                print("Nenhum asset de dados comum (NDVI, EVI, RED, NIR) encontrado.")
                return None, None, None, None
            asset_key_found = available_data_assets[0]
            print(f"Usando asset alternativo: '{asset_key_found}'")
            asset = stac_item.assets[asset_key_found]
        else:
            asset = stac_item.assets[asset_key]

        print(f"Extraindo metadados do asset: '{asset_key}'")

        # Extrair CRS (EPSG code)
        # epsg_code = asset.extra_fields.get('proj:epsg')
        epsg_code = '4326'
        if epsg_code is None:
            print("Código EPSG não encontrado no asset.")
            return None, None, None, None
        crs = rasterio.crs.CRS.from_epsg(epsg_code)
        print(crs)

        # from rasterio.crs import CRS
        # source_crs = CRS.from_string('EPSG:4326')
        # if crs:
        #     crs = CRS.from_string(crs)

        # Extrair Transform (geotransform)
        gdal_transform = asset.extra_fields.get('proj:transform')
        if gdal_transform is None:
            print("Transformação (geotransform) não encontrada no asset.")
            # Tentar calcular a partir do bbox e shape, se disponíveis
            # proj_bbox = asset.extra_fields.get('bbox')
            proj_bbox = stac_item.bbox
            #proj_shape = asset.extra_fields.get('proj:shape') # [height, width]
            proj_shape = [4800, 4800]
            if proj_bbox and proj_shape and len(proj_bbox) == 4 and len(proj_shape) == 2:
                print("Tentando calcular transform a partir de proj:bbox e proj:shape.")
                xmin, ymin, xmax, ymax = proj_bbox
                height_pixels, width_pixels = proj_shape
                # GDAL transform: [ulx, x_res, x_skew, uly, y_skew, y_res (negativo)]
                x_res = (xmax - xmin) / width_pixels
                y_res = (ymax - ymin) / height_pixels # y_res é negativo no transform do GDAL
                transform = Affine(x_res, 0.0, xmin, 0.0, -abs(y_res), ymax) # abs para garantir que y_res seja negativo
                print(f"Transform calculada: {transform}")
            else:
                print("Não foi possível calcular a transformação: proj:bbox ou proj:shape ausentes/inválidos.")
                return None, None, None, None
        else:
            transform = Affine.from_gdal(*gdal_transform)


        # Extrair dimensões (shape)
        shape = [4800, 4800]
        if shape is None or len(shape) != 2:
            print("Dimensões (shape) não encontradas ou inválidas no asset.")
            return None, None, None, None
        height, width = shape[0], shape[1]

        print(f"CRS: {crs}, Transform: {transform}, Width: {width}, Height: {height}")
        return crs, transform, width, height

    except Exception as e:
        print(f"Erro ao buscar ou processar item STAC: {e}")
        return None, None, None, None


def pkl_to_cog_with_stac_ref(pkl_path, output_cog_path,
                             stac_tile_id, stac_collection_id, stac_datetime_range,
                             stac_asset_key='NDVI',
                             num_bands_pkl = 23, # Número de bandas esperado no arquivo PKL
                             dtype_pkl='float32', # Tipo de dado esperado no arquivo PKL
                             overviews_levels=None,
                             block_size=512, compression='DEFLATE'):
    """
    Converte dados de um arquivo .pkl para um Cloud Optimized GeoTIFF (COG),
    usando referência espacial de um item STAC.
    """
    crs, transform, expected_width, expected_height = get_spatial_reference_from_stac(
        stac_tile_id, stac_collection_id, stac_datetime_range, stac_asset_key
    )

    if not all([crs, transform, expected_width, expected_height]):
        print("Não foi possível obter a referência espacial do STAC. Abortando.")
        return

    print(f"Carregando dados de: {pkl_path}")
    with open(pkl_path, 'rb') as f:
        data_array = pickle.load(f)

    if not isinstance(data_array, np.ndarray):
        try:
            data_array = np.array(data_array, dtype=dtype_pkl)
        except Exception as e:
            print(f"Erro ao converter os dados carregados para NumPy array: {e}")
            return

    # Ajustar forma do array e verificar dimensões
    # Supondo que a resolução do .pkl DEVE corresponder à do asset STAC (4800x4800)
    current_height, current_width = data_array.shape[-2], data_array.shape[-1] # Lida com (H,W) ou (B,H,W)
    
    if data_array.ndim == 2: # Imagem de banda única (H, W)
        data_array = data_array.reshape(1, current_height, current_width)
        current_num_bands = 23
    elif data_array.ndim == 3: # (B, H, W) ou (H, W, B)
        if data_array.shape[0] == num_bands_pkl and data_array.shape[1] == current_height and data_array.shape[2] == current_width: # (B,H,W)
            current_num_bands = data_array.shape[0]
        elif data_array.shape[0] == current_height and data_array.shape[1] == current_width and data_array.shape[2] == num_bands_pkl: # (H,W,B)
            data_array = data_array.transpose(2, 0, 1) # Transforma para (B,H,W)
            current_num_bands = data_array.shape[0]
        else:
            print(f"Dimensões do array ({data_array.shape}) não correspondem ao número de bandas esperado ({num_bands_pkl}) e formato.")
            return
    else:
        print(f"Array com dimensões inesperadas: {data_array.ndim}")
        return

    if current_num_bands != num_bands_pkl:
        print(f"Aviso: Número de bandas no array PKL ({current_num_bands}) difere do esperado ({num_bands_pkl}). Usando {current_num_bands}.")
        # num_bands_pkl = current_num_bands # Ou pode decidir abortar

    if not (current_width == expected_width and current_height == expected_height):
        print(f"Erro: Dimensões do arquivo PKL ({current_height}x{current_width}) "
              f"não correspondem às dimensões do asset STAC ({expected_height}x{expected_width}).")
        print("Verifique se o pkl realmente tem 4800x4800 pixels e se a ordem (altura, largura) está correta.")
        return

    print(f"Forma do array PKL carregado e ajustado: {data_array.shape}, Tipo de dado: {data_array.dtype}")

    profile = {
        'driver': 'GTiff',
        'dtype': data_array.dtype,
        'nodata': None, # Defina se houver um valor nodata
        'width': expected_width,
        'height': expected_height,
        'count': current_num_bands, # Usa o número de bandas do array pkl
        'crs': crs,
        'transform': transform,
        'tiled': True,
        'blockxsize': block_size,
        'blockysize': block_size,
        'compress': compression,
        'interleave': 'pixel'
    }

    print(f"Criando COG em: {output_cog_path} com o perfil: {profile}")
    with rasterio.open(output_cog_path, 'w', **profile, BIGTIFF='YES') as dst:
        dst.write(data_array)

        if overviews_levels is None:
            max_dim = max(expected_width, expected_height)
            # Ajuste para evitar níveis de visão geral muito pequenos ou únicos
            overview_candidate_levels = [2**i for i in range(1, int(np.log2(max_dim / block_size)) + 1)]
            overviews_levels = [ov for ov in overview_candidate_levels if ov > 1 and max_dim / ov >= block_size / 2]


        if overviews_levels:
            print(f"Construindo overviews com fatores: {overviews_levels}")
            dst.build_overviews(overviews_levels, Resampling.average) # ou Resampling.nearest para dados categóricos
            dst.update_tags( overviews='yes') # Tag genérica para indicar overviews
        else:
            print("Nenhum overview será construído (ou níveis muito pequenos).")


    print(f"Arquivo COG '{output_cog_path}' criado com sucesso.")





def pkl_to_cog_filtered(pkl_path, output_cog_path,
                             stac_tile_id, stac_collection_id, stac_datetime_range,
                             time_to_select,
                             stac_asset_key='NDVI',
                             num_bands_pkl = 1, # Número de bandas esperado no arquivo PKL
                             dtype_pkl='float32', # Tipo de dado esperado no arquivo PKL
                             overviews_levels=None,
                             block_size=512, compression='DEFLATE'):
    """
    Converte dados de um arquivo .pkl para um Cloud Optimized GeoTIFF (COG),
    usando referência espacial de um item STAC.
    """
    crs, transform, expected_width, expected_height = get_spatial_reference_from_stac(
        stac_tile_id, stac_collection_id, stac_datetime_range, stac_asset_key
    )

    if not all([crs, transform, expected_width, expected_height]):
        print("Não foi possível obter a referência espacial do STAC. Abortando.")
        return

    print(f"Carregando dados de: {pkl_path}")
    with open(pkl_path, 'rb') as f:
        data_array = pickle.load(f)
        data_array = data_array[time_to_select,:,:]

    if not isinstance(data_array, np.ndarray):
        try:
            data_array = np.array(data_array, dtype=dtype_pkl)
        except Exception as e:
            print(f"Erro ao converter os dados carregados para NumPy array: {e}")
            return

    # Ajustar forma do array e verificar dimensões
    # Supondo que a resolução do .pkl DEVE corresponder à do asset STAC (4800x4800)
    current_height, current_width = data_array.shape[-2], data_array.shape[-1] # Lida com (H,W) ou (B,H,W)
    
    if data_array.ndim == 2: # Imagem de banda única (H, W)
        data_array = data_array.reshape(1, current_height, current_width)
        current_num_bands = 1
    elif data_array.ndim == 3: # (B, H, W) ou (H, W, B)
        if data_array.shape[0] == num_bands_pkl and data_array.shape[1] == current_height and data_array.shape[2] == current_width: # (B,H,W)
            current_num_bands = data_array.shape[0]
        elif data_array.shape[0] == current_height and data_array.shape[1] == current_width and data_array.shape[2] == num_bands_pkl: # (H,W,B)
            data_array = data_array.transpose(2, 0, 1) # Transforma para (B,H,W)
            current_num_bands = data_array.shape[0]
        else:
            print(f"Dimensões do array ({data_array.shape}) não correspondem ao número de bandas esperado ({num_bands_pkl}) e formato.")
            return
    else:
        print(f"Array com dimensões inesperadas: {data_array.ndim}")
        return

    if current_num_bands != num_bands_pkl:
        print(f"Aviso: Número de bandas no array PKL ({current_num_bands}) difere do esperado ({num_bands_pkl}). Usando {current_num_bands}.")
        # num_bands_pkl = current_num_bands # Ou pode decidir abortar

    if not (current_width == expected_width and current_height == expected_height):
        print(f"Erro: Dimensões do arquivo PKL ({current_height}x{current_width}) "
              f"não correspondem às dimensões do asset STAC ({expected_height}x{expected_width}).")
        print("Verifique se o pkl realmente tem 4800x4800 pixels e se a ordem (altura, largura) está correta.")
        return

    print(f"Forma do array PKL carregado e ajustado: {data_array.shape}, Tipo de dado: {data_array.dtype}")

    profile = {
        'driver': 'GTiff',
        'dtype': data_array.dtype,
        'nodata': None, # Defina se houver um valor nodata
        'width': expected_width,
        'height': expected_height,
        'count': current_num_bands, # Usa o número de bandas do array pkl
        'crs': crs,
        'transform': transform,
        'tiled': True,
        'blockxsize': block_size,
        'blockysize': block_size,
        'compress': compression,
        'interleave': 'pixel'
    }

    print(f"Criando COG em: {output_cog_path} com o perfil: {profile}")
    with rasterio.open(output_cog_path, 'w', **profile, BIGTIFF='YES') as dst:
        dst.write(data_array)

        if overviews_levels is None:
            max_dim = max(expected_width, expected_height)
            # Ajuste para evitar níveis de visão geral muito pequenos ou únicos
            overview_candidate_levels = [2**i for i in range(1, int(np.log2(max_dim / block_size)) + 1)]
            overviews_levels = [ov for ov in overview_candidate_levels if ov > 1 and max_dim / ov >= block_size / 2]


        if overviews_levels:
            print(f"Construindo overviews com fatores: {overviews_levels}")
            dst.build_overviews(overviews_levels, Resampling.average) # ou Resampling.nearest para dados categóricos
            dst.update_tags( overviews='yes') # Tag genérica para indicar overviews
        else:
            print("Nenhum overview será construído (ou níveis muito pequenos).")


    print(f"Arquivo COG '{output_cog_path}' criado com sucesso.")





