読者です 読者をやめる 読者になる 読者になる

Scala初心者がShapelessのHListでderivingをどうやるのかを学んだ話

ScalaにはHaskellのようなShowやEqをderivingしてくれる機能は無いんですか?と質問をしたところ

typelevel.scala | Deriving Type Class Instances (Part 1)を紹介してもらったのでものは試しとやってみることにしました。

本題に入る前に、このページ中のコードで使っているbuild.sbtを貼っておきます。

name := "deriving"

version := "1.0"

scalaVersion := "2.11.6"

resolvers ++= Seq(
  Resolver.sonatypeRepo("releases")
)

libraryDependencies ++= Seq(
  "com.chuusai" %% "shapeless" % "2.0.0"
)

HListについて

とつとつと進めていたのですが、 唐突にHListにbe familiar withじゃねえやつはこれを見て勉強しな!というお達しが出たので方針転換でこっちを先に進めることにしました。

Shapeless: Exploring Generic Programming in Scala - YouTube

oh, english video online gakusyuu...(´・_・`)

本題

HListというのはどうやら、TupleListのいいとこ取りのようなデータ構造のようです。

ここで

  1. Tupleは長さは固定だが、異なる型を入れられる
  2. Listは長さは可変だが、同じ型しか入らない

ということを踏まえると、これのいいとこ取りを考えると

  1. HListは長さは可変だし、異なる型を入れられる

ということになりそうです。

たぶんHeterogeneous Listの略だと思います。 Heterogeneousは、異なるとか異種のとかそういう意味だったと思うので、 たしかにヘテロジニアスなリストと言えそうです。

ここで、おもむろにScala worksheetを開きます。

そして

import shapeless._
import HList._

val l: Int :: String :: HNil = 1 :: "foo" :: HNil

とすると

l: shapeless.::[Int,shapeless.::[String,shapeless.HNil]] = 1 :: foo :: HNil

が定義できました!これはやばいですね。完全にやばい扉を開けている感じがします。

なお、中身はheadとかtailとかするといつものように取得することが出来ます。

appendの導出

ここからは主にtypelevel.scala | Deriving Type Class Instances (Part 1)の話に戻ります。

HListの話は一旦おいておいて、型クラス の話をすることにしましょう。

型クラスにはいろいろありますが、今回はSemigroup(半群)という型クラスを考えることにします。

Semigroup というのはどういうものかというとこういうものです。

trait Semigroup[S] {
  def append(s1: S, s2: S): S
}

なんやこれ append があるだけやないか!と思うかもしれませんが、 implicit parameter を使うと

// Intの足し算
implicit val intInstance = new Semigroup[Int] {
  def append(s1: Int, s2: Int) = s1 + s2
}

// Stringの足し算
implicit val stringInstance = new Semigroup[String] {
  def append(s1: String, s2: String) = s1 + s2
}

// なんでも足し算
def plus[A](a: A, b: A)(implicit semigroup: Semigroup[A]) = {
  semigroup.append(a, b)
}

plus(1, 2) // 3
plus("aaa", "bbb") // "aaabbb"

のようにplusを呼び出した際に勝手に引数に渡されて、append が統一的に使えるよ。といった効果があります。

うーん。わかるけどわからない。そう思いますか?僕もそう思います(´・_・`)

これならどうでしょうか。

implicit def tupleInstance[A, B](implicit A: Semigroup[A], B: Semigroup[B]) =
  new Semigroup[(A, B)] {
    def append(t1: (A, B), t2: (A, B)): (A, B) = (A.append(t1._1, t2._1), B.append(t1._2, t2._2))
  }

plus((1, "aaa"), (2, "bbb")) // (3, "aaabbb")

なんとSemigroupインスタンスを組み合わせてタプルのSemigroupインスタンスができちゃうんです。

どういうふうに動いているかというと、

  1. plustuple[Int, String]が渡されているのでそれに合うimplicit parameterコンパイラが探し始める
  2. tuple._1の型はIntなのでSemigroup[Int]であるintInstanceAに渡される
  3. tuple._2の型はStringなのでSemigroup[String]であるstringInstanceBに渡される

というような感じです。

ここまでできるとなると、たしかに色々出来そうな気がしてきましたね。

Semigroupは仮定していることが少ないので、このくらいかもしれませんが(Monoidとかだともっと色々できる)、 それでも十分に色々なことができそうです。

3次元ベクトルの足し算

ここからいよいよ本題です。

case class Vector3D(x: Int, y: Int, z: Int) {}

のようなclassがあったとします。

これに足し算を追加したいときはどうするかというと

case class Vector3D(x: Int, y: Int, z: Int) {
  def +(that: Vector3D): Vector3D =
    Vector3D(this.x + that.x, this.y + that.y, this.z + that.z)
}

Vector3D(1,2,3) + Vector3D(4,5,6) // Vector3D(5,7,9)

とか、

