package com.bokesoft.distro.tech.bootsupport.starter.wrapper;

import com.bokesoft.distro.tech.bootsupport.starter.execctl.ExecutionTimeoutManager;
import com.bokesoft.distro.tech.bootsupport.starter.execctl.impl.DefaultConnectionRecoveryFactory;
import com.bokesoft.distro.tech.bootsupport.starter.execctl.impl.recoverers.WarnOnlyConnectionRecoverer;
import com.bokesoft.distro.tech.bootsupport.starter.execctl.model.StartTimeObject;
import com.bokesoft.distro.tech.bootsupport.starter.execctl.recovery.IConnectionRecoverer;
import com.bokesoft.distro.tech.bootsupport.starter.execctl.recovery.IConnectionRecoveryFactory;
import com.bokesoft.distro.tech.bootsupport.starter.execctl.recovery.flags.IUnmanagedRecoverProcess;

import com.bokesoft.distro.tech.yigosupport.extension.intf.IJDBCCompatible;
import com.bokesoft.distro.tech.yigosupport.extension.utils.yigo.YigoDBHelper;
import com.bokesoft.yes.mid.connection.IQueryColumnMetaData;
import com.bokesoft.yes.mid.connection.dbmanager.BatchPsPara;
import com.bokesoft.yes.mid.connection.dbmanager.GeneralDBManager;
import com.bokesoft.yes.mid.connection.dbmanager.PsPara;
import com.bokesoft.yes.mid.connection.dbmanager.QueryArguments;
import com.bokesoft.yes.mid.dbcache.ICacheDBRequest;
import com.bokesoft.yes.tools.preparesql.PrepareSQL;
import com.bokesoft.yigo.common.trace.TraceSystemManager;
import com.bokesoft.yigo.common.trace.intf.ITraceSupplier;
import com.bokesoft.yigo.meta.dataobject.MetaTable;
import com.bokesoft.yigo.meta.schema.MetaSchemaColumn;
import com.bokesoft.yigo.meta.schema.MetaSchemaTable;
import com.bokesoft.yigo.mid.base.DefaultContext;
import com.bokesoft.yigo.mid.base.MidCoreException;
import com.bokesoft.yigo.mid.connection.DataBaseInfo;
import com.bokesoft.yigo.mid.connection.IDBManager;
import com.bokesoft.yigo.mid.util.ContextBuilder;
import com.bokesoft.yigo.struct.datatable.DataTable;
import com.bokesoft.yigo.tools.ve.VE;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.event.Level;

import java.sql.*;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;

public class WrappedDBManager implements IDBManager,IJDBCCompatible {
    private static final Logger LOGGER = LoggerFactory.getLogger(WrappedDBManager.class);

    private static final IConnectionRecoveryFactory DEFAUL_RECOVERY_FACTORY = new DefaultConnectionRecoveryFactory();

    private IDBManager inner;

    private boolean execTimeoutControlEnabled;
    private StartTimeObject startTimeObject;

    private Level logLevel = Level.DEBUG;

    private IConnectionRecoveryFactory recoveryFactory = DEFAUL_RECOVERY_FACTORY;
    private List<IConnectionRecoverer> connectionRecoverers;

    public WrappedDBManager(IDBManager inner, String connectionInstanceId, boolean execTimeoutControlEnabled) {
        this.inner = inner;

        this.execTimeoutControlEnabled = execTimeoutControlEnabled;
        if(execTimeoutControlEnabled) {
            //创建事务开始时间对象
            this.startTimeObject = StartTimeObject.buildTransactionStartTimeObject(connectionInstanceId);
            //添加开始时间对象
            ExecutionTimeoutManager.addStartTimeObject(startTimeObject);
        }
    }
    public void setLogLevel(Level logLevel){
        this.logLevel = logLevel;
    }
    public void setRecoveryFactory(IConnectionRecoveryFactory recoveryFactory) {
        this.recoveryFactory = recoveryFactory;
    }

    @Override
    public VE getVE() {
        return inner.getVE();
    }

    @Override
    public void setVE(VE ve) {
        inner.setVE(ve);
    }

