Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I make this F# function not cause a stack overflow

Tags:

I've written an interesting function in F# which can traverse and map any data structure (much like the everywhere function available in Haskell's Scrap Your Boilerplate). Unfortunately it quickly causes a stack overflow on even fairly small data structures. I was wondering how I can convert it to a tail recursive version, continuation passing style version or an imperative equivalent algorithm. I believe F# supports monads, so the continuation monad is an option.

// These are used for a 50% speedup
let mutable tupleReaders : List<System.Type * (obj -> obj[])> = []
let mutable unionTagReaders : List<System.Type * (obj -> int)> = []
let mutable unionReaders : List<(System.Type * int) * (obj -> obj[])> = []
let mutable unionCaseInfos : List<System.Type * Microsoft.FSharp.Reflection.UnionCaseInfo[]> = []
let mutable recordReaders : List<System.Type * (obj -> obj[])> = []

(*
    Traverses any data structure in a preorder traversal
    Calls f, g, h, i, j which determine the mapping of the current node being considered

    WARNING: Not able to handle option types
    At runtime, option None values are represented as null and so you cannot determine their runtime type.

    See http://stackoverflow.com/questions/21855356/dynamically-determine-type-of-option-when-it-has-value-none
    http://stackoverflow.com/questions/13366647/how-to-generalize-f-option
*)
open Microsoft.FSharp.Reflection
let map5<'a,'b,'c,'d,'e,'z> (f:'a->'a) (g:'b->'b) (h:'c->'c) (i:'d->'d) (j:'e->'e) (src:'z) =
    let ft = typeof<'a>
    let gt = typeof<'b>
    let ht = typeof<'c>
    let it = typeof<'d>
    let jt = typeof<'e>

    let rec drill (o:obj) : obj =
        if o = null then
            o
        else
            let ot = o.GetType()
            if FSharpType.IsUnion(ot) then
                let tag = match List.tryFind (fst >> ot.Equals) unionTagReaders with
                              | Some (_, reader) ->
                                   reader o
                              | None ->
                                   let newReader = FSharpValue.PreComputeUnionTagReader(ot)
                                   unionTagReaders <- (ot, newReader)::unionTagReaders
                                   newReader o
                let info = match List.tryFind (fst >> ot.Equals) unionCaseInfos with
                               | Some (_, caseInfos) ->
                                   Array.get caseInfos tag
                               | None ->
                                   let newCaseInfos = FSharpType.GetUnionCases(ot)
                                   unionCaseInfos <- (ot, newCaseInfos)::unionCaseInfos
                                   Array.get newCaseInfos tag
                let vals = match List.tryFind (fun ((tau, tag'), _) -> ot.Equals tau && tag = tag') unionReaders with
                               | Some (_, reader) ->
                                   reader o
                               | None ->
                                   let newReader = FSharpValue.PreComputeUnionReader info
                                   unionReaders <- ((ot, tag), newReader)::unionReaders
                                   newReader o
                FSharpValue.MakeUnion(info, Array.map traverse vals)
            elif FSharpType.IsTuple(ot) then
                let fields = match List.tryFind (fst >> ot.Equals) tupleReaders with
                                 | Some (_, reader) ->
                                     reader o
                                 | None ->
                                     let newReader = FSharpValue.PreComputeTupleReader(ot)
                                     tupleReaders <- (ot, newReader)::tupleReaders
                                     newReader o
                FSharpValue.MakeTuple(Array.map traverse fields, ot)
            elif FSharpType.IsRecord(ot) then
                let fields = match List.tryFind (fst >> ot.Equals) recordReaders with
                                 | Some (_, reader) ->
                                     reader o
                                 | None ->
                                     let newReader = FSharpValue.PreComputeRecordReader(ot)
                                     recordReaders <- (ot, newReader)::recordReaders
                                     newReader o
                FSharpValue.MakeRecord(ot, Array.map traverse fields)
            else
                o

    and traverse (o:obj) =
        let parent =
            if o = null then
                o
            else
                let ot = o.GetType()
                if ft = ot || ot.IsSubclassOf(ft) then
                    f (o :?> 'a) |> box
                elif gt = ot || ot.IsSubclassOf(gt) then
                    g (o :?> 'b) |> box
                elif ht = ot || ot.IsSubclassOf(ht) then
                    h (o :?> 'c) |> box
                elif it = ot || ot.IsSubclassOf(it) then
                    i (o :?> 'd) |> box
                elif jt = ot || ot.IsSubclassOf(jt) then
                    j (o :?> 'e) |> box
                else
                    o
        drill parent
    traverse src |> unbox : 'z
like image 706
blink Avatar asked Apr 11 '16 06:04

blink


2 Answers

Try this (I just used continuation function as parameter):

namespace Solution

[<CompilationRepresentation(CompilationRepresentationFlags.ModuleSuffix)>]
[<AutoOpen>]
module Solution =

    // These are used for a 50% speedup
    let mutable tupleReaders : List<System.Type * (obj -> obj[])> = []
    let mutable unionTagReaders : List<System.Type * (obj -> int)> = []
    let mutable unionReaders : List<(System.Type * int) * (obj -> obj[])> = []
    let mutable unionCaseInfos : List<System.Type * Microsoft.FSharp.Reflection.UnionCaseInfo[]> = []
    let mutable recordReaders : List<System.Type * (obj -> obj[])> = []

    (*
        Traverses any data structure in a preorder traversal
        Calls f, g, h, i, j which determine the mapping of the current node being considered

        WARNING: Not able to handle option types
        At runtime, option None values are represented as null and so you cannot determine their runtime type.

        See http://stackoverflow.com/questions/21855356/dynamically-determine-type-of-option-when-it-has-value-none
        http://stackoverflow.com/questions/13366647/how-to-generalize-f-option
    *)
    open Microsoft.FSharp.Reflection
    let map5<'a,'b,'c,'d,'e,'z> (f:'a->'a) (g:'b->'b) (h:'c->'c) (i:'d->'d) (j:'e->'e) (src:'z) =
        let ft = typeof<'a>
        let gt = typeof<'b>
        let ht = typeof<'c>
        let it = typeof<'d>
        let jt = typeof<'e>

        let rec drill (o:obj) =
            if o = null then
                (None, fun _ -> o)
            else
                let ot = o.GetType()
                if FSharpType.IsUnion(ot) then
                    let tag = match List.tryFind (fst >> ot.Equals) unionTagReaders with
                                  | Some (_, reader) ->
                                       reader o
                                  | None ->
                                       let newReader = FSharpValue.PreComputeUnionTagReader(ot)
                                       unionTagReaders <- (ot, newReader)::unionTagReaders
                                       newReader o
                    let info = match List.tryFind (fst >> ot.Equals) unionCaseInfos with
                                   | Some (_, caseInfos) ->
                                       Array.get caseInfos tag
                                   | None ->
                                       let newCaseInfos = FSharpType.GetUnionCases(ot)
                                       unionCaseInfos <- (ot, newCaseInfos)::unionCaseInfos
                                       Array.get newCaseInfos tag
                    let vals = match List.tryFind (fun ((tau, tag'), _) -> ot.Equals tau && tag = tag') unionReaders with
                                   | Some (_, reader) ->
                                       reader o
                                   | None ->
                                       let newReader = FSharpValue.PreComputeUnionReader info
                                       unionReaders <- ((ot, tag), newReader)::unionReaders
                                       newReader o
//                    (Some(vals), FSharpValue.MakeUnion(info, Array.map traverse vals))
                    (Some(vals), (fun x -> FSharpValue.MakeUnion(info, x)))
                elif FSharpType.IsTuple(ot) then
                    let fields = match List.tryFind (fst >> ot.Equals) tupleReaders with
                                     | Some (_, reader) ->
                                         reader o
                                     | None ->
                                         let newReader = FSharpValue.PreComputeTupleReader(ot)
                                         tupleReaders <- (ot, newReader)::tupleReaders
                                         newReader o
//                    (FSharpValue.MakeTuple(Array.map traverse fields, ot)
                    (Some(fields), (fun x -> FSharpValue.MakeTuple(x, ot)))
                elif FSharpType.IsRecord(ot) then
                    let fields = match List.tryFind (fst >> ot.Equals) recordReaders with
                                     | Some (_, reader) ->
                                         reader o
                                     | None ->
                                         let newReader = FSharpValue.PreComputeRecordReader(ot)
                                         recordReaders <- (ot, newReader)::recordReaders
                                         newReader o
//                    FSharpValue.MakeRecord(ot, Array.map traverse fields)
                    (Some(fields), (fun x -> FSharpValue.MakeRecord(ot, x)))
                else
                    (None, (fun _ -> o))



        and traverse (o:obj) cont =
            let parent =
                if o = null then
                    o
                else
                    let ot = o.GetType()
                    if ft = ot || ot.IsSubclassOf(ft) then
                        f (o :?> 'a) |> box
                    elif gt = ot || ot.IsSubclassOf(gt) then
                        g (o :?> 'b) |> box
                    elif ht = ot || ot.IsSubclassOf(ht) then
                        h (o :?> 'c) |> box
                    elif it = ot || ot.IsSubclassOf(it) then
                        i (o :?> 'd) |> box
                    elif jt = ot || ot.IsSubclassOf(jt) then
                        j (o :?> 'e) |> box
                    else
                        o
            let child, f = drill parent

            match child with 
                | None -> 
                    f [||] |> cont
                | Some(x) -> 

                    match x.Length with
                        | len when len > 1 ->
                            let resList = System.Collections.Generic.List<obj>()
                            let continuation = Array.foldBack (fun t s -> (fun mC -> resList.Add(mC); traverse t s) ) 
                                                              (x.[1..]) 
                                                              (fun mC -> resList.Add(mC); resList.ToArray() |> f |> cont)
                            traverse (x.[0]) continuation
                        | _ -> traverse x (fun mC -> 
                                            match mC with
                                                | :? (obj[]) as mC -> f mC |> cont
                                                | _ -> f [|mC|] |> cont
                                          )

        traverse src (fun x -> x) |> unbox : 'z

You should build this with enabled Generate tail calls option (by default, this option disabled in Debug mode, but enabled in Release).

Example:

type A1 =
    | A of A2
    | B of int

and A2 =
    | A of A1
    | B of int

and Root = 
    | A1 of A1
    | A2 of A2

[<EntryPoint>]
let main args =
    let rec build (elem: Root) n = 
        if n = 0 then elem
        else 
            match elem with
                | A1(x) -> build (Root.A2(A2.A(x))) (n-1)
                | A2(x) -> build (Root.A1(A1.A(x))) (n-1)
    let tree = build (Root.A1(A1.B(2))) 100000

    let a = map5 (fun x -> x) (fun x -> x) (fun x -> x) (fun x -> x) (fun x -> x) tree
    printf "%A" a
    0

This code finished without Stack Overflow exception.

like image 87
Sattar Imamov Avatar answered Sep 28 '22 04:09

Sattar Imamov


I ended up converting the code to an imperative style to avoid the Stack Overflow:

open Microsoft.FSharp.Reflection

let mutable tupleReaders : List<System.Type * (obj -> obj[])> = []
let mutable unionTagReaders : List<System.Type * (obj -> int)> = []
let mutable unionReaders : List<(System.Type * int) * (obj -> obj[])> = []
let mutable unionCaseInfos : List<System.Type * Microsoft.FSharp.Reflection.UnionCaseInfo[]> = []
let mutable recordReaders : List<System.Type * (obj -> obj[])> = []

type StructureInfo = Union of UnionCaseInfo
                   | Tuple of System.Type
                   | Record of System.Type
                   | Leaf

let map5<'a,'b,'c,'d,'e,'z> (f:'a->'a) (g:'b->'b) (h:'c->'c) (i:'d->'d) (j:'e->'e) (src:'z) : 'z =
    let ft = typeof<'a>
    let gt = typeof<'b>
    let ht = typeof<'c>
    let it = typeof<'d>
    let jt = typeof<'e>

    let getStructureInfo (o : obj) =
        if o = null then
            (Leaf, [||])
        else
            let ot = o.GetType()
            if FSharpType.IsUnion(ot) then
                let tag = match List.tryFind (fst >> ot.Equals) unionTagReaders with
                                | Some (_, reader) ->
                                    reader o
                                | None ->
                                    let newReader = FSharpValue.PreComputeUnionTagReader(ot)
                                    unionTagReaders <- (ot, newReader)::unionTagReaders
                                    newReader o
                let info = match List.tryFind (fst >> ot.Equals) unionCaseInfos with
                                | Some (_, caseInfos) ->
                                    Array.get caseInfos tag
                                | None ->
                                    let newCaseInfos = FSharpType.GetUnionCases(ot)
                                    unionCaseInfos <- (ot, newCaseInfos)::unionCaseInfos
                                    Array.get newCaseInfos tag
                let children =
                    match List.tryFind (fun ((tau, tag'), _) -> ot.Equals tau && tag = tag') unionReaders with
                        | Some (_, reader) ->
                            reader o
                        | None ->
                            let newReader = FSharpValue.PreComputeUnionReader info
                            unionReaders <- ((ot, tag), newReader)::unionReaders
                            newReader o
                (Union info, children)
            elif FSharpType.IsTuple(ot) then
                let children =
                    match List.tryFind (fst >> ot.Equals) tupleReaders with
                        | Some (_, reader) ->
                            reader o
                        | None ->
                            let newReader = FSharpValue.PreComputeTupleReader(ot)
                            tupleReaders <- (ot, newReader)::tupleReaders
                            newReader o
                (Tuple ot, children)
            elif FSharpType.IsRecord(ot) then
                let children =
                    match List.tryFind (fst >> ot.Equals) recordReaders with
                        | Some (_, reader) ->
                            reader o
                        | None ->
                            let newReader = FSharpValue.PreComputeRecordReader(ot)
                            recordReaders <- (ot, newReader)::recordReaders
                            newReader o
                (Record ot, children)
            else
                (Leaf, [||])

    let root = src |> box |> ref
    let mutable nodes = [root]
    let mutable completedNodes = []
    while not (List.isEmpty nodes) do
        let node = List.head nodes
        nodes <- List.tail nodes
        let o = !node
        let o' = if o = null then
                     o
                 else
                     let ot = o.GetType()
                     if ft = ot || ot.IsSubclassOf(ft) then
                         f (o :?> 'a) |> box
                     elif gt = ot || ot.IsSubclassOf(gt) then
                         g (o :?> 'b) |> box
                     elif ht = ot || ot.IsSubclassOf(ht) then
                         h (o :?> 'c) |> box
                     elif it = ot || ot.IsSubclassOf(it) then
                         i (o :?> 'd) |> box
                     elif jt = ot || ot.IsSubclassOf(jt) then
                         j (o :?> 'e) |> box
                     else
                         o
        node := o'
        let (structure, children) = getStructureInfo o'
        let childrenContainers = children |> Array.map ref
        completedNodes <- (node, structure, childrenContainers)::completedNodes
        nodes <- List.append (List.ofArray childrenContainers) nodes

    completedNodes |> List.iter
        (fun (oContainer, structureInfo, childrenContainers) ->
            let children = Array.map (!) childrenContainers
            match structureInfo with
                | Union info ->
                    oContainer := FSharpValue.MakeUnion(info, children)
                | Tuple ot ->
                    oContainer := FSharpValue.MakeTuple(children, ot)
                | Record ot ->
                    oContainer := FSharpValue.MakeRecord(ot, children)
                | Leaf -> ())
    (unbox !root) : 'z
like image 23
blink Avatar answered Sep 28 '22 02:09

blink