Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

jdk.serialFilter is not working for restricting depth of TreeMap in Java (prevent DoS attack through Java)

How to prevent DoS attack through Java TreeMap?

My code has an API which accepts a Map object. Now I want to prevent client to send Map objects of certain length.

Now maxarray in jdk.serialFilter is able to prevent the client sending a HashMap object of size > maxarray.

I want to do the same for TreeMap too. But maxarray field is not working for TreeMap. It is unable to reject that request.

I set maxdepth size too. But nothing is working.

Can anyone please help me with this?

like image 587
learner Avatar asked Jul 29 '19 11:07

learner


3 Answers

TL;DR;

This was an entire adventure in exploring the code that handled the serialization of TreeMap, but I managed to find gold. For the gold(code), scroll all the way to the bottom of the answer. If you want to follow the deduction process so you can do this with other classes, you'll have to struggle through my ramblings.

I might make it more concise, but I just spent 7 hours of reading code and experimenting, I'm fed up with this for now, and this post as it is might be instructive for others that wish to undertake this adventure.

Introduction

My attack avenue is, deserializing the entire thing takes up too much memory, allocating objects you may not wish to use or taking up ram. So I made an idea to just read the raw data, and check the TreeMap size. That way we have the only piece of data we need, in order to evaluate whether we should accept or not. Yes, this means reading the data twice if it's accepted, but that's the trade off you need to make when you wish to use this. This code skips over a lot of the verification steps java uses, because we're not interested in that. We just want a dependable way to get to the TreeMap size without having to load up the entire treemap with all the data.

Where normally you'd load all the data, read the entire file/bytestream and use it to initialize, we only need to read parts of the start of the file. Reducing the work that needs to be done and the time that needs to be wasted. We just need to move the file reading pointer forwards in a dependable way so we always get to the correct bytes for our process. This cuts the workflow for the java process significantly. After the size has been checked by a quick normal file read, it can be put through the actual serialization process, or be discarded. It's a just little bit overhead compared to the normal work done, but serves as an effective barrier whilst staying flexible in your accepted criteria.

normal work vs checking work

Starting the adventure

Looking at the source code of TreeMap https://github.com/openjdk-mirror/jdk7u-jdk/blob/master/src/share/classes/java/util/TreeMap.java#L123 we see that the size is a transient value. This means it doesn't get encoded in the serialized data, so with a quick check it can't be verified by reading the field value from the bytes sent.

But... not all hope is lost. Because if we check the writeObject() we see that the size IS encoded https://github.com/openjdk-mirror/jdk7u-jdk/blob/master/src/share/classes/java/util/TreeMap.java#L2268

This means we have byte values we can check in the raw data that was sent!.

Now let's check the defaultReadObject what it does.

L492 First it checks if it's deserializing, if it isn't it blocks. Okay, not interesting for us.
L495 Then it wants the object instance, the SerialCallbackContext was initialized with this so it doesn't perform a read.
L496 Then it gets an ObjectStreamClass instance from the SerialCallbackContext, so now we're going to work with the ObjectStream.
L497 Some modes were changed, but then we go read the fields.

Allright moving to ObjectInputStream
L1944 again a class reference that was provided to the object stream instantiator(for a quick rundown L262, which is set in L442), so it doesn't perform reading.
L1949 getting the size of the default fields with getPrimDataSize, which is set in the computeFieldOffsets method. This is useful for us, only shame is... it's not accessible, so lets figure out how to emulate this, just as a note.


L1255 It uses a fields variable. This is set at getSerialFields, which sadly is also private. At this point I get the impression I'm messing with powers i'm not supposed to touch. But onwards I go, ignoring the forbidden sign, adventure awaits!
getDeclaredSerialFields and getDefaultSerialFields is called in this method, so we can use the contents of that to emulate it's functionality.
Analyzing getDeclaredSerialFields we see it is only in effect if a serialPersistentFields is declared in the TreeMap class. Neither TreeMap or it's parent AbstractMap contains this field. So we ignore the getDeclaredSerialFields method. On to getDefaultSerialFields

So if we take that code, fiddle around with it, we can get meaningful data and we see that TreeMap has one field, and now we have a dynamic method to "emulate" getting the default fields, should things change for whatever reason.

https://ideone.com/UqqKSG (I left classnames with full paths so it's easier to see which classes I'm using)

    java.lang.reflect.Field[] clFields = TreeMap.class.getDeclaredFields();
    ArrayList<java.lang.reflect.Field> list = new ArrayList<>();
    int mask = java.lang.reflect.Modifier.STATIC | java.lang.reflect.Modifier.TRANSIENT;

    for (int i = 0; i < clFields.length; i++) {
        // Check for non transient and non static fields.
        if ((clFields[i].getModifiers() & mask) == 0) {
            list.add(clFields[i]);
            System.out.println("Found field " + clFields[i].getName());
        }
    }
    int size = list.size();
    System.out.println(size);

