Skip to content

Commit fdac515

Browse files
authored
Merge pull request #32 from benjaminjellis/polars_integration
adding integration with polars
2 parents fe6cadc + 934f0ac commit fdac515

File tree

5 files changed

+115
-4
lines changed

5 files changed

+115
-4
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@ jobs:
2727
run: |
2828
brew install cmake
2929
brew install libomp
30-
cargo build
30+
cargo build --all-features
3131
- name: Build for ubuntu
3232
if: matrix.os == 'ubuntu-latest'
3333
run: |
3434
sudo apt-get update
3535
sudo apt-get install -y cmake libclang-dev libc++-dev gcc-multilib
36-
cargo build
36+
cargo build --all-features
3737
- name: Run tests
38-
run: cargo test
38+
run: cargo test --all-features
3939
continue-on-error: ${{ matrix.rust == 'nightly' }}
4040
- name: Run Clippy
4141
uses: actions-rs/clippy-check@v1

Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,9 @@ lightgbm-sys = { path = "lightgbm-sys", version = "0.3.0" }
1313
libc = "0.2.81"
1414
derive_builder = "0.5.1"
1515
serde_json = "1.0.59"
16+
polars = {version = "0.16.0", optional = true}
17+
18+
19+
[features]
20+
default = []
21+
dataframe = ["polars"]

src/dataset.rs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ use lightgbm_sys;
33
use std;
44
use std::ffi::CString;
55

6+
#[cfg(feature = "dataframe")]
7+
use polars::prelude::*;
8+
69
use crate::{Error, Result};
710

811
/// Dataset used throughout LightGBM for training.
@@ -118,6 +121,76 @@ impl Dataset {
118121

119122
Ok(Self::new(handle))
120123
}
124+
125+
/// Create a new `Dataset` from a polars DataFrame.
126+
///
127+
/// Note: the feature ```dataframe``` is required for this method
128+
///
129+
/// Example
130+
///
131+
#[cfg_attr(
132+
feature = "dataframe",
133+
doc = r##"
134+
extern crate polars;
135+
136+
use lightgbm::Dataset;
137+
use polars::prelude::*;
138+
use polars::df;
139+
140+
let df: DataFrame = df![
141+
"feature_1" => [1.0, 0.7, 0.9, 0.2, 0.1],
142+
"feature_2" => [0.1, 0.4, 0.8, 0.2, 0.7],
143+
"feature_3" => [0.2, 0.5, 0.5, 0.1, 0.1],
144+
"feature_4" => [0.1, 0.1, 0.1, 0.7, 0.9],
145+
"label" => [0.0, 0.0, 0.0, 1.0, 1.0]
146+
].unwrap();
147+
let dataset = Dataset::from_dataframe(df, String::from("label")).unwrap();
148+
"##
149+
)]
150+
#[cfg(feature = "dataframe")]
151+
pub fn from_dataframe(mut dataframe: DataFrame, label_column: String) -> Result<Self> {
152+
let label_col_name = label_column.as_str();
153+
154+
let (m, n) = dataframe.shape();
155+
156+
let label_series = &dataframe.select_series(label_col_name)?[0].cast::<Float32Type>()?;
157+
158+
if label_series.null_count() != 0 {
159+
panic!("Cannot create a dataset with null values, encountered nulls when creating the label array")
160+
}
161+
162+
dataframe.drop_in_place(label_col_name)?;
163+
164+
let mut label_values = Vec::with_capacity(m);
165+
166+
let label_values_ca = label_series.unpack::<Float32Type>()?;
167+
168+
label_values_ca
169+
.into_no_null_iter()
170+
.enumerate()
171+
.for_each(|(_row_idx, val)| {
172+
label_values.push(val);
173+
});
174+
175+
let mut feature_values = Vec::with_capacity(m);
176+
for _i in 0..m {
177+
feature_values.push(Vec::with_capacity(n));
178+
}
179+
180+
for (_col_idx, series) in dataframe.get_columns().iter().enumerate() {
181+
if series.null_count() != 0 {
182+
panic!("Cannot create a dataset with null values, encountered nulls when creating the features array")
183+
}
184+
185+
let series = series.cast::<Float64Type>()?;
186+
let ca = series.unpack::<Float64Type>()?;
187+
188+
ca.into_no_null_iter()
189+
.enumerate()
190+
.for_each(|(row_idx, val)| feature_values[row_idx].push(val));
191+
}
192+
Self::from_mat(feature_values, label_values)
193+
}
121194
}
122195

123196
impl Drop for Dataset {
@@ -151,4 +224,21 @@ mod tests {
151224
let dataset = Dataset::from_mat(data, label);
152225
assert!(dataset.is_ok());
153226
}
227+
228+
#[cfg(feature = "dataframe")]
229+
#[test]
230+
fn from_dataframe() {
231+
use polars::df;
232+
let df: DataFrame = df![
233+
"feature_1" => [1.0, 0.7, 0.9, 0.2, 0.1],
234+
"feature_2" => [0.1, 0.4, 0.8, 0.2, 0.7],
235+
"feature_3" => [0.2, 0.5, 0.5, 0.1, 0.1],
236+
"feature_4" => [0.1, 0.1, 0.1, 0.7, 0.9],
237+
"label" => [0.0, 0.0, 0.0, 1.0, 1.0]
238+
]
239+
.unwrap();
240+
241+
let df_dataset = Dataset::from_dataframe(df, String::from("label"));
242+
assert!(df_dataset.is_ok());
243+
}
154244
}

src/error.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
33
use std::error;
44
use std::ffi::CStr;
5-
use std::fmt::{self, Display};
5+
use std::fmt::{self, Debug, Display};
66

77
use lightgbm_sys;
88

9+
#[cfg(feature = "dataframe")]
10+
use polars::prelude::*;
11+
912
/// Convenience return type for most operations which can return an `LightGBM`.
1013
pub type Result<T> = std::result::Result<T, Error>;
1114

@@ -49,6 +52,15 @@ impl Display for Error {
4952
}
5053
}
5154

55+
#[cfg(feature = "dataframe")]
56+
impl From<PolarsError> for Error {
57+
fn from(pe: PolarsError) -> Self {
58+
Self {
59+
desc: pe.to_string(),
60+
}
61+
}
62+
}
63+
5264
#[cfg(test)]
5365
mod tests {
5466
use super::*;

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ extern crate libc;
22
extern crate lightgbm_sys;
33
extern crate serde_json;
44

5+
#[cfg(feature = "dataframe")]
6+
extern crate polars;
7+
58
#[macro_use]
69
macro_rules! lgbm_call {
710
($x:expr) => {

0 commit comments

Comments
 (0)