package com.uccc.commons.logging;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.nio.charset.Charset;

/**
 *
 * 日志过滤器
 * Created by kidbei on 2016/12/14.
 */
public class CommonLoggingFilter implements Filter {

    private final Logger log = LoggerFactory.getLogger(CommonLoggingFilter.class);
    private long    ignoreBodySize = LoggerWrapper.IGNORE_BODY_SIZE;
    private boolean merge = false;
    private boolean totalTime = false;

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        String ignoreBodySizeStr = filterConfig.getInitParameter("ignoreBodySize");
        if (ignoreBodySizeStr != null && !ignoreBodySizeStr.trim().isEmpty()) {
            ignoreBodySize = Long.valueOf(ignoreBodySizeStr);
        }
        String mergeString = filterConfig.getInitParameter("mergeRequestAndResponse");
        if (mergeString != null && !mergeString.trim().isEmpty()) {
            merge = Boolean.valueOf(mergeString);
        }
        String totalTimeString = filterConfig.getInitParameter("totalTime");
        if (totalTimeString != null && !totalTimeString.trim().isEmpty()) {
            totalTime = Boolean.valueOf(totalTimeString);
        }
    }


    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
            throws IOException, ServletException {
        long last = System.currentTimeMillis();
        HttpServletRequest  request = (HttpServletRequest) servletRequest;
        HttpServletResponse response = (HttpServletResponse) servletResponse;

        String  method = request.getMethod();
        String  query = request.getQueryString();
        String  path = request.getRequestURI() + (query != null ? ("?" + query) : "");

        boolean info = log.isInfoEnabled();
        if (info) {
            servletRequest = new RequestWrapper(request,ignoreBodySize);
            servletResponse = new ResponseWrapper(response,ignoreBodySize);
        }
        if (!merge) {
            logRequest((RequestWrapper) servletRequest,method,path);
        }
        try {
            filterChain.doFilter(servletRequest, servletResponse);
        } finally {
            if (info) {
                try{
                    if (!merge) {
                        logResponse((ResponseWrapper) servletResponse,method,path);
                    } else {
                        logMergedRequestAndResponse((RequestWrapper)servletRequest,(ResponseWrapper)servletResponse,method,path);
                    }
                    if (totalTime) {
                        logTotalTime((RequestWrapper) servletRequest,method,path,last);
                    }
                } catch(Exception e) {
                    log.error("Print log error",e);
                }
            }
        }
    }

    private void logRequest(RequestWrapper request,String method,String path) {
        StringBuilder requestMsg = new StringBuilder();
        String  bodyString = getRequestBodyString(request);

        requestMsg.append("Request ").append(method).append(" ").append(path).append(" for ").append(bodyString);

        log.info(requestMsg.toString());
    }

    private String getRequestBodyString(RequestWrapper request) {
        if (isBinaryContent(request) || isMultipart(request) || request.isBiggerThanIgnoreBody()) {
            return "{}";
        }

        byte[] bodyArray = request.toByteArray();
        if (bodyArray == null || bodyArray.length == 0) {
            return "{}";
        }

        String  bodyString = new String(bodyArray, Charset.forName("utf-8"));

        return bodyString;
    }

    private void logResponse(ResponseWrapper response,String method,String path) {

        StringBuilder responseMsg = new StringBuilder();
        String  bodyString = getResponseBodyString(response);

        responseMsg.append("Response ").append(method).append(" ").append(path).append(" to ").append(bodyString);

        log.info(responseMsg.toString());
    }

    private String getResponseBodyString(ResponseWrapper response) {
        byte[] bodyArray = response.toByteArray();
        if (bodyArray == null || bodyArray.length == 0 || response.isBiggerThanIgnoreBody()) {
            return "{}";
        }

        return new String(bodyArray, Charset.forName("utf-8"));
    }

    private void logMergedRequestAndResponse(RequestWrapper request,ResponseWrapper response,String method,String path) {
        String requestBodyString = getRequestBodyString(request);
        String responseBodyString = getResponseBodyString(response);

        StringBuilder sb = new StringBuilder();
        sb.append("Response ").append(method).append(" ").append(path).append(" for ").append(requestBodyString)
                .append(" to ").append(responseBodyString);

        log.info(sb.toString());
    }

    private void logTotalTime(RequestWrapper request,String method,String path,long last) {
        String requestBodyString = getRequestBodyString(request);
        StringBuilder totalTimeMsg = new StringBuilder();

        long total = System.currentTimeMillis()-last;
        totalTimeMsg.append("TotalTime ").append(total).append(" ms ").append("Request ")
                .append(method).append(" ").append(path).append(" for ").append(requestBodyString);

        log.info(totalTimeMsg.toString());
    }


    private boolean isBinaryContent(final HttpServletRequest request) {
        if (request.getContentType() == null) {
            return false;
        }
        return request.getContentType().startsWith("image") || request.getContentType().startsWith("video") || request.getContentType().startsWith("audio");
    }

    private boolean isMultipart(final HttpServletRequest request) {
        return request.getContentType() != null && request.getContentType().startsWith("multipart/form-data");
    }


    @Override
    public void destroy() {

    }
}
