/*
 * Copyright (c) 2021, 2026 Contributors to the Eclipse Foundation
 *
 * This program and the accompanying materials are made
 * available under the terms of the Eclipse Public License 2.0
 * which is available at https://www.eclipse.org/legal/epl-2.0/
 *
 * SPDX-License-Identifier: EPL-2.0
 */
package org.eclipse.lsat.scheduler;

import static org.slf4j.LoggerFactory.getLogger;

import java.math.BigDecimal;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map.Entry;

import org.eclipse.core.runtime.IProgressMonitor;
import org.eclipse.lsat.common.graph.directed.Aspect;
import org.eclipse.lsat.common.graph.directed.DirectedGraphFactory;
import org.eclipse.lsat.common.scheduler.graph.Task;
import org.eclipse.lsat.common.scheduler.schedule.Schedule;
import org.eclipse.lsat.common.scheduler.schedule.ScheduledDependency;
import org.eclipse.lsat.common.scheduler.schedule.ScheduledTask;
import org.slf4j.Logger;

public class CriticalPathAnalysis<T extends Task> {
    private static final Logger LOGGER = getLogger(CriticalPathAnalysis.class);

    private static final BigDecimal EPSILON = new BigDecimal("1e-10");

    public static final String CRITICAL = "Critical";

    private static class PathStatistics {
        BigDecimal sMin;

        BigDecimal sMax;
    }

    @SuppressWarnings("unchecked")
    public Schedule<T> transformModel(Schedule<T> input, IProgressMonitor monitor) {
        LOGGER.trace("Starting critical path analysis");
        List<ScheduledTask<T>> tasks = input.allNodesInTopologicalOrder();
        var stats = new LinkedHashMap<ScheduledTask<T>, PathStatistics>();
        tasks.forEach(task -> stats.put(task, new PathStatistics()));

        // calculate the earliest start time per task (sMin)
        for (var entry: stats.entrySet()) {
            var t = entry.getKey();
            var cp = entry.getValue();
            BigDecimal resourceStart = t.getSequence().getResource().getStart();
            cp.sMin = t.getIncomingEdges().stream().map(e -> (ScheduledTask<T>)e.getSourceNode())
                    .map(sourceNode -> stats.get(sourceNode).sMin.add(sourceNode.getDuration())).reduce(BigDecimal::max)
                    .orElse(resourceStart);
        }

        BigDecimal virtualEndStartTime = tasks.stream().map(ScheduledTask::getEndTime).reduce(BigDecimal::max)
                .orElse(BigDecimal.ZERO);

        // calculate the latest start time per task (sMax)
        for (ScheduledTask<T> t: tasks.reversed()) {
            var cp = stats.get(t);
            BigDecimal endFirst = t.getOutgoingEdges().stream().map(e -> (ScheduledTask<T>)e.getTargetNode())
                    .map(targetNode -> stats.get(targetNode).sMax).reduce(BigDecimal::min).orElse(virtualEndStartTime);
            cp.sMax = endFirst.subtract(t.getDuration());
        }

        // get the tasks that are on the critical path
        List<ScheduledTask<T>> criticalPathTasks = stats.entrySet().stream()
                .filter(e -> almostEqual(e.getValue().sMin, e.getValue().sMax)).map(Entry::getKey).toList();

        Aspect<ScheduledTask<T>, ScheduledDependency> criticalAspect = DirectedGraphFactory.eINSTANCE.createAspect();
        criticalAspect.setName(CRITICAL);
        criticalAspect.getNodes().addAll(criticalPathTasks);
        input.getAspects().add(criticalAspect);

        // Mark edges between critical tasks as critical
        input.getEdges().forEach(e -> {
            var edgeAspects = new LinkedHashSet<>(e.getSourceNode().getAspects());
            edgeAspects.retainAll(e.getTargetNode().getAspects());
            e.getAspects().addAll(edgeAspects);
        });

        LOGGER.trace("Finished critical path analysis");
        return input;
    }

    /** try to be tolerant for small rounding issues */
    public static boolean almostEqual(BigDecimal a, BigDecimal b) {
        return a.subtract(b).abs().compareTo(EPSILON) < 0;
    }
}