Found field comparator
1


L1951 Back in the ObjectInputStream we see that this size is used to create an array to be used as a buffer for reading, and then they are read fully, with arguments empty array, offset 0, length of fields(1), and false. This method is called in the BlockDataInputStream and the false means it won't be copied. This is just a helper method for handling the datastream with a PeekInputStream(in), we can use the same methods on the stream we're going to have with some fiddling, although we don't need this now because no primitive types are stored in TreeMap. So i'll leave this train of thought for this answer.

L1964 calls for readObject0 which reads the comparator used in TreeMap. It checks for oldMode, which returns whether the stream is read in block data mode or not, and we can see that this was set to stream mode(false) in readFields so I'll skip that part.
L1315 simple check that recursion doesn't happen more than once, but one byte is peeked. Let's see what TreeMap has to provide for that. That took me longer than expected. I can't post the code here, it's too long, but I have it on ideone and a gist.

  • Basically you need to copy over the inline class BlockDataInputStream,
  • add private static native void bytesToFloats(byte[] src, int srcpos, float[] dst, int dstpos, int nfloats);private static native void bytesToDoubles(byte[] src, int srcpos, double[] dst, int dstpos, int ndoubles); to BlockDataInputStream. If you actually need to use these methods substitute them with something Java. It will give a runtime error.
  • copy over the inline class PeekInputStream
  • copy over the java.io.Bits class.
  • The TC_ references need to point to java.io.ObjectStreamConstants.TC_
    BlockDataInputStream bin = new BlockDataInputStream(getTreeMapInputStream());
    bin.setBlockDataMode(false);
    byte b = bin.peekByte();
    System.out.println("Does b ("+String.format("%02X ", b)+") equals TC_RESET?" + (java.io.ObjectStreamConstants.TC_RESET == b ? "yes": "no"));

Does b (-84) equals TC_RESET?no

We see that we read a 0xAC, let's take a shortcut and look in java.io.ObjectStreamConstants what it is. There's no entry for purely 0xAC, but it does seem like part of the header.
Let's do the sanity check from readStreamHeader and insert the contents of that method right before our peekByte code, updating the TC_ references again. We now get an output of 0x73. Progress!
0x73 is TC_OBJECT so lets jump to L1347
There we find that readOrdinaryObject is called which does a readByte().
Then the classDescription is read which jumps to readNonProxy
We then have a call to readUTF(), a readLong(), a readByte(), a readShort, to read fields..., then for every field a readByte(), readUTF().

So, let's emulate that. First thing I encounter is that it tries to read beyond the string length(29184 character classname? don't think so) for the class name, so I'm missing something. I have no idea what I'm missing at this point, but I'm running it on ideone and maybe it runs on a version of java where they added an extra byte before reading the UTF. I can't be bothered to look it up honestly. It works, I'm happy. Anyway, after reading an extra byte it runs perfectly, and we're right where we want to be. TODO: figure out where the extra byte is read

    BlockDataInputStream bin = new BlockDataInputStream(getTreeMapInputStream());
    bin.setBlockDataMode(false);
    short s0 = bin.readShort();
    short s1 = bin.readShort();
    if (s0 != java.io.ObjectStreamConstants.STREAM_MAGIC || s1 != java.io.ObjectStreamConstants.STREAM_VERSION) {
        throw new StreamCorruptedException(
            String.format("invalid stream header: %04X%04X", s0, s1));
    }
    byte b = bin.readByte();
    if(b == java.io.ObjectStreamConstants.TC_OBJECT) {
        bin.readByte();
        String name = bin.readUTF();
        System.out.println(name);
        System.out.println("Is string ("+name+")it a java.util.TreeMap? "+(name.equals("java.util.TreeMap") ? "yes":"no"));
        bin.readLong();
        bin.readByte();
        short fields = bin.readShort();
        for(short i = 0; i < fields; i++) {
            bin.readByte();
            System.out.println("Read field name "+bin.readUTF());
        }
    }

Now we continue on Line 1771 to see what is read then after the class description was read. Following this, there's a lot of object instantiation checking, etc... It's like spagetti I don't feel like plowing through. Let's get hackish and analyze the data.

The data as string

