Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Scala: reflection and case classes

The following code succeeds, but is there a better way of doing the same thing? Perhaps something specific to case classes? In the following code, for each field of type String in my simple case class, the code goes through my list of instances of that case class and finds the length of the longest string of that field.

case class CrmContractorRow(
                             id: Long,
                             bankCharges: String,
                             overTime: String,
                             name$id: Long,
                             mgmtFee: String,
                             contractDetails$id: Long,
                             email: String,
                             copyOfVisa: String)

object Go {
  def main(args: Array[String]) {
    val a = CrmContractorRow(1,"1","1",4444,"1",1,"1","1")
    val b = CrmContractorRow(22,"22","22",22,"55555",22,"nine long","22")
    val c = CrmContractorRow(333,"333","333",333,"333",333,"333","333")
    val rows = List(a,b,c)

    c.getClass.getDeclaredFields.filter(p => p.getType == classOf[String]).foreach{f =>
      f.setAccessible(true)
      println(f.getName + ": " + rows.map(row => f.get(row).asInstanceOf[String]).maxBy(_.length))
    }
  }
}

Result:

bankCharges: 3
overTime: 3
mgmtFee: 5
email: 9
copyOfVisa: 3
like image 497
Anthony Holland Avatar asked Apr 06 '16 13:04

Anthony Holland


3 Answers

If you want to do this kind of thing with Shapeless, I'd strongly suggest defining a custom type class that handles the complicated part and allows you to keep that stuff separate from the rest of your logic.

In this case it sounds like the tricky part of what you're specifically trying to do is getting the mapping from field names to string lengths for all of the String members of a case class. Here's a type class that does that:

import shapeless._, shapeless.labelled.FieldType

trait StringFieldLengths[A] { def apply(a: A): Map[String, Int] }

object StringFieldLengths extends LowPriorityStringFieldLengths {
  implicit val hnilInstance: StringFieldLengths[HNil] =
    new StringFieldLengths[HNil] {
      def apply(a: HNil): Map[String, Int] = Map.empty
    }

  implicit def caseClassInstance[A, R <: HList](implicit
    gen: LabelledGeneric.Aux[A, R],
    sfl: StringFieldLengths[R]
  ): StringFieldLengths[A] = new StringFieldLengths[A] {
    def apply(a: A): Map[String, Int] = sfl(gen.to(a))
  }

  implicit def hconsStringInstance[K <: Symbol, T <: HList](implicit
    sfl: StringFieldLengths[T],
    key: Witness.Aux[K]
  ): StringFieldLengths[FieldType[K, String] :: T] =
    new StringFieldLengths[FieldType[K, String] :: T] {
      def apply(a: FieldType[K, String] :: T): Map[String, Int] =
        sfl(a.tail).updated(key.value.name, a.head.length)
    }
}

sealed class LowPriorityStringFieldLengths {
  implicit def hconsInstance[K, V, T <: HList](implicit
    sfl: StringFieldLengths[T]
  ): StringFieldLengths[FieldType[K, V] :: T] =
    new StringFieldLengths[FieldType[K, V] :: T] {
      def apply(a: FieldType[K, V] :: T): Map[String, Int] = sfl(a.tail)
    }
}

This looks complex, but once you start working with Shapeless a bit you learn to write this kind of thing in your sleep.

Now you can write the logic of your operation in a relatively straightforward way:

def maxStringLengths[A: StringFieldLengths](as: List[A]): Map[String, Int] =
  as.map(implicitly[StringFieldLengths[A]].apply).foldLeft(
    Map.empty[String, Int]
  ) {
    case (x, y) => x.foldLeft(y) {
      case (acc, (k, v)) =>
        acc.updated(k, acc.get(k).fold(v)(accV => math.max(accV, v)))
    }
  }

And then (given rows as defined in the question):

scala> maxStringLengths(rows).foreach(println)
(bankCharges,3)
(overTime,3)
(mgmtFee,5)
(email,9)
(copyOfVisa,3)

This will work for absolutely any case class.

If this is a one-off thing, you might as well use runtime reflection, or you could use the Poly1 approach in Giovanni Caporaletti's answer—it's less generic and it mixes up the different parts of the solution in a way I don't prefer, but it should work just fine. If this is something you're doing a lot of, though, I'd suggest the approach I've given here.

like image 192
Travis Brown Avatar answered Oct 20 '22 03:10

Travis Brown


If you want to use shapeless to get the string fields of a case class and avoid reflection you can do something like this:

import shapeless._
import labelled._

trait lowerPriorityfilterStrings extends Poly2 {
  implicit def default[A] = at[Vector[(String, String)], A] { case (acc, _) => acc }
}

object filterStrings extends lowerPriorityfilterStrings {
  implicit def caseString[K <: Symbol](implicit w: Witness.Aux[K]) = at[Vector[(String, String)], FieldType[K, String]] {
    case (acc, x) =>  acc :+ (w.value.name -> x)
  }
}

val gen = LabelledGeneric[CrmContractorRow]


val a = CrmContractorRow(1,"1","1",4444,"1",1,"1","1")
val b = CrmContractorRow(22,"22","22",22,"55555",22,"nine long","22")
val c = CrmContractorRow(333,"333","333",333,"333",333,"333","333")
val rows = List(a,b,c)

val result = rows
  // get for each element a Vector of (fieldName -> stringField) pairs for the string fields
  .map(r => gen.to(r).foldLeft(Vector[(String, String)]())(filterStrings))
  // get the maximum for each "column"
  .reduceLeft((best, row) => best.zip(row).map {
    case (kv1@(_, v1), (_, v2)) if v1.length > v2.length => kv1
    case (_, kv2) => kv2
  })

result foreach { case (k, v) => println(s"$k: $v") }
like image 29
Giovanni Caporaletti Avatar answered Oct 20 '22 02:10

Giovanni Caporaletti


You probably want to use Scala reflection:

import scala.reflect.runtime.universe._

val rm = runtimeMirror(getClass.getClassLoader)
val instanceMirrors = rows map rm.reflect
typeOf[CrmContractorRow].members collect {
  case m: MethodSymbol if m.isCaseAccessor && m.returnType =:= typeOf[String] =>
    val maxValue = instanceMirrors map (_.reflectField(m).get.asInstanceOf[String]) maxBy (_.length)
    println(s"${m.name}: $maxValue")
}

So that you can avoid issues with cases like:

case class CrmContractorRow(id: Long, bankCharges: String, overTime: String, name$id: Long, mgmtFee: String, contractDetails$id: Long, email: String, copyOfVisa: String) {
  val unwantedVal = "jdjd"
}

Cheers

like image 42
Joan Avatar answered Oct 20 '22 03:10

Joan