    @Override
    public boolean checkViewExist(String s) throws Throwable {
        return inner.checkViewExist(s);
    }

    @Override
    public boolean checkTableExist(String s) throws Throwable {
        return inner.checkTableExist(s);
    }

    @Override
    public HashSet<String> getIndexSet(String s) throws Throwable {
        return inner.getIndexSet(s);
    }

    @Override
    public String searchIndex(String s) throws Throwable {
        return inner.searchIndex(s);
    }

    @Override
    public HashSet<String> getTableColumnSet(String s) throws Throwable {
        return inner.getTableColumnSet(s);
    }

    @Override
    public String getColumnDef(MetaSchemaColumn metaSchemaColumn) throws Throwable {

        return inner.getColumnDef(metaSchemaColumn);
    }

    @Override
    public String getAlterTableStr(MetaSchemaTable metaSchemaTable, List<MetaSchemaColumn> list) throws Throwable {
        return inner.getAlterTableStr(metaSchemaTable, list);
    }

    @Override
    public String keyWordEscape(String s) {
        return inner.keyWordEscape(s);
    }

    @Override
    public StringBuilder appendKeyWordEscape(StringBuilder stringBuilder, String s) {
        return inner.appendKeyWordEscape(stringBuilder,s);
    }

    @Override
    public int getDBType() {
        return inner.getDBType();
    }

    @Override
    public long getCurTime() throws Throwable {
        return inner.getCurTime();
    }

    @Override
    public int getTimezoneOffset() throws Throwable {
        return inner.getTimezoneOffset();
    }

    @Override
    public Object convert(Object o, int i, int i1) throws Throwable {
        return inner.convert(o, i, i1);
    }

    @Override
    public int convertDataType(int i) {
        return inner.convertDataType(i);
    }

    @Override
    public DataTable execPrepareQuery(String s, Object... objects) throws Throwable {
        return queryDataTableWithTrace(()-> inner.execPrepareQuery(s, objects));
    }

    @Override
    public DataTable execPrepareQuery(String s, IQueryColumnMetaData iQueryColumnMetaData, Object... objects) throws Throwable {
        return queryDataTableWithTrace(()-> inner.execPrepareQuery(s, iQueryColumnMetaData, objects));
    }

    @Override
    public DataTable execPrepareQuery(String s, List<Object> list) throws Throwable {
        return queryDataTableWithTrace(() -> inner.execPrepareQuery(s, list));
    }

    @Override
    public DataTable execPrepareQuery(String s, IQueryColumnMetaData iQueryColumnMetaData, List<Object> list) throws Throwable {
        return queryDataTableWithTrace(() -> inner.execPrepareQuery(s, iQueryColumnMetaData, list));
    }

    @Override
    public DataTable execPrepareQuery(String s, List<Integer> list, List<Object> list1) throws Throwable {
        return queryDataTableWithTrace(() -> inner.execPrepareQuery(s, list, list1));
    }

    @Override
    public DataTable execPrepareQuery(String s, IQueryColumnMetaData iQueryColumnMetaData, List<Integer> list, List<Object> list1) throws Throwable {
        return queryDataTableWithTrace(() -> inner.execPrepareQuery(s, iQueryColumnMetaData, list, list1));
    }

    @Override
    public DataTable execQuery(String s) throws Throwable {
        return queryDataTableWithTrace(() -> inner.execQuery(s));
    }

    @Override
    public DataTable execQuery(String s, IQueryColumnMetaData iQueryColumnMetaData) throws Throwable {
        return queryDataTableWithTrace(() -> inner.execQuery(s, iQueryColumnMetaData));
    }

    @Override
    public int execPrepareUpdate(String s, Object... objects) throws Throwable {
        return updateWithTrace(() -> inner.execPrepareUpdate(s, objects));
    }

    @Override
    public int execPrepareUpdate(String s, List<Object> list) throws Throwable {
        return updateWithTrace(() -> inner.execPrepareUpdate(s, list));
    }

    @Override
    public int execPrepareUpdate(String s, List<Integer> list, List<Object> list1) throws Throwable {
        return updateWithTrace(() -> inner.execPrepareUpdate(s, list, list1));
    }