tLjava/util/Comparator;xppwsrjava.lang.Integer¬ᅠᄂ￷チヌ8Ivaluexrjava.lang.Numberニᆲユヤ¢ヒxptData1sq~tData5sq~tData4sq~tData2sq~FtData3x -74 -00 -16 -4C -6A -61 -76 -61 -2F -75 -74 -69 -6C -2F -43 -6F -6D -70 -61 -72 -61 -74 -6F -72 -3B -78 -70 -70 -77 -04 -00 -00 -00 -05 -73 -72 -00 -11 -6A -61 -76 -61 -2E -6C -61 -6E -67 -2E -49 -6E -74 -65 -67 -65 -72 -12 -E2 -A0 -A4 -F7 -81 -87 -38 -02 -00 -01 -49 -00 -05 -76 -61 -6C -75 -65 -78 -72 -00 -10 -6A -61 -76 -61 -2E -6C -61 -6E -67 -2E -4E -75 -6D -62 -65 -72 -86 -AC -95 -1D -0B -94 -E0 -8B -02 -00 -00 -78 -70 -00 -00 -00 -01 -74 -00 -05 -44 -61 -74 -61 -31 -73 -71 -00 -7E -00 -03 -00 -00 -00 -02 -74 -00 -05 -44 -61 -74 -61 -35 -73 -71 -00 -7E -00 -03 -00 -00 -00 -04 -74 -00 -05 -44 -61 -74 -61 -34 -73 -71 -00 -7E -00 -03 -00 -00 -00 -17 -74 -00 -05 -44 -61 -74 -61 -32 -73 -71 -00 -7E -00 -03 -00 -00 -00 -46 -74 -00 -05 -44 -61 -74 -61 -33 -78

The T is We know the size of the elements is written before the elements. The Data1 - Date5 fields are the values stored in the map. So when the Data1sq part comes after that all is moot. Let's add an item to the map to see which value changes!

74 -00 -16 -4C -6A -61 -76 -61 -2F -75 -74 -69 -6C -2F -43 -6F -6D -70 -61 -72 -61 -74 -6F -72 -3B -78 -78 -70 -70 -77 -04 -00 -00 -00 -05 -73 -72 -00 -11 -6A -61 -76 -61 -2E
74 -00 -16 -4C -6A -61 -76 -61 -2F -75 -74 -69 -6C -2F -43 -6F -6D -70 -61 -72 -61 -74 -6F -72 -3B -78 -78 -70 -70 -77 -04 -00 -00 -00 -06 -73 -72 -00 -11 -6A -61 -76 -61 -2E

Okay, Now we know how many bites we stil have to slaughter. Let's see if we can deduct some logic here with the given values.
The first value is a 74. Checking the ObjectStreamConstants we see that that stands for a string. Let's read that byte and then the UTF.
Now we have remaining -70 -70 -77 -04 -00 -00 -00 -06 Let's lay that besides the constants.

NULL - NULL - BLOCKDATA - value 4 - value 0 - value 0 - value 0 - value 6

We could theorize here:

After block data, an integer is written. An integer is four bytes. hence the four. The next four positions make up the integer.

Let see what happens if we add a comparator to the treemap.

xpsr'java.util.Collections$ReverseComparatordハ￰SNJ￐xpwsrjava.lang.Integer¬ᅠᄂ￷チヌ8I
-78 -70 -73 -72 -00 -27 -6A -61 -76 -61 -2E -75 -74 -69 -6C -2E -43 -6F -6C -6C -65 -63 -74 -69 -6F -6E -73 -24 -52 -65 -76 -65 -72 -73 -65 -43 -6F -6D -70 -61 -72 -61 -74 -6F -72 -64 -04 -8A -F0 -53 -4E -4A -D0 -02 -00 -00 -78 -70 -77 -04 -00 -00 -00 -06

We see END_BLOCK, NULL, OBJECT

Okay. So now we know that the second Null is the holder for Comparator data. So we can peek on that one. We need to skip two bytes, then peek if it's an object byte. If it is, we need to read the object data so we can got to our desired position.

