1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  package ch.qos.logback.core.net;
15  
16  import java.io.IOException;
17  import java.io.InputStream;
18  import java.io.InvalidClassException;
19  import java.io.ObjectInputFilter;
20  import java.io.ObjectInputStream;
21  import java.io.ObjectStreamClass;
22  import java.util.ArrayList;
23  import java.util.List;
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  public class HardenedObjectInputStream extends ObjectInputStream {
39  
40      final private List<String> whitelistedClassNames;
41      final private static String[] JAVA_PACKAGES = new String[] { "java.lang", "java.util" };
42      final private static int DEPTH_LIMIT = 16;
43      final private static int ARRAY_LIMIT = 10000;
44  
45      public HardenedObjectInputStream(InputStream in, String[] whitelist) throws IOException {
46          super(in);
47          this.initObjectFilter();
48          this.whitelistedClassNames = new ArrayList<String>();
49          if (whitelist != null) {
50              for (int i = 0; i < whitelist.length; i++) {
51                  this.whitelistedClassNames.add(whitelist[i]);
52              }
53          }
54      }
55  
56      private void initObjectFilter() {
57          this.setObjectInputFilter(ObjectInputFilter.Config.createFilter(
58                  "maxarray=" + ARRAY_LIMIT + ";maxdepth=" + DEPTH_LIMIT + ";"
59          ));
60      }
61      public HardenedObjectInputStream(InputStream in, List<String> whitelist) throws IOException {
62          super(in);
63          this.initObjectFilter();
64          this.whitelistedClassNames = new ArrayList<String>();
65          this.whitelistedClassNames.addAll(whitelist);
66      }
67  
68      @Override
69      protected Class<?> resolveClass(ObjectStreamClass anObjectStreamClass) throws IOException, ClassNotFoundException {
70  
71          String incomingClassName = anObjectStreamClass.getName();
72  
73          if (!isWhitelisted(incomingClassName)) {
74              throw new InvalidClassException("Unauthorized deserialization attempt", anObjectStreamClass.getName());
75          }
76  
77          return super.resolveClass(anObjectStreamClass);
78      }
79  
80      private boolean isWhitelisted(String incomingClassName) {
81          for (int i = 0; i < JAVA_PACKAGES.length; i++) {
82              if (incomingClassName.startsWith(JAVA_PACKAGES[i]))
83                  return true;
84          }
85          for (String whiteListed : whitelistedClassNames) {
86              if (incomingClassName.equals(whiteListed))
87                  return true;
88          }
89          return false;
90      }
91  
92      protected void addToWhitelist(List<String> additionalAuthorizedClasses) {
93          whitelistedClassNames.addAll(additionalAuthorizedClasses);
94      }
95  }