#! /usr/bin/env python

# Copyright (C) 2024- The University of Notre Dame
# This software is distributed under the GNU General Public License.
# See the file COPYING for details.

import sys
import matplotlib.pyplot as plt
import ast
import os
import argparse
import copy

from matplotlib.animation import FuncAnimation
from functools import partial

MICROSECONDS = 1000000
GIGABYTES = 1000000000

def parse_debug(log):
    lines = open(log, 'r').read().splitlines()
    manager_transfers = 0
    transfer_info = []
    file_sizes = {}
    worker_file_info = {}
    transfer_addresses = {}
    # FORMAT [FROM_ADDR, TO_ADDR, CACHENAME, SIZE(ADDED LATER)]
    for line in lines:
        if 'tx to' in line and ' file ' in line:
            date, time, manager, vine, tx, to, machine, addr, opt, cachename, size, mode = line.split(maxsplit=11)
            size = float(size)
            addr = addr.strip('):')
            addr = addr.strip('(')
            transfer_info.append(['manager', addr, cachename])
            if cachename not in file_sizes:
                file_sizes[cachename] = []
            file_sizes[cachename].append(size/GIGABYTES)
        if 'rx from' in line and ' file ' in line:
            try:
                date, time, manager, vine, rx, frm, machine, addr, opt, cachename, size, mode = line.split(maxsplit=11)
            # WQ fallback
            except:
                date, time, manager, vine, rx, frm, machine, addr, opt, cachename, size = line.split(maxsplit=10)
            size = float(size)
            addr = addr.strip('):')
            addr = addr.strip('(')
            transfer_info.append([addr, 'manager', cachename])
            if cachename not in file_sizes:
                file_sizes[cachename] = []
            file_sizes[cachename].append(size/GIGABYTES)
        if 'will get' in line:
            date, time, manager, vine, machine, addr, will, get, filename, fr, url, path = line.split(maxsplit=12)
            path = path.strip('workerip://')
            from_addr, cachenmame = path.split('/') 
            addr = addr.strip('):')
            addr = addr.strip('(')
            from_addr = transfer_addresses[from_addr]
            transfer_info.append([from_addr, addr, cachename])
        if "cache-update" in line:
            date, time, manager, vine, rx, frm, machine, addr, op, cachename, size, x, y, z = line.split(maxsplit=13)
            size = float(size)
            if cachename not in file_sizes:
                file_sizes[cachename] = []
            addr = addr.strip('):')
            addr = addr.strip('(')
            if addr not in worker_file_info:
                worker_file_info[addr] = {'files':{}}
            worker_file_info[addr]['files'][cachename] = 0
            file_sizes[cachename].append(size/GIGABYTES)
        if "transfer-port" in line:
            date, time, manager, vine, rx, frm, machine, addr, transfer_address, tr_addr = line.split(maxsplit=9)
            addr = addr.strip('):')
            addr = addr.strip('(')
            tr_addr = addr.split(':')[0] + ':' + tr_addr
            transfer_addresses[tr_addr] = addr

    info = []
    for transfer in transfer_info:
        transfer.append(max(file_sizes[transfer[2]]))
        info.append(transfer)
    transfer_info = info

    for worker in worker_file_info:
        for cachename in worker_file_info[worker]['files']:
            worker_file_info[worker]['files'][cachename] = max(file_sizes[cachename])
    ids = {}
    # assign ids to workers + manager
    count = 0 
    ids["manager"] = 0
    for transfer in transfer_info:
        if transfer[0] not in ids:
            count +=1
            ids[transfer[0]] = count
        if transfer[1] not in ids:
            count +=1
            ids[transfer[1]] = count
    for worker in worker_file_info:
        if worker not in ids:
            count +=1
            ids[worker] = count
    
    log_info = {}
    log_info['ids'] = ids
    log_info['manager_info'] = {'log':log, 'type':'debug'}
    log_info['transfer_info'] = transfer_info
    log_info['worker_file_info'] = worker_file_info
    return log_info

def plot_data_exg(log_info, axs, frames, args):

    transfer_info = log_info['transfer_info']
    
    ids = log_info['ids']
    count = len(ids)
    
    captured_frames = []
    C = []
    for x in range(count):
        C_in = []
        for y in range(count):
            C_in.append(0)
        C.append(C_in)

    count = 1
    captured_frames.append(copy.deepcopy(C))    
    for transfer in transfer_info:
        count += 1
        from_id = ids[transfer[0]]
        to_id = ids[transfer[1]]
        size = transfer[3]
        C[from_id][to_id] += size
        if count in frames:
            captured_frames.append(copy.deepcopy(C))

    return captured_frames

def data_exg_frame(frame, captured_frames, axs, cbar):
    im = axs.pcolormesh(captured_frames[frame])

    cbar.update_normal(im)
    axs.set_ylabel('Source(ID)')
   # axs.set_xlim(right=20)
   # axs.set_ylim(top=20)
    axs.set_xlabel('Destination(ID)')
    cbar.ax.set_ylabel("Exchanged Data (GB)")

fig, axs = plt.subplots(1, 1, squeeze=False)
log = sys.argv[1]
log_info = parse_debug(log)

fps = 5 
seconds = 20
frame_count = len(log_info['transfer_info']) + 1
frames = [x for x in range(0, frame_count, int(frame_count/fps/seconds))]
captured_frames = plot_data_exg(log_info, axs[0][0], frames, 0)

im = axs[0][0].pcolormesh(captured_frames[0])
cbar = fig.colorbar(im, ax=axs[0][0])
anim = FuncAnimation(fig, partial(data_exg_frame, captured_frames=captured_frames, axs=axs[0][0], cbar=cbar), len(captured_frames))
anim.save('anim.gif')


