import os
import re
import requests
from PIL import Image

from multiprocessing import Pool

from region_config import region_config

# Download image
PARALLEL_MODE = False
res_path = "res/"


def download_frames_from_server():
    print("Downloading images")
    url = "https://wxdata-us-central1.opntele.com/res-raw-asdklfaj842h78tg67asfhgakjfk/"
    text_content = requests.get(url).text
    regex = r"\"twc_(.*)-(\d+)_8.png\""

    matches = re.finditer(regex, text_content, re.MULTILINE)
    filenames = []

    for _, match in enumerate(matches, start=1):
        filename = match.group().replace('"', "")
        filenames.append(filename)

    # Delete files from /res that are not in the list
    for file in os.listdir("res"):
        if file not in filenames:
            if file.startswith("prod_"):
                continue

            print("Deleting", file)
            os.remove(f"res/{file}")

    for filename in filenames:
        # If file already exists, skip
        if os.path.isfile(f"res/{filename}"):
            continue

        # Download file
        request_url = f"{url}{filename}"
        r = requests.get(request_url, allow_redirects=True)
        open(f"res/{filename}", "wb").write(r.content)


def find_image_pairs():
    # Find matching pairs
    sat_pattern = re.compile(r"twc_sat-(\d+)_8.png")
    rad_pattern = re.compile(r"twc_twcRadarMosaic-(\d+)_8.png")

    sat_files = [file for file in os.listdir(res_path) if sat_pattern.match(file)]
    rad_files = [file for file in os.listdir(res_path) if rad_pattern.match(file)]

    matching_pairs = []

    for sat_file in sat_files:
        sat_match = sat_pattern.match(sat_file)
        sat_number = sat_match.group(1)

        for rad_file in rad_files:
            rad_match = rad_pattern.match(rad_file)
            rad_number = rad_match.group(1)

            if sat_number == rad_number:
                matching_pairs.append({"sat": sat_file, "rad": rad_file})
                break  # Break to avoid unnecessary iterations once a match is found

    return matching_pairs


# WaterBG
water_bg_image = Image.open("data/static/slides/water_bg_640.png")

# Basemap
basemap_image = Image.open("data/static/slides/final2.png")

# Outline
outline_image = Image.open("data/static/slides/outlines.png")

# Radar Overlay
radar_overlay = Image.open("data/static/slides/forecast_radar-addition.png")


def process_frame(frame_pair):
    timestamp = frame_pair["sat"].split("-")[1].split("_")[0]

    need_to_process = False
    for output_config_entry in region_config:
        outname = output_config_entry["name"]
        if not os.path.isfile(f"temp/output_{outname}_{timestamp}.png"):
            need_to_process = True
            break

    if not need_to_process:
        print("Skipping rad/sat since we already have it", timestamp)
        return timestamp

    print("Processing", frame_pair)
    sat_image_path = f"{res_path}/{frame_pair['sat']}"
    rad_image_path = f"{res_path}/{frame_pair['rad']}"

    # Process satellite image
    sat_image = Image.open(sat_image_path)
    # # Create an image mask to remove the land's heat from the image
    # sat_gray_image = ImageOps.grayscale(sat_image)
    # sat_threshold = 130
    # sat_image_mask = sat_gray_image.point(lambda p: p > sat_threshold and 255)
    # # Use the mask on the original image
    # sat_image = Image.composite(sat_image, Image.new('RGBA', sat_image.size, (0,0,0,0)), sat_image_mask)

    # Process radar image
    rad_image = Image.open(rad_image_path)
    
    try:
        # Create a new transparent black image with the same size
        rad_shadow_image = Image.new("RGBA", rad_image.size)
        # Paste the original image onto the transparent black image with an offset
        rad_shadow_image.paste(rad_image, (3, 6))
        # Make it all black
        rad_shadow_image_data = rad_shadow_image.getdata()
        rad_shadow_image_new_data = []
        for item in rad_shadow_image_data:
            if item[3] == 0:
                rad_shadow_image_new_data.append((255, 255, 255, 0))
            else:
                # Opacity 100%
                rad_shadow_image_new_data.append((0, 0, 0, int(256 * 1.0)))
        rad_shadow_image.putdata(rad_shadow_image_new_data)

        # Overlay
        output = Image.alpha_composite(basemap_image, sat_image)
        output = Image.alpha_composite(output, rad_shadow_image)
        output = Image.alpha_composite(output, rad_image)
        output = Image.alpha_composite(output, outline_image)

        # Save
        # output.save(f"temp/output_{i}.png")

        print("Starting to create cuts for", timestamp)
        for output_config_entry in region_config:
            # Crop the source image with the given coordinates
            cropped = output.crop(output_config_entry["crop"])

            # Resize the cropped image to 640x480
            cropped = cropped.resize((640, 480), Image.Resampling.LANCZOS)

            # Paste the cropped image into the output image
            img = Image.alpha_composite(water_bg_image, cropped)
            img = Image.alpha_composite(img, radar_overlay)

            # Save the output image
            outname = output_config_entry["name"]
            img.save(f"temp/output_{outname}_{timestamp}.png")
            print("Done saving", outname, timestamp)

        return timestamp
    except OSError as e:
        sat_image.close()
        rad_image.close()
        os.remove(rad_image_path)
        os.remove(sat_image_path)
        raise e


def process_radar_frames(matching_pairs):
    if PARALLEL_MODE:
        print("Running in parallel")
        tasks = []
        with Pool(4) as p:
            for frame_pair in matching_pairs:
                tasks.append(p.apply_async(process_frame, args=(frame_pair,)))
            p.close()
            p.join()
            [result.wait() for result in tasks]
            results = [result.get() for result in tasks]
    else:
        print("Running in 1 process")
        results = []
        for frame_pair in matching_pairs:
            results.append(process_frame(frame_pair))

    return results


def cleanup_current_output_frames(current_timestamps):
    print("cleaning up output frames")
    for file in os.listdir("temp"):
        # Don't delete the 3day or hourly frames
        if file.startswith("3day_"):
            continue
        if file.startswith("hourly_"):
            continue

        if file.startswith("output_"):
            timestamp = file.split("_")[2].split(".")[0]
            if timestamp not in current_timestamps:
                print("Deleting", file)
                os.remove(f"temp/{file}")


def process_rad_all():
    download_frames_from_server()
    matching_pairs = find_image_pairs()
    process_radar_frames(matching_pairs)
    current_timestamps = [
        frame_pair["sat"].split("-")[1].split("_")[0] for frame_pair in matching_pairs
    ]
    cleanup_current_output_frames(current_timestamps)

    return sorted(current_timestamps)


if __name__ == "__main__":
    process_rad_all()
