Skip to content

Agentic Retreival for sub-queries generation #10667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace IssueLabelerService
{
public class OpenAiAnswerService : IAnswerService
{
private readonly ServiceConfiguration _serviceConfiguration;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
private readonly ServiceConfiguration _serviceConfiguration;
private readonly ServiceConfiguration ServiceConfiguration;

This one gets inconsistently applied, but the _name pattern indicates a class-level variable. Since this is a member that is intended to be constant after the instance is created, we generally name these like we would a constant.

private RepositoryConfiguration _config;
private TriageRag _ragService;
private ILogger<AnswerFactory> _logger;
Expand All @@ -18,102 +19,151 @@ public OpenAiAnswerService(ILogger<AnswerFactory> logger, RepositoryConfiguratio
_config = config;
_ragService = ragService;
_logger = logger;

_serviceConfiguration = new ServiceConfiguration(
ModelName: config.AnswerModelName,
IssueIndexName: config.IssueIndexName,
IssueSemanticName: config.IssueSemanticName,
IssueFieldName: config.IssueIndexFieldName,
DocumentIndexName: config.DocumentIndexName,
DocumentSemanticName: config.DocumentSemanticName,
DocumentFieldName: config.DocumentIndexFieldName,
Top: int.Parse(config.SourceCount),
ScoreThreshold: double.Parse(config.ScoreThreshold),
SolutionThreshold: double.Parse(config.SolutionThreshold),
SubqueriesPromptTemplate: config.SubqueriesGenerationPrompt,
SolutionInstructions: config.SolutionInstructions,
SuggestionInstructions: config.SuggestionInstructions,
SolutionUserPrompt: config.SolutionUserPrompt,
SuggestionUserPrompt: config.SuggestionUserPrompt);
}
public async Task<AnswerOutput> AnswerQuery(IssuePayload issue, Dictionary<string, string> labels)
{
// Configuration for Azure services
var modelName = _config.AnswerModelName;
var issueIndexName = _config.IssueIndexName;
var documentIndexName = _config.DocumentIndexName;

// Issue specific configurations
var issueSemanticName = _config.IssueSemanticName;
string issueFieldName = _config.IssueIndexFieldName;

// Document specific configurations
var documentSemanticName = _config.DocumentSemanticName;
string documentFieldName = _config.DocumentIndexFieldName;

// Query + Filtering configurations
string query = $"{issue.Title} {issue.Body}";
int top = int.Parse(_config.SourceCount);
double scoreThreshold = double.Parse(_config.ScoreThreshold);
double solutionThreshold = double.Parse(_config.SolutionThreshold);

var issues = await _ragService.SearchIssuesAsync(issueIndexName, issueSemanticName, issueFieldName, query, top, scoreThreshold, labels);
var docs = await _ragService.SearchDocumentsAsync(documentIndexName, documentSemanticName, documentFieldName, query, top, scoreThreshold, labels);

if (docs.Count == 0 && issues.Count == 0)
{
throw new Exception($"Not enough relevant sources found for {issue.RepositoryName} using the Complete Triage model for issue #{issue.IssueNumber}. Documents: {docs.Count}, Issues: {issues.Count}.");
}
var query = $"{issue.Title} {issue.Body}";
var uniqueIssues = new HashSet<Issue>(new IssueIdComparer());
var uniqueDocs = new HashSet<Document>(new DocumentUrlComparer());

double highestScore = _ragService.GetHighestScore(issues, docs, issue.RepositoryName, issue.IssueNumber);
bool solution = highestScore >= solutionThreshold;

_logger.LogInformation($"Highest relevance score among the sources: {highestScore}");
//Step 1: Generate sub-queries
var subqueries = await _ragService.GenerateSubqueriesAsync(_serviceConfiguration.SubqueriesPromptTemplate, query, _serviceConfiguration.ModelName);

// Format issues
var printableIssues = string.Join("\n\n", issues.Select(issue =>
$"Title: {issue.Title}\nDescription: {issue.chunk}\nURL: {issue.Url}\nScore: {issue.Score}"));
//Step 2: Retrieve results for sub-queries
await RetrieveAndAggregateSubqueryResults(subqueries, labels, uniqueIssues, uniqueDocs);

// Format documents
var printableDocs = string.Join("\n\n", docs.Select(doc =>
$"Content: {doc.chunk}\nURL: {doc.Url}\nScore: {doc.Score}"));
// Step 3: Retrieve results for the original query
var originalIssues = await _ragService.SearchIssuesAsync(_serviceConfiguration.IssueIndexName, _serviceConfiguration.IssueSemanticName, _serviceConfiguration.IssueFieldName, query, _serviceConfiguration.Top, _serviceConfiguration.ScoreThreshold, labels);
var originalDocs = await _ragService.SearchDocumentsAsync(_serviceConfiguration.DocumentIndexName, _serviceConfiguration.DocumentSemanticName, _serviceConfiguration.DocumentFieldName, query, _serviceConfiguration.Top, _serviceConfiguration.ScoreThreshold, labels);

var replacementsUserPrompt = new Dictionary<string, string>
{
{ "Title", issue.Title },
{ "Description", issue.Body },
{ "PrintableDocs", printableDocs },
{ "PrintableIssues", printableIssues }
};
// Step 4: Deduplicate original query results as they are added
uniqueIssues.UnionWith(originalIssues);
uniqueDocs.UnionWith(originalDocs);

string instructions, userPrompt;
if (solution)
if (uniqueIssues.Count == 0 && uniqueDocs.Count == 0)
{
instructions = _config.SolutionInstructions;
userPrompt = AzureSdkIssueLabelerService.FormatTemplate(_config.SolutionUserPrompt, replacementsUserPrompt, _logger);
}
else
{
instructions = _config.SuggestionInstructions;
userPrompt = AzureSdkIssueLabelerService.FormatTemplate(_config.SuggestionUserPrompt, replacementsUserPrompt, _logger);
throw new Exception($"Not enough relevant sources found for {issue.RepositoryName} using the Complete Triage model for issue #{issue.IssueNumber}. Documents: {uniqueDocs.Count}, Issues: {uniqueIssues.Count}.");
}

var response = await _ragService.SendMessageQnaAsync(instructions, userPrompt, modelName);
var highestScore = _ragService.GetHighestScore(uniqueIssues, uniqueDocs, issue.RepositoryName, issue.IssueNumber);
var isSolution = highestScore >= _serviceConfiguration.SolutionThreshold;

_logger.LogInformation($"Highest relevance score among the sources: {highestScore}");

var replacementsUserPrompt = BuildUserPromptData(uniqueIssues, uniqueDocs, issue);

var instructions = isSolution ? _serviceConfiguration.SolutionInstructions : _serviceConfiguration.SuggestionInstructions;
var prompt = isSolution ? _serviceConfiguration.SolutionUserPrompt: _serviceConfiguration.SuggestionUserPrompt;
var userPrompt = AzureSdkIssueLabelerService.FormatTemplate(prompt, replacementsUserPrompt, _logger);

var response = await _ragService.SendMessageQnaAsync(instructions, userPrompt, _serviceConfiguration.ModelName);

string intro, outro;
var replacementsIntro = new Dictionary<string, string>
{
{ "IssueUserLogin", issue.IssueUserLogin },
{ "RepositoryName", issue.RepositoryName }
};

if (solution)
{
intro = AzureSdkIssueLabelerService.FormatTemplate(_config.SolutionResponseIntroduction, replacementsIntro, _logger);
outro = _config.SolutionResponseConclusion;
}
else
{
intro = AzureSdkIssueLabelerService.FormatTemplate(_config.SuggestionResponseIntroduction, replacementsIntro, _logger);
outro = _config.SuggestionResponseConclusion;
}
var responseIntroduction = isSolution ? _config.SolutionResponseIntroduction : _config.SuggestionResponseIntroduction;
var intro = AzureSdkIssueLabelerService.FormatTemplate(responseIntroduction, replacementsIntro, _logger);
var outro = isSolution ? _config.SolutionResponseConclusion : _config.SuggestionResponseConclusion;

if (string.IsNullOrEmpty(response))
{
throw new Exception($"Open AI Response for {issue.RepositoryName} using the Complete Triage model for issue #{issue.IssueNumber} had an empty response.");
}

string formatted_response = intro + response + outro;
var formatted_response = intro + response + outro;

_logger.LogInformation($"Open AI Response for {issue.RepositoryName} using the Complete Triage model for issue #{issue.IssueNumber}.: \n{formatted_response}");

return new AnswerOutput {
Answer = formatted_response,
AnswerType = solution ? "solution" : "suggestion"
return new AnswerOutput
{
Answer = formatted_response,
AnswerType = isSolution ? "solution" : "suggestion"
};
}
private async Task RetrieveAndAggregateSubqueryResults(
IEnumerable<string> subqueries,
Dictionary<string, string> labels,
HashSet<Issue> uniqueIssues,
HashSet<Document> uniqueDocs)
{
var subqueryTasks = subqueries.Select(async subquery =>
{
var subqueryIssues = await _ragService.SearchIssuesAsync(_serviceConfiguration.IssueIndexName, _serviceConfiguration.IssueSemanticName, _serviceConfiguration.IssueFieldName, subquery, _serviceConfiguration.Top, _serviceConfiguration.ScoreThreshold, labels);
var subqueryDocs = await _ragService.SearchDocumentsAsync(_serviceConfiguration.DocumentIndexName, _serviceConfiguration.DocumentSemanticName, _serviceConfiguration.DocumentFieldName, subquery, _serviceConfiguration.Top, _serviceConfiguration.ScoreThreshold, labels);
uniqueIssues.UnionWith(subqueryIssues);
uniqueDocs.UnionWith(subqueryDocs);
});
await Task.WhenAll(subqueryTasks);
}

private Dictionary<string, string> BuildUserPromptData(
IEnumerable<Issue> allIssues,
IEnumerable<Document> allDocs,
IssuePayload issue)
{
var printableIssues = string.Join("\n\n", allIssues.Select(issue =>
$"Title: {issue.Title}\nDescription: {issue.chunk}\nURL: {issue.Url}\nScore: {issue.Score}"));

var printableDocs = string.Join("\n\n", allDocs.Select(doc =>
$"Content: {doc.chunk}\nURL: {doc.Url}\nScore: {doc.Score}"));

var replacementsUserPrompt = new Dictionary<string, string>
{
{ "Title", issue.Title },
{ "Description", issue.Body },
{ "PrintableDocs", printableDocs },
{ "PrintableIssues", printableIssues }
};

return replacementsUserPrompt;
}

private class IssueIdComparer : IEqualityComparer<Issue>
{
public bool Equals(Issue x, Issue y) => x.Id == y.Id;
public int GetHashCode(Issue obj) => obj.Id.GetHashCode();
}

private class DocumentUrlComparer : IEqualityComparer<Document>
{
public bool Equals(Document x, Document y) => x.Url == y.Url;
public int GetHashCode(Document obj) => obj.Url.GetHashCode();
}
private record ServiceConfiguration(
string ModelName,
string IssueIndexName,
string IssueSemanticName,
string IssueFieldName,
string DocumentIndexName,
string DocumentSemanticName,
string DocumentFieldName,
int Top,
double ScoreThreshold,
double SolutionThreshold,
string SubqueriesPromptTemplate,
string SolutionInstructions,
string SuggestionInstructions,
string SolutionUserPrompt,
string SuggestionUserPrompt);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ internal RepositoryConfiguration(IConfiguration config) =>
public string LabelUserPrompt => GetItem("LabelUserPrompt");
public string LabelInstructions => GetItem("LabelInstructions");
public string LabelNames => GetItem("LabelNames");
public string SubqueriesGenerationPrompt => GetItem("SubqueriesGenerationPrompt");

public string GetItem(string name)
{
Expand Down
40 changes: 40 additions & 0 deletions tools/issue-labeler/src/IssueLabelerService/TraigeRag.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ public class TriageRag
private static SearchIndexClient s_searchIndexClient;
private ILogger<TriageRag> _logger;

private const int MinSubqueries = 3;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we move these to the first class members? Generally, C# conventions call for ordering by scope/visibility:

  • const (which is static)
  • static readonly
  • static

and each type by visibility

  • private
  • protected
  • internal
  • public

private const int MaxSubqueries = 7;
private const int CharsPerSubquery = 30;

public TriageRag(ILogger<TriageRag> logger, AzureOpenAIClient openAiClient, SearchIndexClient searchIndexClient)
{
s_openAiClient = openAiClient;
Expand Down Expand Up @@ -225,6 +229,42 @@ public string LabelsFilter(Dictionary<string, string> labels)

return null;
}

public async Task<List<string>> GenerateSubqueriesAsync(string subqueriesPromptTemplate, string query, string modelName)
{
var subqueryCount = Math.Clamp(query.Length / CharsPerSubquery, MinSubqueries, MaxSubqueries).ToString();

var replacementSubqueriesPrompt = new Dictionary<string, string>
{
{ "subqueryCount", subqueryCount },
{ "query", query }
};
var subqueriesPrompt = AzureSdkIssueLabelerService.FormatTemplate(subqueriesPromptTemplate, replacementSubqueriesPrompt, _logger);


_logger.LogInformation($"Starting subquery generation for query: {query}");

if (string.IsNullOrWhiteSpace(query))
{
_logger.LogWarning("Query is empty or null. Returning an empty list of subqueries.");
return new List<string>();
}

_logger.LogInformation($"Instructions for subquery generation: {subqueriesPrompt}");

string response = await SendMessageQnaAsync(subqueriesPrompt, query, modelName);

var subqueries = response
.Split('\n')
.Select(line => line?.Trim())
.Where(line => !string.IsNullOrEmpty(line) && line.Length >= 2 && char.IsDigit(line[0]))
.Select(line => line.Substring(2).TrimEnd())
.ToList();

_logger.LogInformation($"Generated {subqueries.Count} subqueries: {string.Join(", ", subqueries)}");

return subqueries;
}
}

public class Issue
Expand Down