File tree Expand file tree Collapse file tree 2 files changed +1
-19
lines changed
Expand file tree Collapse file tree 2 files changed +1
-19
lines changed Original file line number Diff line number Diff line change 11from __future__ import annotations # noqa
22
33from abc import ABC , abstractmethod
4- from dataclasses import dataclass
54from typing import TYPE_CHECKING , Any
65
76from areal .experimental .openai .types import InteractionWithTokenLogpReward
109 from areal .api .engine_api import InferenceEngine
1110
1211
13- @dataclass (slots = True )
14- class WorkflowTaskInput :
15- """Input payload provided to :class:`RolloutWorkflow` implementations.
16-
17- Parameters
18- ----------
19- data : dict[str, Any]
20- Original sample provided by the dataloader or caller.
21- session_id : int | None, optional
22- Identifier registered with the global session tracer when tracing is
23- enabled.
24- """
25-
26- data : dict [str , Any ]
27- session_id : int | None = None
28-
29-
3012class RolloutWorkflow (ABC ):
3113 @abstractmethod
3214 async def arun_episode (
Original file line number Diff line number Diff line change @@ -157,7 +157,7 @@ async def arun_episode(
157157 "multi_modal_input" : multi_modal_input ,
158158 "versions" : torch .tensor (versions , dtype = torch .int32 ).unsqueeze (0 ),
159159 "attention_mask" : torch .ones (len (seq ), dtype = torch .bool ).unsqueeze (0 ),
160- "rewards" : torch .tensor ([ reward ] , dtype = torch .float32 ),
160+ "rewards" : torch .tensor (reward , dtype = torch .float32 ). unsqueeze ( 0 ),
161161 }
162162 results .append (res )
163163
You can’t perform that action at this time.
0 commit comments