001package ch.qos.logback.core.net;
002
003import static org.junit.Assert.assertEquals;
004
005import java.io.ByteArrayInputStream;
006import java.io.ByteArrayOutputStream;
007import java.io.IOException;
008import java.io.ObjectOutputStream;
009
010import org.junit.After;
011import org.junit.Before;
012import org.junit.Test;
013
014public class HardenedObjectInputStreamTest {
015
016    ByteArrayOutputStream bos;
017    ObjectOutputStream oos;
018    HardenedObjectInputStream inputStream;
019    String[] whitelist = new String[] { Innocent.class.getName() };
020
021    @Before
022    public void setUp() throws Exception {
023        bos = new ByteArrayOutputStream();
024        oos = new ObjectOutputStream(bos);
025    }
026
027    @After
028    public void tearDown() throws Exception {
029    }
030
031    @Test
032    public void smoke() throws ClassNotFoundException, IOException {
033        Innocent innocent = new Innocent();
034        innocent.setAnInt(1);
035        innocent.setAnInteger(2);
036        innocent.setaString("smoke");
037        Innocent back = writeAndRead(innocent);
038        assertEquals(innocent, back);
039    }
040
041    private Innocent writeAndRead(Innocent innocent) throws IOException, ClassNotFoundException {
042        writeObject(oos, innocent);
043        ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
044        inputStream = new HardenedObjectInputStream(bis, whitelist);
045        Innocent fooBack = (Innocent) inputStream.readObject();
046        inputStream.close();
047        return fooBack;
048    }
049
050    private void writeObject(ObjectOutputStream oos, Object o) throws IOException {
051        oos.writeObject(o);
052        oos.flush();
053        oos.close();
054    }
055
056//    @Ignore
057//    @Test
058//    public void denialOfService() throws ClassNotFoundException, IOException {
059//        ByteArrayInputStream bis = new ByteArrayInputStream(payload());
060//        inputStream = new HardenedObjectInputStream(bis, whitelist);
061//        try {
062//            Set set = (Set) inputStream.readObject();
063//            assertNotNull(set);
064//        } finally {
065//            inputStream.close();
066//        }
067//    }
068//
069//    private byte[] payload() throws IOException {
070//        Set root = buildEvilHashset();
071//        return serialize(root);
072//    }
073//
074//    private Set buildEvilHashset() {
075//        Set root = new HashSet();
076//        Set s1 = root;
077//        Set s2 = new HashSet();
078//        for (int i = 0; i < 100; i++) {
079//            Set t1 = new HashSet();
080//            Set t2 = new HashSet();
081//            t1.add("foo"); // make it not equal to t2
082//            s1.add(t1);
083//            s1.add(t2);
084//            s2.add(t1);
085//            s2.add(t2);
086//            s1 = t1;
087//            s2 = t2;
088//        }
089//        return root;
090//    }
091}