#
# This file is part of INPE/ENANDES Project - Processing Module.
# Copyright (C) 2024 INPE.
#
# enandes-processing is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.
#
import pandas as pd
import argparse
import os
from datetime import datetime

import numpy as np
import xarray as xr
from osgeo import gdal

from enandes_processing.constants import LAT_LON_WGS84
from enandes_processing.utils import getGeoT

def mcwd2tif(input, output):
    # Open file
    dataset = xr.open_dataset(input)

    # Compute extent
    llx = dataset.longitude.values[0]
    lly = dataset.latitude.values[0]
    urx = dataset.longitude.values[-1] 
    ury = dataset.latitude.values[-1]
    extent = [llx, lly, urx, ury]

    # Extract variable names
    variablesList = dataset.variables
    array_varname = "mcwd_anomaly"
    var_data = dataset[array_varname].values
    var_data = np.flipud(var_data)



     # Get dimensions
    nlines = var_data.shape[0]
    ncols = var_data.shape[1]

    # Create output dir
    os.makedirs(output, exist_ok=True) 

    # It convert each band to GeoTiff format
    for t in range(len(dataset.time.values)):
        print('Converting {}'.format(dataset.time.values[t]))
        dt_now = datetime.now().strftime("%Y%m%d")

        date_pandas = pd.to_datetime(dataset.time.values[t])
        ultimo_dia = date_pandas + pd.offsets.MonthEnd(0)
        part_name = ultimo_dia.strftime('%Y1231')
 
        #J File_gtiff = f"{input.split('/')[-1].split('.')[0]}_{part_name}.tiff"

        file_name = os.path.splitext(os.path.basename(input))[0]
        parts_input = file_name.split("_")
        index_mcwd = parts_input.index("mcwd")
        initial_part = "MERGE_mcwd" + "_".join(parts_input[index_mcwd + 1:])
        
        File_gtiff = initial_part + "_" + part_name + ".tiff"

        array_2d = var_data[:,:,t] # Reading each time!
   
        # Convert to GeoTiff
        driver = gdal.GetDriverByName('GTiff')
        grid = driver.Create(output + '/' + File_gtiff, ncols, nlines, 
            1, gdal.GDT_Float32, options=['COMPRESS=LZW'])
        grid.GetRasterBand(1).SetNoDataValue(0.0)
        grid.GetRasterBand(1).Fill(0.0)

        # Setup projection and geo-transformation
        grid.SetProjection(LAT_LON_WGS84.ExportToWkt())
        grid.SetGeoTransform(getGeoT(extent, grid.RasterYSize, grid.RasterXSize))

        # Write MCWD data to GeoTiff and save!
        grid.GetRasterBand(1).WriteArray(array_2d)
        grid = None

def main():
    # Create command-line parser
    parser = argparse.ArgumentParser(description='mcwd2tif - Convert MCWD/MERGE product to GeoTiff files.')

    # Create command-line parameters
    #> Input file
    parser.add_argument('--input', '-i', help='Path to ACWD/MERGE file that will be converted',
        type=str, dest='input', required=True)
    #> Output file
    parser.add_argument('--output', '-o', help='Path to output directory where GeoTiff files will be generated', dest='output', required=True)

    # Parse input
    args = parser.parse_args()
    
    # Convert!
    mcwd2tif(args.input, args.output)

if __name__ == '__main__':
    main()
    