001package ch.qos.logback.core.net; 002 003import static org.junit.Assert.assertEquals; 004 005import java.io.ByteArrayInputStream; 006import java.io.ByteArrayOutputStream; 007import java.io.IOException; 008import java.io.ObjectOutputStream; 009 010import org.junit.After; 011import org.junit.Before; 012import org.junit.Test; 013 014public class HardenedObjectInputStreamTest { 015 016 ByteArrayOutputStream bos; 017 ObjectOutputStream oos; 018 HardenedObjectInputStream inputStream; 019 String[] whitelist = new String[] { Innocent.class.getName() }; 020 021 @Before 022 public void setUp() throws Exception { 023 bos = new ByteArrayOutputStream(); 024 oos = new ObjectOutputStream(bos); 025 } 026 027 @After 028 public void tearDown() throws Exception { 029 } 030 031 @Test 032 public void smoke() throws ClassNotFoundException, IOException { 033 Innocent innocent = new Innocent(); 034 innocent.setAnInt(1); 035 innocent.setAnInteger(2); 036 innocent.setaString("smoke"); 037 Innocent back = writeAndRead(innocent); 038 assertEquals(innocent, back); 039 } 040 041 private Innocent writeAndRead(Innocent innocent) throws IOException, ClassNotFoundException { 042 writeObject(oos, innocent); 043 ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray()); 044 inputStream = new HardenedObjectInputStream(bis, whitelist); 045 Innocent fooBack = (Innocent) inputStream.readObject(); 046 inputStream.close(); 047 return fooBack; 048 } 049 050 private void writeObject(ObjectOutputStream oos, Object o) throws IOException { 051 oos.writeObject(o); 052 oos.flush(); 053 oos.close(); 054 } 055 056// @Ignore 057// @Test 058// public void denialOfService() throws ClassNotFoundException, IOException { 059// ByteArrayInputStream bis = new ByteArrayInputStream(payload()); 060// inputStream = new HardenedObjectInputStream(bis, whitelist); 061// try { 062// Set set = (Set) inputStream.readObject(); 063// assertNotNull(set); 064// } finally { 065// inputStream.close(); 066// } 067// } 068// 069// private byte[] payload() throws IOException { 070// Set root = buildEvilHashset(); 071// return serialize(root); 072// } 073// 074// private Set buildEvilHashset() { 075// Set root = new HashSet(); 076// Set s1 = root; 077// Set s2 = new HashSet(); 078// for (int i = 0; i < 100; i++) { 079// Set t1 = new HashSet(); 080// Set t2 = new HashSet(); 081// t1.add("foo"); // make it not equal to t2 082// s1.add(t1); 083// s1.add(t2); 084// s2.add(t1); 085// s2.add(t2); 086// s1 = t1; 087// s2 = t2; 088// } 089// return root; 090// } 091}