@@ -85,6 +85,12 @@ Embedding &Embedding::operator-=(const Embedding &RHS) {
85
85
return *this ;
86
86
}
87
87
88
+ Embedding &Embedding::operator *=(double Factor) {
89
+ std::transform (this ->begin (), this ->end (), this ->begin (),
90
+ [Factor](double Elem) { return Elem * Factor; });
91
+ return *this ;
92
+ }
93
+
88
94
Embedding &Embedding::scaleAndAdd (const Embedding &Src, float Factor) {
89
95
assert (this ->size () == Src.size () && " Vectors must have the same dimension" );
90
96
for (size_t Itr = 0 ; Itr < this ->size (); ++Itr)
@@ -101,6 +107,13 @@ bool Embedding::approximatelyEquals(const Embedding &RHS,
101
107
return true ;
102
108
}
103
109
110
+ void Embedding::print (raw_ostream &OS) const {
111
+ OS << " [" ;
112
+ for (const auto &Elem : Data)
113
+ OS << " " << format (" %.2f" , Elem) << " " ;
114
+ OS << " ]\n " ;
115
+ }
116
+
104
117
// ==----------------------------------------------------------------------===//
105
118
// Embedder and its subclasses
106
119
// ===----------------------------------------------------------------------===//
@@ -196,18 +209,12 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
196
209
for (const auto &I : BB.instructionsWithoutDebug ()) {
197
210
Embedding InstVector (Dimension, 0 );
198
211
199
- const auto OpcVec = lookupVocab (I.getOpcodeName ());
200
- InstVector.scaleAndAdd (OpcVec, OpcWeight);
201
-
202
212
// FIXME: Currently lookups are string based. Use numeric Keys
203
213
// for efficiency.
204
- const auto Type = I.getType ();
205
- const auto TypeVec = getTypeEmbedding (Type);
206
- InstVector.scaleAndAdd (TypeVec, TypeWeight);
207
-
214
+ InstVector += lookupVocab (I.getOpcodeName ());
215
+ InstVector += getTypeEmbedding (I.getType ());
208
216
for (const auto &Op : I.operands ()) {
209
- const auto OperandVec = getOperandEmbedding (Op.get ());
210
- InstVector.scaleAndAdd (OperandVec, ArgWeight);
217
+ InstVector += getOperandEmbedding (Op.get ());
211
218
}
212
219
InstVecMap[&I] = InstVector;
213
220
BBVector += InstVector;
@@ -251,6 +258,43 @@ bool IR2VecVocabResult::invalidate(
251
258
return !(PAC.preservedWhenStateless ());
252
259
}
253
260
261
+ Error IR2VecVocabAnalysis::parseVocabSection (
262
+ StringRef Key, const json::Value &ParsedVocabValue,
263
+ ir2vec::Vocab &TargetVocab, unsigned &Dim) {
264
+ json::Path::Root Path (" " );
265
+ const json::Object *RootObj = ParsedVocabValue.getAsObject ();
266
+ if (!RootObj)
267
+ return createStringError (errc::invalid_argument,
268
+ " JSON root is not an object" );
269
+
270
+ const json::Value *SectionValue = RootObj->get (Key);
271
+ if (!SectionValue)
272
+ return createStringError (errc::invalid_argument,
273
+ " Missing '" + std::string (Key) +
274
+ " ' section in vocabulary file" );
275
+ if (!json::fromJSON (*SectionValue, TargetVocab, Path))
276
+ return createStringError (errc::illegal_byte_sequence,
277
+ " Unable to parse '" + std::string (Key) +
278
+ " ' section from vocabulary" );
279
+
280
+ Dim = TargetVocab.begin ()->second .size ();
281
+ if (Dim == 0 )
282
+ return createStringError (errc::illegal_byte_sequence,
283
+ " Dimension of '" + std::string (Key) +
284
+ " ' section of the vocabulary is zero" );
285
+
286
+ if (!std::all_of (TargetVocab.begin (), TargetVocab.end (),
287
+ [Dim](const std::pair<StringRef, Embedding> &Entry) {
288
+ return Entry.second .size () == Dim;
289
+ }))
290
+ return createStringError (
291
+ errc::illegal_byte_sequence,
292
+ " All vectors in the '" + std::string (Key) +
293
+ " ' section of the vocabulary are not of the same dimension" );
294
+
295
+ return Error::success ();
296
+ };
297
+
254
298
// FIXME: Make this optional. We can avoid file reads
255
299
// by auto-generating a default vocabulary during the build time.
256
300
Error IR2VecVocabAnalysis::readVocabulary () {
@@ -259,32 +303,40 @@ Error IR2VecVocabAnalysis::readVocabulary() {
259
303
return createFileError (VocabFile, BufOrError.getError ());
260
304
261
305
auto Content = BufOrError.get ()->getBuffer ();
262
- json::Path::Root Path ( " " );
306
+
263
307
Expected<json::Value> ParsedVocabValue = json::parse (Content);
264
308
if (!ParsedVocabValue)
265
309
return ParsedVocabValue.takeError ();
266
310
267
- bool Res = json::fromJSON (*ParsedVocabValue, Vocabulary, Path);
268
- if (!Res)
269
- return createStringError (errc::illegal_byte_sequence,
270
- " Unable to parse the vocabulary" );
311
+ ir2vec::Vocab OpcodeVocab, TypeVocab, ArgVocab;
312
+ unsigned OpcodeDim = 0 , TypeDim = 0 , ArgDim = 0 ;
313
+ if (auto Err = parseVocabSection (" Opcodes" , *ParsedVocabValue, OpcodeVocab,
314
+ OpcodeDim))
315
+ return Err;
271
316
272
- if (Vocabulary. empty ())
273
- return createStringError (errc::illegal_byte_sequence,
274
- " Vocabulary is empty " ) ;
317
+ if (auto Err =
318
+ parseVocabSection ( " Types " , *ParsedVocabValue, TypeVocab, TypeDim))
319
+ return Err ;
275
320
276
- unsigned Dim = Vocabulary.begin ()->second .size ();
277
- if (Dim == 0 )
321
+ if (auto Err =
322
+ parseVocabSection (" Arguments" , *ParsedVocabValue, ArgVocab, ArgDim))
323
+ return Err;
324
+
325
+ if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
278
326
return createStringError (errc::illegal_byte_sequence,
279
- " Dimension of vocabulary is zero " );
327
+ " Vocabulary sections have different dimensions " );
280
328
281
- if (!std::all_of (Vocabulary.begin (), Vocabulary.end (),
282
- [Dim](const std::pair<StringRef, Embedding> &Entry) {
283
- return Entry.second .size () == Dim;
284
- }))
285
- return createStringError (
286
- errc::illegal_byte_sequence,
287
- " All vectors in the vocabulary are not of the same dimension" );
329
+ auto scaleVocabSection = [](ir2vec::Vocab &Vocab, double Weight) {
330
+ for (auto &Entry : Vocab)
331
+ Entry.second *= Weight;
332
+ };
333
+ scaleVocabSection (OpcodeVocab, OpcWeight);
334
+ scaleVocabSection (TypeVocab, TypeWeight);
335
+ scaleVocabSection (ArgVocab, ArgWeight);
336
+
337
+ Vocabulary.insert (OpcodeVocab.begin (), OpcodeVocab.end ());
338
+ Vocabulary.insert (TypeVocab.begin (), TypeVocab.end ());
339
+ Vocabulary.insert (ArgVocab.begin (), ArgVocab.end ());
288
340
289
341
return Error::success ();
290
342
}
@@ -304,7 +356,6 @@ void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
304
356
IR2VecVocabAnalysis::Result
305
357
IR2VecVocabAnalysis::run (Module &M, ModuleAnalysisManager &AM) {
306
358
auto Ctx = &M.getContext ();
307
- // FIXME: Scale the vocabulary once. This would avoid scaling per use later.
308
359
// If vocabulary is already populated by the constructor, use it.
309
360
if (!Vocabulary.empty ())
310
361
return IR2VecVocabResult (std::move (Vocabulary));
@@ -323,16 +374,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
323
374
}
324
375
325
376
// ==----------------------------------------------------------------------===//
326
- // IR2VecPrinterPass
377
+ // Printer Passes
327
378
// ===----------------------------------------------------------------------===//
328
379
329
- void IR2VecPrinterPass::printVector (const Embedding &Vec) const {
330
- OS << " [" ;
331
- for (const auto &Elem : Vec)
332
- OS << " " << format (" %.2f" , Elem) << " " ;
333
- OS << " ]\n " ;
334
- }
335
-
336
380
PreservedAnalyses IR2VecPrinterPass::run (Module &M,
337
381
ModuleAnalysisManager &MAM) {
338
382
auto IR2VecVocabResult = MAM.getResult <IR2VecVocabAnalysis>(M);
@@ -353,15 +397,15 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
353
397
354
398
OS << " IR2Vec embeddings for function " << F.getName () << " :\n " ;
355
399
OS << " Function vector: " ;
356
- printVector ( Emb->getFunctionVector ());
400
+ Emb->getFunctionVector (). print (OS );
357
401
358
402
OS << " Basic block vectors:\n " ;
359
403
const auto &BBMap = Emb->getBBVecMap ();
360
404
for (const BasicBlock &BB : F) {
361
405
auto It = BBMap.find (&BB);
362
406
if (It != BBMap.end ()) {
363
407
OS << " Basic block: " << BB.getName () << " :\n " ;
364
- printVector ( It->second );
408
+ It->second . print (OS );
365
409
}
366
410
}
367
411
@@ -373,10 +417,24 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
373
417
if (It != InstMap.end ()) {
374
418
OS << " Instruction: " ;
375
419
I.print (OS);
376
- printVector ( It->second );
420
+ It->second . print (OS );
377
421
}
378
422
}
379
423
}
380
424
}
381
425
return PreservedAnalyses::all ();
382
426
}
427
+
428
+ PreservedAnalyses IR2VecVocabPrinterPass::run (Module &M,
429
+ ModuleAnalysisManager &MAM) {
430
+ auto IR2VecVocabResult = MAM.getResult <IR2VecVocabAnalysis>(M);
431
+ assert (IR2VecVocabResult.isValid () && " IR2Vec Vocabulary is invalid" );
432
+
433
+ auto Vocab = IR2VecVocabResult.getVocabulary ();
434
+ for (const auto &Entry : Vocab) {
435
+ OS << " Key: " << Entry.first << " : " ;
436
+ Entry.second .print (OS);
437
+ }
438
+
439
+ return PreservedAnalyses::all ();
440
+ }
0 commit comments