Skip to content

Commit

Permalink
Use gdrive_retry as decorator;Use Tqdm with disable param; Rename met…
Browse files Browse the repository at this point in the history
…hods.
  • Loading branch information
maxhora committed Nov 20, 2019
1 parent 5294cff commit b883bb9
Showing 1 changed file with 39 additions and 47 deletions.
86 changes: 39 additions & 47 deletions dvc/remote/gdrive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,9 @@ def init_drive(self):
)
)

self.root_id = self.get_path_id(self.path_info, create=True)
self.root_id = self.get_remote_id(self.path_info, create=True)
self.cached_dirs, self.cached_ids = self.cache_root_dirs()

def gdrive_list_file(self, query):
return self.drive.ListFile({"q": query, "maxResults": 1}).GetList()

def gdrive_create_folder(self, title, parent_id):
item = self.drive.CreateFile(
{
"title": title,
"parents": [{"id": parent_id}],
"mimeType": FOLDER_MIME_TYPE,
}
)
item.Upload()
return item

def gdrive_upload_file(
self, args, no_progress_bar=True, from_file="", progress_name=""
):
Expand All @@ -124,11 +110,8 @@ def gdrive_download_file(
from dvc.progress import Tqdm

gdrive_file = self.drive.CreateFile({"id": file_id})
if not no_progress_bar:
tqdm = Tqdm(desc=progress_name, total=int(gdrive_file["fileSize"]))
gdrive_file.GetContentFile(to_file)
if not no_progress_bar:
tqdm.close()
with Tqdm(desc=progress_name, total=int(gdrive_file["fileSize"]), disable=no_progress_bar):
gdrive_file.GetContentFile(to_file)

def gdrive_list_item(self, query):
file_list = self.drive.ListFile({"q": query, "maxResults": 1000})
Expand Down Expand Up @@ -192,12 +175,20 @@ def drive(self):
gdrive = GoogleDrive(gauth)
return gdrive

def create_drive_item(self, parent_id, title):
return gdrive_retry(
lambda: self.gdrive_create_folder(title, parent_id)
)()
@gdrive_retry
def create_remote_dir(self, parent_id, title):
item = self.drive.CreateFile(
{
"title": title,
"parents": [{"id": parent_id}],
"mimeType": FOLDER_MIME_TYPE,
}
)
item.Upload()
return item

def get_drive_item(self, name, parents_ids):
@gdrive_retry
def get_remote_item(self, name, parents_ids):
if not parents_ids:
return None
query = " or ".join(
Expand All @@ -206,14 +197,15 @@ def get_drive_item(self, name, parents_ids):

query += " and trashed=false and title='{}'".format(name)

item_list = gdrive_retry(lambda: self.gdrive_list_file(query))()
# Limit found remote items count to 1 in response
item_list = self.drive.ListFile({"q": query, "maxResults": 1}).GetList()
return next(iter(item_list), None)

def resolve_remote_file(self, parents_ids, path_parts, create):
def resolve_remote_item_from_path(self, parents_ids, path_parts, create):
for path_part in path_parts:
item = self.get_drive_item(path_part, parents_ids)
item = self.get_remote_item(path_part, parents_ids)
if not item and create:
item = self.create_drive_item(parents_ids[0], path_part)
item = self.create_remote_dir(parents_ids[0], path_part)
elif not item:
return None
parents_ids = [item["id"]]
Expand All @@ -230,7 +222,7 @@ def subtract_root_path(self, parts):
break
return parts, [self.root_id]

def get_path_id_from_cache(self, path_info):
def get_remote_id_from_cache(self, path_info):
files_ids = []
parts, parents_ids = self.subtract_root_path(path_info.path.split("/"))
if (
Expand All @@ -245,22 +237,22 @@ def get_path_id_from_cache(self, path_info):

return files_ids, parents_ids, parts

def get_path_id(self, path_info, create=False):
files_ids, parents_ids, parts = self.get_path_id_from_cache(path_info)
def get_remote_id(self, path_info, create=False):
files_ids, parents_ids, parts = self.get_remote_id_from_cache(path_info)

if not parts and files_ids:
return files_ids[0]

file1 = self.resolve_remote_file(parents_ids, parts, create)
file1 = self.resolve_remote_item_from_path(parents_ids, parts, create)
return file1["id"] if file1 else ""

def exists(self, path_info):
return self.get_path_id(path_info) != ""
return self.get_remote_id(path_info) != ""

def _upload(self, from_file, to_info, name, no_progress_bar):
dirname = to_info.parent
if dirname:
parent_id = self.get_path_id(dirname, True)
parent_id = self.get_remote_id(dirname, True)
else:
parent_id = to_info.bucket

Expand All @@ -274,33 +266,33 @@ def _upload(self, from_file, to_info, name, no_progress_bar):
)()

def _download(self, from_info, to_file, name, no_progress_bar):
file_id = self.get_path_id(from_info)
file_id = self.get_remote_id(from_info)
gdrive_retry(
lambda: self.gdrive_download_file(
file_id, to_file, name, no_progress_bar
)
)()

def list_cache_paths(self):
file_id = self.get_path_id(self.path_info)
file_id = self.get_remote_id(self.path_info)
prefix = self.path_info.path
for path in self.list_path(file_id):
for path in self.list_children(file_id):
yield posixpath.join(prefix, path)

def list_file_path(self, drive_file):
if drive_file["mimeType"] == FOLDER_MIME_TYPE:
for i in self.list_path(drive_file["id"]):
yield posixpath.join(drive_file["title"], i)
else:
yield drive_file["title"]

def list_path(self, parent_id):
def list_children(self, parent_id):
for file1 in self.gdrive_list_item(
"'{}' in parents and trashed=false".format(parent_id)
):
for path in self.list_file_path(file1):
for path in self.list_remote_item(file1):
yield path

def list_remote_item(self, drive_file):
if drive_file["mimeType"] == FOLDER_MIME_TYPE:
for i in self.list_children(drive_file["id"]):
yield posixpath.join(drive_file["title"], i)
else:
yield drive_file["title"]

def all(self):
if not hasattr(self, "cached_ids") or not self.cached_ids:
return
Expand Down

0 comments on commit b883bb9

Please sign in to comment.