1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package ch.qos.logback.core.net;
16
17 import java.io.ByteArrayInputStream;
18 import java.io.ByteArrayOutputStream;
19 import java.io.IOException;
20 import java.io.InvalidClassException;
21 import java.io.ObjectOutputStream;
22 import java.util.HashSet;
23 import java.util.Set;
24
25 import org.junit.jupiter.api.AfterEach;
26 import org.junit.jupiter.api.BeforeEach;
27 import org.junit.jupiter.api.Test;
28
29 import static org.junit.jupiter.api.Assertions.assertEquals;
30 import static org.junit.jupiter.api.Assertions.assertNotNull;
31 import static org.junit.jupiter.api.Assertions.assertThrows;
32
33 public class HardenedObjectInputStreamTest {
34
35 ByteArrayOutputStream bos;
36 ObjectOutputStream oos;
37 HardenedObjectInputStream inputStream;
38 String[] whitelist = new String[] { Innocent.class.getName() };
39
40 @BeforeEach
41 public void setUp() throws Exception {
42 bos = new ByteArrayOutputStream();
43 oos = new ObjectOutputStream(bos);
44 }
45
46 @AfterEach
47 public void tearDown() throws Exception {
48 }
49
50 @Test
51 public void smoke() throws ClassNotFoundException, IOException {
52 Innocent innocent = new Innocent();
53 innocent.setAnInt(1);
54 innocent.setAnInteger(2);
55 innocent.setaString("smoke");
56 Innocent back = writeAndRead(innocent);
57 assertEquals(innocent, back);
58 }
59
60 private Innocent writeAndRead(Innocent innocent) throws IOException, ClassNotFoundException {
61 writeObject(oos, innocent);
62 ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
63 inputStream = new HardenedObjectInputStream(bis, whitelist);
64 Innocent fooBack = (Innocent) inputStream.readObject();
65 inputStream.close();
66 return fooBack;
67 }
68
69 private void writeObject(ObjectOutputStream oos, Object o) throws IOException {
70 oos.writeObject(o);
71 oos.flush();
72 oos.close();
73 }
74
75 @Test
76 public void denialOfService() throws ClassNotFoundException, IOException {
77 ByteArrayInputStream bis = new ByteArrayInputStream(payload());
78 inputStream = new HardenedObjectInputStream(bis, whitelist);
79 try {
80 assertThrows(InvalidClassException.class, () -> inputStream.readObject());
81 } finally {
82 inputStream.close();
83 }
84 }
85
86 private byte[] payload() throws IOException {
87 Set root = buildEvilHashset();
88 writeObject(oos, root);
89 return bos.toByteArray();
90 }
91
92 private Set buildEvilHashset() {
93 Set root = new HashSet();
94 Set s1 = root;
95 Set s2 = new HashSet();
96 for (int i = 0; i < 100; i++) {
97 Set t1 = new HashSet();
98 Set t2 = new HashSet();
99 t1.add("foo");
100 s1.add(t1);
101 s1.add(t2);
102 s2.add(t1);
103 s2.add(t2);
104 s1 = t1;
105 s2 = t2;
106 }
107 return root;
108 }
109 }