package com.bokesoft.distro.tech.commons.basis.coordinate.impl;

import com.bokesoft.distro.tech.commons.basis.coordinate.struct.Semaphore;
import com.bokesoft.distro.tech.commons.basis.coordinate.struct.ThrottleConfig;

import java.time.Duration;
import java.time.LocalDateTime;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

/**
 * 简单节流过滤器
 * 按照时间间隔合并信号量
 */
public class SimpleSemaphoreThrottle {


    /**
     * 出现连续信号量时,
     * 合并逻辑的最大间隔,
     * 也是最后一个信号的最大延迟时间
     */
    final Duration step;
    /**
     * 合并连续信号量时,必须发送一次信号的最大累计数量
     */
    final int maxCount;
    /**
     * 合并连续信号量时,必须发送一次信号的最大间隔时间
     */
    final Duration maxStep;
    final ScheduledExecutorService executor;


    /**
     * 最后一次进入的时间
     */
    private LocalDateTime lastTime;

    /**
     * 从最后一次通过的信号开始的计数
     */
    private int count;
    /**
     * 最后一次通过的时间
     */
    private LocalDateTime lastPostTime;

    private Semaphore nextSemaphore;
    private Consumer<Semaphore> nextAction;
    private ScheduledFuture nextFuture;

    public SimpleSemaphoreThrottle(ThrottleConfig config) {
        this(config.getStep(), config.getMaxCount(), config.getMaxStep(), config.getExecutor());
    }

    public SimpleSemaphoreThrottle(Duration step, int maxCount, Duration maxStep, ScheduledExecutorService executor) {
        this.step = step;
        this.maxCount = maxCount;
        this.maxStep = maxStep;
        this.executor = executor;
    }


    public synchronized void onFilter(Semaphore semaphore, Consumer<Semaphore> nextAction) {
        LocalDateTime prevTime = this.lastTime;
        lastTime = LocalDateTime.now();
        if (nextFuture != null) {//只要接收到信号,那么上一次延迟的信号就会取消
            nextFuture.cancel(false);
        }
        if (lastPostTime == null) {//第一次接收信号
            postRun(semaphore, nextAction);
            return;
        }
        if (lastPostTime.plus(maxStep).isBefore(semaphore.time)) {//下一次信号量超过最大间隔
            postRun(semaphore, nextAction);
            return;
        }
        if (++count >= maxCount) {//累计合并的信号量超过最大阈值
            postRun(semaphore, nextAction);
            return;
        }
        if (prevTime.plus(step).isBefore(lastTime)) {//间隔时间大于阈值
            postRun(semaphore, nextAction);
            return;
        }
        //注册延迟发送
        this.nextAction = nextAction;
        this.nextSemaphore = semaphore;
        this.nextFuture = executor.schedule(this::delayRun, step.toMillis(), TimeUnit.MILLISECONDS);

    }

    private void postRun(Semaphore semaphore, Consumer<Semaphore> nextAction) {
        this.count = 0;
        this.lastPostTime = LocalDateTime.now();
        this.nextSemaphore = null;
        this.nextAction = null;
        nextAction.accept(semaphore);

    }

    private synchronized void delayRun() {
        if (nextFuture.isCancelled() || nextAction == null || nextSemaphore == null) {
            return;
        }
        nextAction.accept(nextSemaphore);
    }

}
