Skip to content

Commit fbec267

Browse files
joshlongilayaperumalg
authored andcommitted
first cut of aot improvements
Signed-off-by: Josh Long <[email protected]>
1 parent a43cdc8 commit fbec267

File tree

4 files changed

+36
-21
lines changed

4 files changed

+36
-21
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
*
5151
* @author Christian Tzolov
5252
* @author Thomas Vitale
53+
* @author Josh Long
54+
*
5355
*/
5456
public class OpenAiEmbeddingModel extends AbstractEmbeddingModel {
5557

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/aot/OpenAiRuntimeHints.java

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,13 @@
1616

1717
package org.springframework.ai.openai.aot;
1818

19-
import java.util.Set;
20-
19+
import org.springframework.ai.openai.OpenAiChatOptions;
2120
import org.springframework.ai.openai.api.OpenAiApi;
2221
import org.springframework.ai.openai.api.OpenAiAudioApi;
2322
import org.springframework.ai.openai.api.OpenAiImageApi;
2423
import org.springframework.aot.hint.MemberCategory;
2524
import org.springframework.aot.hint.RuntimeHints;
2625
import org.springframework.aot.hint.RuntimeHintsRegistrar;
27-
import org.springframework.aot.hint.TypeReference;
2826
import org.springframework.lang.NonNull;
2927
import org.springframework.lang.Nullable;
3028

@@ -40,23 +38,22 @@
4038
*/
4139
public class OpenAiRuntimeHints implements RuntimeHintsRegistrar {
4240

43-
private static Set<TypeReference> eval(Set<TypeReference> referenceSet) {
44-
referenceSet.forEach(tr -> System.out.println(tr.toString()));
45-
return referenceSet;
46-
}
47-
4841
@Override
4942
public void registerHints(@NonNull RuntimeHints hints, @Nullable ClassLoader classLoader) {
5043
var mcs = MemberCategory.values();
51-
for (var tr : eval(findJsonAnnotatedClassesInPackage(OpenAiApi.class))) {
44+
for (var tr : (findJsonAnnotatedClassesInPackage(OpenAiChatOptions.class))) {
5245
hints.reflection().registerType(tr, mcs);
5346
}
54-
for (var tr : eval(findJsonAnnotatedClassesInPackage(OpenAiAudioApi.class))) {
47+
for (var tr : (findJsonAnnotatedClassesInPackage(OpenAiApi.class))) {
5548
hints.reflection().registerType(tr, mcs);
5649
}
57-
for (var tr : eval(findJsonAnnotatedClassesInPackage(OpenAiImageApi.class))) {
50+
for (var tr : (findJsonAnnotatedClassesInPackage(OpenAiAudioApi.class))) {
5851
hints.reflection().registerType(tr, mcs);
5952
}
53+
for (var tr : findJsonAnnotatedClassesInPackage(OpenAiImageApi.class)) {
54+
hints.reflection().registerType(tr, mcs);
55+
}
56+
6057
}
6158

6259
}

spring-ai-model/src/main/java/org/springframework/ai/embedding/AbstractEmbeddingModel.java

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,41 @@
1616

1717
package org.springframework.ai.embedding;
1818

19+
import org.springframework.aot.hint.RuntimeHints;
20+
import org.springframework.aot.hint.RuntimeHintsRegistrar;
21+
import org.springframework.context.annotation.ImportRuntimeHints;
22+
import org.springframework.core.io.ClassPathResource;
23+
import org.springframework.core.io.Resource;
24+
import org.springframework.util.Assert;
25+
1926
import java.io.IOException;
2027
import java.util.Map;
2128
import java.util.Properties;
2229
import java.util.concurrent.atomic.AtomicInteger;
2330
import java.util.stream.Collectors;
2431

25-
import org.springframework.core.io.DefaultResourceLoader;
26-
2732
/**
2833
* Abstract implementation of the {@link EmbeddingModel} interface that provides
2934
* dimensions calculation caching.
3035
*
3136
* @author Christian Tzolov
37+
* @author Josh Long
3238
*/
39+
@ImportRuntimeHints(AbstractEmbeddingModel.Hints.class)
3340
public abstract class AbstractEmbeddingModel implements EmbeddingModel {
3441

42+
private static final Resource EMBEDDING_MODEL_DIMENSIONS_PROPERTIES = new ClassPathResource(
43+
"/embedding/embedding-model-dimensions.properties");
44+
3545
private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = loadKnownModelDimensions();
3646

37-
/**
38-
* Default constructor.
39-
*/
40-
public AbstractEmbeddingModel() {
47+
static class Hints implements RuntimeHintsRegistrar {
48+
49+
@Override
50+
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
51+
hints.resources().registerResource(EMBEDDING_MODEL_DIMENSIONS_PROPERTIES);
52+
}
53+
4154
}
4255

4356
/**
@@ -69,10 +82,13 @@ public static int dimensions(EmbeddingModel embeddingModel, String modelName, St
6982

7083
private static Map<String, Integer> loadKnownModelDimensions() {
7184
try {
72-
Properties properties = new Properties();
73-
properties.load(new DefaultResourceLoader()
74-
.getResource("classpath:/embedding/embedding-model-dimensions.properties")
75-
.getInputStream());
85+
var resource = EMBEDDING_MODEL_DIMENSIONS_PROPERTIES;
86+
Assert.notNull(resource, "the embedding dimensions must be non-null");
87+
Assert.state(resource.exists(), "the embedding dimensions properties file must exist");
88+
var properties = new Properties();
89+
try (var in = resource.getInputStream()) {
90+
properties.load(in);
91+
}
7692
return properties.entrySet()
7793
.stream()
7894
.collect(Collectors.toMap(e -> e.getKey().toString(), e -> Integer.parseInt(e.getValue().toString())));

0 commit comments

Comments
 (0)