Let's take a pause and review the code thus far: https://ideone.com/ma6nQy

    BlockDataInputStream bin = new BlockDataInputStream(getTreeMapInputStream());
    bin.setBlockDataMode(false);
    short s0 = bin.readShort();
    short s1 = bin.readShort();
    if (s0 != java.io.ObjectStreamConstants.STREAM_MAGIC || s1 != java.io.ObjectStreamConstants.STREAM_VERSION) {
        throw new StreamCorruptedException(
            String.format("invalid stream header: %04X%04X", s0, s1));
    }
    byte b = bin.peekByte();

    if(b == java.io.ObjectStreamConstants.TC_OBJECT) {
        Ideone.readObject(bin,true);
    }

    if(bin.readByte() == java.io.ObjectStreamConstants.TC_STRING) {
        String className = bin.readUTF();
        System.out.println(className + "starts with L "+(className.charAt(0) == 'L' ? "yes": "no"));
        if(className.charAt(0) == 'L') {
            // Skip two bytes
            bin.readByte();
            bin.readByte();
            b = bin.peekByte();
            if(b == java.io.ObjectStreamConstants.TC_OBJECT) {
                System.out.println("reading object");
                Ideone.readObject(bin,true);
            }
            else {
                // remove the null byte so we end up at same position
                bin.readByte();
            }
        }
    }
    int length = 50;
    byte[] bytes = new byte[length];
    for(int c=0;c<length;c++) {
        bytes[c] = bin.readByte();
        System.out.print((char)(bytes[c]));
    }
    for(int c=0;c<length;c++) {
        System.out.print("-"+String.format("%02X ", bytes[c]));
    }
}

public static void readObject(BlockDataInputStream bin, boolean doExtra) throws Exception {
    byte b = bin.readByte();
    if(b == java.io.ObjectStreamConstants.TC_OBJECT) {
        if(doExtra) {
          bin.readByte();
        }
        String name = bin.readUTF();
        System.out.println(name);
        System.out.println("Is string ("+name+")it a java.util.TreeMap? "+(name.equals("java.util.TreeMap") ? "yes":"no"));
        bin.readLong();
        bin.readByte();
        short fields = bin.readShort();
        for(short i = 0; i < fields; i++) {
            bin.readByte();
            System.out.println("Read field name "+bin.readUTF());
        }
    }
}

Found field comparator
1
java.util.TreeMap
Is string (java.util.TreeMap)it a java.util.TreeMap? yes
Read field name comparator
Ljava/util/Comparator;starts with L yes
reading object
java.util.Collections$ReverseComparator
Is string (java.util.Collections$ReverseComparator)it a java.util.TreeMap? no
xpwsrjava.lang.Integer¬ᅠᄂ￷チヌ8Ivaluexr
-78 -70 -77 -04 -00 -00 -00 -06 -73 -72 -00 -11 -6A -61 -76 -61 -2E -6C -61 -6E -67 -2E -49 -6E -74 -65 -67 -65 -72 -12 -E2 -A0 -A4 -F7 -81 -87 -38 -02 -00 -01 -49 -00 -05 -76 -61 -6C -75 -65 -78 -72

Sadly, we don't end up on the same points in the timeline though.

When there is a comparator, we end with:

-78 -70 -77 -04 -00 -00 -00 -06

When the comparator is removed we end with:

-77 -04 -00 -00 -00 -06

Hmmm. That BLOCK END and NULL looks very familiar. Those are the same bytes we skipped when reading the comparator. These two bytes are always removed, but aparently, the comparator also ads their own BLOCK END and NULL value.

So, if there is a comparator, remove the two trailing bytes so we got what we want, consistently. https://ideone.com/pTu8Fd

-77 -04 -00 -00 -00 -06

We then skip the next BLOCKDATA marker(the 77) and reach the gold!

Adding the extra lines we get our output: https://ideone.com/wy0uF2

    System.out.println(String.format("%02X ", bin.readByte()));
    if(bin.readByte() == (byte)4) {
        System.out.println("The length is "+ bin.readInt());
    }

77
The length is 6

And we have the magic number we need!

Okay. Deducing done, let's clean it up

The useable stuff you care about

Runnable snippet: https://ideone.com/J6ovMy
Complete code also as a gist: https://gist.github.com/tschallacka/8f89982e9569d0b9974dff37d8f45faf

 /**
This is dual licensed under MIT. You can choose wether you want to use CC-BY-SA or MIT.
Copyright 2020 Tschallacka
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
import java.util.*;
import java.lang.*;
import java.io.*;

/* Name of the class has to be "Main" only if the class is public. */
class Ideone
{
    public static void main (String[] args) throws java.lang.Exception
    {

        doTest(1,true);
        doTest(1,false);

        doTest(20,true);
        doTest(20,false);

        doTest(4,true);
        doTest(19,false);
    }

    public static void doTest(int size, boolean comparator) throws java.lang.Exception {
        SerializedTreeMapAnalyzer analyzer = new SerializedTreeMapAnalyzer();
        System.out.println(analyzer.getSize(Ideone.getTreeMapInputStream(size,comparator)));
    }

