-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: split evaluation and export, add cost
- Loading branch information
1 parent
d35b953
commit 1e79331
Showing
8 changed files
with
516 additions
and
314 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from .strategies.openai import OpenAICostStrategy | ||
from .strategies.google import GoogleCostStrategy | ||
|
||
from typing import Optional | ||
|
||
|
||
class CostCalculator: | ||
|
||
__COST_STRATEGIES = { | ||
"openai-chatcompletion": OpenAICostStrategy, | ||
"google": GoogleCostStrategy, | ||
} | ||
|
||
@staticmethod | ||
def compute_costs(samples: list, provider: str, model_name: str) -> Optional[dict]: | ||
strategy = CostCalculator.__COST_STRATEGIES.get(provider) | ||
if strategy is None: | ||
return None | ||
return strategy.compute_costs(samples, model_name) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from typing import Optional | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
|
||
class CostStrategy(ABC): | ||
def __init__(self, model_name: str): | ||
self.model_name = model_name | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def compute_costs(samples: list, model_name: str) -> Optional[dict]: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from typing import Optional | ||
from .cost_strategy import CostStrategy | ||
|
||
import tqdm | ||
|
||
|
||
class GoogleCostStrategy(CostStrategy): | ||
|
||
__COST_PER_MILLION_TOKENS = { | ||
"gemini-1.5-pro": { | ||
"prompt": 3.50, | ||
"completion": 10.50, | ||
} | ||
} | ||
|
||
__COST_PER_MILLION_TOKENS_OVER_128K = { | ||
"gemini-1.5-pro": { | ||
"prompt": 7.00, | ||
"completion": 21.00, | ||
} | ||
} | ||
|
||
@staticmethod | ||
def compute_costs(samples: list, model_name: str) -> Optional[dict]: | ||
if model_name not in GoogleCostStrategy.__COST_PER_MILLION_TOKENS: | ||
return None | ||
|
||
costs = { | ||
"prompt_cost": 0.0, | ||
"completion_cost": 0.0, | ||
"total_cost": 0.0, | ||
} | ||
|
||
for sample in tqdm.tqdm(samples, f"Computing costs for {model_name}..."): | ||
if sample["generation"]: | ||
for generation in sample["generation"]: | ||
if "usage_metadata" not in generation: | ||
continue | ||
|
||
prompt_token_count = generation["usage_metadata"][ | ||
"prompt_token_count" | ||
] | ||
candidates_token_count = generation["usage_metadata"][ | ||
"candidates_token_count" | ||
] | ||
if prompt_token_count > 128000: | ||
prompt_cost = ( | ||
GoogleCostStrategy.__COST_PER_MILLION_TOKENS_OVER_128K[ | ||
model_name | ||
]["prompt"] | ||
) | ||
completion_cost = ( | ||
GoogleCostStrategy.__COST_PER_MILLION_TOKENS_OVER_128K[ | ||
model_name | ||
]["completion"] | ||
) | ||
else: | ||
prompt_cost = GoogleCostStrategy.__COST_PER_MILLION_TOKENS[ | ||
model_name | ||
]["prompt"] | ||
completion_cost = GoogleCostStrategy.__COST_PER_MILLION_TOKENS[ | ||
model_name | ||
]["completion"] | ||
|
||
costs["prompt_cost"] += prompt_cost * prompt_token_count / 1000000 | ||
costs["completion_cost"] += ( | ||
completion_cost * candidates_token_count / 1000000 | ||
) | ||
|
||
costs["total_cost"] = costs["prompt_cost"] + costs["completion_cost"] | ||
return costs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from typing import Optional | ||
from .cost_strategy import CostStrategy | ||
|
||
import tqdm | ||
|
||
|
||
class OpenAICostStrategy(CostStrategy): | ||
|
||
__COST_PER_THOUSAND_TOKENS = { | ||
"gpt-4o-2024-08-06": { | ||
"prompt": 0.00250, | ||
"completion": 0.01000, | ||
} | ||
} | ||
|
||
@staticmethod | ||
def compute_costs(samples: list, model_name: str) -> Optional[dict]: | ||
if model_name not in OpenAICostStrategy.__COST_PER_THOUSAND_TOKENS: | ||
return None | ||
|
||
costs = { | ||
"prompt_cost": 0.0, | ||
"completion_cost": 0.0, | ||
"total_cost": 0.0, | ||
} | ||
|
||
for sample in tqdm.tqdm(samples, f"Computing costs for {model_name}..."): | ||
if sample["generation"]: | ||
prompt_token_count = sample["generation"]["usage"]["prompt_tokens"] | ||
candidates_token_count = sample["generation"]["usage"][ | ||
"completion_tokens" | ||
] | ||
|
||
prompt_cost = OpenAICostStrategy.__COST_PER_THOUSAND_TOKENS[model_name][ | ||
"prompt" | ||
] | ||
completion_cost = OpenAICostStrategy.__COST_PER_THOUSAND_TOKENS[ | ||
model_name | ||
]["completion"] | ||
|
||
costs["prompt_cost"] += prompt_cost * prompt_token_count / 1000 | ||
costs["completion_cost"] += ( | ||
completion_cost * candidates_token_count / 1000 | ||
) | ||
|
||
costs["total_cost"] = costs["prompt_cost"] + costs["completion_cost"] | ||
return costs |
Oops, something went wrong.