forked from justinmajetich/AirBnB_clone
-
Notifications
You must be signed in to change notification settings - Fork 1
/
db_storage.py
executable file
·98 lines (86 loc) · 3.16 KB
/
db_storage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/python3
"""Define storage engine using MySQL database
"""
from models.base_model import BaseModel, Base
from models.user import User
from models.state import State
from models.city import City
from models.amenity import Amenity
from models.place import Place
from models.review import Review
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session
from sqlalchemy.orm.session import sessionmaker, Session
from os import getenv
all_classes = {'State': State, 'City': City,
'User': User, 'Place': Place,
'Review': Review, 'Amenity': Amenity}
class DBStorage:
"""This class manages MySQL storage using SQLAlchemy
Attributes:
__engine: engine object
__session: session object
"""
__engine = None
__session = None
def __init__(self):
"""Create SQLAlchemy engine
"""
# create engine
self.__engine = create_engine('mysql mysqldb://{}:{}@{}:3306/{}'.
format(getenv('HBNB_MYSQL_USER'),
getenv('HBNB_MYSQL_PWD'),
getenv('HBNB_MYSQL_HOST'),
getenv('HBNB_MYSQL_DB')),
pool_pre_ping=True)
# drop tables if test environment
if getenv('HBNB_ENV') == 'test':
Base.metadata.drop_all(self.__engine)
def all(self, cls=None):
"""Query and return all objects by class/generally
Return: dictionary (<class-name>.<object-id>: <obj>)
"""
obj_dict = {}
if cls:
for row in self.__session.query(cls).all():
# populate dict with objects from storage
obj_dict.update({'{}.{}'.
format(type(cls).__name__, row.id,): row})
else:
for key, val in all_classes.items():
for row in self.__session.query(val):
obj_dict.update({'{}.{}'.
format(type(row).__name__, row.id,): row})
return obj_dict
def new(self, obj):
"""Add object to current database session
"""
self.__session.add(obj)
def save(self):
"""Commit current database session
"""
self.__session.commit()
def delete(self, obj=None):
"""Delete obj from database session
"""
if obj:
# determine class from obj
cls_name = all_classes[type(obj).__name__]
# query class table and delete
self.__session.query(cls_name).\
filter(cls_name.id == obj.id).delete()
def reload(self):
"""Create database session
"""
# create session from current engine
Base.metadata.create_all(self.__engine)
# create db tables
session = sessionmaker(bind=self.__engine,
expire_on_commit=False)
# previousy:
# Session = scoped_session(session)
self.__session = scoped_session(session)
def close(self):
"""Close scoped session
"""
self.__session.remove()