-
Notifications
You must be signed in to change notification settings - Fork 9
Feat/add prediction artifact and upload method #292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments about the model structure. We can chat about it some more if you'd like. 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work, @j279li !
fe10e34
to
db77f00
Compare
@jstlaurent I've encountered circular import issues with BenchmarkV2Specification, since the benchmark package already imports from the evaluate package. To work around this, I moved the Predictions class into its own package for now. Would that be alright, or do you have other ideas? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice job!
@j279li Ping me once this is ready for another review! Happy to take a look! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again, @j279li 💪 I've left a few comments we should probably take a look at. It's coming along nicely!
if len(data) > 0: | ||
# Extract a sample value to determine the type | ||
if isinstance(data, (list, tuple)): | ||
sample_value = data[0] | ||
elif hasattr(data, "__getitem__"): | ||
try: | ||
sample_value = data[0] | ||
except (IndexError, TypeError): | ||
sample_value = None | ||
else: | ||
sample_value = None | ||
|
||
if sample_value is not None: | ||
sample_type_name = type(sample_value).__name__ | ||
if sample_type_name == "Mol" and "rdkit" in str(type(sample_value).__module__): | ||
codec_kwargs["object_codec"] = RDKitMolCodec() | ||
codec_kwargs["dtype"] = object | ||
elif sample_type_name == "AtomArray" and "biotite" in str( | ||
type(sample_value).__module__ | ||
): | ||
codec_kwargs["object_codec"] = AtomArrayCodec() | ||
codec_kwargs["dtype"] = object | ||
elif annotation.dtype == np.dtype(object): | ||
# For other object types, use object dtype | ||
codec_kwargs["dtype"] = object | ||
|
||
# Create the array in the Zarr archive | ||
if "object_codec" in codec_kwargs: | ||
# For object codecs, we need to create a numpy object array first | ||
# Use np.empty to avoid numpy trying to convert AtomArrays to numpy arrays | ||
data_array = np.empty(len(data), dtype=object) | ||
for i, item in enumerate(data): | ||
data_array[i] = item | ||
root.array(col, data=data_array, **codec_kwargs) | ||
else: | ||
root.array(col, data=data, **codec_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll need some more context on this block here. Let's chat about this a little further once @cwognum has had a chance to look at it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @j279li, sorry for not spotting any of this on my first review. I've been out of the loop on recent changes to Polaris, so I may also be missing some context here.
We're making good progress, but I think we should take another look at how we've done things for competitions. Specifically, I would like us to revisit the user-facing API as well as the validation of the Predictions
class.
Happy to find some chat to about this tomorrow!
if hasattr(self.benchmark.dataset, "annotations") and col in self.benchmark.dataset.annotations: | ||
annotation = self.benchmark.dataset.annotations[col] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change requested: Rather than using the dataset annotations as the source of truth, we should use the dataset's Zarr archive to determine the dtype. Each prediction column should match the corresponding dataset column.
def zarr_root_path(self) -> str: | ||
return self._zarr_root_path | ||
|
||
def _create_zarr_from_predictions(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change requested: I think this method can be simplified a lot. It can copy the configuration (dtype, codec, etc.) of each array from the benchmark's dataset's Zarr Root and infer the array size from the test split.
Summary of Changes
This PR introduces a new Prediction artifact, and an associated upload method
New Features
Prediction Artifact (
Predictions
)Predictions
model to represent prediction artifacts as Zarr archives.columns
,dtypes
,n_rows
, etc.) for quick data exploration.Upload API (
upload_prediction
)upload_prediction
method inPolarisHubClient
for uploading prediction artifacts.