// 試し終わったら消してください
implicit val vector3DInstance = new Semigroup[Vector3D] {
  def append(v1: Vector3D , v2: Vector3D) = 
    Vector3D(v1.x + v2.x, v1.y + v2.y, v1.z + v2.z)
}
plus(Vector3D(1,2,3), Vector3D(4,5,6)) // Vector3D(5,7,9)

とかしますよね。

しかし Vector2D とか Vector5D とか出てきた時にこれをいちいち書くのはしんどそうです(´・_・`)

なんとか統一的にひょろろーと表現すると、てれれーと足し算ができるようにならないでしょうか。

ということで、先ほどのHListが登場します。

3次元ベクトルの足し算 with HList

どういうことかというと、

まず、HList上での足し算を定義してあげます。

implicit val nilInstance = new Semigroup[HNil] {
  def append(x: HNil, y: HNil) = HNil
}

implicit def consInstance[H, T <: HList](implicit H: Semigroup[H], T: Semigroup[T]) =
  new Semigroup[H :: T] {
    def append(x: H :: T, y: H :: T) = H.append(x.head, y.head) :: T.append(x.tail, y.tail)
  }

このようにすると

    val a = 1 :: 3 :: HNil
    val b = 2 :: 4 :: HNil

    println(plus(a, b)) // 3 :: 7 :: HNil

のような演算が可能になります。

そして、

def to(vec: Vector3D): Int :: Int :: Int :: HNil =
  vec.x :: vec.y :: vec.z :: HNil

def from(hlist: Int :: Int :: Int :: HNil): Vector3D =
 Vector3D(hlist.head, hlist.tail.head, hlist.tail.tail.head)

val c = to(Vector3D(1,2,3)) // 1 :: 2 :: 3 :: HNil
val d = from(c) // Vector3D(1, 2, 3)

のようなHListへの変換する関数と、HListから元に戻す関数を用意してあげます。

なんとなく見えてきましたね。

あとは、Vector3DtoHListに変換して、appendで足し算した後に、fromで元に戻してくれるような関数があればよさそうです。

def subst[A, B](to: A => B, from: B => A)(implicit instance: Semigroup[B]) = new Semigroup[A] {
  def append(a1: A, a2: A) =
    from(instance.append(to(a1), to(a2)))
}

implicit val vectorInstance: Semigroup[Vector3D] = subst(to, from)

val e = plus(Vector3D(1, 2, 3), Vector3D(1, 2, 3)) // Vector3D(2, 4, 6)

できました!!

まとめ

ということで、駆け足でしたが、fromtoHListの世界を行ったり来たりすることで、 いろいろな型クラスのインスタンスを具体的に定義することを避ける事ができそうです。

ぶっちゃけfromとかtoとかを書きたくないんですが、マクロを使うとSemigroup.derive[Vector3D]のように一気に書くことが出来るらしいです。 今回はここまではやりません。(というかまだ出来てません(´;ω;`))

最後に全部のコードを載せておきます。

import shapeless._
import HList._

object Test {

  case class Vector3D(x: Int, y: Int, z: Int) {}

  // Semigroupをimplicit valで渡す
  trait Semigroup[S] {
    def append(s1: S, s2: S): S
  }

  // appendする関数
  def plus[A](a: A, b: A)(implicit semigroup: Semigroup[A]) = {
    semigroup.append(a, b)
  }

  // HListとの変換
  def to(vec: Vector3D): Int :: Int :: Int :: HNil =
    vec.x :: vec.y :: vec.z :: HNil

  def from(hlist: Int :: Int :: Int :: HNil): Vector3D =
    Vector3D(hlist.head, hlist.tail.head, hlist.tail.tail.head)

  // IntのSemigroup
  implicit val intInstance = new Semigroup[Int] {
    def append(x: Int, y: Int) = x + y
  }

  // HListのSemigroup
  implicit val nilInstance = new Semigroup[HNil] {
    def append(x: HNil, y: HNil) = HNil
  }

  implicit def consInstance[H, T <: HList](implicit H: Semigroup[H], T: Semigroup[T]) =
    new Semigroup[H :: T] {
      def append(x: H :: T, y: H :: T) = H.append(x.head, y.head) :: T.append(x.tail, y.tail)
    }

  def subst[A, B](to: A => B, from: B => A)(implicit instance: Semigroup[B]) = new Semigroup[A] {
    def append(a1: A, a2: A) =
      from(instance.append(to(a1), to(a2)))
  }

  implicit val vectorInstance: Semigroup[Vector3D] = subst(to, from)

  def main(args: Array[String]) {
    val a = 1 :: 3 :: HNil
    val b = 2 :: 4 :: HNil

    println(plus(a, b))
    println(plus(Vector3D(1, 2, 3), Vector3D(1, 2, 3)))
  }
}

今後

Scalaは始めたばかりですが、このようなおもしろいライブラリがいっぱいあるのですね! HListの実装周りは初心者にはちょっときつそうなので、もう少し時間を書けて勉強したいなーと思っています。

あとマクロとかをやらないとfromとかtoを書かなければならなくてちゃんと問題が解決しないので、 そのへんもやりたいんですが、この資料パート1で終わってて先が・・・(´・_・`)