    @Override
    public int execUpdate(String s) throws Throwable {
        return updateWithTrace(() -> inner.execUpdate(s));
    }

    @Override
    public Statement createUpdateStatement() throws Throwable {
        return inner.createQueryStatement();
    }

    @Override
    public Statement createQueryStatement() throws Throwable {
        return inner.createQueryStatement();
    }

    @Override
    public PreparedStatement preparedQueryStatement(String s) throws Throwable {
        return inner.preparedQueryStatement(s);
    }

    @Override
    public PreparedStatement preparedUpdateStatement(String s) throws Throwable {
        return inner.preparedUpdateStatement(s);
    }


    @Override
    public ResultSet executeQuery(PreparedStatement preparedStatement, String s, QueryArguments queryArguments) throws Throwable {
        return queryResultSetWithTrace(()-> inner.executeQuery(preparedStatement, s, queryArguments));
    }

    @Override
    public int executeUpdate(PreparedStatement preparedStatement, String s, QueryArguments queryArguments) throws Throwable {
        return updateWithTrace(() -> inner.executeUpdate(preparedStatement, s, queryArguments));
    }

    @Override
    public void executeUpdate(BatchPsPara batchPsPara) throws Throwable {
        updateWithTrace(() ->{ inner.executeUpdate(batchPsPara); return -1;});
    }

    @Override
    public ResultSet executeQuery(PsPara psPara, QueryArguments queryArguments) throws Throwable {
        return queryResultSetWithTrace(() -> inner.executeQuery(psPara, queryArguments));
    }

    @Override
    public int executeUpdate(PsPara psPara, QueryArguments queryArguments) throws Throwable {
        return updateWithTrace(() -> inner.executeUpdate(psPara, queryArguments));
    }

    @Override
    public String getConditionValue(int i, String s) {
        return inner.getConditionValue(i, s);
    }

    @Override
    public String getLikeConditionValue(String s, int i, String s1) {
        return inner.getLikeConditionValue(s, i, s1);
    }

    @Override
    public PrepareSQL getLimitString(String s, String s1, boolean b, int i, int i1) {
        return inner.getLimitString(s, s1, b, i, i1);
    }

    @Override
    public PrepareSQL getCountString(String s) {
        return inner.getCountString(s);
    }

    @Override
    public void rollback() throws SQLException {
        if(execTimeoutControlEnabled) {
            startTimeObject.setStartTime(System.currentTimeMillis());
        }
        inner.rollback();
    }

    @Override
    public void commit() throws SQLException {
        if(execTimeoutControlEnabled) {
            startTimeObject.setStartTime(System.currentTimeMillis());
        }
        inner.commit();
    }

    @Override
    public void close() throws SQLException {
        if(execTimeoutControlEnabled) {
            //清除事务开始时间对象
            ExecutionTimeoutManager.removeStartTimeObject(startTimeObject);
        }

        //如果执行过程中出错, 此处进行恢复处理
        boolean effected = applyRecovery();

        //无论如何都需要 close
        try {
            inner.close();
        }catch (Exception e){
            if (!effected){
                ExceptionUtils.rethrow(e);
            }
        }
    }

    @Override
    public void setParameter(PreparedStatement preparedStatement, int i, Object o, int i1) throws SQLException, MidCoreException {
        inner.setParameter(preparedStatement, i, o, i1);
    }

    @Override
    public void setRowLock(String s, String s1, Long aLong) throws Throwable {
        updateWithTrace(()->{
            inner.setRowLock(s, s1, aLong);
            return 0;
        });
    }

    @Override
    public void setRowLock(String s, String s1, QueryArguments queryArguments) throws Throwable {
        updateWithTrace(()->{
            inner.setRowLock(s, s1, queryArguments);
            return 0;
        });
    }

    @Override
    public void setRowLockEnsureInSYSLock(String s) throws Throwable {
        inner.setRowLockEnsureInSYSLock(s);
    }

    @Override
    public IDBManager getNewDBManager() throws Throwable {
        return inner.getNewDBManager();
    }

