Skip to content

Commit

Permalink
Fixes for Modular Converter on Windows (#34266)
Browse files Browse the repository at this point in the history
* Separator in regex

* Standardize separator for relative path in auto generated message

* open() encoding

* Replace `\` on `os.path.abspath`

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
hlky and ArthurZucker authored Oct 29, 2024
1 parent 626c610 commit 9e3d704
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions utils/modular_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_module_source_from_name(module_name: str) -> str:
if spec is None or spec.origin is None:
return f"Module {module_name} not found"

with open(spec.origin, "r") as file:
with open(spec.origin, "r", encoding="utf-8") as file:
source_code = file.read()
return source_code

Expand Down Expand Up @@ -1132,7 +1132,7 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None,
if pattern is not None:
model_name = pattern.groups()[0]
# Parse the Python file
with open(modular_file, "r") as file:
with open(modular_file, "r", encoding="utf-8") as file:
code = file.read()
module = cst.parse_module(code)
wrapper = MetadataWrapper(module)
Expand All @@ -1143,7 +1143,7 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None,
if node != {}:
# Get relative path starting from src/transformers/
relative_path = re.search(
rf"(src{os.sep}transformers{os.sep}.*|examples{os.sep}.*)", os.path.abspath(modular_file)
r"(src/transformers/.*|examples/.*)", os.path.abspath(modular_file).replace("\\", "/")
).group(1)

header = AUTO_GENERATED_MESSAGE.format(
Expand All @@ -1164,15 +1164,15 @@ def save_modeling_file(modular_file, converted_file):
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
)
if len(converted_file[file_type][0].strip()) > 0 and non_comment_lines > 0:
with open(modular_file.replace("modular_", f"{file_type}_"), "w") as f:
with open(modular_file.replace("modular_", f"{file_type}_"), "w", encoding="utf-8") as f:
f.write(converted_file[file_type][0])
else:
non_comment_lines = len(
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
)
if len(converted_file[file_type][1].strip()) > 0 and non_comment_lines > 0:
logger.warning("The modeling code contains errors, it's written without formatting")
with open(modular_file.replace("modular_", f"{file_type}_"), "w") as f:
with open(modular_file.replace("modular_", f"{file_type}_"), "w", encoding="utf-8") as f:
f.write(converted_file[file_type][1])


Expand Down

0 comments on commit 9e3d704

Please sign in to comment.