I have a simple REST controller that I use for accepting a file being uploaded from a HTML form. The project is Spring Boot 2.6.1 and Java 17. But the problem was also to be found in Spring Boot 2.3.7 and Java 15.
@PostMapping(path = "/file", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
public void handleFileUpload(@RequestParam("file") MultipartFile file) {
fileService.upload(file.getInputStream(), file.getOriginalFilename());
}
The problem is file
is always NULL. I found a lot of different answers about setting a MultipartResolver
bean or enabling spring.http.multipart.enabled = true
but nothing helped. I have a logging filter as one of the first filters in the chain. After debugging in the filter chain I found out that making a call to request.getParts()
made everything work. My filter look like this:
public class LoggingFilter extends GenericFilterBean {
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest httpServletRequest = (HttpServletRequest) request;
BufferedRequestWrapper bufferedRequest = new BufferedRequestWrapper(httpServletRequest);
BufferedResponseWrapper bufferedResponse = new BufferedResponseWrapper((HttpServletResponse) response);
filterChain.doFilter(bufferedRequest, bufferedResponse);
logRequest(httpServletRequest, bufferedRequest);
logResponse(httpServletRequest, bufferedResponse);
}
I changed the filter to:
public class LoggingFilter extends GenericFilterBean {
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest httpServletRequest = (HttpServletRequest) request;
if (request.getContentType() != null && request.getContentType().startsWith("multipart/form-data")) {
httpServletRequest.getParts(); // Trigger initialization of multi-part.
}
BufferedRequestWrapper bufferedRequest = new BufferedRequestWrapper(httpServletRequest);
BufferedResponseWrapper bufferedResponse = new BufferedResponseWrapper((HttpServletResponse) response);
filterChain.doFilter(bufferedRequest, bufferedResponse);
logRequest(httpServletRequest, bufferedRequest);
logResponse(httpServletRequest, bufferedResponse);
}
and everything was working. My question is; why is this needed? And is there a better way of doing this?
Below is a complete example where only the actual logging is removed because we use a custom logging framework.
package com.unwire.ticketing.filter.logging;
import lombok.Getter;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.output.TeeOutputStream;
import org.springframework.web.filter.GenericFilterBean;
import javax.servlet.*;
import javax.servlet.http.*;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.Locale;
import java.util.stream.Collectors;
public class Log extends GenericFilterBean {
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest httpServletRequest = (HttpServletRequest) request;
if (request.getContentType() != null && request.getContentType().startsWith("multipart/form-data")) {
httpServletRequest.getParts(); // Trigger initialization of multi-part.
}
try {
BufferedRequestWrapper bufferedRequest = new BufferedRequestWrapper(httpServletRequest);
BufferedResponseWrapper bufferedResponse = new BufferedResponseWrapper((HttpServletResponse) response);
filterChain.doFilter(bufferedRequest, bufferedResponse);
logRequest(httpServletRequest, bufferedRequest);
logResponse(httpServletRequest, bufferedResponse);
} catch (Throwable t) {
}
}
private void logRequest(HttpServletRequest request, BufferedRequestWrapper bufferedRequest) throws IOException {
String body = bufferedRequest.getRequestBody();
// Log request
}
private void logResponse(HttpServletRequest httpServletRequest, BufferedResponseWrapper bufferedResponse) {
// Log response
}
private static final class BufferedRequestWrapper extends HttpServletRequestWrapper {
private final byte[] buffer;
BufferedRequestWrapper(HttpServletRequest req) throws IOException {
super(req);
if (req.getContentType() == null || (req.getContentType() != null && !req.getContentType().startsWith("application/x-www-form-urlencoded"))) {
// Read InputStream and store its content in a buffer.
InputStream is = req.getInputStream();
ByteArrayOutputStream baos = new ByteArrayOutputStream();
byte[] buf = new byte[1024];
int read;
while ((read = is.read(buf)) > 0) {
baos.write(buf, 0, read);
}
this.buffer = baos.toByteArray();
} else {
buffer = new byte[0];
}
}
@Override
public ServletInputStream getInputStream() {
return new BufferedServletInputStream(new ByteArrayInputStream(this.buffer));
}
@Override
public Collection<Part> getParts() throws IOException, ServletException {
return super.getParts();
}
String getRequestBody() throws IOException {
return IOUtils.readLines(this.getInputStream(), StandardCharsets.UTF_8.name()).stream()
.map(String::trim)
.collect(Collectors.joining());
}
}
private static final class BufferedServletInputStream extends ServletInputStream {
private final ByteArrayInputStream bais;
BufferedServletInputStream(ByteArrayInputStream bais) {
this.bais = bais;
}
@Override
public int available() {
return this.bais.available();
}
@Override
public int read() {
return this.bais.read();
}
@Override
public int read(byte[] buf, int off, int len) {
return this.bais.read(buf, off, len);
}
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return true;
}
@Override
public void setReadListener(ReadListener readListener) {
}
}
public static class TeeServletOutputStream extends ServletOutputStream {
private final TeeOutputStream targetStream;
TeeServletOutputStream(OutputStream one, OutputStream two) {
targetStream = new TeeOutputStream(one, two);
}
@Override
public void write(int arg0) throws IOException {
this.targetStream.write(arg0);
}
public void flush() throws IOException {
super.flush();
this.targetStream.flush();
}
public void close() throws IOException {
super.close();
this.targetStream.close();
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setWriteListener(WriteListener writeListener) {
}
}
public class BufferedResponseWrapper implements HttpServletResponse {
HttpServletResponse original;
TeeServletOutputStream tee;
ByteArrayOutputStream bos;
@Getter
Long startTime;
BufferedResponseWrapper(HttpServletResponse response) {
this.original = response;
this.startTime = System.currentTimeMillis();
}
String getContent() {
if (bos != null) {
return bos.toString();
} else {
return "";
}
}
@Override
public PrintWriter getWriter() throws IOException {
return original.getWriter();
}
@Override
public ServletOutputStream getOutputStream() throws IOException {
if (tee == null) {
bos = new ByteArrayOutputStream();
tee = new TeeServletOutputStream(original.getOutputStream(), bos);
}
return tee;
}
@Override
public String getCharacterEncoding() {
return original.getCharacterEncoding();
}
@Override
public void setCharacterEncoding(String charset) {
original.setCharacterEncoding(charset);
}
@Override
public String getContentType() {
return original.getContentType();
}
@Override
public void setContentType(String type) {
original.setContentType(type);
}
@Override
public void setContentLength(int len) {
original.setContentLength(len);
}
@Override
public void setContentLengthLong(long l) {
original.setContentLengthLong(l);
}
@Override
public int getBufferSize() {
return original.getBufferSize();
}
@Override
public void setBufferSize(int size) {
original.setBufferSize(size);
}
@Override
public void flushBuffer() throws IOException {
if (tee != null) {
tee.flush();
}
}
@Override
public void resetBuffer() {
original.resetBuffer();
}
@Override
public boolean isCommitted() {
return original.isCommitted();
}
@Override
public void reset() {
original.reset();
}
@Override
public Locale getLocale() {
return original.getLocale();
}
@Override
public void setLocale(Locale loc) {
original.setLocale(loc);
}
@Override
public void addCookie(Cookie cookie) {
original.addCookie(cookie);
}
@Override
public boolean containsHeader(String name) {
return original.containsHeader(name);
}
@Override
public String encodeURL(String url) {
return original.encodeURL(url);
}
@Override
public String encodeRedirectURL(String url) {
return original.encodeRedirectURL(url);
}
@Override
public void sendError(int sc, String msg) throws IOException {
original.sendError(sc, msg);
}
@Override
public void sendError(int sc) throws IOException {
original.sendError(sc);
}
@Override
public void sendRedirect(String location) throws IOException {
original.sendRedirect(location);
}
@Override
public void setDateHeader(String name, long date) {
original.setDateHeader(name, date);
}
@Override
public void addDateHeader(String name, long date) {
original.addDateHeader(name, date);
}
@Override
public void setHeader(String name, String value) {
original.setHeader(name, value);
}
@Override
public void addHeader(String name, String value) {
original.addHeader(name, value);
}
@Override
public void setIntHeader(String name, int value) {
original.setIntHeader(name, value);
}
@Override
public void addIntHeader(String name, int value) {
original.addIntHeader(name, value);
}
@Override
public String getHeader(String arg0) {
return original.getHeader(arg0);
}
@Override
public Collection<String> getHeaderNames() {
return original.getHeaderNames();
}
@Override
public Collection<String> getHeaders(String arg0) {
return original.getHeaders(arg0);
}
@Override
public int getStatus() {
return original.getStatus();
}
@Override
public void setStatus(int sc) {
original.setStatus(sc);
}
}
}
Please consider using ContentCachingRequestWrapper.
It's built-in of spring which help you can read caches all content read from the input stream and reader.
Be aware, with multipart file, spring already have a wrapper ... MultipartHttpServletRequest