    @Override
    public void initDataBaseInfo(DataBaseInfo dataBaseInfo) throws Throwable {
        inner.initDataBaseInfo(dataBaseInfo);
    }

    @Override
    public boolean saveDataTable(DataTable dataTable, String s, MetaTable metaTable, Object o) throws Throwable {
        return inner.saveDataTable(dataTable,s,metaTable,o);
    }

    @Override
    public boolean saveDataTable(DataTable dataTable, String s, MetaTable metaTable) throws Throwable {
        return inner.saveDataTable(dataTable, s, metaTable);
    }

    @Override
    public String getDBName() throws Throwable {
        return inner.getDBName();
    }

    @Override
    public String getViewExistCheckSql() {
        return inner.getViewExistCheckSql();
    }

    @Override
    public String getViewStructSql() {
        return inner.getViewStructSql();
    }

    @Override
    public String getTableExistCheckSql() {
        return inner.getTableExistCheckSql();
    }

    @Override
    public String getTableStructSql() {
        return inner.getTableStructSql();
    }

    @Override
    public String getColumnCheckSql() {
        return inner.getColumnCheckSql();
    }

    @Override
    public String getColumnStructSql() {
        return inner.getColumnStructSql();
    }

    @Override
    public String getIndexStructSql() {
        return inner.getIndexStructSql();
    }

    @Override
    public String getIndexSearchSql() {
        return inner.getIndexSearchSql();
    }

    @Override
    public String getIndexCheckSql() {
        return inner.getIndexCheckSql();
    }

    @Override
    public int[] executeUpdateReturn(BatchPsPara batchPsPara) throws Throwable {
        return inner.executeUpdateReturn(batchPsPara);
    }

    @Override
    public int getTransactionID() {
        return inner.getTransactionID();
    }

    @Override
    public Statement createJDBCStatement() throws SQLException {
        return inner.createJDBCStatement();
    }

    @Override
    public Statement createJDBCPrepareStatement(String s) throws SQLException {
        return inner.createJDBCPrepareStatement(s);
    }

    @Override
    public void setKey(String s) {
        inner.setKey(s);
    }

    @Override
    public String getKey() {
        return inner.getKey();
    }

    private DataTable queryDataTableWithTrace(ITraceSupplier<DataTable> supplier) throws SQLException {
        try {
            return TraceSystemManager.withTraceLog(() -> {
                    return supplier.get();
            }, this, true, LOGGER, logLevel);
        }catch (Exception e){
            prepareRecovery(e);
            return ExceptionUtils.rethrow(e);
        }
    }

    private ResultSet queryResultSetWithTrace(ITraceSupplier<ResultSet> dbSupplier) throws SQLException {
        try {
            return TraceSystemManager.withTraceLog(()->{
                return dbSupplier.get();
            },this,true, LOGGER, logLevel);
        }catch (Exception e){
            prepareRecovery(e);
            return ExceptionUtils.rethrow(e);
        }
    }

    private Integer updateWithTrace(ITraceSupplier<Integer> dbSupplier) throws SQLException {
        try{
            return TraceSystemManager.withTraceLog(()->{
                    return dbSupplier.get();
            },this,true, LOGGER, logLevel);
        }catch (Exception e){
            prepareRecovery(e);
            return ExceptionUtils.rethrow(e);
        }
    }

