1 package ch.qos.logback.core.net;
2
3 import java.io.ByteArrayInputStream;
4 import java.io.ByteArrayOutputStream;
5 import java.io.IOException;
6 import java.io.InvalidClassException;
7 import java.io.ObjectOutputStream;
8 import java.util.HashSet;
9 import java.util.Set;
10
11 import org.junit.jupiter.api.AfterEach;
12 import org.junit.jupiter.api.BeforeEach;
13 import org.junit.jupiter.api.Test;
14
15 import static org.junit.jupiter.api.Assertions.assertEquals;
16 import static org.junit.jupiter.api.Assertions.assertNotNull;
17 import static org.junit.jupiter.api.Assertions.assertThrows;
18
19 public class HardenedObjectInputStreamTest {
20
21 ByteArrayOutputStream bos;
22 ObjectOutputStream oos;
23 HardenedObjectInputStream inputStream;
24 String[] whitelist = new String[] { Innocent.class.getName() };
25
26 @BeforeEach
27 public void setUp() throws Exception {
28 bos = new ByteArrayOutputStream();
29 oos = new ObjectOutputStream(bos);
30 }
31
32 @AfterEach
33 public void tearDown() throws Exception {
34 }
35
36 @Test
37 public void smoke() throws ClassNotFoundException, IOException {
38 Innocent innocent = new Innocent();
39 innocent.setAnInt(1);
40 innocent.setAnInteger(2);
41 innocent.setaString("smoke");
42 Innocent back = writeAndRead(innocent);
43 assertEquals(innocent, back);
44 }
45
46 private Innocent writeAndRead(Innocent innocent) throws IOException, ClassNotFoundException {
47 writeObject(oos, innocent);
48 ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
49 inputStream = new HardenedObjectInputStream(bis, whitelist);
50 Innocent fooBack = (Innocent) inputStream.readObject();
51 inputStream.close();
52 return fooBack;
53 }
54
55 private void writeObject(ObjectOutputStream oos, Object o) throws IOException {
56 oos.writeObject(o);
57 oos.flush();
58 oos.close();
59 }
60
61 @Test
62 public void denialOfService() throws ClassNotFoundException, IOException {
63 ByteArrayInputStream bis = new ByteArrayInputStream(payload());
64 inputStream = new HardenedObjectInputStream(bis, whitelist);
65 try {
66 assertThrows(InvalidClassException.class, () -> inputStream.readObject());
67 } finally {
68 inputStream.close();
69 }
70 }
71
72 private byte[] payload() throws IOException {
73 Set root = buildEvilHashset();
74 writeObject(oos, root);
75 return bos.toByteArray();
76 }
77
78 private Set buildEvilHashset() {
79 Set root = new HashSet();
80 Set s1 = root;
81 Set s2 = new HashSet();
82 for (int i = 0; i < 100; i++) {
83 Set t1 = new HashSet();
84 Set t2 = new HashSet();
85 t1.add("foo");
86 s1.add(t1);
87 s1.add(t2);
88 s2.add(t1);
89 s2.add(t2);
90 s1 = t1;
91 s2 = t2;
92 }
93 return root;
94 }
95 }