/*
 * Decompiled with CFR 0.152.
 */
package org.apache.seatunnel.engine.server.task;

import com.hazelcast.cluster.Address;
import com.hazelcast.spi.impl.AbstractInvocationFuture;
import com.hazelcast.spi.impl.operationservice.Operation;
import java.io.IOException;
import java.io.Serializable;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.ExecutionException;
import java.util.function.Function;
import java.util.stream.Collectors;
import lombok.NonNull;
import org.apache.seatunnel.api.serialization.Serializer;
import org.apache.seatunnel.api.source.SourceEvent;
import org.apache.seatunnel.api.source.SourceSplit;
import org.apache.seatunnel.api.source.SourceSplitEnumerator;
import org.apache.seatunnel.engine.common.utils.ExceptionUtil;
import org.apache.seatunnel.engine.core.dag.actions.SourceAction;
import org.apache.seatunnel.engine.server.checkpoint.ActionStateKey;
import org.apache.seatunnel.engine.server.checkpoint.ActionSubtaskState;
import org.apache.seatunnel.engine.server.checkpoint.CheckpointBarrier;
import org.apache.seatunnel.engine.server.checkpoint.operation.TaskAcknowledgeOperation;
import org.apache.seatunnel.engine.server.execution.ProgressState;
import org.apache.seatunnel.engine.server.execution.TaskLocation;
import org.apache.seatunnel.engine.server.task.CoordinatorTask;
import org.apache.seatunnel.engine.server.task.context.SeaTunnelSplitEnumeratorContext;
import org.apache.seatunnel.engine.server.task.operation.checkpoint.BarrierFlowOperation;
import org.apache.seatunnel.engine.server.task.operation.source.LastCheckpointNotifyOperation;
import org.apache.seatunnel.engine.server.task.record.Barrier;
import org.apache.seatunnel.engine.server.task.statemachine.SeaTunnelTaskState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SourceSplitEnumeratorTask<SplitT extends SourceSplit>
extends CoordinatorTask {
    private static final Logger log = LoggerFactory.getLogger(SourceSplitEnumeratorTask.class);
    private static final long serialVersionUID = -3713701594297977775L;
    private final SourceAction<?, SplitT, Serializable> source;
    private SourceSplitEnumerator<SplitT, Serializable> enumerator;
    private SeaTunnelSplitEnumeratorContext<SplitT> enumeratorContext;
    private Serializer<Serializable> enumeratorStateSerializer;
    private int maxReaderSize;
    private Set<Long> unfinishedReaders;
    private Map<TaskLocation, Address> taskMemberMapping;
    private Map<Long, TaskLocation> taskIDToTaskLocationMapping;
    private Map<Integer, TaskLocation> taskIndexToTaskLocationMapping;
    private volatile SeaTunnelTaskState currState;
    private volatile boolean readerRegisterComplete;
    private volatile boolean prepareCloseTriggered;

    @Override
    public void init() throws Exception {
        this.currState = SeaTunnelTaskState.INIT;
        super.init();
        this.readerRegisterComplete = false;
        log.info("starting seatunnel source split enumerator task, source name: " + this.source.getName());
        this.enumeratorContext = new SeaTunnelSplitEnumeratorContext(this.source.getParallelism(), this, this.getMetricsContext());
        this.enumeratorStateSerializer = this.source.getSource().getEnumeratorStateSerializer();
        this.taskMemberMapping = new ConcurrentHashMap<TaskLocation, Address>();
        this.taskIDToTaskLocationMapping = new ConcurrentHashMap<Long, TaskLocation>();
        this.taskIndexToTaskLocationMapping = new ConcurrentHashMap<Integer, TaskLocation>();
        this.maxReaderSize = this.source.getParallelism();
        this.unfinishedReaders = new CopyOnWriteArraySet<Long>();
    }

    @Override
    public void close() throws IOException {
        super.close();
        if (this.enumerator != null) {
            this.enumerator.close();
        }
        this.progress.done();
    }

    public SourceSplitEnumeratorTask(long jobID, TaskLocation taskID, SourceAction<?, SplitT, ?> source) {
        super(jobID, taskID);
        this.source = source;
        this.currState = SeaTunnelTaskState.CREATED;
    }

    @Override
    @NonNull
    public ProgressState call() throws Exception {
        this.stateProcess();
        return this.progress.toState();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void triggerBarrier(Barrier barrier) throws Exception {
        log.debug("split enumer trigger barrier [{}]", (Object)barrier);
        if (barrier.prepareClose()) {
            this.prepareCloseTriggered = true;
            this.prepareCloseBarrierId.set(barrier.getId());
        }
        long barrierId = barrier.getId();
        Serializable snapshotState = null;
        byte[] serialize = null;
        SeaTunnelSplitEnumeratorContext<SplitT> seaTunnelSplitEnumeratorContext = this.enumeratorContext;
        synchronized (seaTunnelSplitEnumeratorContext) {
            if (barrier.snapshot()) {
                snapshotState = this.enumerator.snapshotState(barrierId);
                serialize = this.enumeratorStateSerializer.serialize(snapshotState);
            }
            log.debug("source split enumerator send state [{}] to master", (Object)snapshotState);
            this.sendToAllReader(location -> new BarrierFlowOperation(barrier, (TaskLocation)location));
        }
        if (barrier.snapshot()) {
            this.getExecutionContext().sendToMaster(new TaskAcknowledgeOperation(this.taskLocation, (CheckpointBarrier)barrier, Collections.singletonList(new ActionSubtaskState(ActionStateKey.of(this.source), -1, Collections.singletonList(serialize))))).join();
        }
    }

    @Override
    public void restoreState(List<ActionSubtaskState> actionStateList) throws Exception {
        log.debug("restoreState for split enumerator [{}]", (Object)actionStateList);
        Optional<Serializable> state = actionStateList.stream().map(ActionSubtaskState::getState).flatMap(Collection::stream).filter(Objects::nonNull).map(bytes -> ExceptionUtil.sneaky(() -> this.enumeratorStateSerializer.deserialize((byte[])bytes))).findFirst();
        this.enumerator = state.isPresent() ? this.source.getSource().restoreEnumerator(this.enumeratorContext, state.get()) : this.source.getSource().createEnumerator(this.enumeratorContext);
        this.restoreComplete.complete(null);
        log.debug("restoreState split enumerator [{}] finished", (Object)actionStateList);
    }

    public void addSplitsBack(List<SplitT> splits, int subtaskId) throws ExecutionException, InterruptedException {
        this.getEnumerator().addSplitsBack(splits, subtaskId);
    }

    public void receivedReader(TaskLocation readerId, Address memberAddr) throws InterruptedException, ExecutionException {
        log.info("received reader register, readerID: " + readerId);
        SourceSplitEnumerator<SplitT, Serializable> enumerator = this.getEnumerator();
        this.addTaskMemberMapping(readerId, memberAddr);
        enumerator.registerReader(readerId.getTaskIndex());
        int taskSize = this.taskMemberMapping.size();
        if (this.maxReaderSize == taskSize) {
            this.readerRegisterComplete = true;
            log.debug(String.format("reader register complete, current task size %d", taskSize));
        } else {
            log.debug(String.format("current task size %d, need size %d to complete register", taskSize, this.maxReaderSize));
        }
    }

    public void requestSplit(long taskIndex) throws ExecutionException, InterruptedException {
        this.getEnumerator().handleSplitRequest((int)taskIndex);
    }

    public void handleSourceEvent(int subtaskId, SourceEvent sourceEvent) throws ExecutionException, InterruptedException {
        this.getEnumerator().handleSourceEvent(subtaskId, sourceEvent);
    }

    public void addTaskMemberMapping(TaskLocation taskID, Address memberAdder) {
        this.taskMemberMapping.put(taskID, memberAdder);
        this.taskIDToTaskLocationMapping.put(taskID.getTaskID(), taskID);
        this.taskIndexToTaskLocationMapping.put(taskID.getTaskIndex(), taskID);
        this.unfinishedReaders.add(taskID.getTaskID());
    }

    public Address getTaskMemberAddress(long taskID) {
        return this.taskMemberMapping.get(this.taskIDToTaskLocationMapping.get(taskID));
    }

    public TaskLocation getTaskMemberLocation(long taskID) {
        return this.taskIDToTaskLocationMapping.get(taskID);
    }

    public Address getTaskMemberAddressByIndex(int taskIndex) {
        return this.taskMemberMapping.get(this.taskIndexToTaskLocationMapping.get(taskIndex));
    }

    public TaskLocation getTaskMemberLocationByIndex(int taskIndex) {
        return this.taskIndexToTaskLocationMapping.get(taskIndex);
    }

    private SourceSplitEnumerator<SplitT, Serializable> getEnumerator() throws InterruptedException, ExecutionException {
        while (null == this.restoreComplete) {
            log.warn("Task init is not complete, try to get it again after 200 ms");
            Thread.sleep(200L);
        }
        this.restoreComplete.get();
        return this.enumerator;
    }

    public void readerFinished(long taskID) {
        this.unfinishedReaders.remove(taskID);
        if (this.unfinishedReaders.isEmpty()) {
            this.prepareCloseStatus = true;
        }
    }

    private void stateProcess() throws Exception {
        switch (this.currState) {
            case INIT: {
                this.currState = SeaTunnelTaskState.WAITING_RESTORE;
                this.reportTaskStatus(SeaTunnelTaskState.WAITING_RESTORE);
                break;
            }
            case WAITING_RESTORE: {
                if (this.restoreComplete.isDone()) {
                    this.currState = SeaTunnelTaskState.READY_START;
                    this.reportTaskStatus(SeaTunnelTaskState.READY_START);
                    break;
                }
                Thread.sleep(100L);
                break;
            }
            case READY_START: {
                if (this.startCalled && this.readerRegisterComplete) {
                    this.currState = SeaTunnelTaskState.STARTING;
                    this.enumerator.open();
                    break;
                }
                Thread.sleep(100L);
                break;
            }
            case STARTING: {
                this.currState = SeaTunnelTaskState.RUNNING;
                log.info("received enough reader, starting enumerator...");
                this.enumerator.run();
                break;
            }
            case RUNNING: {
                if (this.prepareCloseStatus) {
                    this.getExecutionContext().sendToMaster(new LastCheckpointNotifyOperation(this.jobID, this.taskLocation));
                    this.currState = SeaTunnelTaskState.PREPARE_CLOSE;
                    break;
                }
                if (this.prepareCloseTriggered) {
                    this.currState = SeaTunnelTaskState.PREPARE_CLOSE;
                    break;
                }
                Thread.sleep(100L);
                break;
            }
            case PREPARE_CLOSE: {
                if (this.closeCalled) {
                    this.currState = SeaTunnelTaskState.CLOSED;
                    break;
                }
                Thread.sleep(100L);
                break;
            }
            case CLOSED: {
                this.close();
                return;
            }
            case CANCELLING: {
                this.close();
                this.currState = SeaTunnelTaskState.CANCELED;
                return;
            }
            default: {
                throw new IllegalArgumentException("Unknown Enumerator State: " + this.currState);
            }
        }
    }

    public Set<Integer> getRegisteredReaders() {
        return this.taskMemberMapping.keySet().stream().map(TaskLocation::getTaskIndex).collect(Collectors.toSet());
    }

    private void sendToAllReader(Function<TaskLocation, Operation> function) {
        ArrayList futures = new ArrayList();
        this.taskMemberMapping.forEach((location, address) -> {
            log.debug("split enumerator send to read--size: {}, location: {}, address: {}", this.taskMemberMapping.size(), location, address.toString());
            futures.add(this.getExecutionContext().sendToMember((Operation)function.apply((TaskLocation)location), (Address)address));
        });
        futures.forEach(AbstractInvocationFuture::join);
    }

    @Override
    public Set<URL> getJarsUrl() {
        return new HashSet<URL>(this.source.getJarUrls());
    }

    @Override
    public void notifyCheckpointComplete(long checkpointId) throws Exception {
        this.getEnumerator().notifyCheckpointComplete(checkpointId);
        if (this.prepareCloseBarrierId.get() == checkpointId) {
            this.closeCall();
        }
    }

    @Override
    public void notifyCheckpointAborted(long checkpointId) throws Exception {
        this.getEnumerator().notifyCheckpointAborted(checkpointId);
        if (this.prepareCloseBarrierId.get() == checkpointId) {
            this.closeCall();
        }
    }
}

