funkcionálně.cz

Přední český blog o funkcionálním programování, kde se o funkcionálním programování nepíše
««« »»»

Jaccardovo tajemství - jak počítat podobnost množin pomalu, jak ji počítat rychle a jak při výpočtu podvádět

8. 12. 2015 — k47

Jac­car­dův index po­dob­nosti je jed­no­du­chá funkce, která udává míru po­dob­nosti mezi dvěma mno­ži­nami. Je de­fi­no­vána jako ve­li­kost prů­niku vy­dě­lená ve­li­kostí sjed­no­cení dvou množin.

J(A, B) = |A ∩ B| / |A ∪ B|

Funkce je to jed­no­du­chá. Otázka je, jak ji im­ple­men­to­vat, aby běžela rychle. V ná­sle­du­jí­cích od­stav­cích se vydám na cestu za nej­větší efek­ti­vi­tou za každou cenu. A když říkám za každou cenu, myslím tím, že se sku­tečně před ničím ne­za­sta­vím.

Když půjde všechno podle plánu, cestou se možná do­staví jeden nebo dva mo­menty osví­cení.


Naivní im­ple­men­tace ve Scale by mohla vy­pa­dat nějak takhle:

def jacc(a: Set[Int], b: Set[Int]): Double = {
  val is = a.intersection(b).size
  val us = a.union(b).size
  if (us == 0) 0 else is.toDouble / us
}

Jed­no­du­chý kód, straš­livý výkon. Pro­blém spo­čívá v tom, že je třeba vy­tvo­řit dvě mno­žiny jen proto, abych zjis­tit jejich ve­li­kost. Ve­li­kost sjed­no­cení není třeba vůbec po­čí­tat, pro­tože se dá jed­no­duše od­vo­dit z prin­cipu in­kluze a ex­kluze.

def jacc(a: Set[Int], b: Set[Int]): Double = {
  val is = a.intersection(b).size
  val us = a.size + b.size - is
  if (us == 0) 0 else is.toDouble / us
}

To je lepší, ale pořád je třeba vy­tvo­řit jednu mno­žinu se všemi alo­ka­cemi a in­ter­ními re­ži­emi, které to obnáší. Lo­gic­kým krokem je nic ne­a­lo­ko­vat a v jedné ite­raci spo­čí­tat ve­li­kost prů­niku.

