Jaccardovo tajemství - jak počítat podobnost množin pomalu, jak ji počítat rychle a jak při výpočtu podvádět
Jaccardův index podobnosti je jednoduchá funkce, která udává míru podobnosti mezi dvěma množinami. Je definována jako velikost průniku vydělená velikostí sjednocení dvou množin.
J(A, B) = |A ∩ B| / |A ∪ B|
Funkce je to jednoduchá. Otázka je, jak ji implementovat, aby běžela rychle. V následujících odstavcích se vydám na cestu za největší efektivitou za každou cenu. A když říkám za každou cenu, myslím tím, že se skutečně před ničím nezastavím.
Když půjde všechno podle plánu, cestou se možná dostaví jeden nebo dva momenty osvícení.
Naivní implementace ve Scale by mohla vypadat nějak takhle:
def jacc(a: Set[Int], b: Set[Int]): Double = { val is = a.intersection(b).size val us = a.union(b).size if (us == 0) 0 else is.toDouble / us }
Jednoduchý kód, strašlivý výkon. Problém spočívá v tom, že je třeba vytvořit dvě množiny jen proto, abych zjistit jejich velikost. Velikost sjednocení není třeba vůbec počítat, protože se dá jednoduše odvodit z principu inkluze a exkluze.
def jacc(a: Set[Int], b: Set[Int]): Double = { val is = a.intersection(b).size val us = a.size + b.size - is if (us == 0) 0 else is.toDouble / us }
To je lepší, ale pořád je třeba vytvořit jednu množinu se všemi alokacemi a interními režiemi, které to obnáší. Logickým krokem je nic nealokovat a v jedné iteraci spočítat velikost průniku.
def jacc(a: Set[Int], b: Set[Int]): Double = { val small = (if (a.size < b.size) a else b) val big = (if (b.size < a.size) a else b) val is = small.count { el => big.contains(el) } val us = a.size + b.size - is if (us == 0) 0 else is.toDouble / us }
To je lepší, ale zdaleka ne ideální. Problém může představovat uspořádání dat a
layout paměti. V případě JVM má generický HashSet
celkem velkou režii a mizernou lokalitu3 . Set[Int]
neuchovává primitivní čtyřbajtové inty, ale reference na boxované Integer
objekty.
Kombinace ref+box může zabírat klidně 32 bajtů na 64-bitovém systému a musí udělat jednu
dereferenci pointeru.
Tomu se dá vyhnout používáním jazyka/runtime, který dělá specializaci typů (reifikovaná generikav C# nebo C++ šablony) nebo kolekcemi specializovanými pro primitivní typy. Na JVM je k dispozici několik takových knihoven a jedna z nejlepších je Koloboke.
import net.openhft.koloboke.collect.set.hash.HashIntSet def jacc(a: HashIntSet, b: HashIntSet): Double = { val small = (if (a.size < b.size) a else b) val big = (if (b.size < a.size) a else b) var is = 0 val cur = small.cursor while (cur.moveNext()) { if (big.contains(cur.elem)) { is += 1 } } val us = a.size + b.size - is if (us == 0) 0 else is.toDouble / us }
Kód je o něco delší, ale na druhou stranu může být mnohem rychlejší. Data jsou uložena v plochých polích a nepotřebují nahánět pointery.
Všechny předchozí změny představovaly pokrok v mezích zákona, pozvolné zlepšování jednoho řešení. Nešlo však o žádné radikální skoky vpřed. Těch můžu dosáhnout jedině, když to vezmu z druhého konce a začnu přemýšlet o tom, co je skutečně potřeba. V tomto případě mě zajímá jen Jaccardova podobnost, nic jiného. Všechny reprezentace množin, které jsem doteď používal byly založeny na hash tabulkách a nabízely tedy mnoho jiné funkcionality. Uměly například v konstantním čase zjistit, zdali se daný element nachází v množině. Já však potřebuji jen rychlý výpočet velikosti průniku. Když budu reprezentovat množinu seřazeným polem, dá se právě tahle veličina spočítat velice efektivně.
def jacc(a: Array[Int], b: Array[Int]): Double = { def intersectionSize(a, b) = { var ai, bi, size = 0 while (ai < a.length && bi < b.length) { val av = a(ai) val bv = b(bi) size += (if (av == bv) 1 else 0) ai += (if (av <= bv) 1 else 0) bi += (if (av >= bv) 1 else 0) } size } val is = intersectionSize(a, b) val us = a.length + b.length - is if (us == 0) 0 else is.toDouble / us }
Stačí velice jednoduchá smyčka, která udělá lineární průchod oběma poli a
nepotřebuje dělat žádné hledání v interních hash tabulkách. Tři
řádky ve tvaru x += (if (cond) 1 else 0)
kompilátoru/JITu silně naznačují,
aby místo podmíněných skoků použil cmov
instrukce1 . To odstraní
potenciální nepředvídatelný skok v těle smyčky, který by všechno mohl
výrazně zpomalit.
Tento kód je nejen velice rychlý, ale také překvapivě jednoduchý. Funguje tak,
že hledá shodné prvky v poli. Když je najde, inkrementuje proměnnou size
a i
oba indexy. V ostatních případech inkrementuje index, který ukazuje na menší
prvek. Jde o algoritmus velice podobný merge sortu.
Tělo smyčky obsahuje pouhých ±20 instrukcí a nepůjde zrychlit redukcí počtu operací v jedné iteraci, ale zmenšením počtu iterací.
Jedním způsobem jak toho dosáhnout, je dívat se n
míst dopředu4 s
tím, že když najdu element menší než ten hledaný, můžu přeskočit n
míst a tím
pádem i n
iterací. Tělo smyčky se trochu zkomplikuje, výsledek však často je
o pár desítek procent rychlejší a jen málokdy dojde ke zpomalení. Pokud je skok
malý, kód přeskakuje jen malé úseky a neušetří příliš iterací. Na druhou
stranu, když je skok velký, kód nemůže nic přeskočit a iteruje jako normální
verze. Je potřeba najít nějaké přijatelné optimum.
while (ai < alen && bi < blen) { val av = a(ai) val bv = b(bi) val _ai = ai val _bi = bi size += (if (av == bv) 1 else 0) ai += (if (av <= bv) (if (a(_ai+skip) < bv) skip else 1) else 0) bi += (if (av >= bv) (if (b(_bi+skip) < av) skip else 1) else 0) }
Pokud bych chtěl jít ještě dál, mohl bych pole předzpracovat a za každý element vložit předpočítanou bitmapu obsahující osm čtyřbitových čísel, které udávají, jak daleko může jeden index poskočit v závislosti na rozdílu porovnávaných hodnot. Za zmínku stojí, že kód obsahuje jen rychlé bitové operace a nepotřebuje žádný podmíněný extra skok. Je třeba jen něco přes 10 instrukcí navíc.
def intersectionSizeWithEmbeddedSkiplists(a: Array[Int], b: Array[Int]): Int = { var size, ai, bi = 0 while (ai != a.length && bi != b.length) { val av = a(ai) val bv = b(bi) val s = (if (av < bv) av else bv) val si = (if (av < bv) ai else bi) val d = java.lang.Math.abs(av - bv) val bits = 32 - Long.numberOfLeadingZeros(d) val slot = bits / 4 val slotval = (s(si+1) >>> (slot * 4)) & ((1 << 4) - 1) val skip = slotval << (slot - 1) size += (if (av == bv) 1 else 0) ai += (if (av <= bv) skip else 0) bi += (if (av >= bv) skip else 0) } size }
O kolik nebo jestli vůbec to zrychlí výsledek jsem ale netestoval, protože mi došla kuráž, když jsem začal přemýšlet jak napsat funkci, která vypočítá možné skoky.
Ale to stále není všechno. Na začátku jsem psal za každou cenu a pořád to myslím vážně.
Když zajdu do extrému, je možné implementovat Jaccarda pomocí AVX2 SIMD instrukcí2 , které prohledávají vektor osmi hodnot paralelně. Jak takováto hrůza vypadá se můžete přesvědčit na vlastní oči v tomto gistu. Vektorizované řešení je v nejhorším případě, kdy nikdy není možné přeskočit několik iterací, 2x pomalejší než přímočará implementace (protože potřebuje vykonat víc instrukcí a každá iterace dělá víc práce), ale v nejlepším případě, kdy může často přeskakovat velký kus vstupního pole, až 4x rychlejší.
Pro další zrychlení je možné na začátku kontrolovat jestli je začátek jednoho pole větší
než konec toho druhého. V takovém případě je jasné, že množiny nemají žádný
společný prvek a Jaccardova podobnost bude vždycky 0. Ke stejnému účelu se dá použít
bitmapa (např. 64 bitů) fungující jako maličký Bloom filtr. Když
udělám logický and
dvou bitmap a dostanu 0 (tj. dvě mapy nemají žádné
společné bity), je jasné, že dvě množiny, ze kterých byly tyto bitmapy odvozeny
nemají žádné společné prvky a Jaccardova podobnost bude opět nulová.
if (b(0) < a(a.length-1) || a(0) < b(b.length-1) || (aBitmap & bBitmap) == 0) { return 0.0 }
To ale stále není všechno. Když mě nezajímá přesná Jaccardova podobnost, ale vystačím si s odhadem, můžu použít MinHash. Ten produkuje jen přibližné výsledky5 , ale může být výrazně rychlejší, protože nepočítá podobnost mezi celými množinami, ale jen jejich otisky fixní velikosti.
MinHash a mnoho dalších skečů jsem implementoval v knihovně sketches. S ní se dá spočítat odhad podobnosti velice jednoduše:
val mh = atrox.sketches.MinHash(sets, 128) mh.estimateSimilarity(i, j)
Když ani tohle nestačí, pomůže už jen locality sensitive hashing (LSH) (viz Mining of Massive Datasets, kapitola 3.3 a 3.4)6 . LSH může výpočet podobnosti výrazně zrychlit, protože omezuje hledání jen na kandidáty, které jsou s velkou pravděpodobností podobné a zcela přeskočí ty, které jsou (opět s velkou pravděpodobností) nepodobné.
A knihovna sketches také obsahuje implementaci LSH.
val mh = atrox.sketches.MinHash(sets, 128) val lsh = atrox.sketches.LSH(mh, bands = 32) lsh.similarItems(i, similarityThreshold)
S vhodně nastavenou LSH je možné úlohu, která by trvala 24 hodin hrubou silou, spočítat za 24 vteřin s celkem rozumnou ztrátou přesnosti.
Myslím, že rychleji než tohle už to není možné.
Dále k tématu:
- Od pohledu dobrý, aneb jak najít skoro stejné obrázky mezi dvěma miliony souborů za méně než deset minut
- Mining of Massive Datasets, kapitola 3.6: Methods for High Degrees of Similarity
- Better bitmap performance with Roaring bitmaps
- Efficient Set Intersection for Inverted Indexing
- Fast Sorted-Set Intersection using SIMD Instructions
- Technicky vzato kompilátor může tento řádek přeložit na trojici instrukcí
cmp
,setX
aadd
a ani nepotřebujecmov
. - Síla SIMD operací je vidět v Někdy je nejchytřejší nedělat nic chytrého
- V případě neměnných kolekcí je to ještě horší, protože jsou interně implementované jako hash array mapped trie a to s sebou přináší další úrovně pointerů a referencí.
- viz Introduction to information retrieval, kapitola 2.3: Faster posting list intersection via skip pointers
- Někoho tady může napadnout, že když mi stačí odhady, můžu použít HyperLogLog k odhadnutí velikosti sjednocení a pak dopočítat průnik, před princip inkluze a exkluze. To funguje, ale není to příliš přesné, protože chyba je relativní vzhledem k velikosti sjednocení a nikoli k průniku, který může být mnohem menší.
- Když mluvím o LSH, měl bych se také zmínit, že existuje alternativní přístup hledání nejbližších sousedů, který není postavený na hashování, ale na binárních stromech. Přibližně to odpovídá rozdílu mezi hash tabulkami a binárními vyhledávacími stromy - hashování nabízí O(1) hledání, stromy hledají v čase O(log n), ale jejich obsah je seřazený. To v případě hledání nejbližších sousedů znamená, že si můžu říct o body, které jsou trochu dál, což LSH nedokáže.