1
2
3
4
5
6
7
8
9
10
11
12
13
14 package ch.qos.logback.core.net;
15
16 import java.io.IOException;
17 import java.io.InputStream;
18 import java.io.InvalidClassException;
19 import java.io.ObjectInputStream;
20 import java.io.ObjectStreamClass;
21 import java.util.ArrayList;
22 import java.util.List;
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37 public class HardenedObjectInputStream extends ObjectInputStream {
38
39 final List<String> whitelistedClassNames;
40 final static String[] JAVA_PACKAGES = new String[] { "java.lang", "java.util" };
41
42 public HardenedObjectInputStream(InputStream in, String[] whilelist) throws IOException {
43 super(in);
44
45 this.whitelistedClassNames = new ArrayList<String>();
46 if (whilelist != null) {
47 for (int i = 0; i < whilelist.length; i++) {
48 this.whitelistedClassNames.add(whilelist[i]);
49 }
50 }
51 }
52
53 public HardenedObjectInputStream(InputStream in, List<String> whitelist) throws IOException {
54 super(in);
55
56 this.whitelistedClassNames = new ArrayList<String>();
57 this.whitelistedClassNames.addAll(whitelist);
58 }
59
60 @Override
61 protected Class<?> resolveClass(ObjectStreamClass anObjectStreamClass) throws IOException, ClassNotFoundException {
62
63 String incomingClassName = anObjectStreamClass.getName();
64
65 if (!isWhitelisted(incomingClassName)) {
66 throw new InvalidClassException("Unauthorized deserialization attempt", anObjectStreamClass.getName());
67 }
68
69 return super.resolveClass(anObjectStreamClass);
70 }
71
72 private boolean isWhitelisted(String incomingClassName) {
73 for (int i = 0; i < JAVA_PACKAGES.length; i++) {
74 if (incomingClassName.startsWith(JAVA_PACKAGES[i]))
75 return true;
76 }
77 for (String whiteListed : whitelistedClassNames) {
78 if (incomingClassName.equals(whiteListed))
79 return true;
80 }
81 return false;
82 }
83
84 protected void addToWhitelist(List<String> additionalAuthorizedClasses) {
85 whitelistedClassNames.addAll(additionalAuthorizedClasses);
86 }
87 }