def jacc(a: Set[Int], b: Set[Int]): Double = {
  val small = (if (a.size < b.size) a else b
  val big   = (if (b.size < a.size) a else b
  val is = small.count { el => big.contains(el) }
  val us = a.size + b.size - is
  if (us == 0) 0 else is.toDouble / us
}

To je lepší, ale zda­leka ne ide­ální. Pro­blém může před­sta­vo­vat uspo­řá­dání dat a layout paměti. V pří­padě JVM má ge­ne­rický Ha­sh­Set celkem velkou režii a mi­zer­nou lo­ka­litu3. Set[Int] ne­u­cho­vává pri­mi­tivní čtyř­baj­tové inty, ale re­fe­rence na bo­xo­vané Integer ob­jekty. Kom­bi­nace ref+box může za­bí­rat klidně 32 bajtů na 64-bi­to­vém sys­tému a musí udělat jednu de­re­fe­renci poin­teru.

Tomu se dá vy­hnout po­u­ží­vá­ním jazyka/run­time, který dělá spe­ci­a­li­zaci typů (re­i­fi­ko­vaná ge­ne­ri­kav C# nebo C++ ša­b­lony) nebo ko­lek­cemi spe­ci­a­li­zo­va­nými pro pri­mi­tivní typy. Na JVM je k dis­po­zici ně­ko­lik ta­ko­vých kniho­ven a jedna z nej­lep­ších je Ko­lo­boke.

import net.openhft.koloboke.collect.set.hash.HashIntSet

def jacc(a: HashIntSet, b: HashIntSet): Double = {
  val small = (if (a.size < b.size) a else b
  val big   = (if (b.size < a.size) a else b

  var is = 0
  val cur = small.cursor
  while (cur.moveNext()) {
    if (big.contains(cur.elem)) {
      is += 1
    }
  }

  val us = a.size + b.size - is
  if (us == 0) 0 else is.toDouble / us
}

Kód je o něco delší, ale na druhou stranu může být mnohem rych­lejší. Data jsou ulo­žena v plo­chých polích a ne­po­tře­bují na­há­nět poin­tery.

Všechny před­chozí změny před­sta­vo­valy pokrok v mezích zákona, po­zvolné zlep­šo­vání jed­noho řešení. Nešlo však o žádné ra­di­kální skoky vpřed. Těch můžu do­sáh­nout jedině, když to vezmu z dru­hého konce a začnu pře­mýš­let o tom, co je sku­tečně po­třeba. V tomto pří­padě mě zajímá jen Jac­car­dova po­dob­nost, nic jiného. Všechny re­pre­zen­tace množin, které jsem doteď po­u­ží­val byly za­lo­ženy na hash ta­bul­kách a na­bí­zely tedy mnoho jiné funk­ci­o­na­lity. Uměly na­pří­klad v kon­stant­ním čase zjis­tit, zdali se daný ele­ment na­chází v mno­žině. Já však po­tře­buji jen rychlý vý­po­čet ve­li­kosti prů­niku. Když budu re­pre­zen­to­vat mno­žinu se­řa­ze­ným polem, dá se právě tahle ve­li­čina spo­čí­tat velice efek­tivně.

def jacc(a: Array[Int], b: Array[Int]): Double = {
  def intersectionSize(a, b) = {
    var ai, bi, size = 0
    while (ai < a.length && bi < b.length) {
      val av = a(ai)
      val bv = b(bi)
      size += (if (av == bv) 1 else 0)
      ai   += (if (av <= bv) 1 else 0)
      bi   += (if (av >= bv) 1 else 0)
    }
    size
  }

  val is = intersectionSize(a, b)
  val us = a.length + b.length - is
  if (us == 0) 0 else is.toDouble / us
}

Stačí velice jed­no­du­chá smyčka, která udělá li­ne­ární prů­chod oběma poli a ne­po­tře­buje dělat žádné hle­dání v in­ter­ních hash ta­bul­kách. Tři řádky ve tvaru x += (if (cond) 1 else 0) kom­pi­lá­toru/JITu silně na­zna­čují, aby místo pod­mí­ně­ných skoků použil cmov in­strukce1. To od­straní po­ten­ci­ální ne­před­ví­da­telný skok v těle smyčky, který by všechno mohl vý­razně zpo­ma­lit.

Tento kód je nejen velice rychlý, ale také pře­kva­pivě jed­no­du­chý. Fun­guje tak, že hledá shodné prvky v poli. Když je najde, in­kre­men­tuje pro­měn­nou size a i oba indexy. V ostat­ních pří­pa­dech in­kre­men­tuje index, který uka­zuje na menší prvek. Jde o al­go­rit­mus velice po­dobný merge sortu.

Tělo smyčky ob­sa­huje pou­hých ±20 in­strukcí a ne­pů­jde zrych­lit re­dukcí počtu ope­rací v jedné ite­raci, ale zmen­še­ním počtu ite­rací.

Jedním způ­so­bem jak toho do­sáh­nout, je dívat se n míst do­předu4 s tím, že když najdu ele­ment menší než ten hle­daný, můžu pře­sko­čit n míst a tím pádem i n ite­rací. Tělo smyčky se trochu zkom­pli­kuje, vý­sle­dek však často je o pár de­sí­tek pro­cent rych­lejší a jen má­lo­kdy dojde ke zpo­ma­lení. Pokud je skok malý, kód pře­ska­kuje jen malé úseky a ne­u­šetří příliš ite­rací. Na druhou stranu, když je skok velký, kód nemůže nic pře­sko­čit a ite­ruje jako nor­mální verze. Je po­třeba najít nějaké při­ja­telné op­ti­mum.

  while (ai < alen && bi < blen) {
    val av = a(ai)
    val bv = b(bi)
    val _ai = ai
    val _bi = bi
    size += (if (av == bv) 1 else 0)
    ai   += (if (av <= bv) (if (a(_ai+skip) < bv) skip else 1) else 0)
    bi   += (if (av >= bv) (if (b(_bi+skip) < av) skip else 1) else 0)
  }

Pokud bych chtěl jít ještě dál, mohl bych pole předzpra­co­vat a za každý ele­ment vložit před­po­čí­ta­nou bitmapu ob­sa­hu­jící osm čtyř­bi­to­vých čísel, které udá­vají, jak daleko může jeden index po­sko­čit v zá­vis­losti na roz­dílu po­rov­ná­va­ných hodnot. Za zmínku stojí, že kód ob­sa­huje jen rychlé bitové ope­race a ne­po­tře­buje žádný pod­mí­něný extra skok. Je třeba jen něco přes 10 in­strukcí navíc.

def intersectionSizeWithEmbeddedSkiplists(a: Array[Int], b: Array[Int]): Int = {
  var size, ai, bi = 0
  while (ai != a.length && bi != b.length) {
    val av = a(ai)
    val bv = b(bi)

    val s  = (if (av < bv) av else bv)
    val si = (if (av < bv) ai else bi)

    val d = java.lang.Math.abs(av - bv)
    val bits = 32 - Long.numberOfLeadingZeros(d)
    val slot = bits / 4

    val slotval = (s(si+1) >>> (slot * 4)) & ((1 << 4) - 1)
    val skip = slotval << (slot - 1)

    size += (if (av == bv) 1 else 0)
    ai   += (if (av <= bv) skip else 0)
    bi   += (if (av >= bv) skip else 0)
  }
  size
}

O kolik nebo jestli vůbec to zrychlí vý­sle­dek jsem ale ne­tes­to­val, pro­tože mi došla kuráž, když jsem začal pře­mýš­let jak napsat funkci, která vy­po­čítá možné skoky.

Ale to stále není všechno. Na za­čátku jsem psal za každou cenu a pořád to myslím vážně.

Když zajdu do ex­trému, je možné im­ple­men­to­vat Jac­carda pomocí AVX2 SIMD in­strukcí2, které pro­hle­dá­vají vektor osmi hodnot pa­ra­lelně. Jak ta­ko­váto hrůza vypadá se můžete pře­svěd­čit na vlastní oči v tomto gistu. Vek­to­ri­zo­vané řešení je v nej­hor­ším pří­padě, kdy nikdy není možné pře­sko­čit ně­ko­lik ite­rací, 2x po­ma­lejší než pří­mo­čará im­ple­men­tace (pro­tože po­tře­buje vy­ko­nat víc in­strukcí a každá ite­race dělá víc práce), ale v nej­lep­ším pří­padě, kdy může často pře­ska­ko­vat velký kus vstup­ního pole, až 4x rych­lejší.

Pro další zrych­lení je možné na za­čátku kon­t­ro­lo­vat jestli je za­čá­tek jed­noho pole větší než konec toho dru­hého. V ta­ko­vém pří­padě je jasné, že mno­žiny nemají žádný spo­lečný prvek a Jac­car­dova po­dob­nost bude vždycky 0. Ke stej­nému účelu se dá použít bitmapa (např. 64 bitů) fun­gu­jící jako ma­ličký Bloom filtr. Když udělám lo­gický and dvou bitmap a do­stanu 0 (tj. dvě mapy nemají žádné spo­lečné bity), je jasné, že dvě mno­žiny, ze kte­rých byly tyto bitmapy od­vo­zeny nemají žádné spo­lečné prvky a Jac­car­dova po­dob­nost bude opět nulová.

  if (b(0) < a(a.length-1) ||
      a(0) < b(b.length-1) ||
      (aBitmap & bBitmap) == 0) {
    return 0.0
  }

To ale stále není všechno. Když mě ne­za­jímá přesná Jac­car­dova po­dob­nost, ale vy­sta­čím si s od­ha­dem, můžu použít MinHash. Ten pro­du­kuje jen při­bližné vý­sledky5, ale může být vý­razně rych­lejší, pro­tože ne­po­čítá po­dob­nost mezi celými mno­ži­nami, ale jen jejich otisky, které mají fixní ve­li­kost.

MinHash a mnoho dal­ších skečů jsem im­ple­men­to­val v knihovně sket­ches. S ní se dá spo­čí­tat odhad po­dob­nosti velice jed­no­duše:

val mh = atrox.sketches.MinHash(sets, 128)
mh.estimateSimilarity(i, j)

Když ani tohle ne­stačí, pomůže už jen lo­ca­lity sensi­tive ha­shing (LSH) (viz Mining of Mas­sive Da­ta­sets, ka­pi­tola 3.3 a 3.4)6. LSH může vý­po­čet po­dob­nosti vý­razně zrych­lit, pro­tože ome­zuje hle­dání jen na kan­di­dáty, které jsou s velkou prav­dě­po­dob­ností po­dobné a zcela pře­skočí ty, které jsou (opět s velkou prav­dě­po­dob­ností) ne­po­dobné.

A knihovna sket­ches také ob­sa­huje im­ple­men­taci LSH.

val mh = atrox.sketches.MinHash(sets, 128)
val lsh = atrox.sketches.LSH(mh, bands = 32)
lsh.similarItems(i, similarityThreshold)

S vhodně na­sta­ve­nou LSH je možné úlohu, která by trvala 24 hodin hrubou silou, spo­čí­tat za 24 vteřin s celkem ro­zum­nou ztrá­tou přes­nosti.

Myslím, že rych­leji než tohle už to není možné.


Dále k tématu:


  1. Tech­nicky vzato kom­pi­lá­tor může tento řádek pře­lo­žit na tro­jici in­strukcí cmp, setXadd a ani ne­po­tře­buje cmov.
  2. Síla SIMD ope­rací je vidět v Někdy je nej­chytřejší ne­dě­lat nic chyt­rého
  3. V pří­padě ne­měn­ných ko­lekcí je to ještě horší, pro­tože jsou in­terně im­ple­men­to­vané jako hash array mapped trie a to s sebou při­náší další úrovně poin­terů a re­fe­rencí.
  4. viz In­tro­duction to in­for­mation re­trie­val, ka­pi­tola 2.3: Faster po­sting list in­ter­section via skip poin­ters
  5. Někoho tady může na­pad­nout, že když mi stačí odhady, můžu použít Hy­per­Lo­gLog k od­had­nutí ve­li­kosti sjed­no­cení a pak do­po­čí­tat průnik, před prin­cip in­kluze a ex­kluze. To fun­guje, ale není to příliš přesné, pro­tože chyba je re­la­tivní vzhle­dem k ve­li­kosti sjed­no­cení a nikoli k prů­niku, který může být mnohem menší.
  6. Když mluvím o LSH, měl bych se také zmínit, že exis­tuje al­ter­na­tivní pří­stup hle­dání nej­bliž­ších sou­sedů, který není po­sta­vený na ha­sho­vání, ale na bi­nár­ních stro­mech. Při­bližně to od­po­vídá roz­dílu mezi hash ta­bul­kami a bi­nár­ními vy­hle­dá­va­cími stromy – ha­sho­vání nabízí O(1) hle­dání, stromy hle­dají v čase O(log n), ale jejich obsah je se­řa­zený. To v pří­padě hle­dání nej­bliž­ších sou­sedů zna­mená, že si můžu říct o body, které jsou trochu dál, což LSH ne­do­káže.
@kaja47, kaja47@k47.cz, deadbeef.k47.cz, starší články