HttpApiCachingFilter 와 CachedHttpServletRequestWrapper
- json과 multipart 외 form-urlencoded의 경우도 허용하면서 logging 정보 남기는 filter 구현 예시
- request 의 inputStream을 한번 읽으면 휘발되기 때문에 이를 방지하는 로직이 중요합니다.
HttpApiCachingFilter
@Component
@Order(value = Ordered.HIGHEST_PRECEDENCE)
@WebFilter(filterName = "HttpApiCachingFilter", urlPatterns = "/*")
@Slf4j
public class HttpApiCachingFilter extends OncePerRequestFilter {
private static String APPLICATION_NAME;
@Value("${spring.application.name}")
private void setApplicationName(String value) {
APPLICATION_NAME = value;
}
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
if (isAsyncDispatch(request)) {
filterChain.doFilter(request, response);
} else if (isMultipartRequest(request)) {
doFilterWrapped(new StandardMultipartHttpServletRequest(request), new CachedHttpServletResponseWrapper(response), filterChain);
} else {
doFilterWrapped(new CachedHttpServletRequestWrapper(request), new CachedHttpServletResponseWrapper(response), filterChain);
}
}
protected void doFilterWrapped(HttpServletRequestWrapper request, ContentCachingResponseWrapper response, FilterChain filterChain)
throws ServletException, IOException {
try {
logRequestHeader(request);
logRequest(request);
filterChain.doFilter(request, response);
} finally {
logResponseHeader(response);
logResponse(response);
response.copyBodyToResponse();
}
}
private static void logRequestHeader(HttpServletRequestWrapper request) {
Enumeration<String> headerNames = request.getHeaderNames();
if (headerNames != null) {
Collections.list(headerNames).forEach(headersName -> {
String headerValue = request.getHeader(headersName);
log.info("Request Header {}: {}", headersName, headerValue);
});
}
}
private static void logRequest(HttpServletRequestWrapper request) throws IOException {
String contentType = request.getContentType();
log.info("Request body: {} uri=[{}] content-type=[{}]", request.getMethod(),
request.getRequestURI(), contentType);
if (contentType != null && contentType.startsWith(MediaType.APPLICATION_FORM_URLENCODED_VALUE)) {
Map<String, String[]> paramMap = request.getParameterMap();
paramMap.forEach((key, values) -> {
for (String value : values) {
log.info("REQUEST Param: {} = {}", key, value);
}
});
} else {
logPayload("REQUEST", contentType, request.getInputStream(), request.getRequestURI());
}
}
private static void logResponseHeader(ContentCachingResponseWrapper response) {
Collection<String> headerNames = response.getHeaderNames();
for (String headerName : headerNames) {
for (String value : response.getHeaders(headerName)) {
log.info("Response Header {}: {}", headerName, value);
}
}
}
private static void logResponse(ContentCachingResponseWrapper response) throws IOException {
logPayload("RESPONSE", response.getContentType(), response.getContentInputStream(), null);
}
private static void logPayload(String direction,
String contentType,
InputStream inputStream,
String targetUri
) throws IOException {
boolean visible = isVisible(MediaType.valueOf(contentType == null ? "application/json" : contentType));
if (visible) {
byte[] content = StreamUtils.copyToByteArray(inputStream);
if (content.length > 0) {
String contentString = new String(content);
log.info("{} Payload: {}", direction, contentString);
}
} else {
log.info("{} Payload: Binary Content", direction);
}
}
private static boolean isVisible(MediaType mediaType) {
final List<MediaType> VISIBLE_TYPES = Arrays.asList(MediaType.valueOf("text/*"),
MediaType.APPLICATION_FORM_URLENCODED,
MediaType.APPLICATION_JSON,
MediaType.APPLICATION_XML,
MediaType.valueOf("application/*+json"),
MediaType.valueOf("application/*+xml"),
MediaType.MULTIPART_FORM_DATA);
return VISIBLE_TYPES.stream().anyMatch(visibleType -> visibleType.includes(mediaType));
}
private boolean isMultipartRequest(HttpServletRequest request) {
String contentType = request.getContentType();
return request.getMethod().equalsIgnoreCase("POST")
&& contentType != null
&& contentType.startsWith("multipart/form-data");
}
}
CachedHttpServletRequestWrapper.java
public class CachedHttpServletRequestWrapper extends HttpServletRequestWrapper {
private final byte[] cachedBody;
private Map<String, String[]> parameterMap;
public CachedHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
super(request);
this.cachedBody = request.getInputStream().readAllBytes();
// form-urlencoded 타입일 때만 파라미터 파싱
String contentType = request.getContentType();
if (contentType != null && contentType.startsWith(MediaType.APPLICATION_FORM_URLENCODED_VALUE)) {
String encoding = request.getCharacterEncoding();
if (encoding == null) {
encoding = StandardCharsets.UTF_8.name();
}
this.parameterMap = parseFormParameters(new String(this.cachedBody, encoding));
} else {
// 원본 request의 파라미터 맵 사용
this.parameterMap = super.getParameterMap();
}
}
@Override
public ServletInputStream getInputStream() {
ByteArrayInputStream bais = new ByteArrayInputStream(cachedBody);
return new ServletInputStream() {
@Override
public int read() {
return bais.read();
}
@Override
public boolean isFinished() {
return bais.available() == 0;
}
@Override
public boolean isReady() {
return true;
}
@Override
public void setReadListener(ReadListener listener) {
}
};
}
@Override
public Map<String, String[]> getParameterMap() {
return this.parameterMap;
}
@Override
public BufferedReader getReader() {
String encoding = getCharacterEncoding();
if (encoding == null) {
encoding = StandardCharsets.UTF_8.name();
}
return new BufferedReader(new InputStreamReader(getInputStream(), Charset.forName(encoding)));
}
private Map<String, String[]> parseFormParameters(String body) throws UnsupportedEncodingException {
Map<String, List<String>> tempMap = new LinkedHashMap<>();
// 빈 바디 처리
if (body == null || body.trim().isEmpty()) {
return new LinkedHashMap<>();
}
for (String pair : body.split("&")) {
if (pair.trim().isEmpty()) {
continue; // 빈 pair 건너뛰기
}
String[] parts = pair.split("=", 2);
if (parts.length == 0) {
continue;
}
String key = URLDecoder.decode(parts[0], StandardCharsets.UTF_8);
String value = parts.length > 1 ? URLDecoder.decode(parts[1], StandardCharsets.UTF_8) : "";
tempMap.computeIfAbsent(key, k -> new ArrayList<>()).add(value);
}
return tempMap.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().toArray(new String[0])));
}
}
CachedHttpServletResponseWrapper.java
public class CachedHttpServletResponseWrapper extends ContentCachingResponseWrapper {
public CachedHttpServletResponseWrapper(HttpServletResponse response) {
super(response);
}
}