View Javadoc
1   package ch.qos.logback.core.net;
2   
3   import java.io.ByteArrayInputStream;
4   import java.io.ByteArrayOutputStream;
5   import java.io.IOException;
6   import java.io.InvalidClassException;
7   import java.io.ObjectOutputStream;
8   import java.util.HashSet;
9   import java.util.Set;
10  
11  import org.junit.jupiter.api.AfterEach;
12  import org.junit.jupiter.api.BeforeEach;
13  import org.junit.jupiter.api.Test;
14  
15  import static org.junit.jupiter.api.Assertions.assertEquals;
16  import static org.junit.jupiter.api.Assertions.assertNotNull;
17  import static org.junit.jupiter.api.Assertions.assertThrows;
18  
19  public class HardenedObjectInputStreamTest {
20  
21      ByteArrayOutputStream bos;
22      ObjectOutputStream oos;
23      HardenedObjectInputStream inputStream;
24      String[] whitelist = new String[] { Innocent.class.getName() };
25  
26      @BeforeEach
27      public void setUp() throws Exception {
28          bos = new ByteArrayOutputStream();
29          oos = new ObjectOutputStream(bos);
30      }
31  
32      @AfterEach
33      public void tearDown() throws Exception {
34      }
35  
36      @Test
37      public void smoke() throws ClassNotFoundException, IOException {
38          Innocent innocent = new Innocent();
39          innocent.setAnInt(1);
40          innocent.setAnInteger(2);
41          innocent.setaString("smoke");
42          Innocent back = writeAndRead(innocent);
43          assertEquals(innocent, back);
44      }
45  
46      private Innocent writeAndRead(Innocent innocent) throws IOException, ClassNotFoundException {
47          writeObject(oos, innocent);
48          ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
49          inputStream = new HardenedObjectInputStream(bis, whitelist);
50          Innocent fooBack = (Innocent) inputStream.readObject();
51          inputStream.close();
52          return fooBack;
53      }
54  
55      private void writeObject(ObjectOutputStream oos, Object o) throws IOException {
56          oos.writeObject(o);
57          oos.flush();
58          oos.close();
59      }
60  
61      @Test
62      public void denialOfService() throws ClassNotFoundException, IOException {
63          ByteArrayInputStream bis = new ByteArrayInputStream(payload());
64          inputStream = new HardenedObjectInputStream(bis, whitelist);
65          try {
66              assertThrows(InvalidClassException.class, () -> inputStream.readObject());
67          } finally {
68              inputStream.close();
69          }
70      }
71  
72      private byte[] payload() throws IOException {
73          Set root = buildEvilHashset();
74          writeObject(oos, root);
75          return bos.toByteArray();
76      }
77  
78      private Set buildEvilHashset() {
79          Set root = new HashSet();
80          Set s1 = root;
81          Set s2 = new HashSet();
82          for (int i = 0; i < 100; i++) {
83              Set t1 = new HashSet();
84              Set t2 = new HashSet();
85              t1.add("foo"); // make it not equal to t2
86              s1.add(t1);
87              s1.add(t2);
88              s2.add(t1);
89              s2.add(t2);
90              s1 = t1;
91              s2 = t2;
92          }
93          return root;
94      }
95  }