Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

zipWithIndex on Apache Flink

I'd like to assign each row of my input an id - which should be a number from 0 to N - 1, where N is the number of rows in the input.

Roughly, I'd like to be able to do something like the following :

val data = sc.textFile(textFilePath, numPartitions)
val rdd = data.map(line => process(line))
val rddMatrixLike = rdd.zipWithIndex.map { case (v, idx) => someStuffWithIndex(idx, v) }

But in Apache Flink. Is it possible?

like image 579
Alexey Grigorev Avatar asked Jun 02 '15 12:06

Alexey Grigorev


2 Answers

This is now a part of the 0.10-SNAPSHOT release of Apache Flink. Examples for zipWithIndex(in) and zipWithUniqueId(in) are available in the official Flink documentation.

like image 174
peterschrott Avatar answered Nov 07 '22 17:11

peterschrott


Here is a simple implementation of the function:

public class ZipWithIndex {

public static void main(String[] args) throws Exception {

    ExecutionEnvironment ee = ExecutionEnvironment.getExecutionEnvironment();

    DataSet<String> in = ee.readTextFile("/home/robert/flink-workdir/debug/input");

    // count elements in each partition
    DataSet<Tuple2<Integer, Long>> counts = in.mapPartition(new RichMapPartitionFunction<String, Tuple2<Integer, Long>>() {
        @Override
        public void mapPartition(Iterable<String> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
            long cnt = 0;
            for (String v : values) {
                cnt++;
            }
            out.collect(new Tuple2<Integer, Long>(getRuntimeContext().getIndexOfThisSubtask(), cnt));
        }
    });

    DataSet<Tuple2<Long, String>> result = in.mapPartition(new RichMapPartitionFunction<String, Tuple2<Long, String>>() {
        long start = 0;

        @Override
        public void open(Configuration parameters) throws Exception {
            super.open(parameters);
            List<Tuple2<Integer, Long>> offsets = getRuntimeContext().getBroadcastVariable("counts");
            Collections.sort(offsets, new Comparator<Tuple2<Integer, Long>>() {
                @Override
                public int compare(Tuple2<Integer, Long> o1, Tuple2<Integer, Long> o2) {
                    return ZipWithIndex.compare(o1.f0, o2.f0);
                }
            });
            for(int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) {
                start += offsets.get(i).f1;
            }
        }

        @Override
        public void mapPartition(Iterable<String> values, Collector<Tuple2<Long, String>> out) throws Exception {
            for(String v: values) {
                out.collect(new Tuple2<Long, String>(start++, v));
            }
        }
    }).withBroadcastSet(counts, "counts");
    result.print();

}

public static int compare(int x, int y) {
    return (x < y) ? -1 : ((x == y) ? 0 : 1);
}
}

This is how it works: I'm using the first mapPartition() operation to go over all elements in the partitions to count how many elements are in there. I need to know the number of elements in each partition to properly set the offsets when assigning the IDs to the elements. The result of the first mapPartition is a DataSet containing mappings. I'm broadcasting this DataSet to all the second mapPartition() operators which will assign the IDs to the elements from the input. In the open() method of the second mapPartition() I'm computing the offset for each partition.

I'm probably going to contribute the code to Flink (after discussing it with the other committers).

like image 5
Robert Metzger Avatar answered Nov 07 '22 18:11

Robert Metzger