123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- package net.chenlin.dp.common.xss;
- import org.apache.commons.io.IOUtils;
- import org.apache.commons.lang.StringUtils;
- import org.springframework.http.HttpHeaders;
- import org.springframework.http.MediaType;
- import javax.servlet.ReadListener;
- import javax.servlet.ServletInputStream;
- import javax.servlet.http.HttpServletRequest;
- import javax.servlet.http.HttpServletRequestWrapper;
- import java.io.ByteArrayInputStream;
- import java.io.IOException;
- import java.util.LinkedHashMap;
- import java.util.Map;
- /**
- * XSS过滤处理
- * @author zcl<yczclcn@163.com>
- */
- public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {
- //没被包装过的HttpServletRequest(特殊场景,需要自己过滤)
- HttpServletRequest orgRequest;
- //html过滤
- private final static HTMLFilter htmlFilter = new HTMLFilter();
- public XssHttpServletRequestWrapper(HttpServletRequest request) {
- super(request);
- orgRequest = request;
- }
- @Override
- public ServletInputStream getInputStream() throws IOException {
- //非json类型,直接返回
- if(!super.getHeader(HttpHeaders.CONTENT_TYPE).equalsIgnoreCase(MediaType.APPLICATION_JSON_VALUE)){
- return super.getInputStream();
- }
- //为空,直接返回
- String json = IOUtils.toString(super.getInputStream(), "utf-8");
- if (StringUtils.isBlank(json)) {
- return super.getInputStream();
- }
- //xss过滤
- json = xssEncode(json);
- final ByteArrayInputStream bis = new ByteArrayInputStream(json.getBytes("utf-8"));
- return new ServletInputStream() {
- @Override
- public boolean isFinished() {
- return true;
- }
- @Override
- public boolean isReady() {
- return true;
- }
- @Override
- public void setReadListener(ReadListener readListener) {
- }
- @Override
- public int read() throws IOException {
- return bis.read();
- }
- };
- }
- @Override
- public String getParameter(String name) {
- String value = super.getParameter(xssEncode(name));
- if (StringUtils.isNotBlank(value)) {
- value = xssEncode(value);
- }
- return value;
- }
- @Override
- public String[] getParameterValues(String name) {
- String[] parameters = super.getParameterValues(name);
- if (parameters == null || parameters.length == 0) {
- return null;
- }
- for (int i = 0; i < parameters.length; i++) {
- parameters[i] = xssEncode(parameters[i]);
- }
- return parameters;
- }
- @Override
- public Map<String,String[]> getParameterMap() {
- Map<String,String[]> map = new LinkedHashMap<>();
- Map<String,String[]> parameters = super.getParameterMap();
- for (String key : parameters.keySet()) {
- String[] values = parameters.get(key);
- for (int i = 0; i < values.length; i++) {
- values[i] = xssEncode(values[i]);
- }
- map.put(key, values);
- }
- return map;
- }
- @Override
- public String getHeader(String name) {
- String value = super.getHeader(xssEncode(name));
- if (StringUtils.isNotBlank(value)) {
- value = xssEncode(value);
- }
- return value;
- }
- private String xssEncode(String input) {
- return htmlFilter.filter(input);
- }
- /**
- * 获取最原始的request
- */
- public HttpServletRequest getOrgRequest() {
- return orgRequest;
- }
- /**
- * 获取最原始的request
- */
- public static HttpServletRequest getOrgRequest(HttpServletRequest request) {
- if (request instanceof XssHttpServletRequestWrapper) {
- return ((XssHttpServletRequestWrapper) request).getOrgRequest();
- }
- return request;
- }
- }
|