Skip to content

Commit 0a209a2

Browse files
committed
Resolve review comments
1 parent 7e28625 commit 0a209a2

File tree

2 files changed

+1
-19
lines changed

2 files changed

+1
-19
lines changed

areal/api/workflow_api.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations # noqa
22

33
from abc import ABC, abstractmethod
4-
from dataclasses import dataclass
54
from typing import TYPE_CHECKING, Any
65

76
from areal.experimental.openai.types import InteractionWithTokenLogpReward
@@ -10,23 +9,6 @@
109
from areal.api.engine_api import InferenceEngine
1110

1211

13-
@dataclass(slots=True)
14-
class WorkflowTaskInput:
15-
"""Input payload provided to :class:`RolloutWorkflow` implementations.
16-
17-
Parameters
18-
----------
19-
data : dict[str, Any]
20-
Original sample provided by the dataloader or caller.
21-
session_id : int | None, optional
22-
Identifier registered with the global session tracer when tracing is
23-
enabled.
24-
"""
25-
26-
data: dict[str, Any]
27-
session_id: int | None = None
28-
29-
3012
class RolloutWorkflow(ABC):
3113
@abstractmethod
3214
async def arun_episode(

areal/workflow/vision_rlvr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ async def arun_episode(
157157
"multi_modal_input": multi_modal_input,
158158
"versions": torch.tensor(versions, dtype=torch.int32).unsqueeze(0),
159159
"attention_mask": torch.ones(len(seq), dtype=torch.bool).unsqueeze(0),
160-
"rewards": torch.tensor([reward], dtype=torch.float32),
160+
"rewards": torch.tensor(reward, dtype=torch.float32).unsqueeze(0),
161161
}
162162
results.append(res)
163163

0 commit comments

Comments
 (0)