    public static ByteArrayInputStream getTreeMapInputStream(int size, boolean comparator) throws Exception {
      TreeMap<Integer, String> tmap = 
             new TreeMap<Integer, String>(comparator?Collections.reverseOrder():null);

      /*Adding elements to TreeMap*/
      for(int i = 0; size > 0 && i < size; i++) {
        tmap.put(i, "Data"+i);
      }

      ByteArrayOutputStream baos = new ByteArrayOutputStream();
      ObjectOutputStream oos = new ObjectOutputStream( baos );
      oos.writeObject( tmap );
      oos.close();
      return  new ByteArrayInputStream(baos.toByteArray());
    }
}

class SerializedTreeMapAnalyzer 
{
    public int getSize(InputStream stream) throws IOException, StreamCorruptedException, Exception {
        BlockDataInputStream bin = new BlockDataInputStream(stream);
        bin.setBlockDataMode(false);

        short s0 = bin.readShort();
        short s1 = bin.readShort();

        if (s0 != java.io.ObjectStreamConstants.STREAM_MAGIC || s1 != java.io.ObjectStreamConstants.STREAM_VERSION) {
            throw new StreamCorruptedException(
                String.format("invalid stream header: %04X%04X", s0, s1));
        }

        byte b = bin.peekByte();

        if(b == java.io.ObjectStreamConstants.TC_OBJECT) {
            this.readObject(bin,true);
        }

        if(bin.readByte() == java.io.ObjectStreamConstants.TC_STRING) {
            String className = bin.readUTF();

            if(className.charAt(0) == 'L') {
                // Skip two bytes
                bin.readByte();
                bin.readByte();
                b = bin.peekByte();
                if(b == java.io.ObjectStreamConstants.TC_OBJECT) {

                    this.readObject(bin,true);
                    bin.readByte();
                    bin.readByte();
                }
                else {
                    // remove the null byte so we end up at same position
                    bin.readByte();
                }
            }
        }
        bin.readByte();
        if(bin.readByte() == (byte)4) {
            return bin.readInt();
        }
        return -1;
    }

    protected void readObject(BlockDataInputStream bin, boolean doExtra) throws Exception {
        byte b = bin.readByte();
        if(b == java.io.ObjectStreamConstants.TC_OBJECT) {
            if(doExtra) {
              bin.readByte();
            }
            String name = bin.readUTF();
            bin.readLong();
            bin.readByte();
            short fields = bin.readShort();
            for(short i = 0; i < fields; i++) {
                bin.readByte();
                bin.readUTF();
            }
        }
    }
}

1
1
20
20
4
19

like image 185
Tschallacka Avatar answered Oct 31 '22 08:10

Tschallacka


Not knowing your API, but usually, you would limit the post size accepted by your application server. In WildFly, you can add the property max-post-size to your http/https-listener. This would limit the amount of data your server is willing to receive, and as a result limiting the amount of data which can be processed per request.

Another approach is to introduce something like a rate limit - when your client is executing too many queries, you can deny any processing of the data. This is a common approach to limit the processing power consumed by individual customers. Since your API doesn't seem to be open (at least you didn't say it was), you can define the rate limit on a customer level. This might be the best approach in your case.

To your approach: When your server knows, how big a Map is, it has actually already accepted and received the data, therefore the resources for that are already gone (although the processing can be limited).

In the end, you'll have to choose an appropriate way for your use case. Your case doesn't sound like the network is the bottleneck, but rather computing power. So I assume a combination of a limited post-size and rate limit would be the best thing in your case.

like image 20
maio290 Avatar answered Oct 31 '22 09:10

maio290


Searching the OpenJDK for getJavaObjectInputStreamAccess().checkArray turns up maxarray checking in these classes.

  • java.util.ArrayDeque
  • java.util.ArrayList
  • java.util.Collection
  • java.util.HashMap
  • java.util.HashSet
  • java.util.Hashtable
  • java.util.IdentityHashMap
  • java.util.ImmutableCollections
  • java.util.PriorityBlockingQueue
  • java.util.PriorityQueue
  • java.util.Properties
  • java.util.concurrent.CopyOnArrayList
  • javax.management.openmbean.TabularDataSupport

And java.io.ObjectInputStream uses it, of course.

What is maxarray attempting to defend against? Presumably malicious streams causing allocation of extremely disproportionate amount of memory. But this isn't cumulative, so would appear to be entirely ineffective against anything new.

TreeMap doesn't use arrays, so maxarray couldn't apply. If we wanted to limit the size of a TreeMap as part of efforts to reduce the maximum size of deserialised objects, then maxrefs and maxbytes are appropriate as with any other serialisable object.

like image 36
Tom Hawtin - tackline Avatar answered Oct 31 '22 09:10

Tom Hawtin - tackline