Skip to content

Commit 6a928ad

Browse files
committed
add spatial join with openassistant 0.0.4
1 parent 884658f commit 6a928ad

File tree

7 files changed

+1446
-83
lines changed

7 files changed

+1446
-83
lines changed

examples/demo-app/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
"@loaders.gl/csv": "^4.3.2",
2626
"@loaders.gl/json": "^4.3.2",
2727
"@loaders.gl/parquet": "^4.3.2",
28-
"@openassistant/ui": "^0.0.1",
28+
"@openassistant/ui": "^0.0.4",
2929
"@types/classnames": "^2.3.1",
3030
"@types/keymirror": "^0.1.1",
3131
"classnames": "^2.2.1",

src/ai-assistant/package.json

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@
3636
"@kepler.gl/layers": "3.1.0-alpha.1",
3737
"@kepler.gl/types": "3.1.0-alpha.1",
3838
"@kepler.gl/utils": "3.1.0-alpha.1",
39-
"@openassistant/core": "^0.0.3",
40-
"@openassistant/echarts": "^0.0.3",
41-
"@openassistant/geoda": "^0.0.3",
42-
"@openassistant/ui": "^0.0.3",
39+
"@openassistant/core": "^0.0.4",
40+
"@openassistant/echarts": "^0.0.4",
41+
"@openassistant/geoda": "^0.0.4",
42+
"@openassistant/ui": "^0.0.4",
4343
"color-interpolate": "^1.0.5"
4444
},
4545
"devDependencies": {

src/ai-assistant/src/components/ai-assistant-component.tsx

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
import React, {useEffect} from 'react';
55
import styled, {withTheme} from 'styled-components';
66
import {MessageModel, useAssistant} from '@openassistant/core';
7-
import {dataClassifyFunctionDefinition} from '@openassistant/geoda';
7+
import {
8+
dataClassifyFunctionDefinition,
9+
spatialCountFunctionDefinition,
10+
SpatialJoinGeometries
11+
} from '@openassistant/geoda';
812
import {histogramFunctionDefinition, scatterplotFunctionDefinition} from '@openassistant/echarts';
913
import {AiAssistant} from '@openassistant/ui';
1014
import '@openassistant/echarts/dist/index.css';
@@ -34,9 +38,11 @@ import {updateLayerColorFunctionDefinition} from '../tools/layer-style-function'
3438
import {SelectedKeplerGlActions} from './ai-assistant-manager';
3539
import {
3640
getDatasetContext,
41+
getGeometriesFromDataset,
3742
getScatterplotValuesFromDataset,
3843
getValuesFromDataset,
39-
highlightRows
44+
highlightRows,
45+
saveAsDataset
4046
} from '../tools/utils';
4147

4248
export type AiAssistantComponentProps = {
@@ -72,6 +78,21 @@ function AiAssistantComponentFactory() {
7278
mapStyle,
7379
visState
7480
}: AiAssistantComponentProps) => {
81+
// get values from dataset, used by LLM functions
82+
const getValuesCallback = (datasetName: string, variableName: string): number[] =>
83+
getValuesFromDataset(visState.datasets, datasetName, variableName);
84+
85+
// highlight rows, used by LLM functions
86+
const highlightRowsCallback = (datasetName: string, selectedRowIndices: number[]) =>
87+
highlightRows(
88+
visState.datasets,
89+
visState.layers,
90+
datasetName,
91+
selectedRowIndices,
92+
keplerGlActions.layerSetIsValid
93+
);
94+
95+
// define LLM functions
7596
const functions = [
7697
basemapFunctionDefinition({mapStyleChange: keplerGlActions.mapStyleChange, mapStyle}),
7798
loadUrlFunctionDefinition({
@@ -95,49 +116,36 @@ function AiAssistantComponentFactory() {
95116
setFilterPlot: keplerGlActions.setFilterPlot
96117
}),
97118
histogramFunctionDefinition({
98-
getValues: (datasetName: string, variableName: string): number[] =>
99-
getValuesFromDataset(visState.datasets, datasetName, variableName),
100-
onSelected: (datasetName: string, selectedRowIndices: number[]) =>
101-
highlightRows(
102-
visState.datasets,
103-
visState.layers,
104-
datasetName,
105-
selectedRowIndices,
106-
keplerGlActions.layerSetIsValid
107-
)
119+
getValues: getValuesCallback,
120+
onSelected: highlightRowsCallback
108121
}),
109122
scatterplotFunctionDefinition({
110-
getValues: (
111-
datasetName: string,
112-
xVariableName: string,
113-
yVariableName: string
114-
): Promise<{x: number[]; y: number[]}> =>
115-
Promise.resolve(
116-
getScatterplotValuesFromDataset(
117-
visState.datasets,
118-
datasetName,
119-
xVariableName,
120-
yVariableName
121-
)
122-
),
123-
onSelected: (datasetName: string, selectedRowIndices: number[]) =>
124-
highlightRows(
125-
visState.datasets,
126-
visState.layers,
127-
datasetName,
128-
selectedRowIndices,
129-
keplerGlActions.layerSetIsValid
130-
)
123+
getValues: async (datasetName: string, xVar: string, yVar: string) =>
124+
getScatterplotValuesFromDataset(visState.datasets, datasetName, xVar, yVar),
125+
onSelected: highlightRowsCallback
131126
}),
132127
dataClassifyFunctionDefinition({
133-
getValues: (datasetName: string, variableName: string): number[] =>
134-
getValuesFromDataset(visState.datasets, datasetName, variableName)
128+
getValues: getValuesCallback
129+
}),
130+
spatialCountFunctionDefinition({
131+
getValues: getValuesCallback,
132+
getGeometries: (datasetName: string): SpatialJoinGeometries =>
133+
getGeometriesFromDataset(
134+
visState.datasets,
135+
visState.layers,
136+
visState.layerData,
137+
datasetName
138+
),
139+
saveAsDataset: (datasetName: string, data: Record<string, number[]>) =>
140+
saveAsDataset(visState.datasets, datasetName, data, keplerGlActions.addDataToMap)
135141
})
136142
];
137143

144+
// enable voice and screen capture
138145
const enableVoiceAndScreenCapture =
139146
aiAssistant.config.provider === 'openai' || aiAssistant.config.provider === 'google';
140147

148+
// define assistant props
141149
const assistantProps = {
142150
name: ASSISTANT_NAME,
143151
description: ASSISTANT_DESCRIPTION,
@@ -151,12 +159,14 @@ function AiAssistantComponentFactory() {
151159

152160
const {initializeAssistant, addAdditionalContext} = useAssistant(assistantProps);
153161

162+
// initialize assistant with context
154163
const initializeAssistantWithContext = async () => {
155164
await initializeAssistant();
156165
const context = getDatasetContext(visState.datasets, visState.layers);
157166
addAdditionalContext({context});
158167
};
159168

169+
// initialize assistant with context
160170
useEffect(() => {
161171
initializeAssistantWithContext();
162172
// re-initialize assistant when datasets, filters or layers change

src/ai-assistant/src/components/ai-assistant-manager.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ export type SelectedKeplerGlActions = {
4141
setFilter: ActionHandler<typeof setFilter>;
4242
setFilterPlot: ActionHandler<typeof setFilterPlot>;
4343
layerSetIsValid: ActionHandler<typeof layerSetIsValid>;
44+
addTableColumn: ActionHandler<typeof addTableColumn>;
4445
};
4546

4647
export type AiAssistantManagerState = {

src/ai-assistant/src/constants.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Copyright contributors to the kepler.gl project
33

44
export const TASK_LIST =
5-
'1. Show dataset/layer/variable info.\n2. Change the basemap style.\n3. Load data from url.\n4. Create a map layer using variable.\n5. Filter the data of a variable.\n6. Create a histogram.';
5+
'1. Show dataset/layer/variable info.\n2. Change the basemap style.\n3. Load data from url.\n4. Create a map layer using variable.\n5. Filter the data of a variable.\n6. Create a histogram.\n7. Classify the data of a variable.\n8. Spatial join two datasets.\n9. Query the data using SQL.';
66

77
export const WELCOME_MESSAGE = `Hi, I am Kepler.gl AI Assistant!\nHere are some tasks I can help you with:\n\n${TASK_LIST}`;
88

@@ -14,6 +14,8 @@ When responding to user queries:
1414
- Identify the appropriate function to call
1515
- Determine all required parameters
1616
- If parameters are missing, ask the user to provide them
17+
- Please ask the user to confirm the parameters
18+
- If the user doesn't agree, try to provide variable functions to the user
1719
- Execute functions in a sequential order
1820
1921
You can execute multiple functions to complete complex tasks, but execute them one at a time in a logical sequence. Always validate the success of each function call before proceeding to the next one.

src/ai-assistant/src/tools/utils.ts

Lines changed: 160 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,17 @@ import interpolate from 'color-interpolate';
55

66
import {Layer} from '@kepler.gl/layers';
77
import {Datasets, KeplerTable} from '@kepler.gl/table';
8+
import {SpatialJoinGeometries} from '@openassistant/geoda';
9+
import {ALL_FIELD_TYPES} from '@kepler.gl/constants';
10+
import {AddDataToMapPayload, ProtoDataset, ProtoDatasetField} from '@kepler.gl/types';
811

12+
/**
13+
* Check if the dataset exists
14+
* @param datasets The kepler.gl datasets
15+
* @param datasetName The name of the dataset
16+
* @param functionName The name of the function
17+
* @returns The result of the check
18+
*/
919
export function checkDatasetNotExists(
1020
datasets: Datasets,
1121
datasetName: string,
@@ -26,6 +36,13 @@ export function checkDatasetNotExists(
2636
return null;
2737
}
2838

39+
/**
40+
* Check if the field exists
41+
* @param dataset The kepler.gl dataset
42+
* @param fieldName The name of the field
43+
* @param functionName The name of the function
44+
* @returns The result of the check
45+
*/
2946
export function checkFieldNotExists(dataset: KeplerTable, fieldName: string, functionName: string) {
3047
const field = dataset.fields.find(f => f.name === fieldName);
3148
if (!field) {
@@ -43,6 +60,12 @@ export function checkFieldNotExists(dataset: KeplerTable, fieldName: string, fun
4360
return null;
4461
}
4562

63+
/**
64+
* Interpolate the colors from the original colors with the given number of colors
65+
* @param originalColors The original colors
66+
* @param numberOfColors The number of colors
67+
* @returns The interpolated colors
68+
*/
4669
export function interpolateColor(originalColors: string[], numberOfColors: number) {
4770
if (originalColors.length === numberOfColors) {
4871
return originalColors;
@@ -98,6 +121,14 @@ export function getScatterplotValuesFromDataset(
98121
return {x: xValues, y: yValues};
99122
}
100123

124+
/**
125+
* Highlight the rows in a dataset
126+
* @param datasets The kepler.gl datasets
127+
* @param layers The kepler.gl layers
128+
* @param datasetName The name of the dataset
129+
* @param selectedRowIndices The indices of the rows to highlight
130+
* @param layerSetIsValid The function to set the layer validity
131+
*/
101132
export function highlightRows(
102133
datasets: Datasets,
103134
layers: Layer[],
@@ -117,11 +148,17 @@ export function highlightRows(
117148
selectLayers.forEach(layer => {
118149
layer.formatLayerData(datasets);
119150
// trigger a re-render using layerSetIsValid() to update the top layer
120-
layerSetIsValid(selectLayers[0], true);
151+
layerSetIsValid(layer, true);
121152
});
122153
}
123154
}
124155

156+
/**
157+
* Get the dataset context, which is used to provide the dataset information to the AI assistant
158+
* @param datasets The kepler.gl datasets
159+
* @param layers The kepler.gl layers
160+
* @returns The dataset context
161+
*/
125162
export function getDatasetContext(datasets: Datasets, layers: Layer[]) {
126163
const context = 'Please remember the following dataset context:';
127164
const dataMeta = Object.values(datasets).map(dataset => ({
@@ -130,7 +167,128 @@ export function getDatasetContext(datasets: Datasets, layers: Layer[]) {
130167
fields: dataset.fields.map(field => ({[field.name]: field.type})),
131168
layers: layers
132169
.filter(layer => layer.config.dataId === dataset.id)
133-
.map(layer => ({id: layer.id, label: layer.config.label, type: layer.type}))
170+
.map(layer => ({
171+
id: layer.id,
172+
label: layer.config.label,
173+
type: layer.type,
174+
geometryMode: layer.config.columnMode,
175+
// get the valid geometry columns as string
176+
geometryColumns: Object.fromEntries(
177+
Object.entries(layer.config.columns)
178+
.filter(([_, value]) => value !== null)
179+
.map(([key, value]) => [
180+
key,
181+
typeof value === 'object' && value !== null
182+
? Object.fromEntries(Object.entries(value).filter(([_, v]) => v !== null))
183+
: value
184+
])
185+
)
186+
}))
134187
}));
135188
return `${context}\n${JSON.stringify(dataMeta)}`;
136189
}
190+
191+
/**
192+
* Get the geometries from a dataset
193+
* @param datasets The kepler.gl datasets
194+
* @param layers The kepler.gl layers
195+
* @param layerData The layer data
196+
* @param datasetName The name of the dataset
197+
* @returns The geometries
198+
*/
199+
export function getGeometriesFromDataset(
200+
datasets: Datasets,
201+
layers: Layer[],
202+
layerData: any[],
203+
datasetName: string
204+
): SpatialJoinGeometries {
205+
const datasetId = Object.keys(datasets).find(dataId => datasets[dataId].label === datasetName);
206+
if (!datasetId) return [];
207+
const dataset = datasets[datasetId];
208+
209+
// get the index of the layer
210+
const layerIndex = layers.findIndex(layer => layer.config.dataId === dataset.id);
211+
if (layerIndex === -1) return [];
212+
213+
const geometries = layerData[layerIndex];
214+
215+
return geometries?.data;
216+
}
217+
218+
/**
219+
* Save the data as a new dataset by joining it with the left dataset
220+
* @param datasets The kepler.gl datasets
221+
* @param datasetName The name of the left dataset
222+
* @param data The data to save
223+
* @param addDataToMap The function to add the data to the map
224+
*/
225+
export function saveAsDataset(
226+
datasets: Datasets,
227+
datasetName: string,
228+
data: Record<string, number[]>,
229+
addDataToMap: (data: AddDataToMapPayload) => void
230+
) {
231+
// find datasetId from datasets
232+
const datasetId = Object.keys(datasets).find(dataId => datasets[dataId].label === datasetName);
233+
if (!datasetId) return;
234+
235+
const leftDataset = datasets[datasetId];
236+
const numRows = leftDataset.length;
237+
238+
const fields: ProtoDatasetField[] = [
239+
// New fields from data
240+
...Object.keys(data).map((fieldName, index) => ({
241+
name: fieldName,
242+
id: `${fieldName}_${index}`,
243+
displayName: fieldName,
244+
type: determineFieldType(data[fieldName][0])
245+
})),
246+
// Existing fields from leftDataset
247+
...leftDataset.fields.map((field, index) => ({
248+
name: field.name,
249+
id: field.id || `${field.name}_${index}`,
250+
displayName: field.displayName,
251+
type: field.type
252+
}))
253+
];
254+
255+
// Pre-calculate data values array
256+
const dataValues = Object.values(data);
257+
258+
const rows = Array(numRows)
259+
.fill(null)
260+
.map((_, rowIdx) => [
261+
// New data values
262+
...dataValues.map(col => col[rowIdx]),
263+
// Existing dataset values
264+
...leftDataset.fields.map(field => leftDataset.getValue(field.name, rowIdx))
265+
]);
266+
267+
// create new dataset
268+
const newDatasetName = `${datasetName}_joined`;
269+
const newDataset: ProtoDataset = {
270+
info: {
271+
id: newDatasetName,
272+
label: newDatasetName
273+
},
274+
data: {
275+
fields,
276+
rows
277+
}
278+
};
279+
280+
addDataToMap({datasets: [newDataset], options: {autoCreateLayers: true, centerMap: true}});
281+
}
282+
283+
/**
284+
* Helper function to determine field type
285+
* @param value The value to determine the field type
286+
* @returns The field type
287+
*/
288+
function determineFieldType(value: unknown): keyof typeof ALL_FIELD_TYPES {
289+
return typeof value === 'number'
290+
? Number.isInteger(value)
291+
? ALL_FIELD_TYPES.integer
292+
: ALL_FIELD_TYPES.real
293+
: ALL_FIELD_TYPES.string;
294+
}

0 commit comments

Comments
 (0)