Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to compare deeply nested discriminated unions?

I want to unit test a function that returns a Result (see below).

My question is: How can I easily check if the result is numerically equal to the expected value?

Here's the version with exact matching.

type QuadraticResult =
    | ComplexResult of  Complex * Complex
    | DoubleResult of float
    | TwoResults of float * float


type Result=
    | QuadraticResult of QuadraticResult
    | LinearResult of LinearFormulaSolver.Result

/// Solves a x² + bx + c = 0
let Compute (a,b,c) : Result =



[<Fact>]
member test.``the solution for x² = 0.0 is a double 0.0`` ()=
    let result = Compute (1.0, 0.0, 0.0)
    let expected = Result.QuadraticResult (DoubleResult 0.0)

    // only exact match, I'd like to test if difference is below a certain threshold
    Assert.Equal (result, expected)

Here's the solution I use so far. It's based on Andreys solution but extended for the allowed distance, permutations of results and the linear case. :

let ComplexEquality distance (x : Complex) (y : Complex )= 
        let dx = x.Real - y.Real
        let dy = x.Imaginary - y.Imaginary
        abs (dx) < distance && abs(dy) < distance


let QuadraticEquality distance x y = match (x,y) with
                        | (ComplexResult (a,b),ComplexResult(c,d)) -> (ComplexEquality distance  a c && ComplexEquality distance b d) || (ComplexEquality distance  a d && ComplexEquality distance b c)
                        | (DoubleResult a,DoubleResult b) -> abs (a - b) < distance
                        | (TwoResults (a,b),TwoResults(c,d)) -> (abs(a - c) < distance && (b - d) < distance) || (abs(a - d) < distance && (b - c) < distance)
                        | _ -> false

let LinearEquality distance x y = match (x , y) with
                        | (SingleResult a, SingleResult b) -> abs (a-b) < distance
                        | (NoResults, NoResults) | (InfiniteResults, InfiniteResults) -> true
                        | _ -> false


let ResultEquality distance x y = match (x,y) with
                        | (QuadraticResult a,QuadraticResult b) -> QuadraticEquality distance a b
                        | (LinearResult a,LinearResult b) -> LinearEquality distance a b
                        | _ -> false

[<Fact>]
member test.``the solution for x² = 0 is a double 0`` ()=
    let result = QuadraticFormulaSolver.Compute (1.0, 0.0, 0.0)
    let expected = Result.QuadraticResult (QuadraticFormulaSolver.DoubleResult 0.00001)

    Assert.True( ResultEquality 0.001 result expected)
like image 220
Onur Avatar asked Mar 11 '23 16:03

Onur


1 Answers

I don't think there is any "magic trick" that wold let you do this automatically. I think you have three options:

  1. Write custom function to do equality test that works over your existing type and performs special kind of comparison for all the nested float values

  2. Write a wrapper over float that implements custom comparison and then use this type inside the discriminated unions

  3. Write some reflection-based magic to perform custom equality testing.

Out of these, I think (1) is probably the easiest option - even though it means some more typing. The option (2) might be interesting if you wanted to use this custom comparison everywhere in your program. Finally (3) might make sense if you had lots of various nested types, but it is also the most error-prone option.

I wrote a minimal demo of (2), but I still think (1) is probably better approach:

[<Struct; CustomComparison; CustomEquality>] 
type ApproxFloat(f:float) = 
  member x.Value = f
  override x.GetHashCode() = f.GetHashCode()
  override x.Equals(another) =
    match another with
    | :? ApproxFloat as y -> abs (x.Value - y.Value) <= 0.001
    | _ -> false
  interface System.IComparable with
    member x.CompareTo(another) = 
      match another with
      | :? ApproxFloat as y -> compare x.Value y.Value
      | _ -> failwith "Cannot compare"

type Complex = 
  | Complex of ApproxFloat * ApproxFloat

type Result = 
  | Result of Complex

Result(Complex(ApproxFloat(1.0), ApproxFloat(1.0))) =
  Result(Complex(ApproxFloat(1.0001), ApproxFloat(1.0001))) 
like image 130
Tomas Petricek Avatar answered Mar 28 '23 21:03

Tomas Petricek