    private static Connection getRawConnection(IDBManager dbm){
        try {
            if (dbm instanceof WrappedDBManager){
                dbm = ((WrappedDBManager)dbm).inner;
            }
            if (dbm instanceof GeneralDBManager){
                //读取 GeneralDBManager#connection
                Connection conn = (Connection) FieldUtils.readField(dbm,"connection",true);
                return getRawConnection(conn);
            }else{
                throw new UnsupportedOperationException(
                    "无法获取 DBManager '"+dbm.getClass().getName()+"' 的 connection 对象");
            }
        } catch (IllegalAccessException e) {
            return ExceptionUtils.rethrow(e);
        }
    }
    private static Connection getRawConnection(Connection conn){
        if (conn instanceof WrappedConnection){
            WrappedConnection wrConn = (WrappedConnection)conn;
            Connection innerConn = wrConn.getInnerConnection();
            return getRawConnection(innerConn);
        }else{
            return conn;
        }
    }
    private boolean prepareRecovery(Exception e) throws SQLException {
        //目前只有 SQLTimeoutException 需要考虑 Recovery
        if(e instanceof SQLTimeoutException) {
            Connection conn = getRawConnection(this);
            if(!conn.isClosed()) {
                IConnectionRecoverer r = this.recoveryFactory.buildRecoverer(this.getDBType(), e);
                r.prepare(conn);
                appendConnectionRecoverers(r);

                LOGGER.error("发现 SQL 执行超时, 将使用 " + r.getClass().getName() + " 进行恢复处理", e);

                return true;
            }
        }
        return false;
    }
    private void appendConnectionRecoverers(IConnectionRecoverer r) {
        // 考虑到 Connection 单线程执行的特点, 使用同步控制应该没有问题
        synchronized (this) {
            if (null==this.connectionRecoverers){
                this.connectionRecoverers = new ArrayList<>();
            }
            this.connectionRecoverers.add(r);
        }
    }
    private boolean applyRecovery() {
        // 考虑到 Connection 单线程执行的特点, 使用同步控制应该没有问题
        synchronized (this) {
            if (null==this.connectionRecoverers){
                return false;
            }

            boolean managed = false;
            //检查是否需要针对 "Managed" 的 Recoverer 准备新的数据库连接
            for (IConnectionRecoverer cr: this.connectionRecoverers){
                if (!(cr instanceof IUnmanagedRecoverProcess)){
                    managed = true;
                    break;
                }
            }

            if (managed){
                return doManagedApplyRecovery();
            }else{
                return doApplyRecovery(null);
            }
        }
    }
    private boolean doManagedApplyRecovery() {
        DefaultContext newCtx = null;
        boolean effected = false;
        String connInfo = "未创建";
        try {
            newCtx = ContextBuilder.create();

            Connection newConn = getRawConnection(newCtx.getDBManager());
            connInfo = WarnOnlyConnectionRecoverer.buildConnectionInfo(newConn);

            LOGGER.warn("数据库恢复处理过程(Connection=["+connInfo+"])开始 ...");
            effected = doApplyRecovery(newConn);
            LOGGER.warn("数据库恢复处理过程(Connection=["+connInfo+"])完成.");

            // newCtx.commit(); - 如果需要 commit, 必须在 Recoverer 自己完成

            return effected;
        } catch (Throwable ex) {
            if (null != newCtx) {
                try { newCtx.rollback(); } catch (Throwable e) { /* Ignore it */ }
            }
            LOGGER.error("数据库恢复处理过程(Connection=["+connInfo+"])出错: " + ex.getMessage(), ex);
            return effected;
        } finally {
            if (null != newCtx) {
                try { newCtx.close(); } catch (Throwable e) { /* Ignore it */ }
            }
        }
    }
    private boolean doApplyRecovery(Connection newConn){
        // 考虑到 Connection 单线程执行的特点, 使用同步控制应该没有问题
        synchronized (this) {
            if (null==this.connectionRecoverers || this.connectionRecoverers.isEmpty()){
                return false;
            }
            for (IConnectionRecoverer cr: this.connectionRecoverers){
                String recovererName = cr.getClass().getName();
                try{
                    cr.recover(newConn);
                    LOGGER.warn("数据库恢复处理 {} 成功完成.", recovererName);
                }catch(Exception ex){
                    LOGGER.error("数据库恢复处理 "+recovererName+" 失败: " + ex.getMessage(), ex);
                }
            }
            return true;
        }
    }

    @Override
    public Connection getConnection() throws SQLException {
        return YigoDBHelper.getConnection(inner);
    }

    @Override
    public void begin() throws Throwable {
        inner.begin();
    }

    @Override
    public int getMainVersion() throws Throwable {
        return inner.getMainVersion();
    }

    @Override
    public ICacheDBRequest getCacheDBRequest() {
        return inner.getCacheDBRequest();
    }
}
